diff --git a/.github/workflows/tests-conda.yml b/.github/workflows/tests-conda.yml index 7c77f0a7..ddb610cc 100644 --- a/.github/workflows/tests-conda.yml +++ b/.github/workflows/tests-conda.yml @@ -22,7 +22,8 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-2019] + # os: [ubuntu-latest, macos-latest, windows-2019] + os: ["macos-latest"] python: ["3.10", "3.11", "3.12"] env: ["latest"] include: @@ -31,19 +32,19 @@ jobs: extra: >- pandas=1.5 geopandas=0.12 - # minimal environment without optional dependencies - - os: "ubuntu-latest" - python: "3.9" - env: "minimal" - # environment for older Windows libgdal to make sure gdal_i.lib is - # properly detected - - os: "windows-2019" - python: "3.10" - env: "libgdal3.5.1" - # environment with nightly wheels - - os: "ubuntu-latest" - python: "3.11" - env: "nightly-deps" + # # minimal environment without optional dependencies + # - os: "ubuntu-latest" + # python: "3.9" + # env: "minimal" + # # environment for older Windows libgdal to make sure gdal_i.lib is + # # properly detected + # - os: "windows-2019" + # python: "3.10" + # env: "libgdal3.5.1" + # # environment with nightly wheels + # - os: "ubuntu-latest" + # python: "3.11" + # env: "nightly-deps" steps: - name: Checkout repo @@ -66,8 +67,8 @@ jobs: echo "GDAL_VERSION=$(gdalinfo --version | cut -c 6-10)" >> $GITHUB_ENV - name: Install pyogrio - run: pip install -e . + run: pip install -e . -v - name: Test run: | - pytest -v --color=yes -r s pyogrio/tests + pytest -v -s --color=yes -r s pyogrio/tests/test_geopandas_io.py::test_read_where_range diff --git a/pyogrio/_io.pyx b/pyogrio/_io.pyx index d0a412ba..8dac72b4 100644 --- a/pyogrio/_io.pyx +++ b/pyogrio/_io.pyx @@ -18,6 +18,7 @@ from libc.string cimport strlen from libc.math cimport isnan cimport cython +from cpython.pycapsule cimport PyCapsule_New, PyCapsule_GetPointer import numpy as np from pyogrio._ogr cimport * @@ -84,6 +85,25 @@ DTYPE_OGR_FIELD_TYPES = { } + +cdef void pycapsule_array_stream_deleter(object stream_capsule) noexcept: + cdef ArrowArrayStream* stream = PyCapsule_GetPointer( + stream_capsule, 'arrow_array_stream' + ) + # Do not invoke the deleter on a used/moved capsule + if stream.release != NULL: + stream.release(stream) + + free(stream) + + +cdef object alloc_c_stream(ArrowArrayStream** c_stream): + c_stream[0] = malloc(sizeof(ArrowArrayStream)) + # Ensure the capsule destructor doesn't call a random release pointer + c_stream[0].release = NULL + return PyCapsule_New(c_stream[0], 'arrow_array_stream', &pycapsule_array_stream_deleter) + + cdef int start_transaction(OGRDataSourceH ogr_dataset, int force) except 1: cdef int err = GDALDatasetStartTransaction(ogr_dataset, force) if err == OGRERR_FAILURE: @@ -1238,12 +1258,16 @@ def ogr_read( } finally: + if fields_c != NULL: + CSLDestroy(fields_c) + fields_c = NULL + if dataset_options != NULL: CSLDestroy(dataset_options) dataset_options = NULL if ogr_dataset != NULL: - if sql is not None: + if sql is not None and ogr_layer != NULL: GDALDatasetReleaseResultSet(ogr_dataset, ogr_layer) GDALClose(ogr_dataset) @@ -1285,9 +1309,10 @@ def ogr_open_arrow( cdef char **fields_c = NULL cdef const char *field_c = NULL cdef char **options = NULL - cdef ArrowArrayStream stream - cdef ArrowSchema schema + cdef ArrowArrayStream *stream + # this block prevents compilation of remaining code in this function, which + # fails for GDAL < 3.6.0 because OGR_L_GetArrowStream is undefined IF CTE_GDAL_VERSION < (3, 6, 0): raise RuntimeError("Need GDAL>=3.6 for Arrow support") @@ -1300,12 +1325,6 @@ def ogr_open_arrow( if fids is not None: raise ValueError("reading by FID is not supported for Arrow") - IF CTE_GDAL_VERSION < (3, 8, 0): - if skip_features: - raise ValueError( - "specifying 'skip_features' is not supported for Arrow for GDAL<3.8.0" - ) - if skip_features < 0: raise ValueError("'skip_features' must be >= 0") @@ -1387,27 +1406,31 @@ def ogr_open_arrow( options = CSLSetNameValue(options, "INCLUDE_FID", "NO") if batch_size > 0: + batch_size_b = str(batch_size).encode('UTF-8') + batch_size_c = batch_size_b options = CSLSetNameValue( options, "MAX_FEATURES_IN_BATCH", - str(batch_size).encode('UTF-8') + batch_size_c ) - # Default to geoarrow metadata encoding - IF CTE_GDAL_VERSION >= (3, 8, 0): - options = CSLSetNameValue( - options, - "GEOMETRY_METADATA_ENCODING", - "GEOARROW".encode('UTF-8') - ) + # Default to geoarrow metadata encoding (only used for GDAL >= 3.8.0) + options = CSLSetNameValue( + options, + "GEOMETRY_METADATA_ENCODING", + "GEOARROW" + ) # make sure layer is read from beginning OGR_L_ResetReading(ogr_layer) - if not OGR_L_GetArrowStream(ogr_layer, &stream, options): + # allocate the stream struct and wrap in capsule to ensure clean-up on error + capsule = alloc_c_stream(&stream) + + if not OGR_L_GetArrowStream(ogr_layer, stream, options): raise RuntimeError("Failed to open ArrowArrayStream from Layer") - stream_ptr = &stream + stream_ptr = stream if skip_features: # only supported for GDAL >= 3.8.0; have to do this after getting @@ -1417,7 +1440,6 @@ def ogr_open_arrow( # stream has to be consumed before the Dataset is closed import pyarrow as pa reader = pa.RecordBatchStreamReader._import_from_c(stream_ptr) - meta = { 'crs': crs, 'encoding': encoding, @@ -1434,7 +1456,11 @@ def ogr_open_arrow( # Mark reader as closed to prevent reading batches reader.close() - CSLDestroy(options) + # `stream` will be freed through `capsule` destructor + + if options != NULL: + CSLDestroy(options) + if fields_c != NULL: CSLDestroy(fields_c) fields_c = NULL @@ -1444,12 +1470,15 @@ def ogr_open_arrow( dataset_options = NULL if ogr_dataset != NULL: - if sql is not None: + if sql is not None and ogr_layer != NULL: GDALDatasetReleaseResultSet(ogr_dataset, ogr_layer) GDALClose(ogr_dataset) ogr_dataset = NULL + print("done with ogr_open_arrow finally block") + + def ogr_read_bounds( str path, object layer=None, diff --git a/pyogrio/_ogr.pxd b/pyogrio/_ogr.pxd index 35fbd29a..3bca864b 100644 --- a/pyogrio/_ogr.pxd +++ b/pyogrio/_ogr.pxd @@ -196,12 +196,15 @@ cdef extern from "arrow_bridge.h": struct ArrowArrayStream: int (*get_schema)(ArrowArrayStream* stream, ArrowSchema* out) + void (*release)(ArrowArrayStream*) noexcept nogil cdef extern from "ogr_api.h": int OGRGetDriverCount() OGRSFDriverH OGRGetDriver(int) + bint OGRGetGEOSVersion(int *pnMajor, int *pnMinor, int *pnPatch) + OGRDataSourceH OGR_Dr_Open(OGRSFDriverH driver, const char *path, int bupdate) const char* OGR_Dr_GetName(OGRSFDriverH driver) diff --git a/pyogrio/_ogr.pyx b/pyogrio/_ogr.pyx index 55d19080..c39a9293 100644 --- a/pyogrio/_ogr.pyx +++ b/pyogrio/_ogr.pyx @@ -42,22 +42,14 @@ def get_gdal_version_string(): return get_string(version) -IF CTE_GDAL_VERSION >= (3, 4, 0): - - cdef extern from "ogr_api.h": - bint OGRGetGEOSVersion(int *pnMajor, int *pnMinor, int *pnPatch) - - def get_gdal_geos_version(): cdef int major, minor, revision - IF CTE_GDAL_VERSION >= (3, 4, 0): - if not OGRGetGEOSVersion(&major, &minor, &revision): - return None - return (major, minor, revision) - ELSE: + if not OGRGetGEOSVersion(&major, &minor, &revision): return None + return (major, minor, revision) + def set_gdal_config_options(dict options): for name, value in options.items(): diff --git a/pyogrio/raw.py b/pyogrio/raw.py index 6499867a..a1a31da2 100644 --- a/pyogrio/raw.py +++ b/pyogrio/raw.py @@ -260,9 +260,12 @@ def read_arrow( gdal_version = get_gdal_version() + # also checking skip_features here because of special handling for GDAL < 3.8.0 + # otherwise it is properly checked in ogr_open_arrow instead if skip_features < 0: raise ValueError("'skip_features' must be >= 0") + # max_features support is shimmed here so it must be validated here if max_features is not None and max_features < 0: raise ValueError("'max_features' must be >= 0") @@ -402,6 +405,12 @@ def open_arrow( if not HAS_ARROW_API: raise RuntimeError("pyarrow and GDAL>= 3.6 required to read using arrow") + gdal_version = get_gdal_version() + if skip_features and gdal_version < (3, 8, 0): + raise ValueError( + "specifying 'skip_features' is not supported for open_arrow for GDAL<3.8.0" + ) + path, buffer = get_vsi_path(path_or_buffer) dataset_kwargs = _preprocess_options_key_value(kwargs) if kwargs else {} diff --git a/pyogrio/tests/test_arrow.py b/pyogrio/tests/test_arrow.py index 02e28dea..8d6841f2 100644 --- a/pyogrio/tests/test_arrow.py +++ b/pyogrio/tests/test_arrow.py @@ -171,7 +171,7 @@ def test_open_arrow_skip_features_unsupported(naturalearth_lowres, skip_features GDAL < 3.8.0""" with pytest.raises( ValueError, - match="specifying 'skip_features' is not supported for Arrow for GDAL<3.8.0", + match="specifying 'skip_features' is not supported for open_arrow for GDAL<3.8.0", ): with open_arrow(naturalearth_lowres, skip_features=skip_features) as ( meta, diff --git a/pyogrio/tests/test_geopandas_io.py b/pyogrio/tests/test_geopandas_io.py index 317af6a8..cd97c85f 100644 --- a/pyogrio/tests/test_geopandas_io.py +++ b/pyogrio/tests/test_geopandas_io.py @@ -317,11 +317,13 @@ def test_read_fid_as_index_only(naturalearth_lowres, use_arrow): assert len(df.columns) == 0 -def test_read_where(naturalearth_lowres_all_ext, use_arrow): +def test_read_where_empty(naturalearth_lowres_all_ext, use_arrow): # empty filter should return full set of records df = read_dataframe(naturalearth_lowres_all_ext, use_arrow=use_arrow, where="") assert len(df) == 177 + +def test_read_where_equals(naturalearth_lowres_all_ext, use_arrow): # should return singular item df = read_dataframe( naturalearth_lowres_all_ext, use_arrow=use_arrow, where="iso_a3 = 'CAN'" @@ -329,6 +331,8 @@ def test_read_where(naturalearth_lowres_all_ext, use_arrow): assert len(df) == 1 assert df.iloc[0].iso_a3 == "CAN" + +def test_read_where_in(naturalearth_lowres_all_ext, use_arrow): df = read_dataframe( naturalearth_lowres_all_ext, use_arrow=use_arrow, @@ -337,6 +341,11 @@ def test_read_where(naturalearth_lowres_all_ext, use_arrow): assert len(df) == 3 assert len(set(df.iso_a3.unique()).difference(["CAN", "USA", "MEX"])) == 0 + +def test_read_where_range(naturalearth_lowres_all_ext, use_arrow): + if naturalearth_lowres_all_ext.suffix not in {".geojsonl", ".gpkg"}: + pytest.skip("only test geojsonl or gpkg") + # should return items within range df = read_dataframe( naturalearth_lowres_all_ext,