Skip to content

Commit

Permalink
added tests of remaining transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Feb 5, 2025
1 parent c589f11 commit e7ee9ec
Show file tree
Hide file tree
Showing 20 changed files with 345 additions and 166 deletions.
4 changes: 3 additions & 1 deletion packages/langchain-graph-retriever/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ simsimd = "simsimd"
spacy = "spacy"
testcontainers = "testcontainers"
tqdm = "tqdm"
types-beautifulsoup4 = "types_beautifulsoup4"
typing-extensions = "typing_extensions"

[tool.deptry.per_rule_ignores]
Expand All @@ -96,7 +97,7 @@ astra = [
"httpx>=0.28.1",
"langchain-astradb>=0.5.3",
]
beautifulsoup4 = [
html = [
"beautifulsoup4>=4.12.3",
]
cassandra = [
Expand Down Expand Up @@ -139,5 +140,6 @@ dev = [
"simsimd>=6.2.1",
"testcontainers>=4.9.0",
"langchain-tests>=0.3.8",
"types-beautifulsoup4>=4.12.0.20250204",
]

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

from langchain_graph_retriever._conversion import doc_to_content
from langchain_graph_retriever.document_transformers.metadata_denormalizer import (
from langchain_graph_retriever.transformers.metadata_denormalizer import (
MetadataDenormalizer,
)

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from collections.abc import Sequence
from typing import Any

from gliner import GLiNER # type: ignore
from langchain_core.documents import BaseDocumentTransformer, Document
from typing_extensions import override


class GLiNEREntityExtractor(BaseDocumentTransformer):
class GLiNERTransformer(BaseDocumentTransformer):
"""
Add metadata to documents about named entities using `GLiNER`_.
Add metadata to documents about named entities using `GLiNER`.
`GLiNER`_ is a Named Entity Recognition (NER) model capable of identifying any
`GLiNER` is a Named Entity Recognition (NER) model capable of identifying any
entity type using a bidirectional transformer encoder (BERT-like).
Preliminaries
Expand All @@ -20,9 +21,9 @@ class GLiNEREntityExtractor(BaseDocumentTransformer):
Note that ``bs4`` is also installed to support the WebBaseLoader in the example,
but not needed by the GLiNEREntityExtractor itself.
.. code-block:: bash
pip install -q langchain_community bs4 gliner
```
pip install -q langchain_community bs4 gliner
```
Example
-------
Expand Down Expand Up @@ -62,8 +63,9 @@ class GLiNEREntityExtractor(BaseDocumentTransformer):
A prefix to add to metadata keys outputted by the extractor.
This will be prepended to the label, with the value (or values) holding the
generated keywords for that entity kind.
model : str, default "urchade/gliner_mediumv2.1"
The GLiNER model to use.
model : str | GLiNER, default "urchade/gliner_mediumv2.1"
The GLiNER model to use. Pass the name of the model to load
or pass an instantiated GLiNER model instance.
""" # noqa: E501

Expand All @@ -73,18 +75,14 @@ def __init__(
*,
batch_size: int = 8,
metadata_key_prefix: str = "",
model: str = "urchade/gliner_mediumv2.1",
model: Any = "urchade/gliner_mediumv2.1",
):
try:
from gliner import GLiNER # type: ignore

if isinstance(model, GLiNER):
self._model = model
elif isinstance(model, str):
self._model = GLiNER.from_pretrained(model)

except ImportError:
raise ImportError(
"gliner is required for the GLiNEREntityExtractor. "
"Please install it with `pip install gliner`."
) from None
else:
raise ValueError(f"Invalid model: {model}")

self._batch_size = batch_size
self._labels = labels
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from typing import Any
from urllib.parse import urldefrag, urljoin, urlparse

from langchain_core._api import beta
from bs4 import BeautifulSoup
from bs4.element import Tag
from langchain_core.documents import BaseDocumentTransformer, Document
from typing_extensions import override

if TYPE_CHECKING:
from bs4 import BeautifulSoup # type: ignore
from bs4.element import Tag # type: ignore


@beta()
class HtmlHyperlinkExtractor(BaseDocumentTransformer):
class HtmlHyperlinkTransformer(BaseDocumentTransformer):
"""
Extract hyperlinks from HTML content.
Expand Down Expand Up @@ -207,14 +203,6 @@ def __init__(
metadata_key: str = "hyperlink",
drop_fragments: bool = True,
):
try:
from bs4 import BeautifulSoup # noqa:F401
except ImportError as e:
raise ImportError(
"BeautifulSoup4 is required for HtmlHyperlinkExtractor. "
"Please install it with `pip install beautifulsoup4`."
) from e

self._url_metadata_key = url_metadata_key
self._metadata_key = metadata_key
self._drop_fragments = drop_fragments
Expand All @@ -224,20 +212,25 @@ def _parse_url(link: Tag, page_url: str, drop_fragments: bool = True) -> str | N
href = link.get("href")
if href is None:
return None
if isinstance(href, list) and len(href) == 1:
href = href[0]
if not isinstance(href, str):
return None

url = urlparse(href)
if url.scheme not in ["http", "https", ""]:
return None

# Join the HREF with the page_url to convert relative paths to absolute.
url = str(urljoin(page_url, href))
joined_url = str(urljoin(page_url, href))

# Fragments would be useful if we chunked a page based on section.
# Then, each chunk would have a different URL based on the fragment.
# Since we aren't doing that yet, they just "break" links. So, drop
# the fragment.
if drop_fragments:
return urldefrag(url).url
return url
return urldefrag(joined_url).url
return joined_url

@staticmethod
def _parse_urls(
Expand All @@ -247,7 +240,7 @@ def _parse_urls(
urls: set[str] = set()

for link in soup_links:
parsed_url = HtmlHyperlinkExtractor._parse_url(
parsed_url = HtmlHyperlinkTransformer._parse_url(
link, page_url=page_url, drop_fragments=drop_fragments
)
# Remove self links and entries for any 'a' tag that failed to parse
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections.abc import Sequence
from typing import Any

from keybert import KeyBERT # type: ignore
from langchain_core.documents import BaseDocumentTransformer, Document
from typing_extensions import override


class KeybertKeywordExtractor(BaseDocumentTransformer):
class KeyBERTTransformer(BaseDocumentTransformer):
"""
Add metadata to documents about keywords using `KeyBERT <https://maartengr.github.io/KeyBERT/>`_.
Expand Down Expand Up @@ -71,18 +72,14 @@ def __init__(
*,
batch_size: int = 8,
metadata_key: str = "keywords",
model: str = "all-MiniLM-L6-v2",
model: str | KeyBERT = "all-MiniLM-L6-v2",
):
try:
import keybert # type: ignore

self._kw_model = keybert.KeyBERT(model=model)
except ImportError:
raise ImportError(
"keybert is required for the KeybertLinkExtractor. "
"Please install it with `pip install keybert`."
) from None

if isinstance(model, KeyBERT):
self._kw_model = model
elif isinstance(model, str):
self._kw_model = KeyBERT(model=model)
else:
raise ValueError(f"Invalid model: {model}")
self._batch_size = batch_size
self._metadata_key = metadata_key

Expand Down
Loading

0 comments on commit e7ee9ec

Please sign in to comment.