Skip to content

Commit

Permalink
Merge pull request #7 from megagonlabs/feature/setup_model
Browse files Browse the repository at this point in the history
add setup_model.py
  • Loading branch information
hiroshi-matsuda-rit authored Dec 5, 2021
2 parents 9ef017e + f38d44e commit b297a9e
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 17 deletions.
83 changes: 66 additions & 17 deletions ginza_transformers/layers/hf_shim_custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from typing import Any
from io import BytesIO
from pathlib import Path
import srsly
Expand All @@ -15,6 +14,17 @@
from transformers import AutoModel, AutoConfig, AutoTokenizer


def override_hf_shims_to_bytes():
assert hf_shim.HFShim.to_bytes is not HFShimCustom.to_bytes
origin = hf_shim.HFShim.to_bytes
hf_shim.HFShim.to_bytes = HFShimCustom.to_bytes
return origin

def recover_hf_shims_to_bytes(origin):
assert hf_shim.HFShim.to_bytes is HFShimCustom.to_bytes
hf_shim.HFShim.to_bytes = origin


def override_hf_shims_from_bytes():
assert hf_shim.HFShim.from_bytes is not HFShimCustom.from_bytes
origin = hf_shim.HFShim.from_bytes
Expand All @@ -28,6 +38,44 @@ def recover_hf_shims_from_bytes(origin):

class HFShimCustom(HFShim):

def to_bytes(self):
config = {}
tok_dict = {}
# weights_bytes = {}
tok_cfg = {}
trf_cfg = {}
hf_model = self._hfmodel
if hf_model.transformer is not None:
tok_dict = {}
config = hf_model.transformer.config.to_dict()
tokenizer = hf_model.tokenizer
with make_tempdir() as temp_dir:
if hasattr(tokenizer, "vocab_file"):
vocab_file_name = tokenizer.vocab_files_names["vocab_file"]
vocab_file_path = str((temp_dir / vocab_file_name).absolute())
with open(vocab_file_path, "wb") as fileh:
fileh.write(hf_model.vocab_file_contents)
tokenizer.vocab_file = vocab_file_path
tokenizer.save_pretrained(str(temp_dir.absolute()))
for x in temp_dir.glob("**/*"):
if x.is_file():
tok_dict[x.name] = x.read_bytes()
filelike = BytesIO()
torch.save(self._model.state_dict(), filelike)
filelike.seek(0)
# weights_bytes = filelike.getvalue()
else:
tok_cfg = hf_model._init_tokenizer_config
trf_cfg = hf_model._init_transformer_config
msg = {
"config": config,
# "state": weights_bytes,
"tokenizer": tok_dict,
"_init_tokenizer_config": tok_cfg,
"_init_transformer_config": trf_cfg,
}
return srsly.msgpack_dumps(msg)

def from_bytes(self, bytes_data):
msg = srsly.msgpack_loads(bytes_data)
config_dict = msg["config"]
Expand Down Expand Up @@ -62,34 +110,35 @@ def from_bytes(self, bytes_data):
with open(vocab_file_path, "rb") as fileh:
vocab_file_contents = fileh.read()

try:
ops = get_current_ops()
if ops.device_type == "cpu":
map_location = "cpu"
else: # pragma: no cover
device_id = torch.cuda.current_device()
map_location = f"cuda:{device_id}"

if "state" in msg:
transformer = AutoModel.from_config(config)
except OSError as e:
filelike = BytesIO(msg["state"])
filelike.seek(0)
transformer.load_state_dict(torch.load(filelike, map_location=map_location))
else:
try:
transformer = AutoModel.from_pretrained(config["_name_or_path"], local_files_only=True)
transformer = AutoModel.from_pretrained(config._name_or_path, local_files_only=True)
except OSError as e2:
print("trying to download model from huggingface hub:", config["_name_or_path"], "...", file=sys.stderr)
transformer = AutoModel.from_pretrained(config["_name_or_path"])
print("trying to download model from huggingface hub:", config._name_or_path, "...", file=sys.stderr)
transformer = AutoModel.from_pretrained(config._name_or_path)
print("succeded", file=sys.stderr)

transformer.to(map_location)
self._model = transformer
self._hfmodel = HFObjects(
tokenizer,
transformer,
vocab_file_contents,
SimpleFrozenDict(),
SimpleFrozenDict(),
)
self._model = transformer
filelike = BytesIO(msg["state"])
filelike.seek(0)
ops = get_current_ops()
if ops.device_type == "cpu":
map_location = "cpu"
else: # pragma: no cover
device_id = torch.cuda.current_device()
map_location = f"cuda:{device_id}"
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
self._model.to(map_location)
else:
self._hfmodel = HFObjects(
None,
Expand Down
31 changes: 31 additions & 0 deletions ginza_transformers/setup_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys

import spacy

from .layers.hf_shim_custom import override_hf_shims_to_bytes, recover_hf_shims_to_bytes


def main():
org_spacy_model_path = sys.argv[1]
dst_spacy_model_path = sys.argv[2]
transformers_model_name = sys.argv[3]

nlp = spacy.load(org_spacy_model_path)
transformer = nlp.get_pipe("transformer")
for i, node in enumerate(transformer.model.walk()):
if node.shims:
break
else:
assert False
node.shims[0]._hfmodel.transformer.config._name_or_path = transformers_model_name
node.shims[0]._hfmodel.tokenizer.save_pretrained(transformers_model_name)
node.shims[0]._hfmodel.transformer.save_pretrained(transformers_model_name)
override_hf_shims_to_bytes()
try:
origin = nlp.to_disk(dst_spacy_model_path)
finally:
recover_hf_shims_to_bytes(origin)


if __name__ == "__main__":
main()

0 comments on commit b297a9e

Please sign in to comment.