Skip to content

Commit

Permalink
Extend [min, max] range to ensure zero-point (#2055)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jul 24, 2024
1 parent dea33e8 commit 64a2f12
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 22 deletions.
25 changes: 11 additions & 14 deletions crates/burn-tensor/src/tensor/quantization/scheme.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{backend::Backend, Int, Tensor};
use crate::{backend::Backend, Tensor};

use super::{CalibrationRange, QuantizationParameters};

Expand All @@ -22,11 +22,6 @@ pub enum QuantizationScheme {
// PerChannelSymmetric,
}

/// Round the tensor to the nearest integer.
fn round<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, D, Int> {
tensor.add_scalar(0.5).int()
}

impl QuantizationScheme {
/// Compute the quantization parameters.
pub fn compute_q_params<B: Backend>(
Expand All @@ -40,15 +35,17 @@ impl QuantizationScheme {
let a = i8::MIN as i32;
let b = i8::MAX as i32;

// Input range `[alpha, beta]`
let input_range = range.max.clone().sub(range.min.clone());
// We extend the `[min, max]` interval to ensure that it contains 0.
// Otherwise, we would not meet the requirement that 0 be an exactly
// representable value (zero-point).
let zero = Tensor::zeros_like(&range.min);
let min = range.min.min_pair(zero);
let zero = Tensor::zeros_like(&range.max);
let max = range.max.max_pair(zero);

QuantizationParameters {
scale: input_range.clone().div_scalar(b - a),
offset: Some(round(
(range.max.mul_scalar(a) - range.min.mul_scalar(b)).div(input_range),
)),
}
let scale = max.sub(min.clone()).div_scalar(b - a);
let offset = Some(-(min.div(scale.clone()).sub_scalar(a)).int());
QuantizationParameters { scale, offset }
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
Expand Down
37 changes: 30 additions & 7 deletions crates/burn-tensor/src/tensor/quantization/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,16 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
let a = E::from(Q::min_value()).unwrap();
let b = E::from(Q::max_value()).unwrap();

// We extend the `[alpha, beta]` interval to ensure that it contains 0.
// Otherwise, we would not meet the requirement that 0 be an exactly
// representable value (zero-point).
let alpha = E::min(alpha, E::zero());
let beta = E::max(beta, E::zero());

// Compute scale and offset to convert a floating point value in range `[alpha, beta]` to the quantized range
let range = beta - alpha;
Self::init(
range / (b - a),
Q::from(E::round(((beta * a) - (alpha * b)) / range)).unwrap(),
)
let scale = (beta - alpha) / (b - a);
let z = -(alpha / scale - a);
Self::init(scale, Q::from(z).unwrap())
}

fn quantize(&self, values: &[E]) -> Vec<Q> {
Expand Down Expand Up @@ -236,8 +240,8 @@ mod tests {
#[test]
fn test_int8_affine_quantization() {
let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
let expected_q = vec![-128, -39, 72, 127];
let expected_d = vec![-1.8039216, -1.0011765, 0.0, 0.49607843];
let expected_q = vec![-128, -40, 71, 126];
let expected_d = vec![-1.794902, -1.0011765, 0.0, 0.49607843];

let affine = AffineQuantization::<f32, i8, i32>::new(-1.8, 0.5);

Expand All @@ -249,6 +253,25 @@ mod tests {
assert_eq!(d, expected_d);
}

#[test]
fn test_affine_should_ensure_zero_point() {
let x: [f32; 6] = [2.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let expected_q = vec![-26, -77, -26, 25, 76, 127];
let expected_d = x.to_vec();

let affine = AffineQuantization::<f32, i8, i32>::new(1.0, 5.0);

assert_eq!(affine.offset, -128);
assert_eq!(affine.scale, 0.019607844);

let q = affine.quantize(&x);
assert_eq!(q, expected_q);

let d = affine.dequantize(&expected_q);

assert_eq!(d, expected_d);
}

#[test]
fn test_int8_symmetric_quantization() {
let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tests/quantization/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod tests {
.offset
.unwrap()
.into_data()
.assert_eq(&TensorData::from([72]), false);
.assert_eq(&TensorData::from([71]), false);
}

#[test]
Expand Down

0 comments on commit 64a2f12

Please sign in to comment.