diff --git a/src/fundus/parser/data.py b/src/fundus/parser/data.py index 1104176e4..f50c1a111 100644 --- a/src/fundus/parser/data.py +++ b/src/fundus/parser/data.py @@ -199,12 +199,8 @@ def as_text_sequence(self) -> TextSequence: texts = [text for tl in self.df_traversal() for text in tl] return TextSequence(texts) - def text(self, join_on: str = "\n\n", strip_text: bool = True) -> str: - if strip_text: - striped_texts = [" ".join(text.split()) for text in self.as_text_sequence()] - return join_on.join(striped_texts) - else: - return join_on.join(self.as_text_sequence()) + def text(self, join_on: str = "\n\n") -> str: + return join_on.join(self.as_text_sequence()) def df_traversal(self) -> Iterable[TextSequence]: def recursion(o: object): diff --git a/src/fundus/parser/utility.py b/src/fundus/parser/utility.py index 7a9c1a335..390d9f8cc 100644 --- a/src/fundus/parser/utility.py +++ b/src/fundus/parser/utility.py @@ -17,7 +17,6 @@ cast, ) -import dateutil.tz import lxml.html import more_itertools from dateutil import parser @@ -27,6 +26,10 @@ from fundus.parser.data import ArticleBody, ArticleSection, TextSequence +def normalize_whitespace(text: str) -> str: + return " ".join(text.split()) + + @total_ordering @dataclass(eq=False) class Node: @@ -34,8 +37,21 @@ class Node: node: lxml.html.HtmlElement = field(compare=False) _break_selector: ClassVar[XPath] = XPath("*//br") - def striped(self, chars: Optional[str] = None) -> str: - return str(self).strip(chars) + # one could replace this recursion with XPath using an expression like this: + # //*[not(self::script) and text()]/text(), but for whatever reason, that's actually 50-150% slower + # than simply using the implemented mixture below + def text_content(self, excluded_tags: Optional[List[str]] = None) -> str: + guarded_excluded_tags: List[str] = excluded_tags or [] + + def _text_content(element: lxml.html.HtmlElement) -> str: + if element.tag in guarded_excluded_tags: + return "" + text = element.text or "" if not isinstance(element, lxml.html.HtmlComment) else "" + children = "".join([_text_content(child) for child in element.iterchildren()]) + tail = element.tail or "" + return text + children + tail + + return _text_content(self._get_break_preserved_node()) def _get_break_preserved_node(self) -> lxml.html.HtmlElement: copied_node = copy(self.node) @@ -55,10 +71,10 @@ def __hash__(self) -> int: return hash(self.position) def __str__(self) -> str: - return self._get_break_preserved_node().text_content() + return self.text_content() def __bool__(self): - return bool(self.striped()) + return bool(normalize_whitespace(self.text_content())) class SummaryNode(Node): @@ -106,13 +122,15 @@ def extract_nodes(selector: XPath, node_type: Type[Node]) -> List[Node]: first = next(instructions) instructions = itertools.chain([first, []], instructions) - summary = TextSequence(map(lambda x: x.striped("\n"), next(instructions))) + summary = TextSequence( + map(lambda x: normalize_whitespace(x.text_content(excluded_tags=["script"])), next(instructions)) + ) sections: List[ArticleSection] = [] for chunk in more_itertools.chunked(instructions, 2): if len(chunk) == 1: chunk.append([]) - texts = [list(map(lambda x: x.striped("\n"), c)) for c in chunk] + texts = [list(map(lambda x: normalize_whitespace(x.text_content(excluded_tags=["script"])), c)) for c in chunk] sections.append(ArticleSection(*map(TextSequence, texts))) return ArticleBody(summary=summary, sections=sections)