Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE: Refactor dataframe where tests, reduce IF conditional compilation blocks #386

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions .github/workflows/tests-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
# os: [ubuntu-latest, macos-latest, windows-2019]
os: ["macos-latest"]
python: ["3.10", "3.11", "3.12"]
env: ["latest"]
include:
Expand All @@ -31,19 +32,19 @@ jobs:
extra: >-
pandas=1.5
geopandas=0.12
# minimal environment without optional dependencies
- os: "ubuntu-latest"
python: "3.9"
env: "minimal"
# environment for older Windows libgdal to make sure gdal_i.lib is
# properly detected
- os: "windows-2019"
python: "3.10"
env: "libgdal3.5.1"
# environment with nightly wheels
- os: "ubuntu-latest"
python: "3.11"
env: "nightly-deps"
# # minimal environment without optional dependencies
# - os: "ubuntu-latest"
# python: "3.9"
# env: "minimal"
# # environment for older Windows libgdal to make sure gdal_i.lib is
# # properly detected
# - os: "windows-2019"
# python: "3.10"
# env: "libgdal3.5.1"
# # environment with nightly wheels
# - os: "ubuntu-latest"
# python: "3.11"
# env: "nightly-deps"

steps:
- name: Checkout repo
Expand All @@ -66,8 +67,8 @@ jobs:
echo "GDAL_VERSION=$(gdalinfo --version | cut -c 6-10)" >> $GITHUB_ENV

- name: Install pyogrio
run: pip install -e .
run: pip install -e . -v

- name: Test
run: |
pytest -v --color=yes -r s pyogrio/tests
pytest -v -s --color=yes -r s pyogrio/tests/test_geopandas_io.py::test_read_where_range
73 changes: 51 additions & 22 deletions pyogrio/_io.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ from libc.string cimport strlen
from libc.math cimport isnan

cimport cython
from cpython.pycapsule cimport PyCapsule_New, PyCapsule_GetPointer
import numpy as np

from pyogrio._ogr cimport *
Expand Down Expand Up @@ -84,6 +85,25 @@ DTYPE_OGR_FIELD_TYPES = {
}



cdef void pycapsule_array_stream_deleter(object stream_capsule) noexcept:
cdef ArrowArrayStream* stream = <ArrowArrayStream*>PyCapsule_GetPointer(
stream_capsule, 'arrow_array_stream'
)
# Do not invoke the deleter on a used/moved capsule
if stream.release != NULL:
stream.release(stream)

free(stream)


cdef object alloc_c_stream(ArrowArrayStream** c_stream):
c_stream[0] = <ArrowArrayStream*> malloc(sizeof(ArrowArrayStream))
# Ensure the capsule destructor doesn't call a random release pointer
c_stream[0].release = NULL
return PyCapsule_New(c_stream[0], 'arrow_array_stream', &pycapsule_array_stream_deleter)


cdef int start_transaction(OGRDataSourceH ogr_dataset, int force) except 1:
cdef int err = GDALDatasetStartTransaction(ogr_dataset, force)
if err == OGRERR_FAILURE:
Expand Down Expand Up @@ -1238,12 +1258,16 @@ def ogr_read(
}

finally:
if fields_c != NULL:
CSLDestroy(fields_c)
fields_c = NULL

if dataset_options != NULL:
CSLDestroy(dataset_options)
dataset_options = NULL

if ogr_dataset != NULL:
if sql is not None:
if sql is not None and ogr_layer != NULL:
GDALDatasetReleaseResultSet(ogr_dataset, ogr_layer)

GDALClose(ogr_dataset)
Expand Down Expand Up @@ -1285,9 +1309,10 @@ def ogr_open_arrow(
cdef char **fields_c = NULL
cdef const char *field_c = NULL
cdef char **options = NULL
cdef ArrowArrayStream stream
cdef ArrowSchema schema
cdef ArrowArrayStream *stream

# this block prevents compilation of remaining code in this function, which
# fails for GDAL < 3.6.0 because OGR_L_GetArrowStream is undefined
IF CTE_GDAL_VERSION < (3, 6, 0):
raise RuntimeError("Need GDAL>=3.6 for Arrow support")

Expand All @@ -1300,12 +1325,6 @@ def ogr_open_arrow(
if fids is not None:
raise ValueError("reading by FID is not supported for Arrow")

IF CTE_GDAL_VERSION < (3, 8, 0):
if skip_features:
raise ValueError(
"specifying 'skip_features' is not supported for Arrow for GDAL<3.8.0"
)

if skip_features < 0:
raise ValueError("'skip_features' must be >= 0")

Expand Down Expand Up @@ -1387,27 +1406,31 @@ def ogr_open_arrow(
options = CSLSetNameValue(options, "INCLUDE_FID", "NO")

if batch_size > 0:
batch_size_b = str(batch_size).encode('UTF-8')
batch_size_c = batch_size_b
options = CSLSetNameValue(
options,
"MAX_FEATURES_IN_BATCH",
str(batch_size).encode('UTF-8')
<const char*>batch_size_c
)

# Default to geoarrow metadata encoding
IF CTE_GDAL_VERSION >= (3, 8, 0):
options = CSLSetNameValue(
options,
"GEOMETRY_METADATA_ENCODING",
"GEOARROW".encode('UTF-8')
)
# Default to geoarrow metadata encoding (only used for GDAL >= 3.8.0)
options = CSLSetNameValue(
options,
"GEOMETRY_METADATA_ENCODING",
"GEOARROW"
)

# make sure layer is read from beginning
OGR_L_ResetReading(ogr_layer)

if not OGR_L_GetArrowStream(ogr_layer, &stream, options):
# allocate the stream struct and wrap in capsule to ensure clean-up on error
capsule = alloc_c_stream(&stream)

if not OGR_L_GetArrowStream(ogr_layer, stream, options):
raise RuntimeError("Failed to open ArrowArrayStream from Layer")

stream_ptr = <uintptr_t> &stream
stream_ptr = <uintptr_t> stream

if skip_features:
# only supported for GDAL >= 3.8.0; have to do this after getting
Expand All @@ -1417,7 +1440,6 @@ def ogr_open_arrow(
# stream has to be consumed before the Dataset is closed
import pyarrow as pa
reader = pa.RecordBatchStreamReader._import_from_c(stream_ptr)

meta = {
'crs': crs,
'encoding': encoding,
Expand All @@ -1434,7 +1456,11 @@ def ogr_open_arrow(
# Mark reader as closed to prevent reading batches
reader.close()

CSLDestroy(options)
# `stream` will be freed through `capsule` destructor

if options != NULL:
CSLDestroy(options)

if fields_c != NULL:
CSLDestroy(fields_c)
fields_c = NULL
Expand All @@ -1444,12 +1470,15 @@ def ogr_open_arrow(
dataset_options = NULL

if ogr_dataset != NULL:
if sql is not None:
if sql is not None and ogr_layer != NULL:
GDALDatasetReleaseResultSet(ogr_dataset, ogr_layer)

GDALClose(ogr_dataset)
ogr_dataset = NULL

print("done with ogr_open_arrow finally block")


def ogr_read_bounds(
str path,
object layer=None,
Expand Down
3 changes: 3 additions & 0 deletions pyogrio/_ogr.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,15 @@ cdef extern from "arrow_bridge.h":

struct ArrowArrayStream:
int (*get_schema)(ArrowArrayStream* stream, ArrowSchema* out)
void (*release)(ArrowArrayStream*) noexcept nogil


cdef extern from "ogr_api.h":
int OGRGetDriverCount()
OGRSFDriverH OGRGetDriver(int)

bint OGRGetGEOSVersion(int *pnMajor, int *pnMinor, int *pnPatch)

OGRDataSourceH OGR_Dr_Open(OGRSFDriverH driver, const char *path, int bupdate)
const char* OGR_Dr_GetName(OGRSFDriverH driver)

Expand Down
14 changes: 3 additions & 11 deletions pyogrio/_ogr.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,14 @@ def get_gdal_version_string():
return get_string(version)


IF CTE_GDAL_VERSION >= (3, 4, 0):

cdef extern from "ogr_api.h":
bint OGRGetGEOSVersion(int *pnMajor, int *pnMinor, int *pnPatch)


def get_gdal_geos_version():
cdef int major, minor, revision

IF CTE_GDAL_VERSION >= (3, 4, 0):
if not OGRGetGEOSVersion(&major, &minor, &revision):
return None
return (major, minor, revision)
ELSE:
if not OGRGetGEOSVersion(&major, &minor, &revision):
return None

return (major, minor, revision)


def set_gdal_config_options(dict options):
for name, value in options.items():
Expand Down
9 changes: 9 additions & 0 deletions pyogrio/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,12 @@ def read_arrow(

gdal_version = get_gdal_version()

# also checking skip_features here because of special handling for GDAL < 3.8.0
# otherwise it is properly checked in ogr_open_arrow instead
if skip_features < 0:
raise ValueError("'skip_features' must be >= 0")

# max_features support is shimmed here so it must be validated here
if max_features is not None and max_features < 0:
raise ValueError("'max_features' must be >= 0")

Expand Down Expand Up @@ -402,6 +405,12 @@ def open_arrow(
if not HAS_ARROW_API:
raise RuntimeError("pyarrow and GDAL>= 3.6 required to read using arrow")

gdal_version = get_gdal_version()
if skip_features and gdal_version < (3, 8, 0):
raise ValueError(
"specifying 'skip_features' is not supported for open_arrow for GDAL<3.8.0"
)

path, buffer = get_vsi_path(path_or_buffer)

dataset_kwargs = _preprocess_options_key_value(kwargs) if kwargs else {}
Expand Down
2 changes: 1 addition & 1 deletion pyogrio/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_open_arrow_skip_features_unsupported(naturalearth_lowres, skip_features
GDAL < 3.8.0"""
with pytest.raises(
ValueError,
match="specifying 'skip_features' is not supported for Arrow for GDAL<3.8.0",
match="specifying 'skip_features' is not supported for open_arrow for GDAL<3.8.0",
):
with open_arrow(naturalearth_lowres, skip_features=skip_features) as (
meta,
Expand Down
11 changes: 10 additions & 1 deletion pyogrio/tests/test_geopandas_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,18 +317,22 @@ def test_read_fid_as_index_only(naturalearth_lowres, use_arrow):
assert len(df.columns) == 0


def test_read_where(naturalearth_lowres_all_ext, use_arrow):
def test_read_where_empty(naturalearth_lowres_all_ext, use_arrow):
# empty filter should return full set of records
df = read_dataframe(naturalearth_lowres_all_ext, use_arrow=use_arrow, where="")
assert len(df) == 177


def test_read_where_equals(naturalearth_lowres_all_ext, use_arrow):
# should return singular item
df = read_dataframe(
naturalearth_lowres_all_ext, use_arrow=use_arrow, where="iso_a3 = 'CAN'"
)
assert len(df) == 1
assert df.iloc[0].iso_a3 == "CAN"


def test_read_where_in(naturalearth_lowres_all_ext, use_arrow):
df = read_dataframe(
naturalearth_lowres_all_ext,
use_arrow=use_arrow,
Expand All @@ -337,6 +341,11 @@ def test_read_where(naturalearth_lowres_all_ext, use_arrow):
assert len(df) == 3
assert len(set(df.iso_a3.unique()).difference(["CAN", "USA", "MEX"])) == 0


def test_read_where_range(naturalearth_lowres_all_ext, use_arrow):
if naturalearth_lowres_all_ext.suffix not in {".geojsonl", ".gpkg"}:
pytest.skip("only test geojsonl or gpkg")

# should return items within range
df = read_dataframe(
naturalearth_lowres_all_ext,
Expand Down