diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 17365822e8..409f174a46 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -938,8 +938,6 @@ Tensor {{ embedding_cuda_op }}( Tensor grad_output_mean; if (static_cast(pooling_mode) == PoolingMode::MEAN) { grad_output_mean = at::empty_like(grad_output_reshaped); - {%- if not dense or not vbe %} - #ifdef FBGEMM_GPU_MEMCHECK const auto func_name1 = "grad_mean{{ vdesc }}_kernel"; #endif @@ -965,7 +963,6 @@ Tensor {{ embedding_cuda_op }}( ); C10_CUDA_KERNEL_LAUNCH_CHECK(); - {%- endif %} // if not dense or not vbe grad_output_accessor = MAKE_PTA_WITH_NAME("{{ embedding_cuda_op }}.2", grad_output_mean, grad_t, 2, 64); } diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index 465e0ec4ec..5d66b7730e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -296,7 +296,7 @@ std::string tensor_on_same_gpu_if_not_optional_check( inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards( const at::Tensor& grad_output) { - auto aligned_grad_output = grad_output; + auto aligned_grad_output = at::empty_like(grad_output).copy_(grad_output); // FIXME: to support aligned memory access in Vec4T load/store function // 16 for FP32 and 8 for FP16 if (grad_output.dim() > 1 && @@ -304,9 +304,6 @@ inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards( grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0)) { aligned_grad_output = grad_output.contiguous(); } - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { - aligned_grad_output = at::empty_like(grad_output).copy_(grad_output); - } return aligned_grad_output; } diff --git a/fbgemm_gpu/test/tbe/training/backward_dense_test.py b/fbgemm_gpu/test/tbe/training/backward_dense_test.py index 07320fe0d7..4dd6f7c19d 100644 --- a/fbgemm_gpu/test/tbe/training/backward_dense_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_dense_test.py @@ -9,7 +9,6 @@ # pyre-ignore-all-errors[56] -import os import unittest import hypothesis.strategies as st @@ -37,19 +36,12 @@ if open_source: # pyre-ignore[21] - from test_utils import ( - additional_decorators, - gradcheck, - optests, - skipIfRocm, - use_cpu_strategy, - ) + from test_utils import additional_decorators, gradcheck, optests, use_cpu_strategy else: from fbgemm_gpu.test.test_utils import ( additional_decorators, gradcheck, optests, - skipIfRocm, use_cpu_strategy, ) @@ -59,11 +51,6 @@ @optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators) class BackwardDenseTest(unittest.TestCase): - @unittest.skipIf( - os.getenv("GITHUB_ENV") is not None, - "This test is currently running into illegal memmory access issues in OSS, and is being investigated; please see https://github.com/pytorch/pytorch/issues/141904.", - ) - @skipIfRocm("Currently runs into memory access issues") @given( T=st.integers(min_value=1, max_value=3), D=st.integers(min_value=2, max_value=128), @@ -330,17 +317,26 @@ def test_backward_dense( # noqa C901 ) y.sum().backward() indice_weight_grad_mask = per_sample_weights.grad.clone().cpu() + if not use_cpu: + torch.cuda.synchronize() + + acc_B = 0 for t in range(T_): B = Bs[t] + table_indice_weight_grad_mask = indice_weight_grad_mask[ + acc_B : acc_B + B * L + ] + table_indice_weight_grad_all = indice_weight_grad_all[acc_B : acc_B + B * L] + acc_B += B * L if feature_requires_grad[t]: torch.testing.assert_close( - indice_weight_grad_mask.view(T_, B, L)[t], - indice_weight_grad_all.view(T_, B, L)[t], + table_indice_weight_grad_mask, + table_indice_weight_grad_all, ) else: torch.testing.assert_close( - indice_weight_grad_mask.view(T_, B, L)[t], - torch.zeros_like(indice_weight_grad_mask.view(T_, B, L)[t]), + table_indice_weight_grad_mask, + torch.zeros_like(table_indice_weight_grad_mask), ) per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)