|
8 | 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 | 9 | # See the License for the specific language governing permissions and
|
10 | 10 | # limitations under the License.
|
| 11 | +from pathlib import Path |
| 12 | + |
11 | 13 | import torch
|
12 | 14 |
|
| 15 | +from examples.common.sample_config import SampleConfig |
| 16 | +from examples.torch.common.example_logger import logger |
13 | 17 | from nncf.api.compression import CompressionAlgorithmController
|
14 | 18 | from nncf.torch.exporter import count_tensors
|
15 | 19 | from nncf.torch.exporter import generate_input_names_list
|
16 | 20 | from nncf.torch.exporter import get_export_args
|
17 | 21 |
|
18 | 22 |
|
19 |
| -def export_model(ctrl: CompressionAlgorithmController, save_path: str, no_strip_on_export: bool) -> None: |
| 23 | +def export_model(ctrl: CompressionAlgorithmController, config: SampleConfig) -> None: |
20 | 24 | """
|
21 |
| - Export compressed model. Supported only 'onnx' format. |
| 25 | + Export compressed model ot OpenVINO format. |
22 | 26 |
|
23 | 27 | :param controller: The compression controller.
|
24 |
| - :param save_path: Path to save onnx file. |
25 |
| - :param no_strip_on_export: Set to skip strip model before export. |
| 28 | + :param config: The sample config. |
26 | 29 | """
|
27 |
| - |
28 |
| - model = ctrl.model if no_strip_on_export else ctrl.strip() |
29 |
| - |
| 30 | + model = ctrl.model if config.no_strip_on_export else ctrl.strip() |
30 | 31 | model = model.eval().cpu()
|
31 | 32 |
|
32 | 33 | export_args = get_export_args(model, device="cpu")
|
33 | 34 | input_names = generate_input_names_list(count_tensors(export_args))
|
34 | 35 |
|
35 |
| - with torch.no_grad(): |
36 |
| - torch.onnx.export(model, export_args, save_path, input_names=input_names) |
| 36 | + input_tensor_list = [] |
| 37 | + input_shape_list = [] |
| 38 | + for info in model.nncf.input_infos.elements: |
| 39 | + input_shape = tuple([1] + info.shape[1:]) |
| 40 | + input_tensor_list.append(torch.rand(input_shape)) |
| 41 | + input_shape_list.append(input_shape) |
| 42 | + |
| 43 | + if len(input_tensor_list) == 1: |
| 44 | + input_tensor_list = input_tensor_list[0] |
| 45 | + input_shape_list = input_shape_list[0] |
| 46 | + |
| 47 | + model_path = Path(config.export_model_path) |
| 48 | + model_path.parent.mkdir(exist_ok=True, parents=True) |
| 49 | + extension = model_path.suffix |
| 50 | + |
| 51 | + if extension == ".onnx": |
| 52 | + with torch.no_grad(): |
| 53 | + torch.onnx.export(model, input_tensor_list, model_path, input_names=input_names) |
| 54 | + elif extension == ".xml": |
| 55 | + import openvino as ov |
| 56 | + from openvino.tools.mo import convert_model |
| 57 | + |
| 58 | + if config.export_to_ir_via_onnx: |
| 59 | + model_onnx_path = model_path.with_suffix(".onnx") |
| 60 | + with torch.no_grad(): |
| 61 | + torch.onnx.export(model, input_tensor_list, model_onnx_path, input_names=input_names) |
| 62 | + ov_model = convert_model(model_onnx_path) |
| 63 | + else: |
| 64 | + ov_model = convert_model(model, example_input=input_tensor_list, input_shape=input_shape_list) |
| 65 | + # Rename input nodes |
| 66 | + for input_node, input_name in zip(ov_model.inputs, input_names): |
| 67 | + input_node.node.set_friendly_name(input_name) |
| 68 | + ov.save_model(ov_model, model_path) |
| 69 | + else: |
| 70 | + raise ValueError(f"--export-model-path argument should have suffix `.xml` or `.onnx` but got {extension}") |
| 71 | + logger.info(f"Saved to {model_path}") |
0 commit comments