36
36
GenerationConfig ,
37
37
GenerationMixin ,
38
38
PretrainedConfig ,
39
+ PreTrainedModel ,
39
40
)
40
41
from transformers .dynamic_module_utils import get_class_from_dynamic_module
41
42
from transformers .generation .candidate_generator import _crop_past_key_values
42
43
from transformers .modeling_outputs import CausalLMOutputWithPast
43
44
from transformers .models .auto .auto_factory import _get_model_class as get_model_class
44
45
46
+ from optimum .exporters import TasksManager
45
47
from optimum .modeling_base import OptimizedModel
46
48
from optimum .utils import NormalizedConfigManager
47
49
51
53
_IPEX_MINIMUM_VERSION_FOR_PATCHING ,
52
54
_patch_model ,
53
55
)
54
- from ..generation . modeling import prepare_jit_inputs
56
+ from ..utils . constant import _TASK_ALIASES
55
57
from ..utils .import_utils import is_ipex_version , is_transformers_version
58
+ from ..utils .modeling_utils import recursive_to_device
56
59
57
60
58
61
logger = logging .getLogger (__name__ )
@@ -73,6 +76,36 @@ def _is_patched_with_ipex(model, task, use_cache: bool = True):
73
76
return model .config .model_type in _IPEX_SUPPORT_MODEL_TYPES
74
77
75
78
79
+ def get_float_type (model_dtype : torch .dtype ):
80
+ if model_dtype == torch .bfloat16 :
81
+ return "bf16"
82
+ elif model_dtype == torch .float16 :
83
+ return "fp16"
84
+ else :
85
+ return "fp32"
86
+
87
+
88
+ def prepare_jit_inputs (model : PreTrainedModel , task : str , use_cache : bool = False ):
89
+ task = _TASK_ALIASES .get (task , task )
90
+ signature = inspect .signature (model .forward ) if hasattr (model , "forward" ) else inspect .signature (model .__call__ )
91
+ onnx_config_class = TasksManager .get_exporter_config_constructor (model = model , exporter = "onnx" , task = task )
92
+ float_dtype = get_float_type (model .dtype )
93
+ if "text-generation" in task :
94
+ onnx_config = onnx_config_class (
95
+ model .config , use_past = use_cache , use_past_in_inputs = use_cache , float_dtype = float_dtype
96
+ )
97
+ else :
98
+ onnx_config = onnx_config_class (model .config )
99
+
100
+ dummy_inputs = onnx_config .generate_dummy_inputs (framework = "pt" )
101
+
102
+ return {
103
+ key : recursive_to_device (dummy_inputs [key ], model .device )
104
+ for key in signature .parameters
105
+ if dummy_inputs .get (key , None ) is not None
106
+ }
107
+
108
+
76
109
class IPEXModel (OptimizedModel ):
77
110
auto_model_class = AutoModel
78
111
export_feature = "feature-extraction"
0 commit comments