|
1 | 1 | # Owner(s): ["module: intel"]
|
2 | 2 |
|
3 | 3 |
|
4 |
| -from torch.testing._internal.common_device_type import instantiate_device_type_tests |
5 |
| -from torch.testing._internal.common_utils import run_tests |
| 4 | +import torch |
| 5 | +import torch._prims as prims |
| 6 | +import torch.utils._pytree as pytree |
| 7 | +import warnings |
| 8 | +from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyXPU, OpDTypes, ops |
| 9 | +from torch.testing._internal.common_utils import run_tests, slowTest, suppress_warnings |
6 | 10 |
|
7 | 11 | try:
|
8 | 12 | from xpu_test_utils import XPUPatchForImport
|
9 | 13 | except Exception as e:
|
10 | 14 | from .xpu_test_utils import XPUPatchForImport
|
11 | 15 | with XPUPatchForImport(False):
|
12 |
| - from test_ops import TestCommon, TestMathBits |
| 16 | + from test_ops import ( |
| 17 | + _ops_and_refs_with_no_numpy_ref, |
| 18 | + TestCommon, |
| 19 | + TestMathBits, |
| 20 | + ) |
| 21 | + |
| 22 | + # Tests that the cpu and gpu results are consistent |
| 23 | + # We add the logs for the test results to help debug, will remove them after the test is stable |
| 24 | + @onlyXPU |
| 25 | + @suppress_warnings |
| 26 | + @slowTest |
| 27 | + @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) |
| 28 | + def _compare_cpu(self, device, dtype, op): |
| 29 | + def to_cpu(arg): |
| 30 | + if isinstance(arg, torch.Tensor): |
| 31 | + return arg.to(device="cpu") |
| 32 | + return arg |
| 33 | + samples = op.reference_inputs(device, dtype) |
| 34 | + for sample in samples: |
| 35 | + cpu_sample = sample.transform(to_cpu) |
| 36 | + cuda_results = op(sample.input, *sample.args, **sample.kwargs) |
| 37 | + cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) |
| 38 | + # output_process_fn_grad has a very unfortunate name |
| 39 | + # We use this function in linalg extensively to postprocess the inputs of functions |
| 40 | + # that are not completely well-defined. Think svd and muliplying the singular vectors by -1. |
| 41 | + # CPU and CUDA implementations of the SVD can return valid SVDs that are different. |
| 42 | + # We use this function to compare them. |
| 43 | + cuda_results = sample.output_process_fn_grad(cuda_results) |
| 44 | + cpu_results = cpu_sample.output_process_fn_grad(cpu_results) |
| 45 | + # Lower tolerance because we are running this as a `@slowTest` |
| 46 | + # Don't want the periodic tests to fail frequently |
| 47 | + try: |
| 48 | + self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3) |
| 49 | + except AssertionError as e: |
| 50 | + raise AssertionError(f"Failed with {sample.input}, {e} \ |
| 51 | + \nthe results are {cuda_results} \nthe expect results are {cpu_results}.") |
| 52 | + |
| 53 | + # We add the logs for the test results to help debug, will remove them after the test is stable |
| 54 | + def _ref_test_helper( |
| 55 | + self, |
| 56 | + ctx, |
| 57 | + device, |
| 58 | + dtype, |
| 59 | + op, |
| 60 | + skip_zero_numel=False, |
| 61 | + skip_zero_dim=False, |
| 62 | + skip_bfloat=False, |
| 63 | + skip_view_consistency=False, |
| 64 | + ): |
| 65 | + # NOTE: this test works by comparing the reference |
| 66 | + ex = None |
| 67 | + for sample in op.reference_inputs(device, dtype, requires_grad=False): |
| 68 | + if ( |
| 69 | + isinstance(sample.input, torch.Tensor) |
| 70 | + and sample.input.numel() == 0 |
| 71 | + and skip_zero_numel |
| 72 | + ): |
| 73 | + continue |
| 74 | + if ( |
| 75 | + isinstance(sample.input, torch.Tensor) |
| 76 | + and sample.input.ndim == 0 |
| 77 | + and skip_zero_dim |
| 78 | + ): |
| 79 | + continue |
| 80 | + |
| 81 | + if skip_bfloat and ( |
| 82 | + ( |
| 83 | + isinstance(sample.input, torch.Tensor) |
| 84 | + and sample.input.dtype == torch.bfloat16 |
| 85 | + ) |
| 86 | + or any( |
| 87 | + isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16 |
| 88 | + for arg in sample.args |
| 89 | + ) |
| 90 | + ): |
| 91 | + continue |
| 92 | + with ctx(): |
| 93 | + ref_result = op(sample.input, *sample.args, **sample.kwargs) |
| 94 | + torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs) |
| 95 | + |
| 96 | + for a, b in zip( |
| 97 | + pytree.tree_leaves(ref_result), pytree.tree_leaves(torch_result) |
| 98 | + ): |
| 99 | + if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): |
| 100 | + prims.utils.compare_tensor_meta(a, b) |
| 101 | + if ( |
| 102 | + getattr(op, "validate_view_consistency", True) |
| 103 | + and not skip_view_consistency |
| 104 | + ): |
| 105 | + msg = ( |
| 106 | + f"The torch implementation {'returns' if b._is_view() else 'does not return'} " |
| 107 | + f"a view, while the reference {'does' if a._is_view() else 'does not'}" |
| 108 | + ) |
| 109 | + try: |
| 110 | + self.assertEqual(a._is_view(), b._is_view(), msg) |
| 111 | + except AssertionError as e: |
| 112 | + raise AssertionError(f"Failed with {sample.input}, {e} \ |
| 113 | + \nthe results are {b} \nthe expect results are {a}.") |
| 114 | + |
| 115 | + # Computes the dtype the more precise computatino would occur in |
| 116 | + precise_dtype = torch.bool |
| 117 | + if prims.utils.is_integer_dtype(dtype): |
| 118 | + # Note: bool and integer dtypes do not have more |
| 119 | + # precise dtypes -- they simply must be close |
| 120 | + precise_dtype = dtype |
| 121 | + if prims.utils.is_float_dtype(dtype): |
| 122 | + precise_dtype = torch.double |
| 123 | + if prims.utils.is_complex_dtype(dtype): |
| 124 | + precise_dtype = torch.cdouble |
| 125 | + |
| 126 | + # Checks if the results are close |
| 127 | + try: |
| 128 | + self.assertEqual( |
| 129 | + ref_result, |
| 130 | + torch_result, |
| 131 | + exact_stride=False, |
| 132 | + exact_device=True, |
| 133 | + exact_layout=True, |
| 134 | + exact_is_coalesced=True, |
| 135 | + ) |
| 136 | + except AssertionError as e: |
| 137 | + # Raises the error if the precise dtype comparison wouldn't be |
| 138 | + # different |
| 139 | + if dtype is precise_dtype: |
| 140 | + raise AssertionError(f"Failed with {sample.input}, {e} \ |
| 141 | + \nthe results are {torch_result} \nthe expect results are {ref_result}.") |
| 142 | + |
| 143 | + ex = e |
| 144 | + |
| 145 | + # Goes to next sample if these results are close |
| 146 | + if not ex: |
| 147 | + continue |
| 148 | + |
| 149 | + # If the results are not close, checks that the |
| 150 | + # reference is more accurate than the torch op |
| 151 | + def _make_precise(x): |
| 152 | + if isinstance(x, torch.dtype): |
| 153 | + return precise_dtype |
| 154 | + if isinstance(x, torch.Tensor) and x.dtype is dtype: |
| 155 | + return x.to(precise_dtype) |
| 156 | + return x |
| 157 | + |
| 158 | + precise_sample = sample.transform(_make_precise) |
| 159 | + precise_result = op.torch_opinfo( |
| 160 | + precise_sample.input, *precise_sample.args, **precise_sample.kwargs |
| 161 | + ) |
| 162 | + |
| 163 | + def _distance(a, b): |
| 164 | + # Special-cases boolean comparisons |
| 165 | + if prims.utils.is_boolean_dtype(a.dtype): |
| 166 | + assert b.dtype is torch.bool |
| 167 | + return (a ^ b).sum() |
| 168 | + |
| 169 | + same = a == b |
| 170 | + if prims.utils.is_float_dtype(a.dtype) or prims.utils.is_complex_dtype( |
| 171 | + a.dtype |
| 172 | + ): |
| 173 | + same = torch.logical_or( |
| 174 | + same, torch.logical_and(torch.isnan(a), torch.isnan(b)) |
| 175 | + ) |
| 176 | + |
| 177 | + actual_error = torch.where(same, 0, torch.abs(a - b)).sum() |
| 178 | + return actual_error |
| 179 | + |
| 180 | + ref_distance = 0 |
| 181 | + for a, b in zip( |
| 182 | + pytree.tree_leaves(ref_result), pytree.tree_leaves(precise_result) |
| 183 | + ): |
| 184 | + ref_distance = ref_distance + _distance(a, b) |
| 185 | + |
| 186 | + torch_distance = 0 |
| 187 | + for a, b in zip( |
| 188 | + pytree.tree_leaves(torch_result), pytree.tree_leaves(precise_result) |
| 189 | + ): |
| 190 | + torch_distance = torch_distance + _distance(a, b) |
| 191 | + |
| 192 | + # TODO: consider adding some tolerance to this comparison |
| 193 | + msg = ( |
| 194 | + f"Reference result was farther ({ref_distance}) from the precise " |
| 195 | + f"computation than the torch result was ({torch_distance})!" |
| 196 | + ) |
| 197 | + try: |
| 198 | + self.assertTrue(ref_distance <= torch_distance, msg=msg) |
| 199 | + except AssertionError as e: |
| 200 | + raise AssertionError(f"Failed with {sample.input}, {e} \ |
| 201 | + \nthe results are {torch_result} \nthe expect results are {precise_result}.") |
| 202 | + |
| 203 | + # Reports numerical accuracy discrepancies |
| 204 | + if ex is not None: |
| 205 | + msg = "Test passed because the reference was more accurate than the torch operator." |
| 206 | + warnings.warn(msg) |
| 207 | + |
| 208 | + TestCommon.test_compare_cpu = _compare_cpu |
| 209 | + TestCommon._ref_test_helper = _ref_test_helper |
| 210 | + |
13 | 211 | instantiate_device_type_tests(TestCommon, globals(), only_for="xpu", allow_xpu=True)
|
14 | 212 | instantiate_device_type_tests(TestMathBits, globals(), only_for="xpu", allow_xpu=True)
|
15 | 213 | # in finegrand
|
|
0 commit comments