Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable OCP FP8 for Latest Archs #111

Open
wants to merge 1 commit into
base: rocm-jaxlib-v0.4.35-qa
Choose a base branch
from

Conversation

ScXfjiang
Copy link

No description provided.

@ScXfjiang ScXfjiang force-pushed the rocm-jaxlib-v0.4.35-qa_enable_ocp_fp8_in_gemm_rewriter branch from 93c9277 to 4d3c1f5 Compare February 20, 2025 19:22
@ScXfjiang ScXfjiang marked this pull request as draft February 20, 2025 19:33
@ScXfjiang ScXfjiang force-pushed the rocm-jaxlib-v0.4.35-qa_enable_ocp_fp8_in_gemm_rewriter branch 6 times, most recently from c6320d5 to 61a5683 Compare February 25, 2025 17:47
@ScXfjiang ScXfjiang force-pushed the rocm-jaxlib-v0.4.35-qa_enable_ocp_fp8_in_gemm_rewriter branch 2 times, most recently from 538aed9 to d20c818 Compare March 5, 2025 19:11
@ScXfjiang ScXfjiang marked this pull request as ready for review March 5, 2025 23:25
@ScXfjiang ScXfjiang force-pushed the rocm-jaxlib-v0.4.35-qa_enable_ocp_fp8_in_gemm_rewriter branch from 646ad3f to a3a64ca Compare March 7, 2025 00:31
@ScXfjiang ScXfjiang force-pushed the rocm-jaxlib-v0.4.35-qa_enable_ocp_fp8_in_gemm_rewriter branch from 9ed9fef to 187ca2e Compare March 7, 2025 15:07
}

// hipBlasLt requires setting the a/b scale pointer (even a dummy one),
// otherwise no algorithms can be found for "a/b scaling". This is to be

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not true. What about d_scale and c_scale?

float elem_a =
__half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E4M3));
float elem_b =
__half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E4M3));
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
float elem_a =
__half2float(__hip_cvt_fp8_to_halfraw(buffer_a[idx], __HIP_E4M3));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have an api that converts directly to float?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants