@@ -143,6 +143,9 @@ class OPENVINO_API Constant : public Op {
143
143
case Type_t::u64:
144
144
fill_data<Type_t::u64>(value);
145
145
break ;
146
+ case Type_t::nf4:
147
+ fill_data<Type_t::nf4>(value);
148
+ break ;
146
149
case Type_t::undefined:
147
150
case Type_t::dynamic:
148
151
OPENVINO_THROW (" unsupported type" );
@@ -408,7 +411,7 @@ class OPENVINO_API Constant : public Op {
408
411
template <element::Type_t Type,
409
412
typename StorageDataType = fundamental_type_for<Type>,
410
413
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
411
- Type != element::Type_t::i4,
414
+ Type != element::Type_t::i4 && Type != element::Type_t::nf4 ,
412
415
bool >::type = true >
413
416
StorageDataType get_element_value (size_t index) const {
414
417
return get_data_ptr<Type>()[index ];
@@ -428,6 +431,13 @@ class OPENVINO_API Constant : public Op {
428
431
return (get_data_ptr<uint8_t >()[index / 2 ] >> (index % 2 ? 0 : 4 )) & 0x0F ;
429
432
}
430
433
434
+ template <element::Type_t Type,
435
+ typename StorageDataType = fundamental_type_for<Type>,
436
+ typename std::enable_if<Type == element::Type_t::nf4, bool >::type = true >
437
+ StorageDataType get_element_value (size_t index) const {
438
+ return (get_data_ptr<uint8_t >()[index / 2 ] >> (index % 2 ? 4 : 0 )) & 0x0F ;
439
+ }
440
+
431
441
template <element::Type_t Type,
432
442
typename StorageDataType = fundamental_type_for<Type>,
433
443
typename std::enable_if<Type == element::Type_t::i4, bool >::type = true >
@@ -554,7 +564,7 @@ class OPENVINO_API Constant : public Op {
554
564
typename T,
555
565
typename StorageDataType = fundamental_type_for<Type>,
556
566
typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
557
- Type != element::Type_t::i4,
567
+ Type != element::Type_t::i4 && Type != element::Type_t::nf4 ,
558
568
bool >::type = true >
559
569
void fill_data (const T& value) {
560
570
#ifdef __clang__
@@ -607,7 +617,9 @@ class OPENVINO_API Constant : public Op {
607
617
template <element::Type_t Type,
608
618
typename T,
609
619
typename StorageDataType = fundamental_type_for<Type>,
610
- typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4, bool >::type = true >
620
+ typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4 ||
621
+ Type == element::Type_t::nf4,
622
+ bool >::type = true >
611
623
void fill_data (const T& value) {
612
624
uint8_t v = value_in_range<Type>(value);
613
625
v &= 0x0F ;
@@ -640,8 +652,8 @@ class OPENVINO_API Constant : public Op {
640
652
template <element::Type_t Type,
641
653
typename T,
642
654
typename StorageDataType = fundamental_type_for<Type>,
643
- typename std::enable_if<Type != element::Type_t::u1 && Type != element::Type_t::u4 &&
644
- Type != element::Type_t::i4,
655
+ typename std::enable_if<Type != element::Type_t::nf4 && Type != element::Type_t::u1 &&
656
+ Type != element::Type_t::u4 && Type != element::Type_t:: i4,
645
657
bool >::type = true >
646
658
void write_buffer (const std::vector<T>& source) {
647
659
auto p = get_data_ptr_nc<Type>();
@@ -670,6 +682,50 @@ class OPENVINO_API Constant : public Op {
670
682
}
671
683
}
672
684
685
+ template <element::Type_t Type,
686
+ typename T,
687
+ typename StorageDataType = fundamental_type_for<Type>,
688
+ typename std::enable_if<Type == element::Type_t::nf4 && std::is_integral<T>::value, bool >::type = true >
689
+ void write_buffer (const std::vector<T>& source) {
690
+ auto p = get_data_ptr_nc<Type>();
691
+ size_t i = 0 ;
692
+ for (; i < source.size () / 2 ; i++) {
693
+ const auto v1 = value_in_range<Type>(source[i * 2 ]) & 0x0F ;
694
+ const auto v2 = value_in_range<Type>(source[i * 2 + 1 ]) & 0x0F ;
695
+ const auto v = (v2 << 4 ) | v1;
696
+ p[i] = static_cast <StorageDataType>(v);
697
+ }
698
+ if (source.size () % 2 ) {
699
+ const auto v = value_in_range<Type>(source[i * 2 ]) & 0x0F ;
700
+ p[i] = static_cast <StorageDataType>(v);
701
+ }
702
+ }
703
+
704
+ template <element::Type_t Type,
705
+ typename T,
706
+ typename StorageDataType = fundamental_type_for<Type>,
707
+ typename std::enable_if<Type == element::Type_t::nf4 &&
708
+ (std::is_floating_point<T>::value || std::is_same<T, bfloat16>::value ||
709
+ std::is_same<T, float16>::value),
710
+ bool >::type = true >
711
+ void write_buffer (const std::vector<T>& source) {
712
+ auto p = get_data_ptr_nc<Type>();
713
+ size_t i = 0 ;
714
+ for (; i < source.size () / 2 ; i++) {
715
+ const auto idx1 = ConvertNF4::quantize (static_cast <float >(source[i * 2 ]));
716
+ const auto idx2 = ConvertNF4::quantize (static_cast <float >(source[i * 2 + 1 ]));
717
+ const auto v1 = value_in_range<Type>(idx1) & 0x0F ;
718
+ const auto v2 = value_in_range<Type>(idx2) & 0x0F ;
719
+ const auto v = (v2 << 4 ) | v1;
720
+ p[i] = static_cast <StorageDataType>(v);
721
+ }
722
+ if (source.size () % 2 ) {
723
+ const auto idx1 = ConvertNF4::quantize (static_cast <float >(source[i * 2 ]));
724
+ const auto v = value_in_range<Type>(idx1) & 0x0F ;
725
+ p[i] = static_cast <StorageDataType>(v);
726
+ }
727
+ }
728
+
673
729
template <element::Type_t Type,
674
730
typename T,
675
731
typename StorageDataType = fundamental_type_for<Type>,
@@ -755,6 +811,9 @@ class OPENVINO_API Constant : public Op {
755
811
case Type_t::u64:
756
812
write_buffer<Type_t::u64>(source);
757
813
break ;
814
+ case Type_t::nf4:
815
+ write_buffer<Type_t::nf4>(source);
816
+ break ;
758
817
case element::Type_t::undefined:
759
818
case element::Type_t::dynamic:
760
819
OPENVINO_THROW (" unsupported type" );
@@ -765,7 +824,9 @@ class OPENVINO_API Constant : public Op {
765
824
}
766
825
template <ov::element::Type_t Type,
767
826
typename ValueT,
768
- typename std::enable_if<Type == ov::element::Type_t::u4, bool >::type = true >
827
+ typename std::enable_if<Type == ov::element::Type_t::u4 || Type == ov::element::Type_t::u4 ||
828
+ Type == ov::element::Type_t::nf4,
829
+ bool >::type = true >
769
830
static ov::fundamental_type_for<Type> value_in_range (const ValueT& value) {
770
831
const auto result = ov::fundamental_type_for<Type>(value);
771
832
OPENVINO_ASSERT (0 <= result && result <= 15 , " assigned value out of range u4 values" );
0 commit comments