Skip to content

Commit

Permalink
Merge pull request #347 from flairNLP/remove-asyncio-from-ccnews
Browse files Browse the repository at this point in the history
Replace `asyncio` with thread-based solution for WARC-path download
  • Loading branch information
MaxDall authored Feb 2, 2024
2 parents 678f303 + d8b64ee commit 5f5ad7e
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions src/fundus/scraping/common_crawl/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from __future__ import annotations

import asyncio
import gzip
import os
import re
from datetime import datetime
from functools import lru_cache, partial, wraps
from multiprocessing import Manager
from multiprocessing.context import TimeoutError
from multiprocessing.pool import MapResult, Pool
from multiprocessing.pool import MapResult, Pool, ThreadPool
from queue import Empty, Queue
from typing import (
Any,
Callable,
Coroutine,
Generic,
Iterator,
List,
Expand All @@ -27,11 +25,11 @@
cast,
)

import aiohttp
import dill
import more_itertools
import requests
from dateutil.rrule import MONTHLY, rrule
from tqdm.asyncio import tqdm
from tqdm import tqdm
from typing_extensions import ParamSpec

from fundus.publishers.base_objects import PublisherEnum
Expand Down Expand Up @@ -147,24 +145,22 @@ def _get_warc_paths(self, start: datetime, end: datetime) -> List[str]:
f"{self.server_address}crawl-data/CC-NEWS/{date.strftime('%Y/%m')}/warc.paths.gz" for date in date_sequence
]

async def load_warc_paths_from(url: str) -> List[str]:
async with aiohttp.ClientSession(raise_for_status=True) as session:
async with session.get(url) as response:
return gzip.decompress(await response.read()).decode("utf-8").split()
with tqdm(total=len(urls), desc="Loading WARC Paths", leave=False) as bar:

load_warc_paths: Coroutine[Any, Any, List[List[str]]] = tqdm.gather(
*[load_warc_paths_from(url) for url in urls],
total=len(urls),
desc="Loading WARC paths",
leave=False,
)
def load_paths(url: str) -> List[str]:
with requests.Session() as session:
paths = gzip.decompress(session.get(url).content).decode("utf-8").split()
bar.update()
return paths

try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
nested_warc_paths = asyncio.run(load_warc_paths)
else:
nested_warc_paths = event_loop.run_until_complete(load_warc_paths)
if self.processes == 0:
nested_warc_paths = [load_paths(url) for url in urls]
else:
# use two threads per process, default two threads per core
max_number_of_threads = self.processes * 2

with ThreadPool(processes=min(len(urls), max_number_of_threads)) as pool:
nested_warc_paths = pool.map(load_paths, urls)

warc_paths: Iterator[str] = more_itertools.flatten(nested_warc_paths)

Expand Down

0 comments on commit 5f5ad7e

Please sign in to comment.