From 00231a6b8fc2016006c0639af8689236c0137e96 Mon Sep 17 00:00:00 2001 From: William Ayd Date: Fri, 15 Nov 2024 11:13:53 -0500 Subject: [PATCH] Fix issue trying to enable logs (#393) --- src/pantab/reader.cpp | 5 ++++- src/pantab/writer.cpp | 5 ++++- tests/conftest.py | 2 +- tests/test_reader.py | 12 ++++++++++++ tests/test_writer.py | 11 +++++++++++ 5 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/pantab/reader.cpp b/src/pantab/reader.cpp index 2b88eafc..bc3268e0 100644 --- a/src/pantab/reader.cpp +++ b/src/pantab/reader.cpp @@ -583,8 +583,11 @@ auto read_from_hyper_query( std::unordered_map &&process_params, size_t chunk_size) -> nb::capsule { - if (!process_params.count("log_config")) + if (!process_params.count("log_config")) { process_params["log_config"] = ""; + } else { + process_params.erase("log_config"); + } if (!process_params.count("default_database_version")) process_params["default_database_version"] = "2"; diff --git a/src/pantab/writer.cpp b/src/pantab/writer.cpp index b5b0e027..0d03d456 100644 --- a/src/pantab/writer.cpp +++ b/src/pantab/writer.cpp @@ -733,8 +733,11 @@ void write_to_hyper( geo_set.insert(colstr); } - if (!process_params.count("log_config")) + if (!process_params.count("log_config")) { process_params["log_config"] = ""; + } else { + process_params.erase("log_config"); + } if (!process_params.count("default_database_version")) process_params["default_database_version"] = "2"; diff --git a/tests/conftest.py b/tests/conftest.py index b0d25731..3fbf2d22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -457,7 +457,7 @@ def roundtripped(request): @pytest.fixture -def tmp_hyper(tmp_path): +def tmp_hyper(tmp_path) -> pathlib.Path: """A temporary file name to write / read a Hyper extract from.""" return tmp_path / "test.hyper" diff --git a/tests/test_reader.py b/tests/test_reader.py index 4dbdf1eb..0eadbe7d 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -254,3 +254,15 @@ def test_read_batches_without_capsule(tmp_hyper, compat, return_type): pt.frame_from_hyper( tmp_hyper, table="test", return_type=return_type, chunk_size=2 ) + + +def test_reader_can_enable_logging(tmp_hyper): + df = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8") + pt.frame_to_hyper(df, tmp_hyper, table="test") + + log_dir = tmp_hyper.parent + params = {"log_config": "enable_me", "log_dir": str(log_dir)} + pt.frame_from_hyper(tmp_hyper, table="test", process_params=params) + + assert (log_dir / "hyperd.log").exists() + (log_dir / "hyperd.log").unlink() diff --git a/tests/test_writer.py b/tests/test_writer.py index 52a122a1..bd7731bd 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -515,3 +515,14 @@ def test_writer_invalid_process_params_raises(tmp_hyper): msg = r"No internal setting named 'not_a_real_parameter'" with pytest.raises(RuntimeError, match=msg): pt.frame_to_hyper(frame, tmp_hyper, table="test", process_params=params) + + +def test_writer_can_enable_logging(tmp_hyper): + tbl = pa.table({"int": pa.array(range(4), type=pa.int16())}) + + log_dir = tmp_hyper.parent + params = {"log_config": "enable_me", "log_dir": str(log_dir)} + pt.frame_to_hyper(tbl, tmp_hyper, table="test", process_params=params) + + assert (log_dir / "hyperd.log").exists() + (log_dir / "hyperd.log").unlink()