Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generators in record splitting instead of lists. #76

Merged
merged 6 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions dabapush/Reader/JSONReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .Reader import FileReader



class JSONReader(FileReader):
"""Reader to read ready to read directories containing multiple json files.
It matches files in the path-tree against the pattern and reads the
Expand All @@ -36,14 +35,11 @@ def read(self) -> Iterator[Record]:
record = Record(
uuid=f"{str(file_record.uuid)}",
payload=(
parsed
if not self.config.flatten_dicts
else flatten(parsed)
parsed if not self.config.flatten_dicts else flatten(parsed)
),
source=file_record,
)
if record not in self.back_log:
yield record
yield record


class JSONReaderConfiguration(ReaderConfiguration):
Expand Down
22 changes: 9 additions & 13 deletions dabapush/Reader/NDJSONReader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""NDJSON Writer plug-in for dabapush"""

# pylint: disable=R,I1101
from typing import Iterator, List
from typing import Iterator

import ujson

Expand All @@ -14,10 +14,10 @@
def read_and_split(
record: Record,
flatten_records: bool = False,
) -> List[Record]:
) -> Iterator[Record]:
"""Reads a file and splits it into records by line."""
with record.payload.open("rt", encoding="utf8") as file:
children = [
children = (
Record(
uuid=f"{str(record.uuid)}:{str(line_number)}",
payload=(
Expand All @@ -28,10 +28,10 @@ def read_and_split(
source=record,
)
for line_number, line in enumerate(file)
]
record.children.extend(children)

return children
)
for child in children:
record.children.append(child)
yield child


class NDJSONReader(FileReader):
Expand All @@ -53,13 +53,9 @@ def read(self) -> Iterator[Record]:
"""reads multiple NDJSON files and emits them line by line"""

for file_record in self.records:
filtered_records = filter(
lambda x: x not in self.back_log,
file_record.split(
func=read_and_split, flatten_records=self.config.flatten_dicts
),
yield from file_record.split(
func=read_and_split, flatten_records=self.config.flatten_dicts
)
yield from filtered_records


class NDJSONReaderConfiguration(ReaderConfiguration):
Expand Down
26 changes: 0 additions & 26 deletions dabapush/Reader/Reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pathlib import Path
from typing import Iterator

import ujson
from loguru import logger as log
from tqdm.auto import tqdm

from ..Configuration.ReaderConfiguration import ReaderConfiguration
Expand Down Expand Up @@ -33,12 +31,6 @@ def __init__(self, config: ReaderConfiguration):
be a subclass of ReaderConfiguration.
"""
self.config = config
self.back_log = []
# initialize file log
if not Path(".dabapush/").exists():
Path(".dabapush/").mkdir()

self.log_path = Path(f".dabapush/{config.name}.jsonl")

@abc.abstractmethod
def read(self) -> Iterator[Record]:
Expand Down Expand Up @@ -77,28 +69,10 @@ def read(self) -> Iterator[Record]:
@property
def records(self) -> Iterator[Record]:
"""Generator for all files matching the pattern in the read_path."""
if self.log_path.exists():
log.debug(
f"Found log file for {self.config.name} at {self.log_path}. Loading..."
)
with self.log_path.open("rt", encoding="utf8") as f:
self.back_log = [Record(**ujson.loads(_)) for _ in f.readlines()]
else:
self.log_path.touch()

yield from (
Record(
uuid=str(a),
payload=a,
event_handlers={"on_done": [self.log]},
)
for a in tqdm(list(Path(self.config.read_path).rglob(self.config.pattern)))
)

def log(self, record: Record):
"""Log the record to the persistent record log file."""
with self.log_path.open("a", encoding="utf8") as f:
for sub_record in record.walk_tree(only_leafs=True):
ujson.dump(sub_record.to_log(), f)
f.write("\n")
log.debug(f"Done with {record.uuid}")
42 changes: 20 additions & 22 deletions dabapush/Record.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# pylint: disable=R0917, R0913

from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Self, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Self, Union
from uuid import uuid4

from loguru import logger as log
Expand Down Expand Up @@ -92,9 +92,9 @@ def split(
self,
key: Optional[str] = None,
id_key: Optional[str] = None,
func: Optional[Callable[[Self, ...], List[Self]]] = None,
func: Optional[Callable[[Self, ...], Iterable[Self]]] = None,
**kwargs,
) -> List[Self]:
) -> Iterable[Self]:
"""Splits the record bases on either a keyword or a function. If a function is provided,
it will be used to split the payload, even if you provide a key. If a key is provided, it
will split the payload.
Expand Down Expand Up @@ -134,22 +134,20 @@ def split(
def _handle_key_split_(self, id_key, key):
payload = self.payload # Get the payload, the original payload
# will be set to None to free memory.
if key not in payload:
return []
if not isinstance(payload[key], list):
return []
split_payload = [
Record(
**{
"payload": value,
"uuid": value.get(id_key) if id_key else uuid4().hex,
"source": self,
}
if key in payload and isinstance(payload[key], list):
split_payload = (
Record(
**{
"payload": value,
"uuid": value.get(id_key) if id_key else uuid4().hex,
"source": self,
}
)
for value in payload[key]
)
for value in payload[key]
]
self.children.extend(split_payload)
return split_payload
for child in split_payload:
self.children.append(child)
yield child

def to_log(self) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
"""Return a loggable representation of the record."""
Expand Down Expand Up @@ -187,10 +185,10 @@ def done(self):
# Signal parent that this record is done
self._state_ = "done"
log.debug(f"Record {self.uuid} is set as done.")
if self.source:
self.source.signal_done()
log.debug(f"Signaled parent {self.source.uuid} of record {self.uuid}.")
self.__dispatch_event__("on_done")
# if self.source:
# self.source.signal_done()
# log.debug(f"Signaled parent {self.source.uuid} of record {self.uuid}.")
# self.__dispatch_event__("on_done")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete?


def signal_done(self):
"""Signal that a child record is done."""
Expand Down
36 changes: 35 additions & 1 deletion dabapush/Writer/Writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
"""

import abc
from pathlib import Path
from typing import Iterator, List

import ujson
from loguru import logger as log

from ..Configuration.WriterConfiguration import WriterConfiguration
Expand All @@ -27,10 +29,31 @@ def __init__(self, config: WriterConfiguration):

self.config = config
self.buffer: List[Record] = []
self.back_log: List[Record] = []
# initialize file log
if not Path(".dabapush/").exists():
Path(".dabapush/").mkdir()

self.log_path = Path(f".dabapush/{config.name}.jsonl")
if self.log_path.exists():
log.debug(
f"Found log file for {self.config.name} at {self.log_path}. Loading..."
)
with self.log_path.open("rt", encoding="utf8") as f:
self.back_log = [
Record(**ujson.loads(_)) # pylint: disable=I1101
for _ in f.readlines()
]
else:
self.log_path.touch()
self.log_file = self.log_path.open( # pylint: disable=R1732
"a", encoding="utf8"
)

def __del__(self):
"""Ensures the buffer is flushed before the object is destroyed."""
self._trigger_persist()
self.log_file.close()

def write(self, queue: Iterator[Record]) -> None:
"""Consumes items from the provided queue.
Expand All @@ -39,16 +62,20 @@ def write(self, queue: Iterator[Record]) -> None:
queue (Iterator[Record]): Items to be consumed.
"""
for item in queue:
if item in self.back_log:
continue
self.buffer.append(item)
if len(self.buffer) >= self.config.chunk_size:
self._trigger_persist()

def _trigger_persist(self):
self.persist()
log.debug(f"Persisted {self.config.chunk_size} records. Setting to done.")
log.debug(f"Persisted {len(self.buffer)} records. Setting to done.")
for record in self.buffer:
log.debug(f"Setting record {record.uuid} as done.")
record.done()
self.log(record)
self.log_file.flush()
self.buffer = []

@abc.abstractmethod
Expand All @@ -72,3 +99,10 @@ def id(self):
str: The ID of the writer.
"""
return self.config.id

def log(self, record: Record):
"""Log the record to the persistent record log file."""
ujson.dump(record.to_log(), self.log_file) # pylint: disable=I1101
self.log_file.write("\n")

log.debug(f"Done with {record.uuid}")
31 changes: 3 additions & 28 deletions tests/Reader/test_JSONReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@

from dabapush.Reader.JSONReader import JSONReader, JSONReaderConfiguration


@pytest.fixture
def input_json_directory(isolated_test_dir):
"Pytest fixture creating a directory with 20 json files."
for idx in range(10,30):
for idx in range(10, 30):
file_path = isolated_test_dir / f"test_{idx}.json"
with file_path.open("wt") as out_file:
json.dump({"test_key": idx}, out_file)
out_file.write("\n")
return isolated_test_dir


def test_read(input_json_directory: Path): # pylint: disable=W0621
"""Should read the data from the file."""
reader = JSONReader(
Expand All @@ -28,30 +30,3 @@ def test_read(input_json_directory: Path): # pylint: disable=W0621
print(record)
assert record.processed_at
assert record.payload == {"test_key": int(record.uuid[-7:-5])}


def test_read_with_backlog(input_json_directory: Path): # pylint: disable=W0621
"""Should only read the new data."""
reader = JSONReaderConfiguration(
"test", read_path=str(input_json_directory.resolve()), pattern="*.json"
).get_instance()

def wrapper():
n = None
for n, record in enumerate(reader.read()):
record.done()
return n or 0

n = wrapper()

assert n + 1 == 20

reader2 = JSONReaderConfiguration(
"test", read_path=str(input_json_directory.resolve())
).get_instance()

records2 = list(reader2.read())
log_path = input_json_directory / ".dabapush/test.jsonl"
assert log_path.exists()
assert len(reader2.back_log) == 20
assert len(records2) == 0
32 changes: 0 additions & 32 deletions tests/Reader/test_NDJSONReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,3 @@ def test_read(isolated_test_dir: Path, data): # pylint: disable=W0621
for n, record in enumerate(records):
assert record.processed_at
assert record.payload == data[n]


def test_read_with_backlog(isolated_test_dir: Path, data): # pylint: disable=W0621
"""Should only read the new data."""
reader = NDJSONReaderConfiguration(
"test", read_path=str(isolated_test_dir.resolve()), pattern="*.ndjson"
).get_instance()
file_path = isolated_test_dir / "test.ndjson"
with file_path.open("wt") as file:
for line in data:
json.dump(line, file)
file.write("\n")

def wrapper():
n = None
for n, record in enumerate(reader.read()):
record.done()
return n or 0

n = wrapper()

assert n + 1 == 20

reader2 = NDJSONReaderConfiguration(
"test", read_path=str(isolated_test_dir.resolve())
).get_instance()

records2 = list(reader2.read())
log_path = isolated_test_dir / ".dabapush/test.jsonl"
assert log_path.exists()
assert len(reader2.back_log) == 20
assert len(records2) == 0
Loading
Loading