Skip to content

Commit b6f2e75

Browse files
authored
[Common]Fix Test for SDPA and Concat Unified Scales Test (#3207)
### Changes Changes the model from having the same input tensor to the QKV to having different input for each ### Reason for changes The former model was causing an error with openvino
1 parent eab6b46 commit b6f2e75

File tree

4 files changed

+26
-16
lines changed

4 files changed

+26
-16
lines changed

tests/openvino/native/test_unified_scales.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@
1717

1818
class TestUnifiedScales(TemplateTestUnifiedScales):
1919
def get_backend_specific_model(self, model: torch.nn.Module) -> ov.Model:
20-
input_shape = model.INPUT_SHAPE
20+
q_input_shape = model.Q_INPUT_SHAPE
21+
kv_input_shape = model.KV_INPUT_SHAPE
22+
2123
backend_model = ov.convert_model(
2224
model,
2325
example_input=(
24-
torch.randn(input_shape),
25-
torch.randn(input_shape),
26+
torch.ones(q_input_shape),
27+
torch.ones(q_input_shape),
28+
torch.ones(kv_input_shape),
29+
torch.ones(kv_input_shape),
2630
),
2731
)
2832

tests/torch/fx/test_unified_scales.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@
1111

1212
import torch
1313

14-
from nncf.torch.nncf_network import NNCFNetwork
1514
from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales
1615
from tests.torch.fx.helpers import get_torch_fx_model_q_transformed
1716

1817

1918
class TestUnifiedScales(TemplateTestUnifiedScales):
20-
def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork:
21-
input_shape = model.INPUT_SHAPE
19+
def get_backend_specific_model(self, model: torch.nn.Module) -> torch.fx.GraphModule:
20+
q_input_shape = model.Q_INPUT_SHAPE
21+
kv_input_shape = model.KV_INPUT_SHAPE
2222
backend_model = get_torch_fx_model_q_transformed(
2323
model,
2424
(
25-
torch.randn(input_shape),
26-
torch.randn(input_shape),
25+
torch.ones(q_input_shape),
26+
torch.ones(q_input_shape),
27+
torch.ones(kv_input_shape),
28+
torch.ones(kv_input_shape),
2729
),
2830
)
2931

tests/torch/quantization/test_unified_scales.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -718,12 +718,15 @@ def test_unified_scales_with_shared_nodes():
718718

719719
class TestUnifiedScales(TemplateTestUnifiedScales):
720720
def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork:
721-
input_shape = model.INPUT_SHAPE
721+
q_input_shape = model.Q_INPUT_SHAPE
722+
kv_input_shape = model.KV_INPUT_SHAPE
722723
backend_model = wrap_model(
723724
model,
724725
(
725-
torch.randn(input_shape),
726-
torch.randn(input_shape),
726+
torch.ones(q_input_shape),
727+
torch.ones(q_input_shape),
728+
torch.ones(kv_input_shape),
729+
torch.ones(kv_input_shape),
727730
),
728731
trace_parameters=True,
729732
)

tests/torch/test_models/synthetic.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -665,16 +665,17 @@ def forward(self, x):
665665

666666

667667
class ConcatSDPABlock(torch.nn.Module):
668-
INPUT_SHAPE = (2, 10, 6)
668+
Q_INPUT_SHAPE = [2, 10, 6]
669+
KV_INPUT_SHAPE = [2, 10, 12]
669670

670671
def __init__(self):
671672
super().__init__()
672673

673-
def forward(self, x, y):
674+
def forward(self, x, y, z, w):
674675
concatenated_input = torch.cat((x, y), dim=-1)
675676
query = concatenated_input
676-
key = concatenated_input
677-
value = concatenated_input
678-
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, dropout_p=0.2)
677+
key = z
678+
value = w
679+
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
679680

680681
return attn_output

0 commit comments

Comments
 (0)