@@ -129,35 +129,13 @@ def __init__(
129
129
** kwargs ,
130
130
):
131
131
OptimizedModel .__init__ (self , model = model , config = config )
132
- device_map = kwargs .pop ("device_map" , None )
133
132
if device_map is None :
134
133
if is_torch_xpu_available (check_device = True ):
135
134
self ._device = torch .device ("xpu:0" )
136
135
elif torch .cuda .is_available ():
137
136
self ._device = torch .device ("cuda:0" )
138
137
else :
139
138
self ._device = torch .device ("cpu" )
140
- else :
141
- if isinstance (device_map , torch .device ):
142
- self ._device = device_map
143
- elif isinstance (device_map , str ):
144
- if device_map in ["auto" , "balanced" , "balanced_low_0" , "sequential" ]:
145
- raise ValueError (
146
- "When passing device_map as a string, the value needs to be a device name (e.g. cpu, xpu:0). "
147
- f"'auto', 'balanced', 'balanced_low_0', 'sequential' are not supported."
148
- )
149
- self ._device = torch .device (device_map )
150
- elif isinstance (device_map , int ):
151
- if is_torch_xpu_available (check_device = True ):
152
- self ._device = torch .device (f"xpu:{ device_map } " )
153
- elif torch .cuda .is_available ():
154
- self ._device = torch .device (f"cuda:{ device_map } " )
155
- else :
156
- self ._device = torch .device ("cpu" )
157
- else :
158
- raise ValueError (
159
- f"device_map should be either be a string, an integer or a torch.device object, but found { type (device_map )} "
160
- )
161
139
self .model .to (self ._device )
162
140
self ._dtype = self .config .torch_dtype if self .config .torch_dtype is not None else torch .float32
163
141
self .model_save_dir = model_save_dir
0 commit comments