Skip to content

Commit

Permalink
Merge pull request #10 from megagonlabs/spacy_v3.4
Browse files Browse the repository at this point in the history
migrate to spacy v3.4
  • Loading branch information
hiroshi-matsuda-rit authored Aug 9, 2022
2 parents 727739d + ac9921c commit b2ce359
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
13 changes: 8 additions & 5 deletions ginza_transformers/layers/hf_shim_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,14 @@ def from_bytes(self, bytes_data):
transformer.load_state_dict(torch.load(filelike, map_location=map_location))
else:
try:
transformer = AutoModel.from_pretrained(config._name_or_path, local_files_only=True)
except OSError as e2:
print("trying to download model from huggingface hub:", config._name_or_path, "...", file=sys.stderr)
transformer = AutoModel.from_pretrained(config._name_or_path)
print("succeded", file=sys.stderr)
transformer = AutoModel.from_pretrained(config_dict["_name_or_path"], local_files_only=True)
except OSError as e1:
try:
transformer = AutoModel.from_pretrained(config._name_or_path)
except OSError as e2:
print("trying to download model from huggingface hub:", config_dict["_name_or_path"], "...", file=sys.stderr)
transformer = AutoModel.from_pretrained(config_dict["_name_or_path"])
print("succeded", file=sys.stderr)

transformer.to(map_location)
self._model = transformer
Expand Down
6 changes: 3 additions & 3 deletions ginza_transformers/layers/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def huggingface_from_pretrained_custom(
try:
trf_config["return_dict"] = True
config = AutoConfig.from_pretrained(str_path, **trf_config)
transformer = AutoModel.from_pretrained(str_path, config=config)
except OSError as e:
transformer = AutoModel.from_pretrained(model_name, local_files_only=True)
except OSError as e1:
try:
transformer = AutoModel.from_pretrained(str_path, local_files_only=True)
transformer = AutoModel.from_pretrained(str_path, config=config)
except OSError as e2:
model_name = str(source)
print("trying to download model from huggingface hub:", model_name, "...", file=sys.stderr)
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
},
install_requires=[
"spacy-transformers>=1.1.2,<1.2.0",
"transformers<4.13.0",
],
license="MIT",
name="ginza-transformers",
packages=find_packages(include=["ginza_transformers", "ginza_transformers.layers"]),
url="https://github.com/megagonlabs/ginza-transformers",
version='0.4.1',
version='0.4.2',
)

0 comments on commit b2ce359

Please sign in to comment.