Skip to content

Commit

Permalink
resolve review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxDall committed Jan 23, 2024
1 parent c5e6d60 commit e716fa3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 89 deletions.
36 changes: 20 additions & 16 deletions src/fundus/scraping/common_crawl/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from fundus.logging import basic_logger
from fundus.publishers.base_objects import PublisherEnum
from fundus.scraping.filter import URLFilter
from fundus.scraping.html import HTML, WarcSource
from fundus.scraping.html import HTML, WarcSource, _default_header


class CCNewsSource:
def __init__(self, *publishers: PublisherEnum, warc_path: str, headers: Optional[Dict[str, str]] = None):
self.publishers = publishers
self.warc_path = warc_path
self.headers = headers or {}
self.headers = headers or _default_header

self._publisher_mapping: Dict[str, PublisherEnum] = {
urlparse(publisher.domain).netloc: publisher for publisher in publishers
Expand All @@ -24,37 +24,41 @@ def __init__(self, *publishers: PublisherEnum, warc_path: str, headers: Optional
def fetch(self, url_filter: Optional[URLFilter] = None) -> Iterator[HTML]:
domains = list(self._publisher_mapping)

def extract_body(record: WarcRecord) -> str:
raw_body = record.reader.read()
def extract_content(record: WarcRecord) -> str:
warc_body: bytes = record.reader.read()
try:
return str(raw_body, encoding=record.http_charset)
return str(warc_body, encoding=record.http_charset)
except (UnicodeDecodeError, TypeError):
return guess_bytes(raw_body)[0]
basic_logger.warning(
f"Couldn't decode record {record.record_id!r} from {target_url!r} "
f"using charset {record.http_charset!r}."
)
return guess_bytes(warc_body)[0]

with requests.Session() as session:
stream = session.get(self.warc_path, stream=True, headers=self.headers).raw

for warc_record in ArchiveIterator(stream, record_types=WarcRecordType.response, verify_digests=True):
target_uri = str(warc_record.headers["WARC-Target-URI"])
target_url = str(warc_record.headers["WARC-Target-URI"])

if url_filter is not None and url_filter(target_uri):
basic_logger.debug(f"Skipped WARC record with target URI {target_uri!r} because of URL filter")
if url_filter is not None and url_filter(target_url):
basic_logger.debug(f"Skipped WARC record with target URI {target_url!r} because of URL filter")
continue
elif (netloc := urlparse(target_uri).netloc) in domains:
elif (netloc := urlparse(target_url).netloc) in domains:
publisher = self._publisher_mapping[netloc]

if publisher.url_filter is not None and publisher.url_filter(target_uri):
if publisher.url_filter is not None and publisher.url_filter(target_url):
basic_logger.debug(
f"Skipped WARC record with target URI {target_uri!r} because of "
f"Skipped WARC record with target URI {target_url!r} because of "
f"publisher specific URL filter"
)
continue

body = extract_body(warc_record)
content = extract_content(warc_record)
html = HTML(
requested_url=target_uri,
responded_url=target_uri,
content=body,
requested_url=target_url,
responded_url=target_url,
content=content,
crawl_date=warc_record.record_date,
source=WarcSource(
publisher=publisher.publisher_name,
Expand Down
134 changes: 62 additions & 72 deletions src/fundus/scraping/common_crawl/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import os
import re
from datetime import datetime
from functools import lru_cache, partial
from functools import lru_cache, partial, wraps
from multiprocessing import Manager
from multiprocessing.context import TimeoutError
from multiprocessing.pool import MapResult, Pool, ThreadPool
from queue import Empty, Queue
from typing import (
Any,
Callable,
Iterable,
Generic,
Iterator,
List,
Literal,
Expand Down Expand Up @@ -43,7 +43,7 @@


# noinspection PyPep8Naming
class dill_wrapper(Callable[P, _T]): # type: ignore[misc]
class dill_wrapper(Generic[P, _T]):
def __init__(self, target: Callable[P, _T]):
"""Wraps function in dill serialization.
Expand All @@ -62,27 +62,49 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> _T:
return self._deserialize()(*args, **kwargs)


def queue_wrapper(queue: Queue[_T], target: Callable[P, Iterator[_T]]) -> Callable[P, None]:
"""Wraps the target callable to add its results to the queue instead of returning them directly.
Args:
queue: (Queue[_T]) The buffer queue.
target: (Callable[P, Iterator[_T]]) A target callable.
Returns:
(Callable[P, None]) The wrapped target.
"""

@wraps(target)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
for obj in target(*args, **kwargs):
queue.put(obj)

return wrapper


class CCNewsCrawler:
def __init__(
self,
*publishers: PublisherEnum,
processes: Optional[int] = None,
processes: int = -1,
server_address: str = "https://data.commoncrawl.org/",
):
"""Initializes a crawler for the CC-NEWS dataset.
Args:
*publishers (PublisherEnum): The publishers to crawl.
processes: Number of process to use for crawling. If None, use os.cpu_count(); if -1 omit multiprocessing
entirely. Defaults to None
processes: Number of additional process to use for crawling.
If -1, the number of processes is set to `os.cpu_count()`.
If `os.cpu_count()` is not available, the number of processes is set to 0.
If 0, only the main process is used. Defaults to -1.
server_address: The CC-NEWS dataset server address. Defaults to 'https://data.commoncrawl.org/'.
"""
self.publishers = publishers
self.processes = processes or os.cpu_count() or -1
self.processes = os.cpu_count() or 0 if processes == -1 else processes
self.server_address = server_address

def _get_list_of_warc_path(self, start: datetime, end: datetime) -> List[str]:
date_pattern: Pattern[str] = re.compile(r"CC-NEWS-(?P<date>\d{14})-\d{5}")
def _get_warc_paths(self, start: datetime, end: datetime) -> List[str]:
# https://regex101.com/r/yDX3G6/1
date_pattern: Pattern[str] = re.compile(r"CC-NEWS-(?P<date>\d{14})-")

if start >= end:
raise ValueError("Start date has to be < end date.")
Expand All @@ -100,23 +122,27 @@ def _get_list_of_warc_path(self, start: datetime, end: datetime) -> List[str]:

def load_paths(url: str) -> List[str]:
with requests.Session() as session:
paths = gzip.decompress(session.get(url).content).decode("utf-8").split()
bar.update()
return paths
return gzip.decompress(session.get(url).content).decode("utf-8").split()

with ThreadPool(processes=len(urls)) as pool, tqdm(total=len(urls), desc="Load WARC paths", leave=False) as bar:
warc_paths = more_itertools.flatten(pool.map(load_paths, urls))
# running two threads per core
max_number_of_threads = 2 * (os.cpu_count() or 1)

with ThreadPool(processes=min(len(urls), max_number_of_threads)) as pool:
warc_paths = more_itertools.flatten(
list(
tqdm(pool.imap_unordered(load_paths, urls), total=len(urls), desc="Loading WARC paths", leave=False)
)
)

start_strf = start.strftime("%Y%m%d%H%M%S")
end_strf = end.strftime("%Y%m%d%H%M%S")

def filter_warc_path_by_date(path: str) -> bool:
match: Optional[Match[str]] = date_pattern.search(path)
match: Optional[re.Match[str]] = date_pattern.search(path)
if match is None:
raise AssertionError(f"Invalid WARC path {path!r}")
return start_strf <= match["date"] <= end_strf


return sorted(
[self.server_address + warc_path for warc_path in filter(filter_warc_path_by_date, warc_paths)],
reverse=True,
Expand All @@ -135,35 +161,18 @@ def _fetch_articles(
yield from scraper.scrape(error_handling, extraction_filter, url_filter)

@staticmethod
def _queue_wrapper(queue: Queue[_T], target: Callable[P, Iterator[_T]]) -> Callable[P, None]:
"""Wraps the target callable to add its results to the queue instead of returning them directly.
Args:
queue: (Queue[_T]) The buffer queue.
target: (Callable[P, Iterator[_T]]) A target callable.
Returns:
(Callable[P, None]) The wrapped target.
"""

@functools.wraps(target)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
for obj in target(*args, **kwargs):
queue.put(obj)

return wrapper

@staticmethod
def _single_crawl(warc_paths: List[str], target: Callable[[str], Iterator[Article]]) -> Iterator[Article]:
def _single_crawl(warc_paths: List[str], article_task: Callable[[str], Iterator[Article]]) -> Iterator[Article]:
for warc_path in warc_paths:
yield from target(warc_path)
yield from article_task(warc_path)

def _parallel_crawl(self, warc_paths: List[str], target: Callable[[str], Iterator[Article]]) -> Iterator[Article]:
def _parallel_crawl(
self, warc_paths: List[str], article_task: Callable[[str], Iterator[Article]]
) -> Iterator[Article]:
with Manager() as manager, Pool(processes=min(self.processes, len(warc_paths))) as pool:
article_queue: Queue[Article] = manager.Queue()
wrapped_target: Callable[[str], None] = self._queue_wrapper(article_queue, target)
serialized_target = dill_wrapper(wrapped_target)
yield from _PoolResult(pool.map_async(serialized_target, warc_paths), article_queue)
wrapped_article_task: Callable[[str], None] = queue_wrapper(article_queue, article_task)
serialized_article_task = dill_wrapper(wrapped_article_task)
yield from pool_queue_iter(pool.map_async(serialized_article_task, warc_paths), article_queue)

def crawl(
self,
Expand Down Expand Up @@ -211,7 +220,7 @@ def crawl(

if max_articles == 0:
return

if max_articles is None:
max_articles = -1

Expand All @@ -227,7 +236,7 @@ def build_extraction_filter() -> Optional[ExtractionFilter]:
else:
return only_complete

warc_paths = self._get_list_of_warc_path(start, end)
warc_paths = self._get_warc_paths(start, end)
response_cache: Set[str] = set()

article_task: Callable[[str], Iterator[Article]] = partial(
Expand All @@ -238,7 +247,7 @@ def build_extraction_filter() -> Optional[ExtractionFilter]:
url_filter=url_filter,
)

if self.processes == -1:
if self.processes == 0:
article_iter = self._single_crawl(warc_paths, article_task)
else:
article_iter = self._parallel_crawl(warc_paths, article_task)
Expand All @@ -251,32 +260,13 @@ def build_extraction_filter() -> Optional[ExtractionFilter]:
break


class _PoolResult(Iterable[_T]):
def __init__(self, result: MapResult[Any], queue: Queue[_T]):
"""Utility class to iterate a pool queue.
Exhaust the queue given with <queue>. If <queue> raises Empty and the pool finished,
raise StopIteration, otherwise continue to wait for the next result from <queue>.
Args:
result (MapResult[Any]): A MapResult returned by a Pool.map/.map_async call to use as a handle.
queue (Queue[_T): The queue to exhaust.
"""
self._handle = result
self._queue = queue

def __next__(self) -> _T:
while True:
def pool_queue_iter(handle: MapResult[Any], queue: Queue[_T]) -> Iterator[_T]:
while True:
try:
yield queue.get(timeout=0.1)
except Empty:
try:
return self._queue.get(timeout=0.1)
except Empty:
try:
self._handle.get(timeout=0.1)
except TimeoutError:
continue
else:
break
raise StopIteration

def __iter__(self) -> Iterator[_T]:
return self
handle.get(timeout=0.1)
except TimeoutError:
continue
return
2 changes: 1 addition & 1 deletion src/fundus/scraping/common_crawl/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def scrape(
if error_handling == "raise":
error_message = f"Run into an error processing article '{html.requested_url}'"
basic_logger.error(error_message)
err.args = (f"{err}\n\n{error_message},)
err.args = (f"{err}\n\n{error_message}",)
raise err
elif error_handling == "catch":
yield Article(html=html, exception=err)
Expand Down

0 comments on commit e716fa3

Please sign in to comment.