|
26 | 26 | from optimum.exporters import TasksManager
|
27 | 27 | from optimum.exporters.onnx.base import OnnxConfig
|
28 | 28 | from optimum.intel.utils import is_transformers_version
|
| 29 | +from optimum.intel.utils.import_utils import is_safetensors_available |
29 | 30 | from optimum.utils import is_diffusers_available
|
30 | 31 | from optimum.utils.save_utils import maybe_save_preprocessors
|
31 | 32 |
|
@@ -232,6 +233,41 @@ def save_config(config, save_dir):
|
232 | 233 | config.to_json_file(output_config_file, use_diff=True)
|
233 | 234 |
|
234 | 235 |
|
| 236 | +def deduce_diffusers_dtype(model_name_or_path, **loading_kwargs): |
| 237 | + dtype = None |
| 238 | + if is_safetensors_available(): |
| 239 | + if Path(model_name_or_path).is_dir(): |
| 240 | + path = Path(model_name_or_path) |
| 241 | + else: |
| 242 | + from diffusers import DiffusionPipeline |
| 243 | + |
| 244 | + path = DiffusionPipeline.download(model_name_or_path, **loading_kwargs) |
| 245 | + model_part_name = None |
| 246 | + if (path / "transformer").is_dir(): |
| 247 | + model_part_name = "transformer" |
| 248 | + elif (path / "unet").is_dir(): |
| 249 | + model_part_name = "unet" |
| 250 | + if model_part_name: |
| 251 | + directory = path / model_part_name |
| 252 | + safetensors_files = [ |
| 253 | + filename for filename in directory.glob("*.safetensors") if len(filename.suffixes) == 1 |
| 254 | + ] |
| 255 | + safetensors_file = None |
| 256 | + if len(safetensors_files) > 0: |
| 257 | + safetensors_file = safetensors_files.pop(0) |
| 258 | + if safetensors_file: |
| 259 | + from safetensors import safe_open |
| 260 | + |
| 261 | + with safe_open(safetensors_file, framework="pt", device="cpu") as f: |
| 262 | + if len(f.keys()) > 0: |
| 263 | + for key in f.keys(): |
| 264 | + tensor = f.get_tensor(key) |
| 265 | + if tensor.dtype.is_floating_point: |
| 266 | + dtype = tensor.dtype |
| 267 | + break |
| 268 | + return dtype |
| 269 | + |
| 270 | + |
235 | 271 | def save_preprocessors(
|
236 | 272 | preprocessors: List, config: PretrainedConfig, output: Union[str, Path], trust_remote_code: bool
|
237 | 273 | ):
|
|
0 commit comments