Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6ec426d

Browse files
committedFeb 25, 2025·
add prepare_jit_inputs
1 parent 6bbd00b commit 6ec426d

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed
 

‎optimum/intel/ipex/modeling_base.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@
3636
GenerationConfig,
3737
GenerationMixin,
3838
PretrainedConfig,
39+
PreTrainedModel,
3940
)
4041
from transformers.dynamic_module_utils import get_class_from_dynamic_module
4142
from transformers.generation.candidate_generator import _crop_past_key_values
4243
from transformers.modeling_outputs import CausalLMOutputWithPast
4344
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
4445

46+
from optimum.exporters import TasksManager
4547
from optimum.modeling_base import OptimizedModel
4648
from optimum.utils import NormalizedConfigManager
4749

@@ -51,8 +53,9 @@
5153
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
5254
_patch_model,
5355
)
54-
from ..generation.modeling import prepare_jit_inputs
56+
from ..utils.constant import _TASK_ALIASES
5557
from ..utils.import_utils import is_ipex_version, is_transformers_version
58+
from ..utils.modeling_utils import recursive_to_device
5659

5760

5861
logger = logging.getLogger(__name__)
@@ -73,6 +76,36 @@ def _is_patched_with_ipex(model, task, use_cache: bool = True):
7376
return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES
7477

7578

79+
def get_float_type(model_dtype: torch.dtype):
80+
if model_dtype == torch.bfloat16:
81+
return "bf16"
82+
elif model_dtype == torch.float16:
83+
return "fp16"
84+
else:
85+
return "fp32"
86+
87+
88+
def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = False):
89+
task = _TASK_ALIASES.get(task, task)
90+
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
91+
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
92+
float_dtype = get_float_type(model.dtype)
93+
if "text-generation" in task:
94+
onnx_config = onnx_config_class(
95+
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
96+
)
97+
else:
98+
onnx_config = onnx_config_class(model.config)
99+
100+
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
101+
102+
return {
103+
key: recursive_to_device(dummy_inputs[key], model.device)
104+
for key in signature.parameters
105+
if dummy_inputs.get(key, None) is not None
106+
}
107+
108+
76109
class IPEXModel(OptimizedModel):
77110
auto_model_class = AutoModel
78111
export_feature = "feature-extraction"

0 commit comments

Comments
 (0)
Please sign in to comment.