Skip to content

Commit 8c5490d

Browse files
committed
to_parquet equivalence arrow and duckdb equivalent
1 parent df57055 commit 8c5490d

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

affinity.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,14 @@ def sql(self, query, **replacements):
224224
By default, df=self.df (pandas dataframe) is used.
225225
The registered views persist across queries. RAM impact TBD.
226226
"""
227-
if not replacements.get("df"):
227+
if replacements.get("df") is None:
228228
duckdb.register("df", self.df)
229229
for k, v in replacements.items():
230230
duckdb.register(k, v)
231231
return duckdb.sql(query)
232232

233233

234-
def to_parquet(self, path, engine="duckdb"):
234+
def to_parquet(self, path, engine="duckdb", **kwargs):
235235
if engine == "arrow":
236236
pq.write_table(self.arrow, path)
237237
if engine == "duckdb":
@@ -246,7 +246,7 @@ def to_parquet(self, path, engine="duckdb"):
246246
COPY (SELECT * FROM df) TO {path} (
247247
FORMAT PARQUET,
248248
KV_METADATA {{ {", ".join(kv_metadata)} }}
249-
);""")
249+
);""", **kwargs)
250250
return path
251251

252252
@property

test_affinity.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,12 @@ class aDataset(af.Dataset):
243243
data = aDataset(v1=[True], v2=[1/2], v3=[3])
244244
test_file_arrow = Path("test_arrow.parquet")
245245
test_file_duckdb = Path("test_duckdb.parquet")
246+
test_file_duckdb_polars = Path("test_duckdb_polars.parquet")
246247
data.to_parquet(test_file_arrow, engine="arrow")
247248
data.to_parquet(test_file_duckdb, engine="duckdb")
249+
data.to_parquet(test_file_duckdb_polars, engine="duckdb", df=data.pl)
248250
class KeyValueMetadata(af.Dataset):
249251
"""Stores results of reading Parquet metadata."""
250-
file_name = af.VectorObject()
251252
key = af.VectorObject()
252253
value = af.VectorObject()
253254
test_file_metadata_arrow = KeyValueMetadata.from_sql(
@@ -259,11 +260,24 @@ class KeyValueMetadata(af.Dataset):
259260
FROM parquet_kv_metadata('{test_file_arrow}')
260261
WHERE DECODE(key) != 'ARROW:schema'
261262
""",
263+
method="pandas",
264+
field_names="strict"
265+
)
266+
test_file_metadata_duckdb = KeyValueMetadata.from_sql(
267+
f"""
268+
SELECT
269+
DECODE(key) AS key,
270+
DECODE(value) AS value,
271+
FROM parquet_kv_metadata('{test_file_duckdb_polars}')
272+
WHERE DECODE(key) != 'ARROW:schema'
273+
""",
262274
method="polars",
263275
field_names="strict"
264276
)
277+
assert test_file_metadata_arrow == test_file_metadata_duckdb
265278
test_file_arrow.unlink()
266279
test_file_duckdb.unlink()
280+
test_file_duckdb_polars.unlink()
267281
assert all(
268282
value in test_file_metadata_arrow.value.values
269283
for value in [

0 commit comments

Comments
 (0)