Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/quant/per block #2849

Merged
merged 28 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ jobs:
- rust: prev
toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }}
steps:
# disable incremental compilation (reduces artifact size)
- name: Set CI Profile
run: echo "CARGO_PROFILE_TEST_INCREMENTAL=false" >> $GITHUB_ENV
# --------------------------------------------------------------------------------
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 26 additions & 10 deletions burn-book/src/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ tensors and can collect their statistics, such as the min and max value when usi

```rust , ignore
# use burn::module::Quantizer;
# use burn::tensor::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType};
# use burn::tensor::quantization::{Calibration, QuantizationScheme, QuantizationType};
#
// Quantization config
let mut quantizer = Quantizer {
calibration: MinMaxCalibration {},
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
calibration: Calibration::MinMax,
scheme: QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8),
};

// Quantize the weights
Expand Down Expand Up @@ -95,9 +95,9 @@ _quantization-time_ (weights are static), but activations require more attention

To compute the quantization parameters, Burn supports the following `Calibration` methods.

| Method | Description |
| :------------------ | :------------------------------------------------------------------------------- |
| `MinMaxCalibration` | Computes the quantization range mapping based on the running min and max values. |
| Method | Description |
| :------- | :------------------------------------------------------------------------------- |
| `MinMax` | Computes the quantization range mapping based on the running min and max values. |

### Quantization Scheme

Expand All @@ -116,7 +116,23 @@ channel with per-channel quantization (commonly used with CNNs).

Burn currently supports the following `QuantizationScheme` variants.

| Variant | Description |
| :------------------- | :------------------------------------------------------------------------------------------------------------- |
| `PerTensorAffine` | Computes the quantization parameters for the whole tensor and applies an affine range mapping with zero point. |
| `PerTensorSymmetric` | Computes the quantization parameters for the whole tensor and applies a scale range mapping centered around 0. |
| Variant | Description |
| :----------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `PerTensor(mode, type)` | Applies a single set of quantization parameters to the entire tensor. The `mode` defines how values are transformed, and `type` represents the target quantization type. |
| `PerBlock(mode, type, layout)` | Applies quantization parameters to individual blocks within the tensor. The `layout` defines how the tensor is partitioned. |

#### Quantization Mode

| Mode | Description |
| ----------- | -------------------------------------------------------------------- |
| `Affine` | Maps values using an affine transformation with a zero point offset. |
| `Symmetric` | Maps values using a scale factor for a range centered around zero. |

---

#### Block Layout

| Layout | Description |
| ------------------ | -------------------------------------------------------- |
| `Flat(block_size)` | Divides the tensor into linear 1D blocks of fixed size. |
| `Grid(m, n)` | Divides the tensor into 2D blocks of `m` x `n` elements. |
2 changes: 1 addition & 1 deletion crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::marker::PhantomData;

use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps},
quantization::{QTensorPrimitive, QuantizationStrategy},
quantization::QTensorPrimitive,
Device,
};
use candle_core::{backend::BackendDevice, DeviceLocation};
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-candle/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::Range;
use burn_tensor::{
backend::Backend,
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy},
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
DType, Device, Shape, TensorData,
};

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-candle/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
quantization::{QTensorPrimitive, QuantizationScheme},
DType, Element, Shape, TensorData, TensorMetadata,
};

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
};
use alloc::vec::Vec;
pub use burn_derive::Module;
use burn_tensor::{ops::Device, quantization::Calibration, Bool, Int, Tensor};
use burn_tensor::{ops::Device, Bool, Int, Tensor};

/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
/// the `alloc` crate.
Expand Down Expand Up @@ -204,7 +204,7 @@ pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
}

/// Quantize the weights of the module.
fn quantize_weights<C: Calibration>(self, quantizer: &mut Quantizer<C>) -> Self {
fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
self.map(quantizer)
}
}
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-core/src/module/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ use burn_tensor::{
use crate::module::{ModuleMapper, ParamId};

/// Describes how to quantize a module.
pub struct Quantizer<C: Calibration> {
pub struct Quantizer {
/// The calibration method used in quantization.
pub calibration: C,
pub calibration: Calibration,
/// The quantization scheme.
pub scheme: QuantizationScheme,
}

impl<B: Backend, C: Calibration> ModuleMapper<B> for Quantizer<C> {
impl<B: Backend> ModuleMapper<B> for Quantizer {
fn map_float<const D: usize>(&mut self, _id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let range = self.calibration.compute_range(&tensor);
let range = self.scheme.compute_range(&tensor, &self.calibration);
let qparams = self.scheme.compute_q_params(range);
tensor.quantize(&self.scheme, qparams)
}
Expand Down
Loading