Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MinMax] Embedding nodes as input nodes for inference graph #3320

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from

Conversation

daniil-lyakhov
Copy link
Collaborator

@daniil-lyakhov daniil-lyakhov commented Feb 28, 2025

Reopen of the #2862

Changes

  • Embedding nodes are used as input nodes for the inference graph (with that embedding nodes are being included to the inference_nncf_graph)
  • inference_nncf_graph is used to identify weighted nodes
  • PT/FX MinMax get_weight_nodes method is updated to work with the inference graph
  • Constant folding is removed from the OpenVINOQuantizer and FX nncf.quantize implementation

Reason for changes

Related tickets

163025

Tests

  • tests/cross_fw/test_templates/test_quantizer_config.py is updated with shape_of /constant embedding model and conv model with constant branches
  • TorchFX reference graphs for VIT and Swin were updated: constant branches are present in the quantized graph but they don't have quantizers inside
  • conformance test post_training_quantization/625/ - passed

@github-actions github-actions bot added NNCF PT Pull requests that updates NNCF PyTorch NNCF OpenVINO Pull requests that updates NNCF OpenVINO NNCF ONNX Pull requests that updates NNCF ONNX NNCF PTQ Pull requests that updates NNCF PTQ experimental labels Feb 28, 2025
@daniil-lyakhov daniil-lyakhov force-pushed the dl/shape_of_sub_emb_fix branch from 9689dca to 0a2f240 Compare February 28, 2025 11:36
@daniil-lyakhov daniil-lyakhov marked this pull request as ready for review February 28, 2025 13:28
@daniil-lyakhov daniil-lyakhov requested a review from a team as a code owner February 28, 2025 13:28
@@ -50,6 +51,7 @@
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.torch.quantization.layers import PTQuantizerSpec
from nncf.torch.quantization.layers import get_scale_shape
from nncf.torch.utils import get_weight_nodes_in_inference_grpah
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from nncf.torch.utils import get_weight_nodes_in_inference_grpah
from nncf.torch.utils import get_weight_nodes_in_inference_graph

@@ -467,3 +470,46 @@ def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
# The model had no parameters at all, assume FP32
dtype = torch.float32
return dtype


def get_weight_nodes_in_inference_grpah(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_weight_nodes_in_inference_grpah(
def get_weight_nodes_in_inference_graph(

@@ -86,7 +86,6 @@ def quantize_impl(
advanced_parameters=advanced_parameters,
)

# To make it easier for bias correction algorithms.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this comment removed?


# Inference graph does not containt constans, so
# any missed input edge means it is a constant branch.
return node.metatype in [om.PTMatMulMetatype, om.PTAddmmMetatype] and len(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use variables to make it more readable

# any missed input edge means it is a constant branch.
return node.metatype in [om.PTMatMulMetatype, om.PTAddmmMetatype] and len(
inference_nncf_graph.get_input_edges(node)
) < len(node.metatype.weight_port_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it work for possible missed inputs, that determinate in get_nodes_with_missed_input_edges?



def get_weight_nodes_in_inference_grpah(
inference_nncf_graph: NNCFGraph, mat_mul_metatypes: List[om.PTOperatorMetatype]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like mat_mul_metatypes expected always same value, mat_mul_metatypes should not be used as argument instead use reusable constant variable

Co-authored-by: Alexander Dokuchaev <alexander.dokuchaev@intel.com>
@github-actions github-actions bot added the API Public API-impacting changes label Mar 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API Public API-impacting changes experimental NNCF ONNX Pull requests that updates NNCF ONNX NNCF OpenVINO Pull requests that updates NNCF OpenVINO NNCF PT Pull requests that updates NNCF PyTorch NNCF PTQ Pull requests that updates NNCF PTQ
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants