Skip to content

Commit 44500eb

Browse files
authoredJun 9, 2023
Fix TS model for BLOOM architecture (#344)
1 parent 571f6c3 commit 44500eb

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed
 

‎optimum/intel/generation/modeling.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from ..utils.constant import _TASK_ALIASES
3333
from ..utils.import_utils import is_torch_version, is_transformers_version
34+
from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask
3435

3536

3637
if is_transformers_version("<", "4.25.0"):
@@ -266,6 +267,13 @@ def _from_transformers(
266267
}
267268

268269
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
270+
271+
if model.config.model_type == "bloom":
272+
model.transformer._prepare_attn_mask = _prepare_attn_mask
273+
274+
if model.config.model_type == "llama":
275+
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
276+
269277
traced_model = jit_trace(model, task, use_cache)
270278
save_dir = TemporaryDirectory()
271279
save_dir_path = Path(save_dir.name)

‎optimum/intel/openvino/modeling_decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from optimum.utils import NormalizedConfigManager
3131

3232
from ..utils.import_utils import is_transformers_version
33+
from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask
3334
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
34-
from .modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask
3535
from .utils import ONNX_WEIGHTS_NAME
3636

3737

File renamed without changes.

‎tests/generation/test_modeling.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def test_compare_to_transformers(self, model_arch):
6666
with torch.no_grad():
6767
trfs_outputs = trfs_model(**tokens)
6868
# Compare outputs with original transformers model
69-
atol = 1e-1 if model_arch == "bloom" else 1e-4
70-
self.assertTrue(torch.allclose(outputs.logits, trfs_outputs.logits, atol=atol))
69+
self.assertTrue(torch.allclose(outputs.logits, trfs_outputs.logits, atol=1e-4))
7170
# Compare outputs with loaded model
7271
with tempfile.TemporaryDirectory() as tmpdirname:
7372
model.save_pretrained(tmpdirname)

0 commit comments

Comments
 (0)
Please sign in to comment.