|
15 | 15 | import inspect
|
16 | 16 | import logging
|
17 | 17 | import os
|
| 18 | +from collections import deque |
| 19 | +from copy import deepcopy |
| 20 | +from datasets import load_dataset |
18 | 21 | from pathlib import Path
|
19 | 22 | from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
20 | 23 |
|
|
23 | 26 | import torch
|
24 | 27 | import transformers
|
25 | 28 | from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric
|
| 29 | +from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters |
26 | 30 | from nncf.torch import create_compressed_model, register_default_init_args, register_module
|
27 | 31 | from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
|
28 | 32 | from nncf.torch.initialization import PTInitializingDataLoader
|
@@ -577,4 +581,104 @@ def _weight_only_quantization(
|
577 | 581 | # awq=config.quant_method == "awq", # TODO : remove and add it back once nncf v2.9.0
|
578 | 582 | ignored_scope=ignored_scope,
|
579 | 583 | dataset=dataset,
|
| 584 | + subset_size=config.subset_size, |
580 | 585 | )
|
| 586 | + |
| 587 | + |
| 588 | +def _get_operation_const_op(operation, const_port_id: int): |
| 589 | + node = operation.input_value(const_port_id).get_node() |
| 590 | + queue = deque([node]) |
| 591 | + constant_node = None |
| 592 | + allowed_propagation_types_list = ["Convert", "FakeQuantize", "Reshape"] |
| 593 | + |
| 594 | + while len(queue) != 0: |
| 595 | + curr_node = queue.popleft() |
| 596 | + if curr_node.get_type_name() == "Constant": |
| 597 | + constant_node = curr_node |
| 598 | + break |
| 599 | + if len(curr_node.inputs()) == 0: |
| 600 | + break |
| 601 | + if curr_node.get_type_name() in allowed_propagation_types_list: |
| 602 | + queue.append(curr_node.input_value(0).get_node()) |
| 603 | + |
| 604 | + return constant_node |
| 605 | + |
| 606 | + |
| 607 | +def _is_embedding(node) -> bool: |
| 608 | + allowed_types_list = ["f16", "f32", "f64"] |
| 609 | + const_port_id = 0 |
| 610 | + input_tensor = node.input_value(const_port_id) |
| 611 | + if input_tensor.get_element_type().get_type_name() in allowed_types_list: |
| 612 | + const_node = _get_operation_const_op(node, const_port_id) |
| 613 | + if const_node is not None: |
| 614 | + return True |
| 615 | + |
| 616 | + return False |
| 617 | + |
| 618 | + |
| 619 | +def _collect_ops_with_weights(model): |
| 620 | + ops_with_weights = [] |
| 621 | + for op in model.get_ops(): |
| 622 | + if op.get_type_name() == "MatMul": |
| 623 | + constant_node_0 = _get_operation_const_op(op, const_port_id=0) |
| 624 | + constant_node_1 = _get_operation_const_op(op, const_port_id=1) |
| 625 | + if constant_node_0 or constant_node_1: |
| 626 | + ops_with_weights.append(op.get_friendly_name()) |
| 627 | + if op.get_type_name() == "Gather" and _is_embedding(op): |
| 628 | + ops_with_weights.append(op.get_friendly_name()) |
| 629 | + |
| 630 | + return ops_with_weights |
| 631 | + |
| 632 | + |
| 633 | +def get_stable_diffusion_dataset( |
| 634 | + dataset_name: str, nsamples: int = 50, seed: int = 0, text_column: str = "caption" |
| 635 | +) -> nncf.Dataset: |
| 636 | + if dataset_name not in [ |
| 637 | + "conceptual_captions", |
| 638 | + "laion/220k-GPT4Vision-captions-from-LIVIS", |
| 639 | + "laion/filtered-wit" |
| 640 | + ]: |
| 641 | + raise ValueError( |
| 642 | + f"""You have entered a string value for dataset. You can only choose between |
| 643 | + ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'], |
| 644 | + but we found {dataset_name}""" |
| 645 | + ) |
| 646 | + |
| 647 | + data = load_dataset(dataset_name, split="train", streaming=True).shuffle(seed=seed).take(nsamples) |
| 648 | + dataset = [batch[text_column] for batch in data] |
| 649 | + return nncf.Dataset(dataset) |
| 650 | + |
| 651 | + |
| 652 | +def _hybrid_quantization( |
| 653 | + model: openvino.runtime.Model, quantization_config: Union[OVWeightQuantizationConfig, Dict] |
| 654 | +): |
| 655 | + dataset = quantization_config.dataset |
| 656 | + wc_ignored_scope = deepcopy(quantization_config.ignored_scope) |
| 657 | + |
| 658 | + if isinstance(wc_ignored_scope, dict): |
| 659 | + wc_ignored_scope["types"] = wc_ignored_scope.get("types", []) + ["Convolution"] |
| 660 | + else: |
| 661 | + assert wc_ignored_scope is None |
| 662 | + wc_ignored_scope = {"types": ["Convolution"]} |
| 663 | + |
| 664 | + ops_to_compress = _collect_ops_with_weights(model) |
| 665 | + ptq_ignored_scope = deepcopy(quantization_config.ignored_scope) |
| 666 | + if isinstance(ptq_ignored_scope, dict): |
| 667 | + ptq_ignored_scope["names"] = ptq_ignored_scope.get("names", []) + ops_to_compress |
| 668 | + else: |
| 669 | + assert ptq_ignored_scope is None |
| 670 | + ptq_ignored_scope = {"names": ops_to_compress} |
| 671 | + |
| 672 | + quantization_config.dataset = None # Apply Weight Compression without dataset |
| 673 | + quantization_config.ignored_scope = wc_ignored_scope |
| 674 | + compressed_model = _weight_only_quantization(model, quantization_config) |
| 675 | + |
| 676 | + quantized_model = nncf.quantize( |
| 677 | + compressed_model, |
| 678 | + dataset, |
| 679 | + model_type=nncf.ModelType.TRANSFORMER, |
| 680 | + ignored_scope=nncf.IgnoredScope(**ptq_ignored_scope), |
| 681 | + advanced_parameters=nncf.AdvancedQuantizationParameters(AdvancedSmoothQuantParameters(matmul=-1)), |
| 682 | + subset_size=quantization_config.subset_size, |
| 683 | + ) |
| 684 | + return quantized_model |
0 commit comments