From d8b64ee9af6a62d7d12d0b75b4d20e7011a54028 Mon Sep 17 00:00:00 2001 From: Max Dallabetta Date: Thu, 1 Feb 2024 16:52:40 +0100 Subject: [PATCH] replace asyncio with a thread based solution to download warc paths --- src/fundus/scraping/common_crawl/pipeline.py | 38 +++++++++----------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/fundus/scraping/common_crawl/pipeline.py b/src/fundus/scraping/common_crawl/pipeline.py index ea1e3ce24..4aa66070c 100644 --- a/src/fundus/scraping/common_crawl/pipeline.py +++ b/src/fundus/scraping/common_crawl/pipeline.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import gzip import os import re @@ -8,12 +7,11 @@ from functools import lru_cache, partial, wraps from multiprocessing import Manager from multiprocessing.context import TimeoutError -from multiprocessing.pool import MapResult, Pool +from multiprocessing.pool import MapResult, Pool, ThreadPool from queue import Empty, Queue from typing import ( Any, Callable, - Coroutine, Generic, Iterator, List, @@ -27,11 +25,11 @@ cast, ) -import aiohttp import dill import more_itertools +import requests from dateutil.rrule import MONTHLY, rrule -from tqdm.asyncio import tqdm +from tqdm import tqdm from typing_extensions import ParamSpec from fundus.publishers.base_objects import PublisherEnum @@ -147,24 +145,22 @@ def _get_warc_paths(self, start: datetime, end: datetime) -> List[str]: f"{self.server_address}crawl-data/CC-NEWS/{date.strftime('%Y/%m')}/warc.paths.gz" for date in date_sequence ] - async def load_warc_paths_from(url: str) -> List[str]: - async with aiohttp.ClientSession(raise_for_status=True) as session: - async with session.get(url) as response: - return gzip.decompress(await response.read()).decode("utf-8").split() + with tqdm(total=len(urls), desc="Loading WARC Paths", leave=False) as bar: - load_warc_paths: Coroutine[Any, Any, List[List[str]]] = tqdm.gather( - *[load_warc_paths_from(url) for url in urls], - total=len(urls), - desc="Loading WARC paths", - leave=False, - ) + def load_paths(url: str) -> List[str]: + with requests.Session() as session: + paths = gzip.decompress(session.get(url).content).decode("utf-8").split() + bar.update() + return paths - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - nested_warc_paths = asyncio.run(load_warc_paths) - else: - nested_warc_paths = event_loop.run_until_complete(load_warc_paths) + if self.processes == 0: + nested_warc_paths = [load_paths(url) for url in urls] + else: + # use two threads per process, default two threads per core + max_number_of_threads = self.processes * 2 + + with ThreadPool(processes=min(len(urls), max_number_of_threads)) as pool: + nested_warc_paths = pool.map(load_paths, urls) warc_paths: Iterator[str] = more_itertools.flatten(nested_warc_paths)