Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dense backward test #92

Open
wants to merge 5 commits into
base: abokovoi/upstream
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,6 @@ Tensor {{ embedding_cuda_op }}(
Tensor grad_output_mean;
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN) {
grad_output_mean = at::empty_like(grad_output_reshaped);
{%- if not dense or not vbe %}

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name1 = "grad_mean{{ vdesc }}_kernel";
#endif
Expand All @@ -965,7 +963,6 @@ Tensor {{ embedding_cuda_op }}(
);

C10_CUDA_KERNEL_LAUNCH_CHECK();
{%- endif %} // if not dense or not vbe

grad_output_accessor = MAKE_PTA_WITH_NAME("{{ embedding_cuda_op }}.2", grad_output_mean, grad_t, 2, 64);
}
Expand Down
5 changes: 1 addition & 4 deletions fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,14 @@ std::string tensor_on_same_gpu_if_not_optional_check(

inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards(
const at::Tensor& grad_output) {
auto aligned_grad_output = grad_output;
auto aligned_grad_output = at::empty_like(grad_output).copy_(grad_output);
// FIXME: to support aligned memory access in Vec4T load/store function
// 16 for FP32 and 8 for FP16
if (grad_output.dim() > 1 &&
(reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0)) {
aligned_grad_output = grad_output.contiguous();
}
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
aligned_grad_output = at::empty_like(grad_output).copy_(grad_output);
}
return aligned_grad_output;
}

Expand Down
32 changes: 14 additions & 18 deletions fbgemm_gpu/test/tbe/training/backward_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

# pyre-ignore-all-errors[56]

import os
import unittest

import hypothesis.strategies as st
Expand Down Expand Up @@ -37,19 +36,12 @@

if open_source:
# pyre-ignore[21]
from test_utils import (
additional_decorators,
gradcheck,
optests,
skipIfRocm,
use_cpu_strategy,
)
from test_utils import additional_decorators, gradcheck, optests, use_cpu_strategy
else:
from fbgemm_gpu.test.test_utils import (
additional_decorators,
gradcheck,
optests,
skipIfRocm,
use_cpu_strategy,
)

Expand All @@ -59,11 +51,6 @@

@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class BackwardDenseTest(unittest.TestCase):
@unittest.skipIf(
os.getenv("GITHUB_ENV") is not None,
"This test is currently running into illegal memmory access issues in OSS, and is being investigated; please see https://github.com/pytorch/pytorch/issues/141904.",
)
@skipIfRocm("Currently runs into memory access issues")
@given(
T=st.integers(min_value=1, max_value=3),
D=st.integers(min_value=2, max_value=128),
Expand Down Expand Up @@ -330,17 +317,26 @@ def test_backward_dense( # noqa C901
)
y.sum().backward()
indice_weight_grad_mask = per_sample_weights.grad.clone().cpu()
if not use_cpu:
torch.cuda.synchronize()

acc_B = 0
for t in range(T_):
B = Bs[t]
table_indice_weight_grad_mask = indice_weight_grad_mask[
acc_B : acc_B + B * L
]
table_indice_weight_grad_all = indice_weight_grad_all[acc_B : acc_B + B * L]
acc_B += B * L
if feature_requires_grad[t]:
torch.testing.assert_close(
indice_weight_grad_mask.view(T_, B, L)[t],
indice_weight_grad_all.view(T_, B, L)[t],
table_indice_weight_grad_mask,
table_indice_weight_grad_all,
)
else:
torch.testing.assert_close(
indice_weight_grad_mask.view(T_, B, L)[t],
torch.zeros_like(indice_weight_grad_mask.view(T_, B, L)[t]),
table_indice_weight_grad_mask,
torch.zeros_like(table_indice_weight_grad_mask),
)

per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu)
Expand Down