Skip to content

Commit 9612264

Browse files
committed
Add log for UT random issues to debug
Signed-off-by: Cheng, Penghui <penghui.cheng@intel.com>
1 parent 04e5a38 commit 9612264

File tree

1 file changed

+201
-3
lines changed

1 file changed

+201
-3
lines changed

test/xpu/test_ops_xpu.py

+201-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,213 @@
11
# Owner(s): ["module: intel"]
22

33

4-
from torch.testing._internal.common_device_type import instantiate_device_type_tests
5-
from torch.testing._internal.common_utils import run_tests
4+
import torch
5+
import torch._prims as prims
6+
import torch.utils._pytree as pytree
7+
import warnings
8+
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyXPU, OpDTypes, ops
9+
from torch.testing._internal.common_utils import run_tests, slowTest, suppress_warnings
610

711
try:
812
from xpu_test_utils import XPUPatchForImport
913
except Exception as e:
1014
from .xpu_test_utils import XPUPatchForImport
1115
with XPUPatchForImport(False):
12-
from test_ops import TestCommon, TestMathBits
16+
from test_ops import (
17+
_ops_and_refs_with_no_numpy_ref,
18+
TestCommon,
19+
TestMathBits,
20+
)
21+
22+
# Tests that the cpu and gpu results are consistent
23+
# We add the logs for the test results to help debug, will remove them after the test is stable
24+
@onlyXPU
25+
@suppress_warnings
26+
@slowTest
27+
@ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
28+
def _compare_cpu(self, device, dtype, op):
29+
def to_cpu(arg):
30+
if isinstance(arg, torch.Tensor):
31+
return arg.to(device="cpu")
32+
return arg
33+
samples = op.reference_inputs(device, dtype)
34+
for sample in samples:
35+
cpu_sample = sample.transform(to_cpu)
36+
cuda_results = op(sample.input, *sample.args, **sample.kwargs)
37+
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
38+
# output_process_fn_grad has a very unfortunate name
39+
# We use this function in linalg extensively to postprocess the inputs of functions
40+
# that are not completely well-defined. Think svd and muliplying the singular vectors by -1.
41+
# CPU and CUDA implementations of the SVD can return valid SVDs that are different.
42+
# We use this function to compare them.
43+
cuda_results = sample.output_process_fn_grad(cuda_results)
44+
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
45+
# Lower tolerance because we are running this as a `@slowTest`
46+
# Don't want the periodic tests to fail frequently
47+
try:
48+
self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3)
49+
except AssertionError as e:
50+
raise AssertionError(f"Failed with {sample.input}, {e} \
51+
\nthe results are {cuda_results} \nthe expect results are {cpu_results}.")
52+
53+
# We add the logs for the test results to help debug, will remove them after the test is stable
54+
def _ref_test_helper(
55+
self,
56+
ctx,
57+
device,
58+
dtype,
59+
op,
60+
skip_zero_numel=False,
61+
skip_zero_dim=False,
62+
skip_bfloat=False,
63+
skip_view_consistency=False,
64+
):
65+
# NOTE: this test works by comparing the reference
66+
ex = None
67+
for sample in op.reference_inputs(device, dtype, requires_grad=False):
68+
if (
69+
isinstance(sample.input, torch.Tensor)
70+
and sample.input.numel() == 0
71+
and skip_zero_numel
72+
):
73+
continue
74+
if (
75+
isinstance(sample.input, torch.Tensor)
76+
and sample.input.ndim == 0
77+
and skip_zero_dim
78+
):
79+
continue
80+
81+
if skip_bfloat and (
82+
(
83+
isinstance(sample.input, torch.Tensor)
84+
and sample.input.dtype == torch.bfloat16
85+
)
86+
or any(
87+
isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16
88+
for arg in sample.args
89+
)
90+
):
91+
continue
92+
with ctx():
93+
ref_result = op(sample.input, *sample.args, **sample.kwargs)
94+
torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
95+
96+
for a, b in zip(
97+
pytree.tree_leaves(ref_result), pytree.tree_leaves(torch_result)
98+
):
99+
if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
100+
prims.utils.compare_tensor_meta(a, b)
101+
if (
102+
getattr(op, "validate_view_consistency", True)
103+
and not skip_view_consistency
104+
):
105+
msg = (
106+
f"The torch implementation {'returns' if b._is_view() else 'does not return'} "
107+
f"a view, while the reference {'does' if a._is_view() else 'does not'}"
108+
)
109+
try:
110+
self.assertEqual(a._is_view(), b._is_view(), msg)
111+
except AssertionError as e:
112+
raise AssertionError(f"Failed with {sample.input}, {e} \
113+
\nthe results are {b} \nthe expect results are {a}.")
114+
115+
# Computes the dtype the more precise computatino would occur in
116+
precise_dtype = torch.bool
117+
if prims.utils.is_integer_dtype(dtype):
118+
# Note: bool and integer dtypes do not have more
119+
# precise dtypes -- they simply must be close
120+
precise_dtype = dtype
121+
if prims.utils.is_float_dtype(dtype):
122+
precise_dtype = torch.double
123+
if prims.utils.is_complex_dtype(dtype):
124+
precise_dtype = torch.cdouble
125+
126+
# Checks if the results are close
127+
try:
128+
self.assertEqual(
129+
ref_result,
130+
torch_result,
131+
exact_stride=False,
132+
exact_device=True,
133+
exact_layout=True,
134+
exact_is_coalesced=True,
135+
)
136+
except AssertionError as e:
137+
# Raises the error if the precise dtype comparison wouldn't be
138+
# different
139+
if dtype is precise_dtype:
140+
raise AssertionError(f"Failed with {sample.input}, {e} \
141+
\nthe results are {torch_result} \nthe expect results are {ref_result}.")
142+
143+
ex = e
144+
145+
# Goes to next sample if these results are close
146+
if not ex:
147+
continue
148+
149+
# If the results are not close, checks that the
150+
# reference is more accurate than the torch op
151+
def _make_precise(x):
152+
if isinstance(x, torch.dtype):
153+
return precise_dtype
154+
if isinstance(x, torch.Tensor) and x.dtype is dtype:
155+
return x.to(precise_dtype)
156+
return x
157+
158+
precise_sample = sample.transform(_make_precise)
159+
precise_result = op.torch_opinfo(
160+
precise_sample.input, *precise_sample.args, **precise_sample.kwargs
161+
)
162+
163+
def _distance(a, b):
164+
# Special-cases boolean comparisons
165+
if prims.utils.is_boolean_dtype(a.dtype):
166+
assert b.dtype is torch.bool
167+
return (a ^ b).sum()
168+
169+
same = a == b
170+
if prims.utils.is_float_dtype(a.dtype) or prims.utils.is_complex_dtype(
171+
a.dtype
172+
):
173+
same = torch.logical_or(
174+
same, torch.logical_and(torch.isnan(a), torch.isnan(b))
175+
)
176+
177+
actual_error = torch.where(same, 0, torch.abs(a - b)).sum()
178+
return actual_error
179+
180+
ref_distance = 0
181+
for a, b in zip(
182+
pytree.tree_leaves(ref_result), pytree.tree_leaves(precise_result)
183+
):
184+
ref_distance = ref_distance + _distance(a, b)
185+
186+
torch_distance = 0
187+
for a, b in zip(
188+
pytree.tree_leaves(torch_result), pytree.tree_leaves(precise_result)
189+
):
190+
torch_distance = torch_distance + _distance(a, b)
191+
192+
# TODO: consider adding some tolerance to this comparison
193+
msg = (
194+
f"Reference result was farther ({ref_distance}) from the precise "
195+
f"computation than the torch result was ({torch_distance})!"
196+
)
197+
try:
198+
self.assertTrue(ref_distance <= torch_distance, msg=msg)
199+
except AssertionError as e:
200+
raise AssertionError(f"Failed with {sample.input}, {e} \
201+
\nthe results are {torch_result} \nthe expect results are {precise_result}.")
202+
203+
# Reports numerical accuracy discrepancies
204+
if ex is not None:
205+
msg = "Test passed because the reference was more accurate than the torch operator."
206+
warnings.warn(msg)
207+
208+
TestCommon.test_compare_cpu = _compare_cpu
209+
TestCommon._ref_test_helper = _ref_test_helper
210+
13211
instantiate_device_type_tests(TestCommon, globals(), only_for="xpu", allow_xpu=True)
14212
instantiate_device_type_tests(TestMathBits, globals(), only_for="xpu", allow_xpu=True)
15213
# in finegrand

0 commit comments

Comments
 (0)