@@ -605,88 +605,77 @@ inline void tinygemm_kernel(
605
605
//
606
606
void weight_to_int4pack_kernel (
607
607
const Tensor& weight_packed,
608
- const Tensor& weight,
609
- int N, int K) {
608
+ const Tensor& weight) {
610
609
611
610
auto weight_packed_data = reinterpret_cast <uint8_t *>(weight_packed.data_ptr ());
612
- const auto weight_data = weight.data_ptr <uint8_t >();
611
+ const auto weight_data = weight.data_ptr <int32_t >();
612
+
613
+ int N = weight.size (0 );
614
+ int K = weight.size (1 );
613
615
614
616
// 64 for avx512 and 32 for avx2/non-vectorized
615
617
constexpr int BLOCK_N = vec::Vectorized<float >::size () * 4 ;
616
618
const int NB = (N + BLOCK_N - 1 ) / BLOCK_N;
617
- int K_div_2 = K / 2 ;
618
619
619
620
// parallel on NB blocks
620
621
at::parallel_for (0 , NB, 0 , [&](int begin, int end) {
621
622
for (const auto i : c10::irange (begin, end)) {
622
623
int nb_size = std::min (BLOCK_N, N - i * BLOCK_N);
623
624
624
- const uint8_t * src = weight_data + i * BLOCK_N * K_div_2 ;
625
+ const int32_t * src = weight_data + i * BLOCK_N * K ;
625
626
uint8_t * dst = weight_packed_data + i * K * BLOCK_N / 2 ;
626
- for (const auto k : c10::irange (K_div_2 )) {
627
+ for (const auto k : c10::irange (K )) {
627
628
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
628
629
if (nb_size == BLOCK_N) {
629
630
for (const auto d : c10::irange (16 )) {
630
- uint8_t val0 = src[(d + 0 ) * K_div_2 + k];
631
- uint8_t val1 = src[(d + 16 ) * K_div_2 + k];
632
- uint8_t val2 = src[(d + 32 ) * K_div_2 + k];
633
- uint8_t val3 = src[(d + 48 ) * K_div_2 + k];
634
-
635
- uint8_t packed02_0 = (val2 & 0xF0 ) | ((val0 & 0xF0 ) >> 4 );
636
- uint8_t packed13_0 = (val3 & 0xF0 ) | ((val1 & 0xF0 ) >> 4 );
637
- uint8_t packed02_1 = ((val2 & 0xF ) << 4 ) | (val0 & 0xF );
638
- uint8_t packed13_1 = ((val3 & 0xF ) << 4 ) | (val1 & 0xF );
639
-
640
- dst[k * 2 * 32 + d] = packed02_0;
641
- dst[k * 2 * 32 + 16 + d] = packed13_0;
642
- dst[(k * 2 + 1 ) * 32 + d] = packed02_1;
643
- dst[(k * 2 + 1 ) * 32 + 16 + d] = packed13_1;
631
+ int32_t val0 = src[(d + 0 ) * K + k];
632
+ int32_t val1 = src[(d + 16 ) * K + k];
633
+ int32_t val2 = src[(d + 32 ) * K + k];
634
+ int32_t val3 = src[(d + 48 ) * K + k];
635
+
636
+ uint8_t packed02 = (((uint8_t )(val2) << 4 )) | ((uint8_t )(val0));
637
+ uint8_t packed13 = (((uint8_t )(val3) << 4 )) | ((uint8_t )(val1));
638
+
639
+ dst[k * 32 + d] = packed02;
640
+ dst[k * 32 + 16 + d] = packed13;
644
641
}
645
642
} else {
646
643
// for nb_size 16, 32, 48
647
644
for (int n = 0 ; n < nb_size; n += 2 ) {
648
- uint8_t val0 = src[n * K_div_2 + k];
649
- uint8_t val1 = src[n * K_div_2 + K_div_2 + k];
645
+ int32_t val0 = src[n * K + k];
646
+ int32_t val1 = src[n * K + K + k];
650
647
651
- uint8_t packed_0 = ((val1 & 0xF0 )) | ((val0 & 0xF0 ) >> 4 );
652
- uint8_t packed_1 = ((val1 & 0xF ) << 4 ) | (val0 & 0xF );
653
- dst[k * 2 * nb_size / 2 + n / 2 ] = packed_0;
654
- dst[(k * 2 + 1 ) * nb_size / 2 + n / 2 ] = packed_1;
648
+ uint8_t packed = (((uint8_t )(val1) << 4 )) | ((uint8_t )(val0));
649
+ dst[k * nb_size / 2 + n / 2 ] = packed;
655
650
}
656
651
}
657
652
#elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
658
653
if (nb_size == BLOCK_N) {
659
654
// for nb_size 32
660
655
for (const auto d : c10::irange (16 )) {
661
- uint8_t val0 = src[(d + 0 ) * K_div_2 + k];
662
- uint8_t val1 = src[(d + 16 ) * K_div_2 + k];
656
+ int32_t val0 = src[(d + 0 ) * K + k];
657
+ int32_t val1 = src[(d + 16 ) * K + k];
663
658
664
- uint8_t packed01_0 = ((val1 & 0xF0 ) | ((val0 & 0xF0 ) >> 4 ));
665
- uint8_t packed01_1 = ((val1 & 0xF ) << 4 ) | (val0 & 0xF );
666
- dst[k * 2 * 16 + d] = packed01_0;
667
- dst[(k * 2 + 1 ) * 16 + d] = packed01_1;
659
+ uint8_t packed01 = (((uint8_t )(val1) << 4 )) | ((uint8_t )(val0));
660
+ dst[k * 16 + d] = packed01;
668
661
}
669
662
} else {
670
663
// for nb_size 16
671
664
for (int n = 0 ; n < nb_size; n += 2 ) {
672
- int32_t val0 = src[n * K_div_2 + k];
673
- int32_t val1 = src[n * K_div_2 + K_div_2 + k];
665
+ int32_t val0 = src[n * K + k];
666
+ int32_t val1 = src[n * K + K + k];
674
667
675
- uint8_t packed_0 = ((val1 & 0xF0 )) | ((val0 & 0xF0 ) >> 4 );
676
- uint8_t packed_1 = ((val1 & 0xF ) << 4 ) | (val0 & 0xF );
677
- dst[k * 2 * nb_size / 2 + n / 2 ] = packed_0;
678
- dst[(k * 2 + 1 ) * nb_size / 2 + n / 2 ] = packed_1;
668
+ uint8_t packed = (((uint8_t )(val1) << 4 )) | ((uint8_t )(val0));
669
+ dst[k * nb_size / 2 + n / 2 ] = packed;
679
670
}
680
671
}
681
672
#else
682
673
for (int n = 0 ; n < nb_size; n += 2 ) {
683
- uint8_t val0 = src[n * K_div_2 + k];
684
- uint8_t val1 = src[n * K_div_2 + K_div_2 + k];
674
+ int32_t val0 = src[n * K + k];
675
+ int32_t val1 = src[n * K + K + k];
685
676
686
- uint8_t packed_0 = ((val1 & 0xF0 )) | ((val0 & 0xF0 ) >> 4 );
687
- uint8_t packed_1 = ((val1 & 0xF ) << 4 ) | (val0 & 0xF );
688
- dst[k * 2 * nb_size / 2 + n / 2 ] = packed_0;
689
- dst[(k * 2 + 1 ) * nb_size / 2 + n / 2 ] = packed_1;
677
+ uint8_t packed = (((uint8_t )(val1) << 4 )) | ((uint8_t )(val0));
678
+ dst[k * nb_size / 2 + n / 2 ] = packed;
690
679
}
691
680
#endif
692
681
}
@@ -700,15 +689,16 @@ void int4pack_mm_kernel_(
700
689
const Tensor& A,
701
690
const Tensor& B,
702
691
int qGroupSize,
703
- const Tensor& qScaleAndZeros,
704
- int N, int K) {
692
+ const Tensor& qScaleAndZeros) {
705
693
706
694
const auto * A_data = A.const_data_ptr <T>();
707
695
const auto * B_data = reinterpret_cast <const uint8_t *>(B.const_data_ptr ());
708
696
auto * C_data = C.data_ptr <T>();
709
697
const auto * S_data = qScaleAndZeros.const_data_ptr <T>();
710
698
711
699
int M = A.size (0 );
700
+ int N = B.size (0 );
701
+ int K = A.size (1 );
712
702
713
703
constexpr int BLOCK_M = 4 ;
714
704
// 64 for avx512 and 32 for avx2/non-vectorized
@@ -762,14 +752,13 @@ void int4pack_mm_kernel(
762
752
const Tensor& A,
763
753
const Tensor& B,
764
754
int qGroupSize,
765
- const Tensor& qScaleAndZeros,
766
- int N, int K) {
755
+ const Tensor& qScaleAndZeros) {
767
756
if (C.scalar_type () == kBFloat16 ) {
768
- int4pack_mm_kernel_<BFloat16>(C, A, B, qGroupSize, qScaleAndZeros, N, K );
757
+ int4pack_mm_kernel_<BFloat16>(C, A, B, qGroupSize, qScaleAndZeros);
769
758
} else if (C.scalar_type () == kHalf ) {
770
- int4pack_mm_kernel_<Half>(C, A, B, qGroupSize, qScaleAndZeros, N, K );
759
+ int4pack_mm_kernel_<Half>(C, A, B, qGroupSize, qScaleAndZeros);
771
760
} else {
772
- int4pack_mm_kernel_<float >(C, A, B, qGroupSize, qScaleAndZeros, N, K );
761
+ int4pack_mm_kernel_<float >(C, A, B, qGroupSize, qScaleAndZeros);
773
762
}
774
763
}
775
764
0 commit comments