diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 51fa214f73..6c00d4bf0f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -147,6 +147,10 @@ jobs: - rust: prev toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }} steps: + # disable incremental compilation (reduces artifact size) + - name: Set CI Profile + run: echo "CARGO_PROFILE_TEST_INCREMENTAL=false" >> $GITHUB_ENV + # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/setup-rust@v1 with: diff --git a/Cargo.lock b/Cargo.lock index 9796224469..3b9f26803d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -944,6 +944,7 @@ version = "0.17.0" dependencies = [ "proc-macro2", "quote", + "syn 2.0.98", ] [[package]] diff --git a/burn-book/src/quantization.md b/burn-book/src/quantization.md index 5b9995caab..2061b4a2d9 100644 --- a/burn-book/src/quantization.md +++ b/burn-book/src/quantization.md @@ -45,12 +45,12 @@ tensors and can collect their statistics, such as the min and max value when usi ```rust , ignore # use burn::module::Quantizer; -# use burn::tensor::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType}; +# use burn::tensor::quantization::{Calibration, QuantizationScheme, QuantizationType}; # // Quantization config let mut quantizer = Quantizer { - calibration: MinMaxCalibration {}, - scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8), + calibration: Calibration::MinMax, + scheme: QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8), }; // Quantize the weights @@ -95,9 +95,9 @@ _quantization-time_ (weights are static), but activations require more attention To compute the quantization parameters, Burn supports the following `Calibration` methods. -| Method | Description | -| :------------------ | :------------------------------------------------------------------------------- | -| `MinMaxCalibration` | Computes the quantization range mapping based on the running min and max values. | +| Method | Description | +| :------- | :------------------------------------------------------------------------------- | +| `MinMax` | Computes the quantization range mapping based on the running min and max values. | ### Quantization Scheme @@ -116,7 +116,23 @@ channel with per-channel quantization (commonly used with CNNs). Burn currently supports the following `QuantizationScheme` variants. -| Variant | Description | -| :------------------- | :------------------------------------------------------------------------------------------------------------- | -| `PerTensorAffine` | Computes the quantization parameters for the whole tensor and applies an affine range mapping with zero point. | -| `PerTensorSymmetric` | Computes the quantization parameters for the whole tensor and applies a scale range mapping centered around 0. | +| Variant | Description | +| :----------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `PerTensor(mode, type)` | Applies a single set of quantization parameters to the entire tensor. The `mode` defines how values are transformed, and `type` represents the target quantization type. | +| `PerBlock(mode, type, layout)` | Applies quantization parameters to individual blocks within the tensor. The `layout` defines how the tensor is partitioned. | + +#### Quantization Mode + +| Mode | Description | +| ----------- | -------------------------------------------------------------------- | +| `Affine` | Maps values using an affine transformation with a zero point offset. | +| `Symmetric` | Maps values using a scale factor for a range centered around zero. | + +--- + +#### Block Layout + +| Layout | Description | +| ------------------ | -------------------------------------------------------- | +| `Flat(block_size)` | Divides the tensor into linear 1D blocks of fixed size. | +| `Grid(m, n)` | Divides the tensor into 2D blocks of `m` x `n` elements. | diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index f8b1f10d94..9cc5d83306 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use burn_tensor::{ backend::{Backend, DeviceId, DeviceOps}, - quantization::{QTensorPrimitive, QuantizationStrategy}, + quantization::QTensorPrimitive, Device, }; use candle_core::{backend::BackendDevice, DeviceLocation}; diff --git a/crates/burn-candle/src/ops/qtensor.rs b/crates/burn-candle/src/ops/qtensor.rs index f4c2e96f04..ee440f7063 100644 --- a/crates/burn-candle/src/ops/qtensor.rs +++ b/crates/burn-candle/src/ops/qtensor.rs @@ -3,7 +3,7 @@ use std::ops::Range; use burn_tensor::{ backend::Backend, ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy}, + quantization::{QuantizationParametersPrimitive, QuantizationScheme}, DType, Device, Shape, TensorData, }; diff --git a/crates/burn-candle/src/tensor.rs b/crates/burn-candle/src/tensor.rs index d038624764..8927ed4eb7 100644 --- a/crates/burn-candle/src/tensor.rs +++ b/crates/burn-candle/src/tensor.rs @@ -1,5 +1,5 @@ use burn_tensor::{ - quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}, + quantization::{QTensorPrimitive, QuantizationScheme}, DType, Element, Shape, TensorData, TensorMetadata, }; diff --git a/crates/burn-core/src/module/base.rs b/crates/burn-core/src/module/base.rs index 03ab4def79..99c79fee0a 100644 --- a/crates/burn-core/src/module/base.rs +++ b/crates/burn-core/src/module/base.rs @@ -5,7 +5,7 @@ use crate::{ }; use alloc::vec::Vec; pub use burn_derive::Module; -use burn_tensor::{ops::Device, quantization::Calibration, Bool, Int, Tensor}; +use burn_tensor::{ops::Device, Bool, Int, Tensor}; /// Type alias to `Vec` which supports `no_std` environments, but automatically using /// the `alloc` crate. @@ -204,7 +204,7 @@ pub trait Module: Clone + Send + core::fmt::Debug { } /// Quantize the weights of the module. - fn quantize_weights(self, quantizer: &mut Quantizer) -> Self { + fn quantize_weights(self, quantizer: &mut Quantizer) -> Self { self.map(quantizer) } } diff --git a/crates/burn-core/src/module/quantize.rs b/crates/burn-core/src/module/quantize.rs index e800a27967..75af12781d 100644 --- a/crates/burn-core/src/module/quantize.rs +++ b/crates/burn-core/src/module/quantize.rs @@ -7,16 +7,16 @@ use burn_tensor::{ use crate::module::{ModuleMapper, ParamId}; /// Describes how to quantize a module. -pub struct Quantizer { +pub struct Quantizer { /// The calibration method used in quantization. - pub calibration: C, + pub calibration: Calibration, /// The quantization scheme. pub scheme: QuantizationScheme, } -impl ModuleMapper for Quantizer { +impl ModuleMapper for Quantizer { fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { - let range = self.calibration.compute_range(&tensor); + let range = self.scheme.compute_range(&tensor, &self.calibration); let qparams = self.scheme.compute_q_params(range); tensor.quantize(&self.scheme, qparams) } diff --git a/crates/burn-cubecl/src/kernel/quantization/dequantize.rs b/crates/burn-cubecl/src/kernel/quantization/dequantize.rs index c318baa276..6be8eb3c52 100644 --- a/crates/burn-cubecl/src/kernel/quantization/dequantize.rs +++ b/crates/burn-cubecl/src/kernel/quantization/dequantize.rs @@ -1,7 +1,8 @@ use crate::tensor::CubeTensor; -use crate::FloatElement; use crate::{CubeElement, CubeRuntime}; -use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; +use burn_tensor::quantization::{ + BlockLayout, QuantizationMode, QuantizationScheme, QuantizationType, +}; use burn_tensor::DType; use cubecl::calculate_cube_count_elemwise; use cubecl::prelude::*; @@ -9,17 +10,19 @@ use cubecl::prelude::*; use super::{QParams, QTensor}; #[cube] -pub(crate) fn dequantize_affine_int8( - value: Line, - scale: f32, - offset: i32, -) -> Line { +fn dequantize_affine_int8(value: Line, scale: f32, offset: i32) -> Line { // x = scale * (x_q - offset) Line::cast_from(scale) * Line::cast_from(value - Line::cast_from(offset)) } #[cube] -pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 { +fn dequantize_symmetric_int8(value: Line, scale: f32) -> Line { + // x = scale * x_q + Line::cast_from(scale) * Line::cast_from(value) +} + +#[cube] +fn extract_i8(value: u32, offset: u32) -> i32 { // Extract 8-bit segment let value = (value >> offset) & 0xFF; // Check if the value is negative by inspecting the MSB and subtract 256 if it is @@ -29,7 +32,7 @@ pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 { } #[cube] -pub(crate) fn extract_i8s(value: u32) -> Line { +fn unpack_i8s(value: u32) -> Line { let mut line = Line::empty(4); // Extract each 8-bit segment line[0] = extract_i8(value, 0); @@ -41,7 +44,7 @@ pub(crate) fn extract_i8s(value: u32) -> Line { } #[cube(launch_unchecked)] -pub(crate) fn dequantize_per_tensor_affine_int8_kernel( +fn dequantize_per_tensor_affine_int8_kernel( input: &QTensor, output: &mut Tensor>, #[comptime] scheme: QuantizationScheme, @@ -51,17 +54,17 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( terminate!(); } - let qparams = QParams::new(scheme); - let (scale, offset) = qparams.values(input); + let qparams = QParams::new(scheme, 0u32); + let (scale, offset) = qparams.values(input, ABSOLUTE_POS); let value = input[ABSOLUTE_POS]; // Input line size is fixed to 1 if comptime!(output.line_size() == 4) { - output[ABSOLUTE_POS] = dequantize_affine_int8(extract_i8s(value[0]), scale, offset); + output[ABSOLUTE_POS] = dequantize_affine_int8(unpack_i8s(value[0]), scale, offset); } else { // For very small inputs where number of elements < 4, the output line size is 1 - let out = dequantize_affine_int8::(extract_i8s(value[0]), scale, offset); + let out = dequantize_affine_int8::(unpack_i8s(value[0]), scale, offset); #[unroll] for j in 0..out.size() { @@ -70,15 +73,9 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( } } -#[cube] -pub(crate) fn dequantize_symmetric_int8(value: Line, scale: f32) -> Line { - // x = scale * x_q - Line::cast_from(scale) * Line::cast_from(value) -} - // Would have wrapped symmetric with the same affine kernel but cube doesn't support Option for offset. #[cube(launch_unchecked)] -pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( +fn dequantize_per_tensor_symmetric_int8_kernel( input: &QTensor, output: &mut Tensor>, #[comptime] scheme: QuantizationScheme, @@ -88,17 +85,17 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( terminate!(); } - let qparams = QParams::new(scheme); - let (scale, _) = qparams.values(input); + let qparams = QParams::new(scheme, 0u32); + let (scale, _) = qparams.values(input, ABSOLUTE_POS); let value = input[ABSOLUTE_POS]; // Input line size is fixed to 1 if comptime!(output.line_size() == 4) { - output[ABSOLUTE_POS] = dequantize_symmetric_int8(extract_i8s(value[0]), scale); + output[ABSOLUTE_POS] = dequantize_symmetric_int8(unpack_i8s(value[0]), scale); } else { // For very small inputs where number of elements < 4, the output line size is 1 - let out = dequantize_symmetric_int8::(extract_i8s(value[0]), scale); + let out = dequantize_symmetric_int8::(unpack_i8s(value[0]), scale); #[unroll] for j in 0..out.size() { @@ -107,7 +104,70 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( } } -pub(crate) fn dequantize_per_tensor(tensor: CubeTensor) -> CubeTensor +#[cube(launch_unchecked)] +fn dequantize_per_block_symmetric_int8_kernel( + input: &QTensor, + output: &mut Tensor>, + #[comptime] scheme: QuantizationScheme, + #[comptime] num_blocks: u32, +) { + // Last num_blocks positions contains the qparams + if ABSOLUTE_POS >= input.len() - num_blocks { + terminate!(); + } + + let qparams = QParams::new(scheme, num_blocks); + let (scale, _) = qparams.values(input, ABSOLUTE_POS); + + let value = input[ABSOLUTE_POS]; + + // Input line size is fixed to 1 + if comptime!(output.line_size() == 4) { + output[ABSOLUTE_POS] = dequantize_symmetric_int8(unpack_i8s(value[0]), scale); + } else { + // For very small inputs where number of elements < 4, the output line size is 1 + let out = dequantize_symmetric_int8::(unpack_i8s(value[0]), scale); + + #[unroll] + for j in 0..out.size() { + output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); + } + } +} + +#[cube(launch_unchecked)] +fn dequantize_per_block_affine_int8_kernel( + input: &QTensor, + output: &mut Tensor>, + #[comptime] scheme: QuantizationScheme, + #[comptime] num_blocks: u32, +) { + // Last 2 * num_blocks positions contain the qparams + if ABSOLUTE_POS >= input.len() - 2 * num_blocks { + terminate!(); + } + + let qparams = QParams::new(scheme, num_blocks); + let (scale, offset) = qparams.values(input, ABSOLUTE_POS); + + let value = input[ABSOLUTE_POS]; + + // Input line size is fixed to 1 + if comptime!(output.line_size() == 4) { + output[ABSOLUTE_POS] = dequantize_affine_int8(unpack_i8s(value[0]), scale, offset); + } else { + // For very small inputs where number of elements < 4, the output line size is 1 + let out = dequantize_affine_int8::(unpack_i8s(value[0]), scale, offset); + + #[unroll] + for j in 0..out.size() { + output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); + } + } +} + +/// Convert the tensor back to a higher precision data type. +pub fn dequantize(tensor: CubeTensor) -> CubeTensor where R: CubeRuntime, F: CubeElement, @@ -134,7 +194,7 @@ where if let DType::QFloat(scheme) = tensor.dtype { match scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) => { unsafe { dequantize_per_tensor_affine_int8_kernel::launch_unchecked::( &client, @@ -146,7 +206,7 @@ where ) }; } - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { unsafe { dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( &client, @@ -158,17 +218,45 @@ where ) }; } + QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Flat(block_size), + ) => { + let num_blocks = num_out_elems as u32 / block_size; + unsafe { + dequantize_per_block_affine_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_array_arg::(line_size_in), + output.as_tensor_arg::(line_size_out), + scheme, + num_blocks, + ) + }; + } + QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + BlockLayout::Flat(block_size), + ) => { + let num_blocks = num_out_elems as u32 / block_size; + unsafe { + dequantize_per_block_symmetric_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_array_arg::(line_size_in), + output.as_tensor_arg::(line_size_out), + scheme, + num_blocks, + ) + }; + } + _ => panic!("Unsupported scheme for dequantize {scheme:?}"), } } output } - -/// Convert the tensor back to a higher precision data type. -pub fn dequantize(tensor: CubeTensor) -> CubeTensor -where - R: CubeRuntime, - F: FloatElement, -{ - dequantize_per_tensor::(tensor) -} diff --git a/crates/burn-cubecl/src/kernel/quantization/qtensor.rs b/crates/burn-cubecl/src/kernel/quantization/qtensor.rs index 26d9f65091..d6c56d5aab 100644 --- a/crates/burn-cubecl/src/kernel/quantization/qtensor.rs +++ b/crates/burn-cubecl/src/kernel/quantization/qtensor.rs @@ -1,6 +1,6 @@ #![allow(missing_docs)] // cube derive macros -use burn_tensor::quantization::QuantizationScheme; +use burn_tensor::quantization::{BlockLayout, QuantizationMode, QuantizationScheme}; use cubecl::prelude::*; /// Quantization parameters. @@ -8,6 +8,8 @@ use cubecl::prelude::*; pub struct QParams { #[cube(comptime)] scheme: QuantizationScheme, + #[cube(comptime)] + num_blocks: u32, } /// Quantized tensor representation. @@ -16,34 +18,72 @@ pub type QTensor = Array>; #[cube] impl QParams { /// Create a new quantization parameters instance. - pub fn new(scheme: QuantizationScheme) -> Self { - QParams { scheme } + pub fn new(scheme: QuantizationScheme, #[comptime] num_blocks: u32) -> Self { + QParams { scheme, num_blocks } } /// Get the quantization parameters values. - pub fn values(&self, tensor: &QTensor) -> (f32, i32) { + pub fn values(&self, tensor: &QTensor, value_pos: u32) -> (f32, i32) { let len = tensor.len(); match comptime!(self.scheme) { - QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) { - // For line size of 1, scale is the last value in the buffer - 1 => ( - f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), - i32::cast_from(tensor[len - 2][tensor.line_size() - 1]), - ), - // For any other line size > 1, scale and zero-point offset are the last two elements - _ => { - let values = tensor[len - 1]; - ( - f32::bitcast_from(values[tensor.line_size() - 1]), - i32::cast_from(values[tensor.line_size() - 2]), - ) + QuantizationScheme::PerTensor(QuantizationMode::Affine, _) => { + match comptime!(tensor.line_size()) { + // For line size of 1, scale is the last value in the buffer + 1 => ( + f32::bitcast_from(tensor[len - 1][0]), + i32::cast_from(tensor[len - 2][0]), + ), + // For any other line size > 1, scale and zero-point offset are the last two elements + _ => { + let values = tensor[len - 1]; + ( + f32::bitcast_from(values[tensor.line_size() - 1]), + i32::cast_from(values[tensor.line_size() - 2]), + ) + } } - }, + } // Symmetric quantization only contains the scaling factor as the last element - QuantizationScheme::PerTensorSymmetric(_) => ( + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, _) => ( f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), 0, ), + // For affine quantization, there are 2 parameters per block + // The (scale, offset) parameters are stored contiguously by parameter type + // [offset, offset, offset, ..., scale, scale, scale, ...] + // (but we might want to store them with each block in the future?) + QuantizationScheme::PerBlock( + QuantizationMode::Affine, + _dtype, + BlockLayout::Flat(block_size), + ) => { + // For each position in the quantized tensor, there are 4 packed values. + // The block size must be a factor of 4, so at least [4, 8, ...] values are contained in a single block + let line_size = tensor.line_size(); + let block_idx = value_pos * 4 / block_size; + + let scale = + tensor[len - (self.num_blocks - block_idx) / line_size][block_idx % line_size]; + let offset = tensor[len - (2 * self.num_blocks - block_idx / line_size)] + [block_idx % line_size]; + + (f32::bitcast_from(scale), i32::cast_from(offset)) + } + QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + _dtype, + BlockLayout::Flat(block_size), + ) => { + // For each position in the quantized tensor, there are 4 packed values. + // The block size must be a factor of 4, so at least [4, 8, ...] values are contained in a single block + let line_size = tensor.line_size(); + let block_idx = value_pos * 4 / block_size; + + let scale = + tensor[len - (self.num_blocks - block_idx) / line_size][block_idx % line_size]; + (f32::bitcast_from(scale), 0) + } + _ => comptime!(unimplemented!()), } } } diff --git a/crates/burn-cubecl/src/kernel/quantization/quantize.rs b/crates/burn-cubecl/src/kernel/quantization/quantize.rs index 49565e9ca0..4fde25f82e 100644 --- a/crates/burn-cubecl/src/kernel/quantization/quantize.rs +++ b/crates/burn-cubecl/src/kernel/quantization/quantize.rs @@ -1,12 +1,28 @@ use crate::tensor::CubeTensor; -use crate::FloatElement; use crate::{CubeElement, CubeRuntime, IntElement}; -use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; +use burn_tensor::quantization::{ + BlockLayout, QuantizationMode, QuantizationScheme, QuantizationType, +}; +use burn_tensor::Shape; use cubecl::calculate_cube_count_elemwise; use cubecl::prelude::*; #[cube] -pub(crate) fn quantize_affine_int8( +fn pack_i8s_to_u32s(value: Line) -> u32 { + // NOTE: assuming line size of 4 + let line_size = value.size(); + let mut v_packed = 0; + + #[unroll] + for i in 0..line_size { + // Shift and combine into u32 + v_packed |= (value[i] & 0xFF) << (8 * i); + } + v_packed +} + +#[cube] +fn quantize_affine_int8( value: Line, scale: f32, offset: i32, @@ -24,8 +40,53 @@ pub(crate) fn quantize_affine_int8( ) } +#[cube] +fn quantize_symmetric_int8( + value: Line, + scale: f32, + range_min: F, + range_max: F, +) -> Line { + // x_q = clamp(round(x / scale), a, b) + // NOTE: we add 256 before casting to unsigned to correctly represent negative values + Line::cast_from( + Line::clamp( + Line::round(value / Line::cast_from(scale)), + Line::new(range_min), + Line::new(range_max), + ) + Line::cast_from(comptime!(256f32)), + ) +} + +#[cube] +fn quantize_affine_int8_packed( + input: Line, + scale: f32, + offset: i32, + range_min: f32, + range_max: f32, +) -> u32 { + // Assuming a line size of 4 (equal to the number of values packed) + let value = quantize_affine_int8::(input, scale, offset, range_min, range_max); + // Shift and combine into u32 + pack_i8s_to_u32s(value) +} + +#[cube] +fn quantize_symmetric_int8_packed( + input: Line, + scale: f32, + range_min: f32, + range_max: f32, +) -> u32 { + // Assuming a line size of 4 (equal to the number of values packed) + let value = quantize_symmetric_int8::(input, scale, range_min, range_max); + // Shift and combine into u32 + pack_i8s_to_u32s(value) +} + #[cube(launch_unchecked)] -pub(crate) fn quantize_per_tensor_affine_int8_kernel( +fn quantize_per_tensor_affine_int8_kernel( input: &Tensor>, scale: &Tensor, offset: &Tensor, @@ -52,67 +113,25 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( terminate!(); } - let line_size = comptime!(input.line_size()); - if comptime!(line_size == 4) { - // Assuming a line size of 4 (equal to the number of values packed) - let value = - quantize_affine_int8::(input[ABSOLUTE_POS], scale, offset, range_min, range_max); - // Shift and combine into u32 - output[ABSOLUTE_POS] = pack_i8s_to_u32s(value); + if comptime!(input.line_size() == 4) { + output[ABSOLUTE_POS] = + quantize_affine_int8_packed(input[ABSOLUTE_POS], scale, offset, range_min, range_max); } else { - let mut v_packed = 0; + // line size 1 let num_packed = comptime!(4); + let mut values = Line::::empty(num_packed); #[unroll] for i in 0..num_packed { - let v = quantize_affine_int8::( - input[ABSOLUTE_POS + i], - scale, - offset, - range_min, - range_max, - ); - // Shift and combine into u32 - v_packed |= (v[0] & 0xFF) << (8 * i); + values[i] = input[ABSOLUTE_POS + i][0]; } - output[ABSOLUTE_POS] = v_packed; + output[ABSOLUTE_POS] = + quantize_affine_int8_packed(values, scale, offset, range_min, range_max); } } -#[cube] -pub(crate) fn quantize_symmetric_int8( - value: Line, - scale: f32, - range_min: F, - range_max: F, -) -> Line { - // x_q = clamp(round(x / scale), a, b) - // NOTE: we add 256 before casting to unsigned to correctly represent negative values - Line::cast_from( - Line::clamp( - Line::round(value / Line::cast_from(scale)), - Line::new(range_min), - Line::new(range_max), - ) + Line::cast_from(comptime!(256f32)), - ) -} - -#[cube] -pub(crate) fn pack_i8s_to_u32s(value: Line) -> u32 { - // NOTE: assuming line size of 4 - let line_size = value.size(); - let mut v_packed = 0; - - #[unroll] - for i in 0..line_size { - // Shift and combine into u32 - v_packed |= (value[i] & 0xFF) << (8 * i); - } - v_packed -} - // Would have wrapped symmetric with the same affine kernel but cube doesn't support Option for offset. #[cube(launch_unchecked)] -pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( +fn quantize_per_tensor_symmetric_int8_kernel( input: &Tensor>, scale: &Tensor, range_min: f32, @@ -131,128 +150,268 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( terminate!(); } - let line_size = comptime!(input.line_size()); - if comptime!(line_size == 4) { - // Assuming a vectorization factor of 4 (equal to the number of values packed) - let value = - quantize_symmetric_int8::(input[ABSOLUTE_POS], scale, range_min, range_max); - // Shift and combine into u32 - output[ABSOLUTE_POS] = pack_i8s_to_u32s(value); + if comptime!(input.line_size() == 4) { + output[ABSOLUTE_POS] = + quantize_symmetric_int8_packed(input[ABSOLUTE_POS], scale, range_min, range_max); } else { + // line size 1 let num_packed = comptime!(4); - let mut v_packed = 0; + let mut values = Line::::empty(num_packed); #[unroll] for i in 0..num_packed { - let v = quantize_symmetric_int8::( - input[ABSOLUTE_POS + i], - scale, - range_min, - range_max, - ); - // Shift and combine into u32 - v_packed |= (v[0] & 0xFF) << (8 * i); + values[i] = input[ABSOLUTE_POS + i][0]; } - output[ABSOLUTE_POS] = v_packed; + output[ABSOLUTE_POS] = quantize_symmetric_int8_packed(values, scale, range_min, range_max); + } +} + +#[cube(launch_unchecked)] +fn quantize_per_block_flat_symmetric_int8_kernel( + input: &Tensor>, + scale: &Tensor, + range_min: f32, + range_max: f32, + block_size: u32, + output: &mut Array, + #[comptime] num_blocks: u32, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + // Cast the scale to u32 and write the value in the output + if ABSOLUTE_POS >= output.len() - num_blocks { + let scale_idx = num_blocks - (output.len() - ABSOLUTE_POS); + output[ABSOLUTE_POS] = u32::bitcast_from(scale[scale_idx]); + terminate!(); + } + + let line_size = comptime!(input.line_size()); + let block_idx = (ABSOLUTE_POS * line_size) / block_size; + let scale = scale[block_idx]; + if comptime!(line_size == 4) { + output[ABSOLUTE_POS] = + quantize_symmetric_int8_packed(input[ABSOLUTE_POS], scale, range_min, range_max); + } +} + +#[cube(launch_unchecked)] +fn quantize_per_block_flat_affine_int8_kernel( + input: &Tensor>, + scale: &Tensor, + offset: &Tensor, + range_min: f32, + range_max: f32, + block_size: u32, + output: &mut Array, + #[comptime] num_blocks: u32, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + // Cast the scale to u32 and write the value in the output + if ABSOLUTE_POS >= output.len() - num_blocks { + let scale_idx = num_blocks - (output.len() - ABSOLUTE_POS); + output[ABSOLUTE_POS] = u32::bitcast_from(scale[scale_idx]); + terminate!(); + } + + // Cast the offset to u32 and write the value in the output + if ABSOLUTE_POS >= output.len() - 2 * num_blocks { + let offset_idx = 2 * num_blocks - (output.len() - ABSOLUTE_POS); + output[ABSOLUTE_POS] = u32::bitcast_from(offset[offset_idx]); + terminate!(); + } + + let line_size = comptime!(input.line_size()); + let block_idx = (ABSOLUTE_POS * line_size) / block_size; + let scale = scale[block_idx]; + let offset = offset[block_idx]; + if comptime!(line_size == 4) { + output[ABSOLUTE_POS] = + quantize_affine_int8_packed(input[ABSOLUTE_POS], scale, offset, range_min, range_max); } } -pub(crate) fn quantize_per_tensor( +fn create_quantized_output( + client: ComputeClient, + num_input_elems: usize, + device: R::Device, + shape: Shape, + scheme: QuantizationScheme, +) -> CubeTensor { + // Output tensor contains 4x less elements (four int8 values packed in a single u32) + let output_elems_size = usize::div_ceil(num_input_elems, 4) * core::mem::size_of::(); + + // Scale and offset (optional) qparams are also packed in the tensor data + let qparams_size = match &scheme { + QuantizationScheme::PerTensor(mode, ..) => match mode { + QuantizationMode::Affine => core::mem::size_of::() + core::mem::size_of::(), + QuantizationMode::Symmetric => core::mem::size_of::(), + }, + QuantizationScheme::PerBlock(mode, _, layout) => { + let num_blocks = match layout { + BlockLayout::Flat(block_size) => num_input_elems / *block_size as usize, + BlockLayout::Grid(m, n) => num_input_elems / (m * n) as usize, + }; + + match mode { + QuantizationMode::Affine => { + (core::mem::size_of::() + core::mem::size_of::()) * num_blocks + } + QuantizationMode::Symmetric => core::mem::size_of::() * num_blocks, + } + } + }; + + let handle = client.empty(output_elems_size + qparams_size); + CubeTensor::new_contiguous( + client, + device, + shape, + handle, + burn_tensor::DType::QFloat(scheme), + ) +} + +/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. +pub fn quantize( tensor: CubeTensor, + scheme: &QuantizationScheme, scale: CubeTensor, offset: Option>, - scheme: QuantizationScheme, ) -> CubeTensor where R: CubeRuntime, F: CubeElement, I: IntElement, { - let ndims = tensor.shape.num_dims(); - let num_elems = tensor.shape.num_elements(); let client = tensor.client.clone(); // Output tensor contains 4x less elements (four int8 values packed in a single u32) - let output_num_elems = usize::div_ceil(num_elems, 4) * core::mem::size_of::(); + let num_elems = tensor.shape.num_elements(); // Force vectorization to process 4 quantized values packed for 1 output value let line_size: u8 = if num_elems < 4 { 1 } else { 4 }; let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); - let dummy_array = vec![1; ndims]; - if let Some(offset) = offset { - // Scale and offset qparams are also packed in the tensor dat - let handle = client - .empty(output_num_elems + core::mem::size_of::() + core::mem::size_of::()); - let output = CubeTensor::new_contiguous( - client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - handle, - burn_tensor::DType::QFloat(scheme), - ); - - unsafe { - quantize_per_tensor_affine_int8_kernel::launch_unchecked::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(line_size), - // Ignore shape and stride - TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), - TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), - ScalarArg::new(i8::MIN as f32), - ScalarArg::new(i8::MAX as f32), - output.as_array_arg::(1), - ) - }; - output - } else { - // Scale qparam is also packed in the tensor data - let handle = client.empty(output_num_elems + core::mem::size_of::()); - let output = CubeTensor::new_contiguous( - client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - handle, - burn_tensor::DType::QFloat(scheme), - ); - - unsafe { - quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(line_size), - // Ignore shape and stride - TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), - ScalarArg::new(-i8::MAX as f32), - ScalarArg::new(i8::MAX as f32), - output.as_array_arg::(1), - ) - }; - - output - } -} + let output = create_quantized_output( + client.clone(), + num_elems, + tensor.device.clone(), + tensor.shape.clone(), + *scheme, + ); -/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. -pub fn quantize( - tensor: CubeTensor, - scheme: &QuantizationScheme, - scale: CubeTensor, - offset: Option>, -) -> CubeTensor -where - R: CubeRuntime, - F: FloatElement, - I: IntElement, -{ match scheme { - QuantizationScheme::PerTensorAffine(dtype) - | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - quantize_per_tensor::(tensor, scale, offset, *scheme) + QuantizationScheme::PerTensor(mode, QuantizationType::QInt8) => { + let ndims = tensor.shape.num_dims(); + let dummy_array = vec![1; ndims]; + + match mode { + QuantizationMode::Affine => { + unsafe { + quantize_per_tensor_affine_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + // Ignore shape and stride + TensorArg::from_raw_parts::( + &scale.handle, + &dummy_array, + &dummy_array, + 1, + ), + TensorArg::from_raw_parts::( + &offset.expect("Should have offset").handle, + &dummy_array, + &dummy_array, + 1, + ), + ScalarArg::new(i8::MIN as f32), + ScalarArg::new(i8::MAX as f32), + output.as_array_arg::(1), + ) + }; + } + QuantizationMode::Symmetric => { + unsafe { + quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + // Ignore shape and stride + TensorArg::from_raw_parts::( + &scale.handle, + &dummy_array, + &dummy_array, + 1, + ), + ScalarArg::new(-i8::MAX as f32), + ScalarArg::new(i8::MAX as f32), + output.as_array_arg::(1), + ) + }; + } } - }, + } + QuantizationScheme::PerBlock( + mode, + QuantizationType::QInt8, + BlockLayout::Flat(block_size), + ) => { + if line_size != 4 { + panic!("Per-block quantization is only supported for a line size of 4, got {line_size} ({num_elems} elements)") + } + + if block_size % line_size as u32 != 0 { + panic!("Block size must be a factor of {line_size}, got {block_size}") + } + + let num_blocks = num_elems as u32 / block_size; + match mode { + QuantizationMode::Affine => { + unsafe { + quantize_per_block_flat_affine_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + scale.as_tensor_arg::(1), + offset.expect("Should have offset").as_tensor_arg::(1), + ScalarArg::new(i8::MIN as f32), + ScalarArg::new(i8::MAX as f32), + ScalarArg::new(*block_size), + output.as_array_arg::(1), + num_blocks, + ) + }; + } + QuantizationMode::Symmetric => { + unsafe { + quantize_per_block_flat_symmetric_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + scale.as_tensor_arg::(1), + ScalarArg::new(-i8::MAX as f32), + ScalarArg::new(i8::MAX as f32), + ScalarArg::new(*block_size), + output.as_array_arg::(1), + num_blocks, + ) + }; + } + } + } + QuantizationScheme::PerBlock(.., BlockLayout::Grid(..)) => { + panic!("Per-block quantization is not supported for grid layout") + } } + + output } diff --git a/crates/burn-cubecl/src/ops/qtensor.rs b/crates/burn-cubecl/src/ops/qtensor.rs index b61a46195e..02bb745d3f 100644 --- a/crates/burn-cubecl/src/ops/qtensor.rs +++ b/crates/burn-cubecl/src/ops/qtensor.rs @@ -2,7 +2,9 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType}, + quantization::{ + BlockLayout, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType, + }, DType, Device, Shape, TensorData, }; @@ -40,12 +42,21 @@ where fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { match data.dtype { DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) - | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) + | QuantizationScheme::PerBlock( + _mode, + QuantizationType::QInt8, + BlockLayout::Flat(..), + ) => { // TensorData quantized representation is the same, with multiple quantized values // packed into u32 and quantization parameters appended to the bytes new_qtensor(data.as_bytes(), data.shape.clone(), scheme, device) } + QuantizationScheme::PerBlock( + _mode, + QuantizationType::QInt8, + BlockLayout::Grid(..), + ) => panic!("Per-block quantization is not supported for grid layout"), }, _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", @@ -54,6 +65,8 @@ where } } + // TODO: quantize_dynamic (we can compute min-max on the fly and scale, especially when not per-tensor) + fn quantize( tensor: FloatTensor, scheme: &QuantizationScheme, @@ -82,6 +95,7 @@ where let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; + // We use the same internal representation TensorData::from_bytes(bytes, tensor.shape, tensor.dtype) } diff --git a/crates/burn-cubecl/src/tests/mod.rs b/crates/burn-cubecl/src/tests/mod.rs index c59602a17d..50878c0e92 100644 --- a/crates/burn-cubecl/src/tests/mod.rs +++ b/crates/burn-cubecl/src/tests/mod.rs @@ -127,6 +127,7 @@ macro_rules! testgen_jit { burn_tensor::testgen_calibration!(); burn_tensor::testgen_scheme!(); burn_tensor::testgen_quantize!(); + burn_tensor::testgen_q_data!(); } } diff --git a/crates/burn-cubecl/src/tests/quantization.rs b/crates/burn-cubecl/src/tests/quantization.rs index 77cf8dbb9b..4856ace85e 100644 --- a/crates/burn-cubecl/src/tests/quantization.rs +++ b/crates/burn-cubecl/src/tests/quantization.rs @@ -2,13 +2,14 @@ mod tests { use super::*; use burn_tensor::{ - quantization::{QuantizationScheme, QuantizationType}, + quantization::{BlockLayout, QuantizationMode, QuantizationScheme, QuantizationType}, Tensor, }; #[test] fn should_quantize_dequantize_symmetric_single() { - let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); let input = Tensor::::from_floats([-1.8], &Default::default()); let input_ref = Tensor::::from_data(input.to_data(), &Default::default()); @@ -26,7 +27,8 @@ mod tests { #[test] fn should_quantize_dequantize_affine_single() { - let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8); let input = Tensor::::from_floats([-1.8], &Default::default()); let input_ref = Tensor::::from_data(input.to_data(), &Default::default()); @@ -44,7 +46,8 @@ mod tests { #[test] fn should_quantize_dequantize_symmetric_multiple() { - let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); let input = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default()); let input_ref = @@ -63,7 +66,8 @@ mod tests { #[test] fn should_quantize_dequantize_affine_multiple() { - let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8); let input = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default()); let input_ref = @@ -79,4 +83,61 @@ mod tests { output.to_data().assert_approx_eq(&output_ref.to_data(), 3); } + + #[test] + fn should_quantize_dequantize_per_block_symmetric() { + // block_size > line_size + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + BlockLayout::Flat(8), + ); + let input = Tensor::::from_floats( + [ + [-1.8, -1.0, 0.0, 0.5, -0.8, 1.2, 0.25, 0.5], + [-0.08, 0.12, 0.025, 0.05, 0.2, 0.3, 0.4, 0.5], + ], + &Default::default(), + ); + let input_ref = + Tensor::::from_data(input.to_data(), &Default::default()); + + let output = input.quantize_dynamic(&scheme); + let output_ref = input_ref.quantize_dynamic(&scheme); + + output.to_data().assert_eq(&output_ref.to_data(), false); + + let output = output.dequantize(); + let output_ref = output_ref.dequantize(); + + output.to_data().assert_approx_eq(&output_ref.to_data(), 3); + } + + #[test] + fn should_quantize_dequantize_per_block_affine() { + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Flat(4), + ); + let input = Tensor::::from_floats( + [ + [-1.8, -1.0, 0.0, 0.5, -0.8, 1.2, 0.25, 0.5], + [0.5, 0.25, 1.2, -0.8, 0.2, 0.3, 0.4, 0.5], + ], + &Default::default(), + ); + let input_ref = + Tensor::::from_data(input.to_data(), &Default::default()); + + let output = input.quantize_dynamic(&scheme); + let output_ref = input_ref.quantize_dynamic(&scheme); + + output.to_data().assert_eq(&output_ref.to_data(), false); + + let output = output.dequantize(); + let output_ref = output_ref.dequantize(); + + output.to_data().assert_approx_eq(&output_ref.to_data(), 3); + } } diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index e5b2c730c1..d12d9a6b5b 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -13,7 +13,7 @@ version.workspace = true [features] autotune = ["burn-cubecl/autotune"] -default = ["fusion", "autotune", "burn-cubecl/default", "cubecl/default"] +default = ["std", "fusion", "autotune", "burn-cubecl/default", "cubecl/default"] doc = ["burn-cubecl/doc"] fusion = ["burn-fusion", "burn-cubecl/fusion"] std = ["burn-cubecl/std", "cubecl/std"] diff --git a/crates/burn-import/onnx-tests/tests/split/split.py b/crates/burn-import/onnx-tests/tests/split/split.py index e53d88b902..19775a8de5 100644 --- a/crates/burn-import/onnx-tests/tests/split/split.py +++ b/crates/burn-import/onnx-tests/tests/split/split.py @@ -16,7 +16,7 @@ def forward(self, x): def main(): - # Set seed for reproducability + # Set seed for reproducibility torch.manual_seed(42) torch.set_printoptions(precision=8) diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index c5dd4756e0..9054ca4c44 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -1,10 +1,12 @@ +use alloc::vec; use core::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - AffineQuantization, QParams, QuantizationParametersPrimitive, QuantizationScheme, - QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization, + AffineQuantization, QParams, QuantizationMode, QuantizationParametersPrimitive, + QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes, + SymmetricQuantization, }, DType, ElementConversion, Shape, TensorData, TensorMetadata, }; @@ -45,14 +47,33 @@ impl QTensorOps { + QuantizationScheme::PerTensor(mode, QuantizationType::QInt8) + | QuantizationScheme::PerBlock(mode, QuantizationType::QInt8, _) => { + // We should probably check that `Q` matches i8.. but it's the only valid type now let (values, qparams) = q_bytes.into_vec_i8(); + let data = TensorData::new(values, shape); - let data = TensorData::new(values, shape).convert::(); - let qparams = QParams { - scale: qparams.scale, - offset: qparams.offset.map(|x| x.elem::()), + let qparams = match mode { + QuantizationMode::Affine => qparams + .scale + .into_iter() + .zip( + qparams + .offset + .unwrap() + .into_iter() + .map(|x| Some(x.elem::())), + ) + .map(|(scale, offset)| QParams { scale, offset }) + .collect(), + QuantizationMode::Symmetric => qparams + .scale + .into_iter() + .map(|scale| QParams { + scale, + offset: None, + }) + .collect(), }; NdArrayQTensor { @@ -75,40 +96,87 @@ impl QTensorOps, ) -> QuantizedTensor { + // Implement with ndarray instead of QuantizationStrategy? let (strategy, qparams) = match scheme { - QuantizationScheme::PerTensorAffine(dtype) => match dtype { - QuantizationType::QInt8 => { - let scale = into_data_f(qparams.scale).iter().next().unwrap(); - let offset = into_data(qparams.offset.unwrap()) - .iter::() - .next() - .unwrap(); - ( - QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( - scale, - offset.elem(), - )), - QParams { - scale, - offset: Some(offset), - }, - ) - } - }, - QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - let scale = into_data_f(qparams.scale).iter().next().unwrap(); - ( - QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( - scale, - )), - QParams { - scale, - offset: None, - }, - ) - } - }, + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) => { + let scale = into_data_f(qparams.scale).iter().next().unwrap(); + let offset = into_data(qparams.offset.unwrap()) + .iter::() + .next() + .unwrap(); + ( + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( + scale, + offset.elem(), + )), + vec![QParams { + scale, + offset: Some(offset), + }], + ) + } + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + let scale = into_data_f(qparams.scale).iter().next().unwrap(); + ( + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( + scale, + )), + vec![QParams { + scale, + offset: None, + }], + ) + } + QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + layout, + ) => { + let scale = into_data_f(qparams.scale); + let offset = into_data(qparams.offset.unwrap()); + let (strategy, qparams) = scale + .iter() + .zip(offset.iter::()) + .map(|(s, o)| { + ( + AffineQuantization::init(s, o.elem()), + QParams { + scale: s, + offset: Some(o), + }, + ) + }) + .unzip(); + + ( + QuantizationStrategy::PerBlockAffineInt8(strategy, *layout), + qparams, + ) + } + QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + layout, + ) => { + let scale = into_data_f(qparams.scale); + let (strategy, qparams) = scale + .iter() + .map(|s| { + ( + SymmetricQuantization::init(s), + QParams { + scale: s, + offset: None, + }, + ) + }) + .unzip(); + + ( + QuantizationStrategy::PerBlockSymmetricInt8(strategy, *layout), + qparams, + ) + } }; let shape = tensor.shape(); diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index 54f0485a3a..85e6893907 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -2,12 +2,13 @@ use core::mem; use burn_tensor::{ quantization::{ - AffineQuantization, QParams, QTensorPrimitive, QuantizationScheme, QuantizationStrategy, - QuantizationType, SymmetricQuantization, + AffineQuantization, QParams, QTensorPrimitive, QuantizationMode, QuantizationScheme, + QuantizationStrategy, QuantizationType, SymmetricQuantization, }, DType, Element, Shape, TensorData, TensorMetadata, }; +use alloc::vec::Vec; use ndarray::{ArcArray, ArrayD, IxDyn}; use crate::element::QuantElement; @@ -338,24 +339,48 @@ pub struct NdArrayQTensor { /// The quantization scheme. pub scheme: QuantizationScheme, /// The quantization parameters. - pub qparams: QParams, + pub qparams: Vec>, } impl NdArrayQTensor { /// Returns the quantization strategy, including quantization parameters, for the given tensor. pub fn strategy(&self) -> QuantizationStrategy { match self.scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) => { QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( - self.qparams.scale, - self.qparams.offset.unwrap().elem(), + self.qparams[0].scale, + self.qparams[0].offset.unwrap().elem(), )) } - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( - self.qparams.scale, + self.qparams[0].scale, )) } + QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + layout, + ) => QuantizationStrategy::PerBlockAffineInt8( + self.qparams + .iter() + .map(|qparams| { + AffineQuantization::init(qparams.scale, qparams.offset.unwrap().elem()) + }) + .collect(), + layout, + ), + QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + layout, + ) => QuantizationStrategy::PerBlockSymmetricInt8( + self.qparams + .iter() + .map(|qparams| SymmetricQuantization::init(qparams.scale)) + .collect(), + layout, + ), } } } @@ -452,7 +477,8 @@ mod tests { let device = Default::default(); let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device); - let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8); let qparams = QuantizationParametersPrimitive { scale: B::float_from_data(TensorData::from([scale]), &device), offset: Some(B::int_from_data(TensorData::from([offset as i64]), &device)), diff --git a/crates/burn-tch/Cargo.toml b/crates/burn-tch/Cargo.toml index 69b0bee004..03b003b3ca 100644 --- a/crates/burn-tch/Cargo.toml +++ b/crates/burn-tch/Cargo.toml @@ -12,7 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tch" version.workspace = true [features] -default = [] +default = ["std"] +std = [] doc = ["tch/doc-only"] [dependencies] diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 47c7e9a404..67eca50ce3 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -3,8 +3,8 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - QParams, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType, - QuantizedBytes, + QParams, QuantizationMode, QuantizationParametersPrimitive, QuantizationScheme, + QuantizationType, QuantizedBytes, }, DType, Shape, TensorData, TensorMetadata, }; @@ -25,14 +25,16 @@ fn quantize( } match scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => tensor.quantize_per_tensor( - qparams.scale.elem(), - qparams.offset.unwrap().elem(), - tch::Kind::QInt8, - ), - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) => tensor + .quantize_per_tensor( + qparams.scale.elem(), + qparams.offset.unwrap().elem(), + tch::Kind::QInt8, + ), + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { tensor.quantize_per_tensor(qparams.scale.elem(), 0, tch::Kind::QInt8) } + QuantizationScheme::PerBlock(_mode, _dtype, _block_layout) => unimplemented!(), } } @@ -46,23 +48,33 @@ impl QTensorOps for LibTorch { // So for now we have to load the dequantized values to quantize them back since the dequantization // methods take the values provided when quantizing. match data.dtype { - DType::QFloat(scheme) => { - let num_elements = data.num_elements(); - let q_bytes = QuantizedBytes { - bytes: data.into_bytes(), - scheme, - num_elements, - }; - - let (values, qparams) = q_bytes.dequantize(); - let tensor = tch::Tensor::from_slice(&values).to(device); - let tensor = quantize(tensor.reshape(shape_tch.dims), &scheme, &qparams); - - TchQTensor { - qtensor: TchTensor::new(tensor), - scheme, + DType::QFloat(scheme) => match scheme { + QuantizationScheme::PerTensor(_, _) => { + let shape = data.shape.clone(); + let num_elements = data.num_elements(); + let q_bytes = QuantizedBytes { + bytes: data.into_bytes(), + scheme, + num_elements, + }; + + let (values, qparams) = q_bytes.dequantize(&shape); + let qparams = QParams { + scale: qparams.scale[0], + offset: qparams.offset.map(|x| x[0]), + }; + let tensor = tch::Tensor::from_slice(&values).to(device); + let tensor = quantize(tensor.reshape(shape_tch.dims), &scheme, &qparams); + + TchQTensor { + qtensor: TchTensor::new(tensor), + scheme, + } } - } + QuantizationScheme::PerBlock(..) => { + panic!("Per-block quantization is not supported by tch") + } + }, _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", data.dtype @@ -82,20 +94,23 @@ impl QTensorOps for LibTorch { } let qtensor = match scheme { - QuantizationScheme::PerTensorAffine(dtype) => match dtype { - QuantizationType::QInt8 => tensor.tensor.quantize_per_tensor_tensor_qparams( + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) => { + tensor.tensor.quantize_per_tensor_tensor_qparams( &qparams.scale.tensor, &qparams.offset.unwrap().tensor, tch::Kind::QInt8, - ), - }, - QuantizationScheme::PerTensorSymmetric(_) => { + ) + } + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { tensor.tensor.quantize_per_tensor_tensor_qparams( &qparams.scale.tensor, &tch::Tensor::zeros_like(&qparams.scale.tensor), tch::Kind::QInt8, ) } + QuantizationScheme::PerBlock(..) => { + panic!("Tch does not support per-block quantization") + } }; TchQTensor { @@ -109,21 +124,22 @@ impl QTensorOps for LibTorch { scheme: &QuantizationScheme, ) -> QuantizedTensor { let qtensor = match &scheme { - QuantizationScheme::PerTensorAffine(dtype) => match dtype { + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) => { // Notes on `reduce_range`: // https://github.com/pytorch/pytorch/issues/93140 // https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - QuantizationType::QInt8 => tensor + tensor .tensor - .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false), - }, - QuantizationScheme::PerTensorSymmetric(dtype) => { + .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false) + } + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { log::warn!("LibTorch backend does not support symmetric per-tensor scheme for dynamic quantization, reverting to the default per-tensor affine quantization"); - match dtype { - QuantizationType::QInt8 => tensor - .tensor - .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false), - } + tensor + .tensor + .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false) + } + QuantizationScheme::PerBlock(..) => { + panic!("Tch does not support per-block quantization") } }; diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index e4908a5070..9a33ab5ca5 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -1,8 +1,8 @@ use crate::{LibTorchDevice, TchElement}; use burn_tensor::{ quantization::{ - AffineQuantization, QTensorPrimitive, QuantizationScheme, QuantizationStrategy, - QuantizationType, SymmetricQuantization, + AffineQuantization, QTensorPrimitive, QuantizationMode, QuantizationScheme, + QuantizationStrategy, QuantizationType, SymmetricQuantization, }, DType, Shape, TensorData, TensorMetadata, }; @@ -331,24 +331,22 @@ impl TchQTensor { /// Returns the quantization strategy, including quantization parameters, for the given tensor. pub fn strategy(&self) -> QuantizationStrategy { match &self.scheme { - QuantizationScheme::PerTensorAffine(dtype) => match dtype { - QuantizationType::QInt8 => { - let scale = self.qtensor.tensor.q_scale(); - let offset = self.qtensor.tensor.q_zero_point(); - QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( - scale as f32, - offset as i8, - )) - } - }, - QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - let scale = self.qtensor.tensor.q_scale(); - QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( - scale as f32, - )) - } - }, + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) => { + let scale = self.qtensor.tensor.q_scale(); + let offset = self.qtensor.tensor.q_zero_point(); + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( + scale as f32, + offset as i8, + )) + } + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + let scale = self.qtensor.tensor.q_scale(); + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( + scale as f32, + )) + } + // Only per-tensor and per-channel are supported + QuantizationScheme::PerBlock(..) => unreachable!(), } } } @@ -375,7 +373,7 @@ mod tests { use super::*; use burn_tensor::ops::QTensorOps; - use burn_tensor::quantization::QuantizationParametersPrimitive; + use burn_tensor::quantization::{QuantizationMode, QuantizationParametersPrimitive}; use burn_tensor::{Distribution, Tensor, TensorPrimitive}; use rand::prelude::StdRng; use rand::SeedableRng; @@ -440,7 +438,8 @@ mod tests { fn should_support_qtensor_strategy() { let tensor = TchTensor::from_data::(TensorData::from([-1.8, -1.0, 0.0, 0.5]), tch::Device::Cpu); - let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8); let qparams = QuantizationParametersPrimitive::> { scale: TchTensor::from_data::(TensorData::from([0.009_019_608]), tch::Device::Cpu), offset: Some(TchTensor::from_data::( diff --git a/crates/burn-tensor-testgen/Cargo.toml b/crates/burn-tensor-testgen/Cargo.toml index 50447bfbad..189bab58fd 100644 --- a/crates/burn-tensor-testgen/Cargo.toml +++ b/crates/burn-tensor-testgen/Cargo.toml @@ -14,3 +14,4 @@ proc-macro = true [dependencies] proc-macro2 = { workspace = true } quote = { workspace = true } +syn = { workspace = true } diff --git a/crates/burn-tensor-testgen/src/lib.rs b/crates/burn-tensor-testgen/src/lib.rs index ba72c66558..b0827d60fd 100644 --- a/crates/burn-tensor-testgen/src/lib.rs +++ b/crates/burn-tensor-testgen/src/lib.rs @@ -1,6 +1,136 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{parse_macro_input, Attribute, Expr, ItemFn, Lit, Meta, MetaNameValue}; + +// Define a structure to parse the attribute arguments +struct AttributeArgs { + args: Punctuated, +} + +impl Parse for AttributeArgs { + fn parse(input: ParseStream) -> syn::Result { + Ok(AttributeArgs { + args: Punctuated::parse_terminated(input)?, + }) + } +} + +#[allow(clippy::test_attr_in_doctest)] +/// **This is only meaningful when the `reason` is specific and clear.** +/// +/// A proc macro attribute that adds panic handling to test functions. +/// +/// # Usage +/// ```rust, ignore +/// #[might_panic(reason = "expected panic message prefix")] +/// #[test] +/// fn test_that_might_panic() { +/// // test code that might panic (with acceptable reason) +/// } +/// ``` +/// +/// # Behavior +/// - If the test does not panic, it passes. +/// - If the test panics with a message starting with the expected prefix, the failure is ignored. +/// - If the test panics with a different message, the test fails. +/// +/// # Note +/// This proc macro uses [`std::panic::catch_unwind`]. As such, it does not work in a no-std environment. +/// Make sure it is feature gated when an `"std"` feature is available. +#[proc_macro_attribute] +pub fn might_panic(args: TokenStream, input: TokenStream) -> TokenStream { + // Parse the attribute arguments + let args = parse_macro_input!(args as AttributeArgs); + let input_fn = parse_macro_input!(input as ItemFn); + + // Extract the expected panic reason + let mut expected_reason = None; + for arg in args.args.iter() { + if let Meta::NameValue(MetaNameValue { path, value, .. }) = arg { + if path.is_ident("reason") { + if let Expr::Lit(lit) = value { + if let Lit::Str(ref lit_str) = lit.lit { + expected_reason = Some(lit_str.value()); + } + } + } + } + } + + let expected_reason = match expected_reason { + Some(reason) => reason, + None => { + return syn::Error::new( + proc_macro2::Span::call_site(), + "The #[might_panic] attribute requires a 'reason' parameter", + ) + .to_compile_error() + .into(); + } + }; + + let fn_name = &input_fn.sig.ident; + let fn_vis = &input_fn.vis; + let fn_generics = &input_fn.sig.generics; + let fn_block = &input_fn.block; + let fn_attrs = input_fn + .attrs + .iter() + .filter(|attr| !attr.path().is_ident("test")) + .collect::>(); + + // Create a wrapped test function + let wrapper_name = format_ident!("{}_might_panic", fn_name); + + let expanded = quote! { + #(#fn_attrs)* + #fn_vis fn #fn_name #fn_generics() { + #fn_block + } + + #[test] + #fn_vis fn #wrapper_name #fn_generics() { + use std::panic::{self, AssertUnwindSafe}; + + let expected_reason = #expected_reason; + let result = panic::catch_unwind(AssertUnwindSafe(|| { + #fn_name(); + })); + + match result { + Ok(_) => { + // Test passed without panic - this is OK + } + Err(e) => { + // Convert the panic payload to a string + let panic_msg = if let Some(s) = e.downcast_ref::() { + s.to_string() + } else if let Some(s) = e.downcast_ref::<&str>() { + s.to_string() + } else { + "Unknown panic".to_string() + }; + + // Check if the panic message starts with the expected reason + if !panic_msg.starts_with(expected_reason) { + panic!( + "Test '{}' marked as 'might_panic' failed. Expected reason: '{}'", + stringify!(#fn_name), + expected_reason + ); + } + } + } + } + }; + + expanded.into() +} + #[allow(missing_docs)] #[proc_macro_attribute] pub fn testgen(attr: TokenStream, item: TokenStream) -> TokenStream { diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index 68a293e5bb..4cd8c843f4 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -17,6 +17,10 @@ mod tensor; #[allow(missing_docs)] pub mod tests; +#[cfg(feature = "export_tests")] +// Re-export the might_panic proc macro for easy access +pub use burn_tensor_testgen::might_panic; + pub use half::{bf16, f16}; pub(crate) use tensor::check::macros::check; pub use tensor::*; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index b50d0d0596..75ccfcf381 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -316,6 +316,9 @@ where /// # Returns /// /// The quantized tensor. + /// + /// # Notes + /// This uses [min-max calibration](crate::quantization::Calibration::MinMax). pub fn quantize_dynamic(self, scheme: &QuantizationScheme) -> Tensor { Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic( self.primitive.tensor(), diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index aa88894d60..346324c0e1 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -8,9 +8,7 @@ use bytemuck::{checked::CheckedCastError, AnyBitPattern}; use half::{bf16, f16}; use crate::{ - quantization::{ - Quantization, QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes, - }, + quantization::{QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes}, tensor::bytes::Bytes, DType, Distribution, Element, ElementConversion, }; @@ -261,8 +259,8 @@ impl TensorData { // bool is a byte value equal to either 0 or 1 DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::())), DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) - | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) + | QuantizationScheme::PerBlock(_mode, QuantizationType::QInt8, ..) => { // Quantized int8 values let q_bytes = QuantizedBytes { bytes: self.bytes.clone(), @@ -454,18 +452,8 @@ impl TensorData { DType::F32, "Only f32 data type can be quantized" ); - match &quantization { - QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::quantized( - strategy.quantize(self.as_slice().unwrap()), - self.shape, - quantization, - ), - QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::quantized( - strategy.quantize(self.as_slice().unwrap()), - self.shape, - quantization, - ), - } + let values = quantization.quantize(self.as_slice().unwrap(), &self.shape); + TensorData::quantized(values, self.shape, quantization) } /// Dequantizes the data according to its quantization scheme. @@ -478,7 +466,7 @@ impl TensorData { num_elements, }; - let values = q_bytes.dequantize().0; + let values = q_bytes.dequantize(&self.shape).0; Ok(Self::new(values, self.shape)) } else { Err(DataError::TypeMismatch(format!( @@ -549,13 +537,19 @@ impl TensorData { }; match (q, q_other) { ( - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8), - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8), - ) - | ( - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8), - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8), - ) => self.assert_eq_elem::(other), + QuantizationScheme::PerTensor(mode, QuantizationType::QInt8), + QuantizationScheme::PerTensor(mode_other, QuantizationType::QInt8), + ) if mode == mode_other => self.assert_eq_elem::(other), + ( + QuantizationScheme::PerBlock(mode, QuantizationType::QInt8, layout), + QuantizationScheme::PerBlock( + mode_other, + QuantizationType::QInt8, + layout_other, + ), + ) if mode == mode_other && layout == layout_other => { + self.assert_eq_elem::(other) + } _ => panic!("Quantization schemes differ ({:?} != {:?})", q, q_other), } } @@ -838,8 +832,8 @@ impl core::fmt::Display for TensorData { DType::U8 => format!("{:?}", self.as_slice::().unwrap()), DType::Bool => format!("{:?}", self.as_slice::().unwrap()), DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) - | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) + | QuantizationScheme::PerBlock(_mode, QuantizationType::QInt8, ..) => { format!("{:?} {scheme:?}", self.try_as_slice::().unwrap()) } }, diff --git a/crates/burn-tensor/src/tensor/element/base.rs b/crates/burn-tensor/src/tensor/element/base.rs index 3c2b70f5ba..d645c0cedc 100644 --- a/crates/burn-tensor/src/tensor/element/base.rs +++ b/crates/burn-tensor/src/tensor/element/base.rs @@ -332,10 +332,10 @@ impl DType { DType::U8 => core::mem::size_of::(), DType::Bool => core::mem::size_of::(), DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensorAffine(qtype) - | QuantizationScheme::PerTensorSymmetric(qtype) => match qtype { - QuantizationType::QInt8 => core::mem::size_of::(), - }, + QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) + | QuantizationScheme::PerBlock(_mode, QuantizationType::QInt8, ..) => { + core::mem::size_of::() + } }, } } diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 781ed7c6eb..6b6e5d18af 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -3,7 +3,9 @@ use core::{future::Future, ops::Range}; use crate::{ backend::Backend, - quantization::{QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme}, + quantization::{ + Calibration, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, + }, Device, Shape, TensorData, TensorMetadata, }; @@ -65,8 +67,7 @@ pub trait QTensorOps { /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. fn quantize_dynamic(tensor: FloatTensor, scheme: &QuantizationScheme) -> QuantizedTensor { // Dynamically compute min/max tensor range and qparams before quantizing - let min = B::float_min(tensor.clone()); - let max = B::float_max(tensor.clone()); + let (min, max) = scheme.compute_range_primitive::(tensor.clone(), &Calibration::MinMax); let qparams = scheme.compute_q_params_primitive(min, max); Self::quantize(tensor, scheme, qparams) } diff --git a/crates/burn-tensor/src/tensor/quantization/bytes.rs b/crates/burn-tensor/src/tensor/quantization/bytes.rs index 6d880cc923..f67971f929 100644 --- a/crates/burn-tensor/src/tensor/quantization/bytes.rs +++ b/crates/burn-tensor/src/tensor/quantization/bytes.rs @@ -4,8 +4,9 @@ use crate::{Bytes, Element}; use alloc::vec::Vec; use super::{ - pack_i8s_to_u32s, unpack_u32s_to_i8s, AffineQuantization, QParams, Quantization, - QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization, + pack_i8s_to_u32s, unpack_u32s_to_i8s, AffineQuantization, BlockLayout, QParams, + QuantizationMode, QuantizationScheme, QuantizationStrategy, QuantizationType, + SymmetricQuantization, }; /// Quantized data bytes representation. @@ -31,9 +32,10 @@ impl QuantizedBytes { pub fn new(value: Vec, strategy: QuantizationStrategy) -> Self { let mut bytes: Bytes; let num_elements = value.len(); + let scheme = strategy.scheme(); match strategy { - QuantizationStrategy::PerTensorAffineInt8(q) => { + QuantizationStrategy::PerTensorAffineInt8(quant) => { if TypeId::of::() == TypeId::of::() { // Re-interpret `Vec` as `Vec` with `Vec::from_raw_parts` let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value)); @@ -42,13 +44,13 @@ impl QuantizedBytes { panic!("Invalid quantized type"); } // Scale is always stored as f32 and zero-point offset as i32 - let offset = q.offset as i32; - let scale_bytes = bytemuck::bytes_of(&q.scale); + let offset = quant.offset as i32; + let scale_bytes = bytemuck::bytes_of(&quant.scale); let offset_bytes = bytemuck::bytes_of(&offset); bytes.extend_from_byte_slice_aligned(offset_bytes, align_of::()); bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::()); } - QuantizationStrategy::PerTensorSymmetricInt8(q) => { + QuantizationStrategy::PerTensorSymmetricInt8(quant) => { if TypeId::of::() == TypeId::of::() { // Re-interpret `Vec` as `Vec` with `Vec::from_raw_parts` let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value)); @@ -56,44 +58,82 @@ impl QuantizedBytes { } else { panic!("Invalid quantized type"); } - let scale_bytes = bytemuck::bytes_of(&q.scale); + let scale_bytes = bytemuck::bytes_of(&quant.scale); bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::()); } + QuantizationStrategy::PerBlockAffineInt8(quant, _layout) => { + if TypeId::of::() == TypeId::of::() { + // Re-interpret `Vec` as `Vec` with `Vec::from_raw_parts` + let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value)); + bytes = Bytes::from_elems(u32s); + } else { + panic!("Invalid quantized type"); + } + let mut scale_bytes = Vec::with_capacity(quant.len() * size_of::()); + let mut offset_bytes = Vec::with_capacity(quant.len() * size_of::()); + for q in quant { + scale_bytes.extend_from_slice(bytemuck::bytes_of(&q.scale)); + offset_bytes.extend_from_slice(bytemuck::bytes_of(&(q.offset as i32))); + } + bytes.extend_from_byte_slice_aligned(offset_bytes.as_slice(), align_of::()); + bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::()); + } + QuantizationStrategy::PerBlockSymmetricInt8(quant, _layout) => { + if TypeId::of::() == TypeId::of::() { + // Re-interpret `Vec` as `Vec` with `Vec::from_raw_parts` + let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value)); + bytes = Bytes::from_elems(u32s); + } else { + panic!("Invalid quantized type"); + } + let mut scale_bytes = Vec::with_capacity(quant.len() * size_of::()); + for q in quant { + scale_bytes.extend_from_slice(bytemuck::bytes_of(&q.scale)); + } + bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::()); + } } Self { bytes, - scheme: strategy.scheme(), + scheme, num_elements, } } /// Returns the int8 quantized values with the quantization parameters. - pub fn into_vec_i8(self) -> (Vec, QParams) { + pub fn into_vec_i8(self) -> (Vec, QParams, Vec>) { let numel = self.num_elements; let scheme = self.scheme; - let (values, qparams) = self.split_values_off(); + let (values, (qparams, num_params)) = self.split_values_off(); let values = unpack_u32s_to_i8s(values, numel); // Quantization parameters are added at the end of the tensor data. - // As such, the last bytes always correspond to the scale parameter. - // If the quantization scheme includes an offset (zero-point) parameter, it is next to last. + // As such, the last bytes always correspond to the scale parameter(s). + // If the quantization scheme includes an offset (zero-point) parameter, the value(s) + // precede(s) the scale parameter(s) bytes. + // For example, per-block quantization can have multiple parameters for a single tensor: + // [offset, offset, offset, ..., scale, scale, scale, ...] let scale_size = core::mem::size_of::(); // scale is stored as f32 - let qparams_bytes = bytemuck::cast_slice(&qparams); + let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams); let total_bytes = qparams_bytes.len(); - let scale = *bytemuck::checked::from_bytes(&qparams_bytes[total_bytes - scale_size..]); - - let offset = match scheme { - QuantizationScheme::PerTensorAffine(_) => { - let offset_size = core::mem::size_of::(); // zero-point offset is stored as i32 - Some(*bytemuck::checked::from_bytes::( - &qparams_bytes - [total_bytes - scale_size - offset_size..total_bytes - scale_size], - ) as i8) - } - QuantizationScheme::PerTensorSymmetric(_) => None, - }; + + let scales_size = scale_size * num_params; + + let scale = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec(); + let mut offset = None; + + if scheme.mode() == QuantizationMode::Affine { + // Bytes end with [offset, offset, offset, ...] + let offset_size = core::mem::size_of::(); // zero-point offset is stored as i32 + let offsets_size = offset_size * num_params; + let offsets: &[i32] = bytemuck::cast_slice( + &qparams_bytes[total_bytes - scales_size - offsets_size..total_bytes - scales_size], + ); + + offset = Some(offsets.iter().map(|&x| x as i8).collect()); + } (values, QParams { scale, offset }) } @@ -101,7 +141,7 @@ impl QuantizedBytes { /// Splits the quantized values of the tensor from the quantization parameters. /// /// Returns the packed values and a newly allocated vector containing the quantization parameters. - fn split_values_off(self) -> (Vec, Vec) { + fn split_values_off(self) -> (Vec, (Vec, usize)) { // The bytes can be created either from packed u32 or existing bytes with the same representation. let mut values = match self.bytes.align() { 1 => { @@ -120,33 +160,79 @@ impl QuantizedBytes { _ => unreachable!(), }; - let scale_size = 1; // f32 scale is the same number of bytes as u32 + let (num_params, mode) = match self.scheme { + QuantizationScheme::PerTensor(mode, ..) => (1, mode), + QuantizationScheme::PerBlock(mode, _, layout) => { + let num_blocks = match layout { + BlockLayout::Flat(block_size) => self.num_elements / block_size as usize, + BlockLayout::Grid(m, n) => self.num_elements / (m * n) as usize, + }; + (num_blocks, mode) + } + }; + + let scale_size = num_params; // f32 scale is the same number of bytes as u32 let mut values_end = values.len() - scale_size; - if let QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) = self.scheme { - values_end -= 1; // zero-point offset is stored as i32 (same number of bytes as u32) + if mode == QuantizationMode::Affine { + values_end -= num_params; // zero-point offset is stored as i32 (same number of bytes as u32) } let qparams = values.split_off(values_end); - (values, qparams) + (values, (qparams, num_params)) } /// Dequantizes the data according to its quantization scheme. - pub fn dequantize(self) -> (Vec, QParams) { + pub fn dequantize(self, shape: &[usize]) -> (Vec, QParams, Vec>) { match self.scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) => { + let (values, qparams) = self.into_vec_i8(); + let strategy = QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( + qparams.scale[0], + qparams.offset.as_ref().unwrap()[0], + )); + (strategy.dequantize(&values, shape), qparams) + } + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + let (values, qparams) = self.into_vec_i8(); + let strategy = QuantizationStrategy::PerTensorSymmetricInt8( + SymmetricQuantization::init(qparams.scale[0]), + ); + (strategy.dequantize(&values, shape), qparams) + } + QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + layout, + ) => { let (values, qparams) = self.into_vec_i8(); - let strategy = AffineQuantization::::init( - qparams.scale, - qparams.offset.unwrap(), + let strategy = QuantizationStrategy::PerBlockAffineInt8( + qparams + .scale + .iter() + .zip(qparams.offset.as_ref().unwrap().iter()) + .map(|(&s, &o)| AffineQuantization::init(s, o)) + .collect(), + layout, ); - (strategy.dequantize(&values), qparams) + (strategy.dequantize(&values, shape), qparams) } - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + layout, + ) => { let (values, qparams) = self.into_vec_i8(); - let strategy = SymmetricQuantization::::init(qparams.scale); - (strategy.dequantize(&values), qparams) + let strategy = QuantizationStrategy::PerBlockSymmetricInt8( + qparams + .scale + .iter() + .map(|&s| SymmetricQuantization::init(s)) + .collect(), + layout, + ); + (strategy.dequantize(&values, shape), qparams) } } } @@ -184,11 +270,12 @@ unsafe fn reinterpret_vec(mut input: Vec) -> Vec { #[cfg(test)] mod tests { + use super::*; use alloc::vec; #[test] - fn should_pack_unpack_quantization_parameters_symmetric() { + fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() { // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] let scale = 0.03937008; let values = vec![0i8, 25, 51, 76, 102, 127]; @@ -200,14 +287,14 @@ mod tests { let (q_values, qparams) = q_bytes.into_vec_i8(); - assert_eq!(qparams.scale, scale); + assert_eq!(qparams.scale, vec![scale]); assert_eq!(qparams.offset, None); assert_eq!(q_values, values); } #[test] - fn should_pack_unpack_quantization_parameters_affine() { + fn should_pack_unpack_quantization_parameters_per_tensor_affine() { let scale = 0.019607844; let offset = -128; // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] @@ -219,8 +306,33 @@ mod tests { let (q_values, qparams) = q_bytes.into_vec_i8(); + assert_eq!(qparams.scale, vec![scale]); + assert_eq!(qparams.offset, Some(vec![offset])); + + assert_eq!(q_values, values); + } + + #[test] + fn should_pack_unpack_quantization_parameters_per_block_symmetric() { + // Quantized 2x2 blocks: [[-1.8, -1.0, 0.0, 0.5], [-0.8, 1.2, 0.25, 0.5]] + let scale = vec![0.014_173_228, 0.009_448_819]; + let values = vec![-127i8, -71, -85, 127, 0, 35, 26, 53]; + + let q_bytes = QuantizedBytes::new( + values.clone(), + QuantizationStrategy::PerBlockSymmetricInt8( + vec![ + SymmetricQuantization::init(scale[0]), + SymmetricQuantization::init(scale[1]), + ], + BlockLayout::Grid(2, 2), + ), + ); + + let (q_values, qparams) = q_bytes.into_vec_i8(); + assert_eq!(qparams.scale, scale); - assert_eq!(qparams.offset, Some(offset)); + assert_eq!(qparams.offset, None); assert_eq!(q_values, values); } diff --git a/crates/burn-tensor/src/tensor/quantization/calibration.rs b/crates/burn-tensor/src/tensor/quantization/calibration.rs index c8060f6547..97a44fbcf5 100644 --- a/crates/burn-tensor/src/tensor/quantization/calibration.rs +++ b/crates/burn-tensor/src/tensor/quantization/calibration.rs @@ -3,32 +3,14 @@ use crate::{backend::Backend, Tensor}; /// The observed input calibration range. #[derive(Clone, Debug)] pub struct CalibrationRange { - /// Minimum observed value. + /// Minimum observed value(s). pub min: Tensor, - /// Maximum observed value. + /// Maximum observed value(s). pub max: Tensor, } /// Calibration method used to compute the quantization range mapping. -pub trait Calibration { - /// Compute the input tensor range. - fn compute_range( - &self, - tensor: &Tensor, - ) -> CalibrationRange; -} - -/// Computes the per-tensor quantization range mapping based on the min and max values. -pub struct MinMaxCalibration {} - -impl Calibration for MinMaxCalibration { - fn compute_range( - &self, - tensor: &Tensor, - ) -> CalibrationRange { - let min = tensor.clone().min(); - let max = tensor.clone().max(); - - CalibrationRange { min, max } - } +pub enum Calibration { + /// Computes quantization range mapping based on the min and max values. + MinMax, } diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index 71cc6de81c..761b9e6123 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -2,9 +2,11 @@ use serde::{Deserialize, Serialize}; -use crate::{backend::Backend, Tensor, TensorPrimitive}; +use crate::{backend::Backend, Shape, Tensor, TensorMetadata, TensorPrimitive}; -use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive}; +use super::{ + Calibration, CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive, +}; #[cfg(feature = "cubecl")] use cubecl::prelude::*; @@ -17,18 +19,34 @@ pub enum QuantizationType { QInt8, } +// CubeType not implemented for usize +/// Block quantization layout. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "cubecl", derive(CubeType, PartialOrd, Ord))] +pub enum BlockLayout { + /// The tensor is split into linear segments of N elements. + Flat(u32), + /// The tensor is split into segments of M x N elements. + Grid(u32, u32), +} +/// Quantization mode. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))] +pub enum QuantizationMode { + /// Affine or asymmetric quantization. + Affine, + /// Symmetric or scale quantization. + Symmetric, +} + /// Quantization scheme. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] #[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))] pub enum QuantizationScheme { - /// Per-tensor affine/asymmetric quantization. - PerTensorAffine(QuantizationType), - /// Per-tensor symmetric quantization. - PerTensorSymmetric(QuantizationType), - // /// Per-channel affine/asymmetric quantization. - // PerChannelAffine, - // /// Per-channel symmetric quantization. - // PerChannelSymmetric, + /// Per-tensor quantization. + PerTensor(QuantizationMode, QuantizationType), + /// Per-block quantization. + PerBlock(QuantizationMode, QuantizationType, BlockLayout), } #[cfg(feature = "cubecl")] @@ -47,49 +65,165 @@ impl cubecl::frontend::Init for QuantizationScheme { } impl QuantizationScheme { + /// Get the [quantization mode](QuantizationMode) + pub fn mode(&self) -> QuantizationMode { + match self { + QuantizationScheme::PerTensor(mode, ..) | QuantizationScheme::PerBlock(mode, ..) => { + *mode + } + } + } + + /// Compute the quantization range mapping. + pub fn compute_range( + &self, + tensor: &Tensor, + calibration: &Calibration, + ) -> CalibrationRange { + let (min, max) = match &tensor.primitive { + TensorPrimitive::Float(tensor) => { + self.compute_range_primitive::(tensor.clone(), calibration) + } + TensorPrimitive::QFloat(_) => unreachable!(), + }; + + CalibrationRange { + min: Tensor::from_primitive(TensorPrimitive::Float(min)), + max: Tensor::from_primitive(TensorPrimitive::Float(max)), + } + } + + pub(crate) fn compute_range_primitive( + &self, + tensor: B::FloatTensorPrimitive, + calibration: &Calibration, + ) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) { + match calibration { + Calibration::MinMax => match self { + QuantizationScheme::PerTensor(_, _) => { + (B::float_min(tensor.clone()), B::float_max(tensor)) + } + QuantizationScheme::PerBlock(.., layout) => match layout { + // For per-block quantization, we can compute the (min, max) range with pooling + BlockLayout::Flat(block_size) => { + let block_size = *block_size as usize; + // Tensor shape must be divisible by block size + let shape = tensor.shape(); + let numel = shape.num_elements(); + assert_eq!( + numel % block_size, 0, + "Cannot compute per-block quantization range with block size {block_size} and tensor of shape {shape:?}" + ); + let num_blocks = numel / block_size; + + let tensor = B::float_reshape(tensor, Shape::new([num_blocks, block_size])); + let min = B::float_reshape( + B::float_min_dim(tensor.clone(), 1), + Shape::new([num_blocks]), + ); + let max = + B::float_reshape(B::float_max_dim(tensor, 1), Shape::new([num_blocks])); + // Tensors with shape [b * num_blocks] + (min, max) + } + BlockLayout::Grid(m, n) => { + let (m, n) = (*m as usize, *n as usize); + let shape = tensor.shape(); + let (b, h, w) = match shape.num_dims() { + 2 => { + let [h, w] = shape.dims(); + (1, h, w) + } + 3 => { + let [b, h, w] = shape.dims(); // leading batch dim + (b, h, w) + } + _ => unimplemented!( + "Per-block grid quantization is only supported for 2D or 3D tensors" + ), + }; + // For optimized dynamic quantization, we probably want a custom kernel that computes the + // (min, max) range to quantize each block on-the-fly. + // For static quantization, it doesn't really matter. + assert!( + h % m == 0 && w % n == 0, + "Cannot compute per-block quantization range with block grid [{m}, {n}] and tensor of shape {shape:?}" + ); + let num_blocks_h = h / m; + let num_blocks_w = w / n; + + // Max and min pooling + let reshaped = B::float_reshape(tensor, Shape::new([b, 1, h, w])); + let max = B::max_pool2d(reshaped.clone(), [m, n], [m, n], [0, 0], [1, 1]); + let min = B::float_neg(B::max_pool2d( + B::float_neg(reshaped), + [m, n], + [m, n], + [0, 0], + [0, 0], + )); + + // Tensors with shape [b * num_blocks_h * num_blocks_w] + let out_shape = Shape::new([b * num_blocks_h * num_blocks_w]); + ( + B::float_reshape(min, out_shape.clone()), + B::float_reshape(max, out_shape), + ) + } + }, + }, + } + } + /// Compute the quantization parameters. pub fn compute_q_params( &self, range: CalibrationRange, ) -> QuantizationParameters { + // Quantization parameters are computed element-wise based on the calibration range, + // so it's the same operations for per-tensor and per-block (just that the latter has + // more parameters) match self { - QuantizationScheme::PerTensorAffine(dtype) => match dtype { - QuantizationType::QInt8 => { - // Quantized range `[a, b]` - let a = i8::MIN as i32; - let b = i8::MAX as i32; - - // We extend the `[min, max]` interval to ensure that it contains 0. - // Otherwise, we would not meet the requirement that 0 be an exactly - // representable value (zero-point). - let zero = Tensor::zeros_like(&range.min); - let min = range.min.min_pair(zero); - let zero = Tensor::zeros_like(&range.max); - let max = range.max.max_pair(zero); - - // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the - // scale to 0.1 to avoid division by zero. - let scale = max.sub(min.clone()).div_scalar(b - a); - let scale = scale.clone().mask_fill(scale.equal_elem(0.), 0.1); - let offset = Some(-(min.div(scale.clone()).sub_scalar(a)).int()); - QuantizationParameters { scale, offset } - } - }, - QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - // Quantized range `[a, b]` - let b = i8::MAX as i32; - let a = -b; - - // Compute scale to convert an input value in range `[-alpha, alpha]` - let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2); - - QuantizationParameters { - scale: values_range.div_scalar(b - a), - offset: None, - } + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) + | QuantizationScheme::PerBlock(QuantizationMode::Affine, QuantizationType::QInt8, ..) => + { + // Quantized range `[a, b]` + let a = i8::MIN as i32; + let b = i8::MAX as i32; + + // We extend the `[min, max]` interval to ensure that it contains 0. + // Otherwise, we would not meet the requirement that 0 be an exactly + // representable value (zero-point). + let zero = Tensor::zeros_like(&range.min); + let min = range.min.min_pair(zero); + let zero = Tensor::zeros_like(&range.max); + let max = range.max.max_pair(zero); + + // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the + // scale to 0.1 to avoid division by zero. + let scale = max.sub(min.clone()).div_scalar(b - a); + let scale = scale.clone().mask_fill(scale.equal_elem(0.), 0.1); + let offset = Some(-(min.div(scale.clone()).sub_scalar(a)).int()); + QuantizationParameters { scale, offset } + } + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) + | QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + .., + ) => { + // Quantized range `[a, b]` + let b = i8::MAX as i32; + let a = -b; + + // Compute scale to convert an input value in range `[-alpha, alpha]` + let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2); + + QuantizationParameters { + scale: values_range.div_scalar(b - a), + offset: None, } - }, + } } } diff --git a/crates/burn-tensor/src/tensor/quantization/strategy.rs b/crates/burn-tensor/src/tensor/quantization/strategy.rs index 73f1b1c0b0..e2b051f8d8 100644 --- a/crates/burn-tensor/src/tensor/quantization/strategy.rs +++ b/crates/burn-tensor/src/tensor/quantization/strategy.rs @@ -1,22 +1,147 @@ -use core::{ - hash::{Hash, Hasher}, - marker::PhantomData, -}; - -use alloc::vec::Vec; +use alloc::{vec, vec::Vec}; use burn_common::{iter_slice_par, run_par}; -use num_traits::{Float, PrimInt}; +use core::marker::PhantomData; +use num_traits::{Float, PrimInt, Signed}; use serde::{Deserialize, Serialize}; -use super::{QuantizationScheme, QuantizationType}; +use crate::{Element, ElementConversion}; + +use super::{BlockLayout, QuantizationMode, QuantizationScheme, QuantizationType}; /// Quantization strategy. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum QuantizationStrategy { /// Per-tensor `int8` affine/asymmetric quantization. PerTensorAffineInt8(AffineQuantization), /// Per-tensor `int8` symmetric quantization. PerTensorSymmetricInt8(SymmetricQuantization), + /// Per-block `int8` affine/asymmetric quantization. + PerBlockAffineInt8(Vec>, BlockLayout), + /// Per-block `int8` symmetric quantization. + PerBlockSymmetricInt8(Vec>, BlockLayout), +} + +impl QuantizationStrategy { + /// Quantize the values to a lower precision data type. + pub fn quantize(&self, values: &[f32], shape: &[usize]) -> Vec { + match self { + QuantizationStrategy::PerTensorAffineInt8(strategy) => strategy.quantize(values), + QuantizationStrategy::PerTensorSymmetricInt8(strategy) => strategy.quantize(values), + QuantizationStrategy::PerBlockAffineInt8(strategy, layout) => match layout { + BlockLayout::Flat(block_size) => { + apply_per_block(values, *block_size as usize, |block_id, v| { + strategy[block_id].quantize(v) + }) + } + BlockLayout::Grid(m, n) => { + apply_per_block_grid(values, shape, *m as usize, *n as usize, |block_id, v| { + strategy[block_id].quantize_one(v) + }) + } + }, + QuantizationStrategy::PerBlockSymmetricInt8(strategy, layout) => match layout { + BlockLayout::Flat(block_size) => { + apply_per_block(values, *block_size as usize, |block_id, v| { + strategy[block_id].quantize(v) + }) + } + BlockLayout::Grid(m, n) => { + apply_per_block_grid(values, shape, *m as usize, *n as usize, |block_id, v| { + strategy[block_id].quantize_one(v) + }) + } + }, + } + } + + /// Dequantize the values to a higher precision data type. + pub fn dequantize(&self, values: &[i8], shape: &[usize]) -> Vec { + match self { + QuantizationStrategy::PerTensorAffineInt8(strategy) => strategy.dequantize(values), + QuantizationStrategy::PerTensorSymmetricInt8(strategy) => strategy.dequantize(values), + QuantizationStrategy::PerBlockAffineInt8(strategy, layout) => match layout { + BlockLayout::Flat(block_size) => { + apply_per_block(values, *block_size as usize, |block_id, v| { + strategy[block_id].dequantize(v) + }) + } + BlockLayout::Grid(m, n) => { + apply_per_block_grid(values, shape, *m as usize, *n as usize, |block_id, v| { + strategy[block_id].dequantize_one(v) + }) + } + }, + QuantizationStrategy::PerBlockSymmetricInt8(strategy, layout) => match layout { + BlockLayout::Flat(block_size) => { + apply_per_block(values, *block_size as usize, |block_id, v| { + strategy[block_id].dequantize(v) + }) + } + BlockLayout::Grid(m, n) => { + apply_per_block_grid(values, shape, *m as usize, *n as usize, |block_id, v| { + strategy[block_id].dequantize_one(v) + }) + } + }, + } + } +} + +fn apply_per_block_grid O>( + values: &[I], + shape: &[usize], + m: usize, + n: usize, + transform: F, +) -> Vec { + let (b, height, width) = match shape.len() { + 2 => (1, shape[0], shape[1]), + 3 => (shape[0], shape[1], shape[2]), + _ => unimplemented!("Per-block grid quantization is only supported for 2D or 3D tensors"), + }; + assert!( + height % m == 0 && width % n == 0, + "Invalid per-block quantization with block grid [{m}, {n}] and tensor of shape {shape:?}" + ); + let mut output = vec![0.elem::(); values.len()]; + + let mut block_id = 0; + // TODO: parallel + for ih in (0..b * height).step_by(m) { + for iw in (0..width).step_by(n) { + // block height + for bh in 0..m { + let start_idx = (ih + bh) * width + iw; + // block width + for bw in 0..n { + let elem_idx = start_idx + bw; + let x_q = transform(block_id, values[elem_idx]); + output[elem_idx] = x_q; + } + } + block_id += 1; + } + } + output +} + +fn apply_per_block Vec>( + values: &[I], + block_size: usize, + transform: F, +) -> Vec { + let numel = values.len(); + assert_eq!( + numel % block_size, + 0, + "Invalid per-block quantization with block size {block_size} and {numel} values" + ); + // TODO: parallel chunks + values + .chunks(block_size) + .enumerate() + .flat_map(|(block_id, block)| transform(block_id, block)) + .collect() } impl QuantizationStrategy { @@ -24,11 +149,21 @@ impl QuantizationStrategy { pub fn scheme(&self) -> QuantizationScheme { match self { QuantizationStrategy::PerTensorAffineInt8(_) => { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8) } QuantizationStrategy::PerTensorSymmetricInt8(_) => { - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) } + QuantizationStrategy::PerBlockSymmetricInt8(_, layout) => QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + *layout, + ), + QuantizationStrategy::PerBlockAffineInt8(_, layout) => QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + *layout, + ), } } } @@ -36,12 +171,18 @@ impl QuantizationStrategy { /// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision /// data type `Q` and vice-versa. pub trait Quantization { + /// Returns the quantization range `[a, b]`. + fn range() -> (Q, Q); /// Create a new quantization scheme for an input range `[alpha, beta]`. fn new(alpha: E, beta: E) -> Self; /// Convert the values to a lower precision data type. fn quantize(&self, values: &[E]) -> Vec; + /// Convert a single value to a lower precision data type. + fn quantize_one(&self, value: E) -> Q; /// Convert the values back to a higher precision data type. fn dequantize(&self, values: &[Q]) -> Vec; + /// Convert a single value back to a higher precision data type. + fn dequantize_one(&self, value: Q) -> E; } /// Affine quantization scheme. @@ -81,9 +222,9 @@ impl for AffineQuantization { fn new(alpha: E, beta: E) -> Self { - // Q range `[a, b]` - let a = E::from(Q::min_value()).unwrap(); - let b = E::from(Q::max_value()).unwrap(); + let (a, b) = Self::range(); + let a = E::from(a).unwrap(); + let b = E::from(b).unwrap(); // We extend the `[alpha, beta]` interval to ensure that it contains 0. // Otherwise, we would not meet the requirement that 0 be an exactly @@ -102,47 +243,57 @@ impl } fn quantize(&self, values: &[E]) -> Vec { - // Quantized range `[a, b]` - let a = E::from(Q::min_value()).unwrap(); - let b = E::from(Q::max_value()).unwrap(); - - // x_q = clamp(round(x / scale + offset), a, b) - let z = E::from(self.offset).unwrap(); run_par!(|| { iter_slice_par!(values) - .map(|x| Q::from(x.div(self.scale).add(z).round().clamp(a, b)).unwrap()) + .map(|x| self.quantize_one(*x)) .collect() }) } fn dequantize(&self, values: &[Q]) -> Vec { - // x = scale * (x_q - offset) run_par!(|| { iter_slice_par!(values) - .map(|x_q| { - self.scale - * (E::from( - A::from(*x_q) - .unwrap() - .saturating_sub(A::from(self.offset).unwrap()), - ) - .unwrap()) - }) + .map(|x_q| self.dequantize_one(*x_q)) .collect() }) } + + fn quantize_one(&self, value: E) -> Q { + let (a, b) = Self::range(); + let a = E::from(a).unwrap(); + let b = E::from(b).unwrap(); + + // x_q = clamp(round(x / scale + offset), a, b) + let z = E::from(self.offset).unwrap(); + Q::from(value.div(self.scale).add(z).round().clamp(a, b)).unwrap() + } + + fn dequantize_one(&self, value: Q) -> E { + // x = scale * (x_q - offset) + self.scale + * (E::from( + A::from(value) + .unwrap() + .saturating_sub(A::from(self.offset).unwrap()), + ) + .unwrap()) + } + + fn range() -> (Q, Q) { + (Q::min_value(), Q::max_value()) + } } /// Symmetric quantization scheme. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct SymmetricQuantization { +pub struct SymmetricQuantization { /// The scaling factor. pub scale: E, /// The quantized type. _q: PhantomData, } -impl SymmetricQuantization { +impl SymmetricQuantization { /// Initialize a symmetric quantization scheme with the given parameters. pub fn init(scale: E) -> Self { Self { @@ -152,18 +303,13 @@ impl SymmetricQuantization Quantization +impl Quantization for SymmetricQuantization { fn new(alpha: E, beta: E) -> Self { - assert!( - !Q::min_value().is_zero(), - "Symmetric quantization is only valid for signed integers." - ); - - // Quantized range `[a, b]` - let b = E::from(Q::max_value()).unwrap(); - let a = b.neg(); + let (a, b) = Self::range(); + let a = E::from(a).unwrap(); + let b = E::from(b).unwrap(); // Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range let alpha = alpha.abs().max(beta.abs()); @@ -175,57 +321,31 @@ impl Quantization } fn quantize(&self, values: &[E]) -> Vec { - // Quantized range [a, b] - let b = E::from(Q::max_value()).unwrap(); - let a = b.neg(); - - // x_q = clamp(round(x / scale), a, b) - values - .iter() - .map(|x| Q::from(x.div(self.scale).round().clamp(a, b)).unwrap()) - .collect() + values.iter().map(|x| self.quantize_one(*x)).collect() } fn dequantize(&self, values: &[Q]) -> Vec { - // x = scale * x_q - values - .iter() - .map(|x_q| self.scale * E::from(*x_q).unwrap()) - .collect() + values.iter().map(|x_q| self.dequantize_one(*x_q)).collect() } -} -// Masks for the parts of the IEEE 754 float -const SIGN_MASK: u64 = 0x8000000000000000u64; -const EXP_MASK: u64 = 0x7ff0000000000000u64; -const MAN_MASK: u64 = 0x000fffffffffffffu64; - -#[inline] -/// Used for hashing. Input must not be zero or NaN. -/// Adapted from: https://github.com/reem/rust-ordered-float/blob/master/src/lib.rs -fn raw_double_bits(f: &F) -> u64 { - let (man, exp, sign) = f.integer_decode(); - let exp_u64 = exp as u16 as u64; - let sign_u64 = (sign > 0) as u64; - (man & MAN_MASK) | ((exp_u64 << 52) & EXP_MASK) | ((sign_u64 << 63) & SIGN_MASK) -} + fn quantize_one(&self, value: E) -> Q { + let (a, b) = Self::range(); + let a = E::from(a).unwrap(); + let b = E::from(b).unwrap(); -#[inline(always)] -fn canonicalize_signed_zero(x: T) -> T { - // -0.0 + 0.0 == +0.0 under IEEE754 roundTiesToEven rounding mode, - // which Rust guarantees. Thus by adding a positive zero we - // canonicalize signed zero without any branches in one instruction. - x + T::zero() -} + // x_q = clamp(round(x / scale), a, b) + Q::from(value.div(self.scale).round().clamp(a, b)).unwrap() + } -impl Hash - for AffineQuantization -{ - fn hash(&self, state: &mut H) { - // Hash raw bits. - let bits = raw_double_bits(&canonicalize_signed_zero(self.scale)); - bits.hash(state); - self.offset.hash(state); + fn dequantize_one(&self, value: Q) -> E { + // x = scale * x_q + self.scale * E::from(value).unwrap() + } + + fn range() -> (Q, Q) { + // Only implemented for symmetric *signed* at this time + let b = Q::max_value(); + (b.neg(), b) } } @@ -242,26 +362,19 @@ impl Eq { } -impl Hash for SymmetricQuantization { - fn hash(&self, state: &mut H) { - // Hash raw bits. - let bits = raw_double_bits(&canonicalize_signed_zero(self.scale)); - bits.hash(state); - } -} - -impl PartialEq for SymmetricQuantization { +impl PartialEq + for SymmetricQuantization +{ fn eq(&self, other: &Self) -> bool { self.scale == other.scale } } -impl Eq for SymmetricQuantization {} +impl Eq for SymmetricQuantization {} #[cfg(test)] mod tests { use super::*; - use alloc::vec; #[test] fn test_int8_affine_quantization() { @@ -313,4 +426,158 @@ mod tests { assert_eq!(d, expected_d); } + + #[test] + fn test_int8_symmetric_quantization_per_block_flat() { + let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5]; + let shape = &[2, 4]; + let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35]; + let expected_d = vec![ + -1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063, + ]; + + let symmetric = SymmetricQuantization::::new(-1.8, 0.5); + let strategy = QuantizationStrategy::PerBlockSymmetricInt8( + vec![symmetric.clone(), symmetric], + BlockLayout::Flat(4), + ); + + let q: Vec = strategy.quantize(&x, shape); + assert_eq!(q, expected_q); + + let d = symmetric.dequantize(&expected_q); + + assert_eq!(d, expected_d); + } + + #[test] + fn test_int8_affine_quantization_per_block_flat() { + let x = vec![ + [-1.8, -1.0, 0.0, 0.5], + [-0.8, 1.2, 0.25, 0.5], + [-8., 12., 2.5, 5.], + [0.2, 0.3, 0.4, 0.5], + ] + .concat(); + let shape = &[2, 8]; + let expected_q = vec![ + [-128i8, -40, 71, 126], + [-128, 127, 6, 38], + [-128, 127, 6, 38], + [-26, 25, 76, 127], + ] + .concat(); + let expected_d = vec![ + [-1.794902, -1.0011765, 0.0, 0.49607843], + [-0.8000001, 1.2, 0.2509804, 0.5019608], + [-8.0, 12.0, 2.509804, 5.019608], + [0.20000002, 0.3, 0.40000004, 0.5], + ] + .concat(); + + // Affine quantization for each block with range min/max + let per_block_strategy = vec![ + AffineQuantization::::new(-1.8, 0.5), + AffineQuantization::::new(-0.8, 1.2), + AffineQuantization::::new(-8., 12.), + AffineQuantization::::new(0.2, 0.5), + ]; + let strategy = + QuantizationStrategy::PerBlockAffineInt8(per_block_strategy, BlockLayout::Flat(4)); + + let q: Vec = strategy.quantize(&x, shape); + assert_eq!(q, expected_q); + + let d = strategy.dequantize(&expected_q, shape); + + assert_eq!(d, expected_d); + } + + #[test] + fn test_int8_symmetric_quantization_per_block_grid() { + let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, 0.5, 0.0, -1.0, -1.8]; + let shape = &[2, 4]; + let expected_q = vec![-127, -71, 0, 35, 35, 0, -71, -127]; + let expected_d = vec![ + -1.8, -1.0062993, 0.0, 0.496063, 0.496063, 0.0, -1.0062993, -1.8, + ]; + + let symmetric = SymmetricQuantization::::new(-1.8, 0.5); + let strategy = QuantizationStrategy::PerBlockSymmetricInt8( + vec![symmetric.clone(), symmetric], + BlockLayout::Grid(2, 2), + ); + + let q: Vec = strategy.quantize(&x, shape); + assert_eq!(q, expected_q); + + let d = strategy.dequantize(&expected_q, shape); + + assert_eq!(d, expected_d); + } + + #[test] + fn test_int8_symmetric_quantization_per_block_grid_3d() { + let shape = &[2, 4, 4]; + let x = vec![ + // 2x2 blocks: [[-1.8, -1.0, 0.0, 0.5], [-0.8, 1.2, 0.25, 0.5]] + [-1.8, -1.0, -0.8, 1.2], + [0.0, 0.5, 0.25, 0.5], + // 2x2 blocks: [[-0.08, 0.12, 0.025, 0.05], [0.2, 0.3, 0.4, 0.5]] + [-0.08, 0.12, 0.2, 0.3], + [0.025, 0.05, 0.4, 0.5], + // 2x2 blocks: [[0.01, 0.03, 0.02, 0.06], [4.0, 3.0, 2.0, 1.0]] + [0.01, 0.03, 4.0, 3.0], + [0.02, 0.06, 2.0, 1.0], + // 2x2 blocks: [[0.4, 0.3, 0.2, 0.1], [0.5, 0.0, -1.0, -1.8]] + [0.4, 0.3, 0.5, 0.0], + [0.2, 0.1, -1.0, -1.8], + ] + .concat(); // easier to visualize with a vec of rows + let expected_q = vec![ + [-127, -71, -85, 127], + [0, 35, 26, 53], + [-85, 127, 51, 76], + [26, 53, 102, 127], + [21, 64, 127, 95], + [42, 127, 64, 32], + [127, 95, 35, 0], + [64, 32, -71, -127], + ] + .concat(); + let expected_d = vec![ + [-1.8, -1.0062993, -0.8031496, 1.2], + [0.0, 0.496063, 0.24566929, 0.5007874], + [-0.08031496, 0.12, 0.2007874, 0.2992126], + [0.024566928, 0.05007874, 0.4015748, 0.5], + [0.009921259, 0.03023622, 4.0, 2.992126], + [0.019842518, 0.06, 2.015748, 1.007874], + [0.4, 0.2992126, 0.496063, 0.0], + [0.2015748, 0.1007874, -1.0062993, -1.8], + ] + .concat(); + + // Symmetric quantization for each block with range min/max + let per_block_strategy = vec![ + SymmetricQuantization::::new(-1.8, 0.5), + SymmetricQuantization::::new(-0.8, 1.2), + SymmetricQuantization::::new(-0.08, 0.12), + SymmetricQuantization::::new(0.2, 0.5), + SymmetricQuantization::::new(0.01, 0.06), + SymmetricQuantization::::new(1.0, 4.0), + SymmetricQuantization::::new(0.1, 0.4), + SymmetricQuantization::::new(-1.8, 0.5), + ]; + let strategy = QuantizationStrategy::PerBlockSymmetricInt8( + per_block_strategy, + BlockLayout::Grid(2, 2), + ); + + let q: Vec = strategy.quantize(&x, shape); + assert_eq!(q, expected_q); + + let d = strategy.dequantize(&expected_q, shape); + + assert_eq!(d, expected_d); + } } diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index ee9aec9fe8..1dfab07b6f 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -84,7 +84,7 @@ macro_rules! testgen_quantization { use burn_tensor::{ backend::Backend, - quantization::{QuantizationScheme, QuantizationType}, + quantization::{QuantizationMode, QuantizationScheme, QuantizationType}, Tensor, TensorData, }; @@ -101,13 +101,19 @@ macro_rules! testgen_quantization { /// Creates a quantized int8 tensor from the floating point data using per-tensor symmetric quantization. pub fn int8_symmetric>(floats: F) -> Tensor { Tensor::from_floats(floats, &Default::default()).quantize_dynamic( - &QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8), + &QuantizationScheme::PerTensor( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + ), ) } /// Creates a quantized int8 tensor from the floating point data using per-tensor affine quantization. pub fn int8_affine>(floats: F) -> Tensor { Tensor::from_floats(floats, &Default::default()).quantize_dynamic( - &QuantizationScheme::PerTensorAffine(QuantizationType::QInt8), + &QuantizationScheme::PerTensor( + QuantizationMode::Affine, + QuantizationType::QInt8, + ), ) } } @@ -118,6 +124,7 @@ macro_rules! testgen_quantization { burn_tensor::testgen_calibration!(); burn_tensor::testgen_scheme!(); burn_tensor::testgen_quantize!(); + burn_tensor::testgen_q_data!(); // test ops burn_tensor::testgen_q_abs!(); diff --git a/crates/burn-tensor/src/tests/quantization/calibration.rs b/crates/burn-tensor/src/tests/quantization/calibration.rs index 8140be4f28..dab3a39590 100644 --- a/crates/burn-tensor/src/tests/quantization/calibration.rs +++ b/crates/burn-tensor/src/tests/quantization/calibration.rs @@ -2,16 +2,20 @@ mod tests { use super::*; use burn_tensor::{ - quantization::{Calibration, MinMaxCalibration, QuantizationType}, + quantization::{ + BlockLayout, Calibration, QuantizationMode, QuantizationScheme, QuantizationType, + }, Tensor, TensorData, }; + // NOTE: The scheme variant fields are not important for calibration, only the "main" variant (e.g., per-tensor) #[test] - fn min_max_calibration_range() { + fn min_max_calibration_range_per_tensor() { let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &Default::default()); - let calibration = MinMaxCalibration {}; + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8); - let range = calibration.compute_range(&tensor); + let range = scheme.compute_range(&tensor, &Calibration::MinMax); range .min @@ -22,4 +26,106 @@ mod tests { .into_data() .assert_eq(&TensorData::from([0.5]), false); } + + #[test] + fn min_max_calibration_range_per_block_flat_all() { + let tensor = TestTensor::<2>::from_floats( + [[-1.8, -1.0, 0.0, 0.5], [1., 2., 3., 4.]], + &Default::default(), + ); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Flat(8), + ); + + let range = scheme.compute_range(&tensor, &Calibration::MinMax); + + range + .min + .into_data() + .assert_eq(&TensorData::from([-1.8]), false); + range + .max + .into_data() + .assert_eq(&TensorData::from([4.]), false); + } + + #[test] + fn min_max_calibration_range_per_block_flat_row() { + let tensor = TestTensor::<2>::from_floats( + [[-1.8, -1.0, 0.0, 0.5], [1., 2., 3., 4.]], + &Default::default(), + ); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Flat(4), + ); + + let range = scheme.compute_range(&tensor, &Calibration::MinMax); + range + .min + .into_data() + .assert_eq(&TensorData::from([-1.8, 1.]), false); + range + .max + .into_data() + .assert_eq(&TensorData::from([0.5, 4.]), false); + } + + #[test] + #[should_panic(expected = "Cannot compute per-block quantization range")] + fn min_max_calibration_range_per_block_flat_invalid() { + let tensor = TestTensor::<2>::from_floats( + [[-1.8, -1.0, 0.0, 0.5], [1., 2., 3., 4.]], + &Default::default(), + ); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Flat(3), + ); + let _ = scheme.compute_range(&tensor, &Calibration::MinMax); + } + + #[test] + fn min_max_calibration_range_per_block_grid() { + let tensor = TestTensor::<2>::from_floats( + [[-1.8, -1.0, 0.0, 0.5], [1., 2., 3., 4.]], + &Default::default(), + ); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Grid(2, 2), + ); + + let range = scheme.compute_range(&tensor, &Calibration::MinMax); + + range + .min + .into_data() + .assert_eq(&TensorData::from([-1.8, 0.]), false); + range + .max + .into_data() + .assert_eq(&TensorData::from([2., 4.]), false); + } + + #[test] + #[should_panic(expected = "Cannot compute per-block quantization range")] + fn min_max_calibration_range_per_block_grid_invalid() { + let tensor = TestTensor::<2>::from_floats( + [[-1.8, -1.0, 0.0, 0.5], [1., 2., 3., 4.]], + &Default::default(), + ); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Grid(3, 3), + ); + + let _ = scheme.compute_range(&tensor, &Calibration::MinMax); + } } diff --git a/crates/burn-tensor/src/tests/quantization/data.rs b/crates/burn-tensor/src/tests/quantization/data.rs new file mode 100644 index 0000000000..6432650594 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/data.rs @@ -0,0 +1,109 @@ +#[burn_tensor_testgen::testgen(q_data)] +mod tests { + use super::*; + use alloc::{vec, vec::Vec}; + use burn_tensor::quantization::{ + AffineQuantization, BlockLayout, QuantizationStrategy, SymmetricQuantization, + }; + use burn_tensor::{Tensor, TensorData}; + + // NOTE: we mark the per-block tests as `might_panic` since backends are not strictly + // required to support this quantization scheme. + // Also std feature gated (until `catch_unwind` is stable in core). + #[cfg(feature = "std")] + use burn_tensor::might_panic; + + #[test] + fn should_support_per_tensor_affine_int8() { + let data = TensorData::quantized( + vec![-128i8, -39, 72, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72)), + ); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); + + tensor.into_data().assert_eq(&data, true); + } + + #[test] + fn should_support_per_tensor_symmetric_int8() { + let data = TensorData::quantized( + vec![-127i8, -71, 0, 35], + [4], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( + 0.014_173_228, + )), + ); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); + + tensor.into_data().assert_eq(&data, true); + } + + #[cfg(feature = "std")] + #[might_panic(reason = "Per-block quantization is not supported")] + #[test] + fn should_support_per_block_flat() { + // Per-block qparams + let data = TensorData::quantized( + vec![ + [-127i8, -71, 0, 35, -56, 85, 18, 35], + [-20, 30, 6, 13, 51, 76, 102, 127], + ] + .concat(), + [2, 8], + QuantizationStrategy::PerBlockAffineInt8( + vec![ + AffineQuantization::init(0.009019608, 71), + AffineQuantization::init(0.007843138, -26), + AffineQuantization::init(0.00078431366, -25), + AffineQuantization::init(0.0019607844, -128), + ], + BlockLayout::Flat(4), + ), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + tensor.into_data().assert_eq(&data, true); + } + + #[cfg(feature = "std")] + #[might_panic(reason = "Per-block quantization is not supported")] + #[test] + fn should_support_per_block_grid() { + // Per-block qparams + let scales: [f32; 8] = [ + 0.014173228, + 0.009448819, + 0.0009448819, + 0.003937008, + 0.00047244094, + 0.031496063, + 0.0031496063, + 0.014173228, + ]; + let data = TensorData::quantized( + vec![ + [-127i8, -71, -85, 127], + [0, 35, 26, 53], + [-85, 127, 51, 76], + [26, 53, 102, 127], + [21, 64, 127, 95], + [42, 127, 64, 32], + [127, 95, 35, 0], + [64, 32, -71, -127], + ] + .concat(), + [8, 4], + QuantizationStrategy::PerBlockSymmetricInt8( + scales + .iter() + .map(|&s| SymmetricQuantization::init(s)) + .collect(), + BlockLayout::Grid(2, 2), + ), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + tensor.into_data().assert_eq(&data, true); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/mod.rs b/crates/burn-tensor/src/tests/quantization/mod.rs index bc9bec8673..61663b34ea 100644 --- a/crates/burn-tensor/src/tests/quantization/mod.rs +++ b/crates/burn-tensor/src/tests/quantization/mod.rs @@ -1,3 +1,4 @@ mod calibration; +mod data; mod ops; mod scheme; diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs index 5bc6c0cf5f..9fe119dca0 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -1,14 +1,22 @@ #[burn_tensor_testgen::testgen(quantize)] + mod tests { use super::*; - use burn_tensor::ops::QTensorOps; + use alloc::{vec, vec::Vec}; use burn_tensor::quantization::{ - AffineQuantization, QParams, QuantizationParameters, QuantizationScheme, - QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization, + AffineQuantization, BlockLayout, QParams, QuantizationMode, QuantizationParameters, + QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes, + SymmetricQuantization, }; use burn_tensor::{DType, Tensor, TensorData}; - fn get_q_params(data: TensorData) -> QParams { + // NOTE: we mark the per-block tests as `might_panic` since backends are not strictly + // required to support this quantization scheme. + // Also std feature gated (until `catch_unwind` is stable in core). + #[cfg(feature = "std")] + use burn_tensor::might_panic; + + fn get_q_params(data: TensorData) -> QParams, Vec> { let num_elements = data.num_elements(); let scheme = if let DType::QFloat(scheme) = data.dtype { scheme @@ -27,14 +35,16 @@ mod tests { fn should_support_quantize_affine_int8() { let device = Default::default(); let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device); - let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8); let qparams = QuantizationParameters { scale: Tensor::from_floats([0.009_019_608], &device), offset: Some(Tensor::from_ints([72], &device)), }; - let x_q = tensor.quantize(&scheme, qparams).into_data(); + let x_q = tensor.clone().quantize(&scheme, qparams); + let x_q_data = x_q.to_data(); let expected = TensorData::quantized( vec![-128i8, -39, 72, 127], [4], @@ -42,27 +52,37 @@ mod tests { ); // Values equality - x_q.assert_eq(&expected, true); + x_q_data.assert_eq(&expected, true); // Quantization parameters check - let qparams = get_q_params(x_q); + let qparams = get_q_params(x_q_data); let expected = get_q_params(expected); + assert_eq!(qparams.scale.len(), 1); assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset.as_ref().map(|x| x.len()), Some(1)); assert_eq!(qparams.offset, expected.offset); + + // Dequantize + let x = x_q.dequantize(); + + // Precision 2 for dequantization errors + x.into_data().assert_approx_eq(&tensor.into_data(), 2); } #[test] fn should_support_quantize_symmetric_int8() { let device = Default::default(); let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device); - let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); let qparams = QuantizationParameters { scale: Tensor::from_floats([0.014_173_228], &device), offset: None, }; - let x_q = tensor.quantize(&scheme, qparams).into_data(); + let x_q = tensor.clone().quantize(&scheme, qparams); + let x_q_data = x_q.to_data(); let expected = TensorData::quantized( vec![-127i8, -71, 0, 35], [4], @@ -72,33 +92,21 @@ mod tests { ); // Values equality - x_q.assert_eq(&expected, true); + x_q_data.assert_eq(&expected, true); // Quantization parameters check - let qparams = get_q_params(x_q); + let qparams = get_q_params(x_q_data); let expected = get_q_params(expected); + assert_eq!(qparams.scale.len(), 1); assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset, None); assert_eq!(qparams.offset, expected.offset); - } - - #[test] - fn should_support_dequantize() { - let device = Default::default(); - // Quantized [-1.8, -1.0, 0.0, 0.5] - let data = TensorData::quantized( - vec![-127i8, -71, 0, 35], - [4], - QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( - 0.014_173_228, - )), - ); - let x_q = TestTensor::<1>::from_data(data, &device); + // Dequantize let x = x_q.dequantize(); // Precision 2 for dequantization errors - x.into_data() - .assert_approx_eq(&TensorData::from([-1.8, -1.0, 0.0, 0.5]), 2); + x.into_data().assert_approx_eq(&tensor.into_data(), 2); } #[test] @@ -107,7 +115,8 @@ mod tests { // NOTE: we use fully representable values since different backend implementations could differ slightly // due to rounding discrepancies let tensor = TestTensor::<1>::from_floats([5., 0., 4., -10.], &device); - let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8); let x_q = tensor.quantize_dynamic(&scheme); @@ -119,4 +128,240 @@ mod tests { x_q.into_data().assert_eq(&expected, false); } + + #[cfg(feature = "std")] + #[might_panic(reason = "Per-block quantization is not supported")] + #[test] + fn should_support_quantize_per_block_symmetric_int8() { + let device = Default::default(); + let tensor = TestTensor::<2>::from_floats( + [ + [-1.8, -1.0, 0.0, 0.5], + [-0.8, 1.2, 0.25, 0.5], + [-0.08, 0.12, 0.025, 0.05], + [0.2, 0.3, 0.4, 0.5], + [0.1, 0.3, 0.2, 0.6], + [4.0, 3.0, 2.0, 1.0], + [0.4, 0.3, 0.2, 0.1], + [0.5, 0.0, -1.0, -1.8], + ], + &device, + ); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + BlockLayout::Flat(4), + ); + + // Per-block qparams + let scales: [f32; 8] = [ + 0.014173228, + 0.009448819, + 0.0009448819, + 0.003937008, + 0.0047244094, + 0.031496063, + 0.0031496063, + 0.014173228, + ]; + let qparams = QuantizationParameters { + scale: Tensor::from_floats(scales, &device), + offset: None, + }; + + let x_q = tensor.clone().quantize(&scheme, qparams); + + let x_q_data = x_q.to_data(); + let expected = TensorData::quantized( + vec![ + [-127i8, -71, 0, 35], + [-85, 127, 26, 53], + [-85, 127, 26, 53], + [51, 76, 102, 127], + [21, 64, 42, 127], + [127, 95, 64, 32], + [127, 95, 64, 32], + [35, 0, -71, -127], + ] + .concat(), + [8, 4], + QuantizationStrategy::PerBlockSymmetricInt8( + scales + .iter() + .map(|&s| SymmetricQuantization::init(s)) + .collect(), + BlockLayout::Flat(4), + ), + ); + + // Values equality + x_q_data.assert_eq(&expected, true); + + // Quantization parameters check + let qparams = get_q_params(x_q_data); + let expected = get_q_params(expected); + assert_eq!(qparams.scale.len(), 8); + assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset, None); + assert_eq!(qparams.offset, expected.offset); + + // Dequantize + let x = x_q.dequantize(); + + // Precision 2 for dequantization errors + x.into_data().assert_approx_eq(&tensor.into_data(), 2); + } + + #[cfg(feature = "std")] + #[might_panic(reason = "Per-block quantization is not supported")] + #[test] + fn should_support_quantize_per_block_affine_int8() { + let device = Default::default(); + let tensor = TestTensor::<2>::from_floats( + [ + [-1.8, -1.0, 0.0, 0.5, -0.8, 1.2, 0.25, 0.5], + [-8., 12., 2.5, 5., 0.2, 0.3, 0.4, 0.5], + ], + &device, + ); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Flat(4), + ); + + // Per-block qparams + let scales: [f32; 4] = [0.009019608, 0.007843138, 0.078431366, 0.0019607844]; + let offsets: [i8; 4] = [71, -26, -26, -128]; + let qparams = QuantizationParameters { + scale: Tensor::from_floats(scales, &device), + offset: Some(Tensor::from_ints(offsets, &device)), + }; + + let x_q = tensor.clone().quantize(&scheme, qparams); + + let x_q_data = x_q.to_data(); + let expected = TensorData::quantized( + vec![ + [-128i8, -40, 71, 126], + [-128, 127, 6, 38], + [-128, 127, 6, 38], + [-26, 25, 76, 127], + ] + .concat(), + [2, 8], + QuantizationStrategy::PerBlockAffineInt8( + scales + .iter() + .zip(offsets.iter()) + .map(|(&s, &o)| AffineQuantization::init(s, o)) + .collect(), + BlockLayout::Flat(4), + ), + ); + + // Values equality + x_q_data.assert_eq(&expected, true); + + // Quantization parameters check + let qparams = get_q_params(x_q_data); + let expected = get_q_params(expected); + assert_eq!(qparams.scale.len(), 4); + assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset.as_ref().unwrap().len(), 4); + assert_eq!(qparams.offset, expected.offset); + + // Dequantize + let x = x_q.dequantize(); + + // Precision 2 for dequantization errors + x.into_data().assert_approx_eq(&tensor.into_data(), 2); + } + + #[cfg(feature = "std")] + #[might_panic(reason = "Per-block quantization is not supported")] + #[test] + fn should_support_quantize_per_block_grid_symmetric_int8() { + let device = Default::default(); + let tensor = TestTensor::<2>::from_floats( + [ + // 2x2 blocks: [[-1.8, -1.0, 0.0, 0.5], [-0.8, 1.2, 0.25, 0.5]] + [-1.8, -1.0, -0.8, 1.2], + [0.0, 0.5, 0.25, 0.5], + // 2x2 blocks: [[-0.8, 1.2, 0.25, 0.5], [0.2, 0.3, 0.4, 0.5]] + [-0.8, 1.2, 0.2, 0.3], + [0.25, 0.5, 0.4, 0.5], + // 2x2 blocks: [[0.1, 0.3, 0.2, 0.6], [4.0, 3.0, 2.0, 1.0]] + [0.1, 0.3, 4.0, 3.0], + [0.2, 0.6, 2.0, 1.0], + // 2x2 blocks: [[0.4, 0.3, 0.2, 0.1], [0.5, 0.0, -1.0, -1.8]] + [0.4, 0.3, 0.5, 0.0], + [0.2, 0.1, -1.0, -1.8], + ], + &device, + ); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + BlockLayout::Grid(2, 2), + ); + + // Per-block qparams + let scales: [f32; 8] = [ + 0.014173228, + 0.009448819, + 0.009448819, + 0.003937008, + 0.0047244094, + 0.031496063, + 0.0031496063, + 0.014173228, + ]; + let qparams = QuantizationParameters { + scale: Tensor::from_floats(scales, &device), + offset: None, + }; + + let x_q = tensor.clone().quantize(&scheme, qparams); + + let x_q_data = x_q.to_data(); + let expected = TensorData::quantized( + vec![ + [-127i8, -71, -85, 127], + [0, 35, 26, 53], + [-85, 127, 51, 76], + [26, 53, 102, 127], + [21, 64, 127, 95], + [42, 127, 64, 32], + [127, 95, 35, 0], + [64, 32, -71, -127], + ] + .concat(), + [8, 4], + QuantizationStrategy::PerBlockSymmetricInt8( + scales + .iter() + .map(|&s| SymmetricQuantization::init(s)) + .collect(), + BlockLayout::Grid(2, 2), + ), + ); + + // Values equality + x_q_data.assert_eq(&expected, true); + + // Quantization parameters check + let qparams = get_q_params(x_q_data); + let expected = get_q_params(expected); + assert_eq!(qparams.scale.len(), 8); + assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset, None); + assert_eq!(qparams.offset, expected.offset); + + // Dequantize + let x = x_q.dequantize(); + + // Precision 2 for dequantization errors + x.into_data().assert_approx_eq(&tensor.into_data(), 2); + } } diff --git a/crates/burn-tensor/src/tests/quantization/scheme.rs b/crates/burn-tensor/src/tests/quantization/scheme.rs index a305d25f43..07a5549887 100644 --- a/crates/burn-tensor/src/tests/quantization/scheme.rs +++ b/crates/burn-tensor/src/tests/quantization/scheme.rs @@ -2,14 +2,17 @@ mod tests { use super::*; use burn_tensor::{ - quantization::{CalibrationRange, QuantizationScheme, QuantizationType}, + quantization::{ + BlockLayout, CalibrationRange, QuantizationMode, QuantizationScheme, QuantizationType, + }, Tensor, TensorData, }; #[test] fn per_tensor_affine_int8() { let device = Default::default(); - let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Affine, QuantizationType::QInt8); let range = CalibrationRange { min: TestTensor::<1>::from_floats([-1.8], &device), max: TestTensor::<1>::from_floats([0.5], &device), @@ -31,10 +34,11 @@ mod tests { #[test] fn per_tensor_symmetric_int8() { let device = Default::default(); - let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); + let scheme = + QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); let range = CalibrationRange { - min: TestTensor::<1>::from_floats([-1.8], &device), - max: TestTensor::<1>::from_floats([0.5], &device), + min: TestTensor::<1>::from_floats([0.5], &device), + max: TestTensor::<1>::from_floats([1.8], &device), }; let qparams = scheme.compute_q_params(range); @@ -45,4 +49,52 @@ mod tests { .assert_approx_eq(&TensorData::from([0.014_173_228]), 8); assert!(qparams.offset.is_none()); } + + #[test] + fn per_block_affine_int8() { + let device = Default::default(); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Affine, + QuantizationType::QInt8, + BlockLayout::Flat(3), // layout doesn't matter when computing qparams + ); + let range = CalibrationRange { + min: TestTensor::<1>::from_floats([-1.8, -2.0, 0.5], &device), + max: TestTensor::<1>::from_floats([0.5, 1.5, 1.8], &device), + }; + + let qparams = scheme.compute_q_params(range); + + qparams.scale.into_data().assert_approx_eq( + &TensorData::from([0.009_019_608, 0.013_725_490, 0.007_0588_234]), + 8, + ); + qparams + .offset + .unwrap() + .into_data() + .assert_eq(&TensorData::from([71, 17, -128]), false); + } + + #[test] + fn per_block_symmetric_int8() { + let device = Default::default(); + let scheme = QuantizationScheme::PerBlock( + QuantizationMode::Symmetric, + QuantizationType::QInt8, + BlockLayout::Flat(3), // layout doesn't matter when computing qparams + ); + let range = CalibrationRange { + min: TestTensor::<1>::from_floats([-1.8, -2.0, 0.5], &device), + max: TestTensor::<1>::from_floats([0.5, 1.5, 1.8], &device), + }; + + let qparams = scheme.compute_q_params(range); + + qparams.scale.into_data().assert_approx_eq( + &TensorData::from([0.014_173_228, 0.015_748_031, 0.014_173_228]), + 8, + ); + assert!(qparams.offset.is_none()); + } } diff --git a/crates/burn-vision/src/backends/cpu/morphology/mod.rs b/crates/burn-vision/src/backends/cpu/morphology/mod.rs index c23eca3389..f6f7747f28 100644 --- a/crates/burn-vision/src/backends/cpu/morphology/mod.rs +++ b/crates/burn-vision/src/backends/cpu/morphology/mod.rs @@ -110,10 +110,8 @@ pub fn morph>( DType::U8 => morph_typed::(data, shape, kernel, op, iter, btype, bvalue, &device), DType::Bool => morph_bool::(data, shape, kernel, op, iter, btype, bvalue, &device), DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { - morph_typed::(data, shape, kernel, op, iter, btype, bvalue, &device) - } - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) + | QuantizationScheme::PerBlock(_mode, QuantizationType::QInt8, ..) => { morph_typed::(data, shape, kernel, op, iter, btype, bvalue, &device) } },