Skip to content

Commit

Permalink
Finish review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxDall committed Jan 30, 2024
1 parent 2ab29c2 commit b6b8eae
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 74 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"requests~=2.28.2",
"tqdm~=4.66.1",
"fastwarc~=0.14.5",
"ftfy~=6.1.3",
"chardet~=5.2.0",
"dill~=0.3.7"
]

Expand Down
81 changes: 50 additions & 31 deletions src/fundus/scraping/common_crawl/html.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, Iterator, Optional
from urllib.parse import urlparse

import chardet
import requests
from fastwarc import ArchiveIterator, WarcRecord, WarcRecordType
from ftfy import guess_bytes

from fundus.logging import basic_logger
from fundus.publishers.base_objects import PublisherEnum
Expand All @@ -22,18 +22,31 @@ def __init__(self, *publishers: PublisherEnum, warc_path: str, headers: Optional
}

def fetch(self, url_filter: Optional[URLFilter] = None) -> Iterator[HTML]:
domains = list(self._publisher_mapping)

def extract_content(record: WarcRecord) -> str:
def extract_content(record: WarcRecord) -> Optional[str]:
warc_body: bytes = record.reader.read()

try:
return str(warc_body, encoding=record.http_charset)
except (UnicodeDecodeError, TypeError):
basic_logger.warning(
f"Couldn't decode record {record.record_id!r} from {target_url!r} "
f"using charset {record.http_charset!r}."
)
return guess_bytes(warc_body)[0]
encoding: Optional[str] = chardet.detect(warc_body)["encoding"]

if encoding is not None:
basic_logger.debug(
f"Try decoding record {record.record_id!r} from {target_url!r} using "
f"detected encoding {encoding}."
)

try:
return str(warc_body, encoding=encoding)
except UnicodeDecodeError:
basic_logger.warning(
f"Couldn't decode record {record.record_id!r} from {target_url!r} with "
f"original charset {record.http_charset} using detected charset {encoding}."
)
else:
basic_logger.warning(f"Couldn't detect charset for record {record.record_id!r} from {target_url!r}")

return None

with requests.Session() as session:
stream = session.get(self.warc_path, stream=True, headers=self.headers).raw
Expand All @@ -44,27 +57,33 @@ def extract_content(record: WarcRecord) -> str:
if url_filter is not None and url_filter(target_url):
basic_logger.debug(f"Skipped WARC record with target URI {target_url!r} because of URL filter")
continue
elif (netloc := urlparse(target_url).netloc) in domains:
publisher = self._publisher_mapping[netloc]

if publisher.url_filter is not None and publisher.url_filter(target_url):
basic_logger.debug(
f"Skipped WARC record with target URI {target_url!r} because of "
f"publisher specific URL filter"
)
continue

content = extract_content(warc_record)
html = HTML(
requested_url=target_url,
responded_url=target_url,
content=content,
crawl_date=warc_record.record_date,
source=WarcSource(
publisher=publisher.publisher_name,
warc_path=self.warc_path,
warc_headers=dict(warc_record.headers),
http_headers=dict(warc_record.http_headers),
),
publisher_domain: str = urlparse(target_url).netloc

if publisher_domain not in self._publisher_mapping:
continue

publisher = self._publisher_mapping[publisher_domain]

if publisher.url_filter is not None and publisher.url_filter(target_url):
basic_logger.debug(
f"Skipped WARC record with target URI {target_url!r} because of "
f"publisher specific URL filter"
)
yield html
continue

if (content := extract_content(warc_record)) is None:
continue

yield HTML(
requested_url=target_url,
responded_url=target_url,
content=content,
crawl_date=warc_record.record_date,
source=WarcSource(
publisher=publisher.publisher_name,
warc_path=self.warc_path,
warc_headers=dict(warc_record.headers),
http_headers=dict(warc_record.http_headers),
),
)
93 changes: 51 additions & 42 deletions src/fundus/scraping/common_crawl/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

import asyncio
import gzip
import os
import re
from datetime import datetime
from functools import lru_cache, partial, wraps
from multiprocessing import Manager
from multiprocessing.context import TimeoutError
from multiprocessing.pool import MapResult, Pool, ThreadPool
from multiprocessing.pool import MapResult, Pool
from queue import Empty, Queue
from typing import (
Any,
Callable,
Coroutine,
Generic,
Iterator,
List,
Expand All @@ -25,11 +27,11 @@
cast,
)

import aiohttp
import dill
import more_itertools
import requests
from dateutil.rrule import MONTHLY, rrule
from tqdm import tqdm
from tqdm.asyncio import tqdm
from typing_extensions import ParamSpec

from fundus.publishers.base_objects import PublisherEnum
Expand Down Expand Up @@ -81,6 +83,31 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
return wrapper


def pool_queue_iter(handle: MapResult[Any], queue: Queue[_T]) -> Iterator[_T]:
"""Utility function to iterate exhaustively over a pool queue.
The underlying iterator of this function repeatedly exhausts the given queue.
Then, if the queue is empty only if all the pool's jobs have finished, the iterator reruns.
Otherwise, it waits for the queue to be populated with the next result from the pool.
Args:
handle (MapResult[Any]): A handle o the MappedResult of the underling multiprocessing pool.
queue (Queue[_T]): The pool queue.
Returns:
Iterator[_T]: The iterator over the queue as it is populated.
"""
while True:
try:
yield queue.get(timeout=0.1)
except Empty:
try:
handle.get(timeout=0.1)
except TimeoutError:
continue
return


class CCNewsCrawler:
def __init__(
self,
Expand All @@ -98,7 +125,7 @@ def __init__(
If 0, only the main process is used. Defaults to -1.
server_address: The CC-NEWS dataset server address. Defaults to 'https://data.commoncrawl.org/'.
"""
self.publishers = publishers
self.publishers = tuple(more_itertools.collapse(publishers))
self.processes = os.cpu_count() or 0 if processes == -1 else processes
self.server_address = server_address

Expand All @@ -120,19 +147,26 @@ def _get_warc_paths(self, start: datetime, end: datetime) -> List[str]:
f"{self.server_address}crawl-data/CC-NEWS/{date.strftime('%Y/%m')}/warc.paths.gz" for date in date_sequence
]

def load_paths(url: str) -> List[str]:
with requests.Session() as session:
return gzip.decompress(session.get(url).content).decode("utf-8").split()
async def load_warc_paths_from(url: str) -> List[str]:
async with aiohttp.ClientSession(raise_for_status=True) as session:
async with session.get(url) as response:
return gzip.decompress(await response.read()).decode("utf-8").split()

# running two threads per core
max_number_of_threads = 2 * (os.cpu_count() or 1)
load_warc_paths: Coroutine[Any, Any, List[List[str]]] = tqdm.gather(
*[load_warc_paths_from(url) for url in urls],
total=len(urls),
desc="Loading WARC paths",
leave=False,
)

with ThreadPool(processes=min(len(urls), max_number_of_threads)) as pool:
warc_paths = more_itertools.flatten(
list(
tqdm(pool.imap_unordered(load_paths, urls), total=len(urls), desc="Loading WARC paths", leave=False)
)
)
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
nested_warc_paths = asyncio.run(load_warc_paths)
else:
nested_warc_paths = event_loop.run_until_complete(load_warc_paths)

warc_paths: Iterator[str] = more_itertools.flatten(nested_warc_paths)

start_strf = start.strftime("%Y%m%d%H%M%S")
end_strf = end.strftime("%Y%m%d%H%M%S")
Expand All @@ -151,7 +185,7 @@ def filter_warc_path_by_date(path: str) -> bool:
@staticmethod
def _fetch_articles(
warc_path: str,
publishers: Tuple[PublisherEnum],
publishers: Tuple[PublisherEnum, ...],
error_handling: Literal["suppress", "catch", "raise"],
extraction_filter: Optional[ExtractionFilter] = None,
url_filter: Optional[URLFilter] = None,
Expand All @@ -170,7 +204,7 @@ def _parallel_crawl(
) -> Iterator[Article]:
# As one could think, because we're downloading a bunch of files, this task is IO-bound, but it is actually
# process-bound. The reason is that we stream the data and process it on the fly rather than downloading all
# files and processing them afterwards. Therefore, we utilize multiprocessing here instead of multithreading.
# files and processing them afterward. Therefore, we utilize multiprocessing here instead of multithreading.
with Manager() as manager, Pool(processes=min(self.processes, len(warc_paths))) as pool:
article_queue: Queue[Article] = manager.Queue()

Expand Down Expand Up @@ -268,28 +302,3 @@ def build_extraction_filter() -> Optional[ExtractionFilter]:
yield article
if article_idx == max_articles:
break


def pool_queue_iter(handle: MapResult[Any], queue: Queue[_T]) -> Iterator[_T]:
"""Utility function to iterate exhaustively over a pool queue.
The underlying iterator of this function repeatedly exhausts the given queue.
Then, if the queue is empty only if all the pool's jobs have finished, the iterator reruns.
Otherwise, it waits for the queue to be populated with the next result from the pool.
Args:
handle (MapResult[Any]): A handle o the MappedResult of the underling multiprocessing pool.
queue (Queue[_T]): The pool queue.
Returns:
Iterator[_T]: The iterator over the queue as it is populated.
"""
while True:
try:
yield queue.get(timeout=0.1)
except Empty:
try:
handle.get(timeout=0.1)
except TimeoutError:
continue
return

0 comments on commit b6b8eae

Please sign in to comment.