Skip to content

Commit f77eb07

Browse files
yanbing-jpytorchmergebot
authored andcommitted
Split int4wo weight packing (pytorch#139611)
Fixes pytorch/ao#1117. This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on pytorch/ao#1117 (comment). Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8. Pull Request resolved: pytorch#139611 Approved by: https://github.com/jerryzh168
1 parent 7691064 commit f77eb07

10 files changed

+147
-125
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

+15-30
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@
3232
#else
3333
#include <ATen/ops/_addmm_activation_native.h>
3434
#include <ATen/ops/_compute_linear_combination_native.h>
35-
#include <ATen/ops/_convert_weight_to_int4pack_native.h>
35+
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h>
3636
#include <ATen/ops/_int_mm_native.h>
3737
#include <ATen/ops/_linalg_check_errors.h>
3838
#include <ATen/ops/_linalg_det.h>
3939
#include <ATen/ops/_linalg_det_native.h>
4040
#include <ATen/ops/_linalg_slogdet.h>
4141
#include <ATen/ops/_linalg_slogdet_native.h>
4242
#include <ATen/ops/_unsafe_view.h>
43-
#include <ATen/ops/_weight_int4pack_mm_native.h>
43+
#include <ATen/ops/_weight_int4pack_mm_for_cpu_native.h>
4444
#include <ATen/ops/_weight_int8pack_mm_native.h>
4545
#include <ATen/ops/abs.h>
4646
#include <ATen/ops/addbmm_native.h>
@@ -3436,34 +3436,21 @@ Tensor _convert_weight_to_int4pack_cpu(
34363436

34373437
TORCH_CHECK(in.dim() == 2,
34383438
__func__, " : expect weight to be 2D tensor.");
3439-
TORCH_CHECK(in.dtype() == at::kByte,
3440-
__func__, " : expect weight to be kByte.");
3441-
TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8,
3442-
__func__, " : innerKTiles need to be 2, 4, or 8, got ", innerKTiles);
3439+
TORCH_CHECK(in.dtype() == at::kInt,
3440+
__func__, " : expect weight to be kInt.");
34433441

34443442
auto weight = in.contiguous();
34453443
auto N = weight.size(0);
3446-
auto K = weight.size(1) * 2;
3447-
3448-
// Create fake shapes for cpu. The meta registration in dynamo requires
3449-
// operator has the same output shape for each device. So creating a fake
3450-
// shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2}
3451-
constexpr int64_t kNTileSize = 8;
3452-
constexpr int64_t kKTileSize = 16;
3453-
auto nTiles = (N + kNTileSize - 1) / kNTileSize;
3444+
auto K = weight.size(1);
34543445

34553446
TORCH_CHECK(N % 16 == 0,
34563447
__func__, " : expect N to be dividable by 16");
3457-
const int64_t kSuperKTileSize = kKTileSize * innerKTiles;
3458-
TORCH_CHECK( K % kSuperKTileSize == 0,
3459-
__func__, " : epxect K to be dividable by ", kSuperKTileSize);
3460-
auto kSuperTiles = (K + kSuperKTileSize - 1) / kSuperKTileSize;
3448+
TORCH_CHECK(K % 2 == 0,
3449+
"_convert_weight_to_int4pack: expect K to be dividable by 2");
34613450

3462-
auto weight_packed = at::empty(
3463-
{nTiles, kSuperTiles, 32, innerKTiles / 2},
3464-
at::TensorOptions().dtype(at::kInt));
3451+
auto weight_packed = at::empty({N, K / 2}, weight.options().dtype(at::kByte));
34653452

3466-
weight_to_int4pack_stub(kCPU, weight_packed, weight, N, K);
3453+
weight_to_int4pack_stub(kCPU, weight_packed, weight);
34673454
return weight_packed;
34683455
}
34693456

@@ -3473,10 +3460,8 @@ Tensor _weight_int4pack_mm_cpu(
34733460
int64_t qGroupSize,
34743461
const Tensor& qScaleAndZeros) {
34753462

3476-
constexpr int64_t kNTileSize = 8;
3477-
34783463
auto M = A.size(0);
3479-
auto N = B.size(0) * kNTileSize;
3464+
auto N = B.size(0);
34803465
auto K = A.size(1);
34813466

34823467
TORCH_CHECK(A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
@@ -3486,12 +3471,12 @@ Tensor _weight_int4pack_mm_cpu(
34863471
TORCH_CHECK(A.dim() == 2,
34873472
__func__, " : expect A to be 2D tensor.");
34883473

3489-
TORCH_CHECK(B.dtype() == kInt,
3490-
__func__, " : expect B to be int32 tensor.");
3474+
TORCH_CHECK(B.dtype() == kByte,
3475+
__func__, " : expect B to be uint8 tensor.");
34913476
TORCH_CHECK(B.is_contiguous(),
34923477
__func__, " : expect B to be contiguous.");
3493-
TORCH_CHECK(B.dim() == 4,
3494-
__func__, " : expect B to 4d tensor.");
3478+
TORCH_CHECK(B.size(1) == K / 2,
3479+
__func__, " : expect B.size(1) to be K/2, got ", B.size(1));
34953480

34963481
TORCH_CHECK(qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128
34973482
|| qGroupSize == 256,
@@ -3502,7 +3487,7 @@ Tensor _weight_int4pack_mm_cpu(
35023487
__func__, ": expect qScaleAndZeros to be 3d tensor with sizes [:, ", N, ", 2]");
35033488

35043489
auto C = at::empty({M, N}, A.options());
3505-
int4pack_mm_stub(kCPU, C, A, B, qGroupSize, qScaleAndZeros, N, K);
3490+
int4pack_mm_stub(kCPU, C, A, B, qGroupSize, qScaleAndZeros);
35063491

35073492
return C;
35083493
}

aten/src/ATen/native/cpu/int4mm_kernel.cpp

+40-51
Original file line numberDiff line numberDiff line change
@@ -605,88 +605,77 @@ inline void tinygemm_kernel(
605605
//
606606
void weight_to_int4pack_kernel(
607607
const Tensor& weight_packed,
608-
const Tensor& weight,
609-
int N, int K) {
608+
const Tensor& weight) {
610609

611610
auto weight_packed_data = reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
612-
const auto weight_data = weight.data_ptr<uint8_t>();
611+
const auto weight_data = weight.data_ptr<int32_t>();
612+
613+
int N = weight.size(0);
614+
int K = weight.size(1);
613615

614616
// 64 for avx512 and 32 for avx2/non-vectorized
615617
constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
616618
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
617-
int K_div_2 = K / 2;
618619

619620
// parallel on NB blocks
620621
at::parallel_for(0, NB, 0, [&](int begin, int end) {
621622
for (const auto i : c10::irange(begin, end)) {
622623
int nb_size = std::min(BLOCK_N, N - i * BLOCK_N);
623624

624-
const uint8_t* src = weight_data + i * BLOCK_N * K_div_2;
625+
const int32_t* src = weight_data + i * BLOCK_N * K;
625626
uint8_t* dst = weight_packed_data + i * K * BLOCK_N / 2;
626-
for (const auto k : c10::irange(K_div_2)) {
627+
for (const auto k : c10::irange(K)) {
627628
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
628629
if (nb_size == BLOCK_N) {
629630
for (const auto d : c10::irange(16)) {
630-
uint8_t val0 = src[(d + 0) * K_div_2 + k];
631-
uint8_t val1 = src[(d + 16) * K_div_2 + k];
632-
uint8_t val2 = src[(d + 32) * K_div_2 + k];
633-
uint8_t val3 = src[(d + 48) * K_div_2 + k];
634-
635-
uint8_t packed02_0 = (val2 & 0xF0) | ((val0 & 0xF0) >> 4);
636-
uint8_t packed13_0 = (val3 & 0xF0) | ((val1 & 0xF0) >> 4);
637-
uint8_t packed02_1 = ((val2 & 0xF) << 4) | (val0 & 0xF);
638-
uint8_t packed13_1 = ((val3 & 0xF) << 4) | (val1 & 0xF);
639-
640-
dst[k * 2 * 32 + d] = packed02_0;
641-
dst[k * 2 * 32 + 16 + d] = packed13_0;
642-
dst[(k * 2 + 1) * 32 + d] = packed02_1;
643-
dst[(k * 2 + 1) * 32 + 16 + d] = packed13_1;
631+
int32_t val0 = src[(d + 0) * K + k];
632+
int32_t val1 = src[(d + 16) * K + k];
633+
int32_t val2 = src[(d + 32) * K + k];
634+
int32_t val3 = src[(d + 48) * K + k];
635+
636+
uint8_t packed02 = (((uint8_t)(val2) << 4)) | ((uint8_t)(val0));
637+
uint8_t packed13 = (((uint8_t)(val3) << 4)) | ((uint8_t)(val1));
638+
639+
dst[k * 32 + d] = packed02;
640+
dst[k * 32 + 16 + d] = packed13;
644641
}
645642
} else {
646643
// for nb_size 16, 32, 48
647644
for (int n = 0; n < nb_size; n += 2) {
648-
uint8_t val0 = src[n * K_div_2 + k];
649-
uint8_t val1 = src[n * K_div_2 + K_div_2 + k];
645+
int32_t val0 = src[n * K + k];
646+
int32_t val1 = src[n * K + K + k];
650647

651-
uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
652-
uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
653-
dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
654-
dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
648+
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
649+
dst[k * nb_size / 2 + n / 2] = packed;
655650
}
656651
}
657652
#elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
658653
if (nb_size == BLOCK_N) {
659654
// for nb_size 32
660655
for (const auto d : c10::irange(16)) {
661-
uint8_t val0 = src[(d + 0) * K_div_2 + k];
662-
uint8_t val1 = src[(d + 16) * K_div_2 + k];
656+
int32_t val0 = src[(d + 0) * K + k];
657+
int32_t val1 = src[(d + 16) * K + k];
663658

664-
uint8_t packed01_0 = ((val1 & 0xF0) | ((val0 & 0xF0) >> 4));
665-
uint8_t packed01_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
666-
dst[k * 2 * 16 + d] = packed01_0;
667-
dst[(k * 2 + 1) * 16 + d] = packed01_1;
659+
uint8_t packed01 = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
660+
dst[k * 16 + d] = packed01;
668661
}
669662
} else {
670663
// for nb_size 16
671664
for (int n = 0; n < nb_size; n += 2) {
672-
int32_t val0 = src[n * K_div_2 + k];
673-
int32_t val1 = src[n * K_div_2 + K_div_2 + k];
665+
int32_t val0 = src[n * K + k];
666+
int32_t val1 = src[n * K + K + k];
674667

675-
uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
676-
uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
677-
dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
678-
dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
668+
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
669+
dst[k * nb_size / 2 + n / 2] = packed;
679670
}
680671
}
681672
#else
682673
for (int n = 0; n < nb_size; n += 2) {
683-
uint8_t val0 = src[n * K_div_2 + k];
684-
uint8_t val1 = src[n * K_div_2 + K_div_2 + k];
674+
int32_t val0 = src[n * K + k];
675+
int32_t val1 = src[n * K + K + k];
685676

686-
uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
687-
uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
688-
dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
689-
dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
677+
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
678+
dst[k * nb_size / 2 + n / 2] = packed;
690679
}
691680
#endif
692681
}
@@ -700,15 +689,16 @@ void int4pack_mm_kernel_(
700689
const Tensor& A,
701690
const Tensor& B,
702691
int qGroupSize,
703-
const Tensor& qScaleAndZeros,
704-
int N, int K) {
692+
const Tensor& qScaleAndZeros) {
705693

706694
const auto* A_data = A.const_data_ptr<T>();
707695
const auto* B_data = reinterpret_cast<const uint8_t*>(B.const_data_ptr());
708696
auto* C_data = C.data_ptr<T>();
709697
const auto* S_data = qScaleAndZeros.const_data_ptr<T>();
710698

711699
int M = A.size(0);
700+
int N = B.size(0);
701+
int K = A.size(1);
712702

713703
constexpr int BLOCK_M = 4;
714704
// 64 for avx512 and 32 for avx2/non-vectorized
@@ -762,14 +752,13 @@ void int4pack_mm_kernel(
762752
const Tensor& A,
763753
const Tensor& B,
764754
int qGroupSize,
765-
const Tensor& qScaleAndZeros,
766-
int N, int K) {
755+
const Tensor& qScaleAndZeros) {
767756
if (C.scalar_type() == kBFloat16) {
768-
int4pack_mm_kernel_<BFloat16>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
757+
int4pack_mm_kernel_<BFloat16>(C, A, B, qGroupSize, qScaleAndZeros);
769758
} else if (C.scalar_type() == kHalf) {
770-
int4pack_mm_kernel_<Half>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
759+
int4pack_mm_kernel_<Half>(C, A, B, qGroupSize, qScaleAndZeros);
771760
} else {
772-
int4pack_mm_kernel_<float>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
761+
int4pack_mm_kernel_<float>(C, A, B, qGroupSize, qScaleAndZeros);
773762
}
774763
}
775764

aten/src/ATen/native/cpu/int_mm_kernel.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
namespace at::native {
77

8-
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
9-
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
8+
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&);
9+
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
1010
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
1111

1212
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub)

aten/src/ATen/native/native_functions.yaml

+10-2
Original file line numberDiff line numberDiff line change
@@ -4149,16 +4149,24 @@
41494149

41504150
- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
41514151
dispatch:
4152-
CPU: _convert_weight_to_int4pack_cpu
41534152
CUDA: _convert_weight_to_int4pack_cuda
41544153
MPS: _convert_weight_to_int4pack_mps
41554154

41564155
- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
41574156
dispatch:
4158-
CPU: _weight_int4pack_mm_cpu
41594157
MPS: _weight_int4pack_mm_mps
41604158
CUDA: _weight_int4pack_mm_cuda
41614159

4160+
# Split int4 pack weight between cpu and other devices due to
4161+
# https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756.
4162+
- func: _convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor
4163+
dispatch:
4164+
CPU: _convert_weight_to_int4pack_cpu
4165+
4166+
- func: _weight_int4pack_mm_for_cpu(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
4167+
dispatch:
4168+
CPU: _weight_int4pack_mm_cpu
4169+
41624170
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
41634171
dispatch:
41644172
CPU: _weight_int8pack_mm_cpu

test/expect/HasDecompTest.test_has_decomposition.expect

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ aten::_convert_indices_from_coo_to_csr.out
5757
aten::_convert_indices_from_csr_to_coo
5858
aten::_convert_indices_from_csr_to_coo.out
5959
aten::_convert_weight_to_int4pack
60+
aten::_convert_weight_to_int4pack_for_cpu
6061
aten::_convolution
6162
aten::_convolution.out
6263
aten::_copy_from
@@ -637,6 +638,7 @@ aten::_values
637638
aten::_values_copy
638639
aten::_values_copy.out
639640
aten::_weight_int4pack_mm
641+
aten::_weight_int4pack_mm_for_cpu
640642
aten::_weight_int8pack_mm
641643
aten::_weight_norm_interface_backward
642644
aten::_weight_norm_interface_backward.out

0 commit comments

Comments
 (0)