diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index ca108763b2..166c0dbd42 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -9,6 +9,7 @@ use super::{ dropout::DropoutNode, expand::ExpandNode, floor::FloorNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, + matmul_integer::MatMulIntegerNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode, @@ -107,6 +108,7 @@ pub enum Node { LayerNorm(LayerNormNode), Linear(LinearNode), Matmul(MatmulNode), + MatmulInteger(MatMulIntegerNode), MaxPool1d(MaxPool1dNode), MaxPool2d(MaxPool2dNode), Mean(MeanNode), @@ -163,6 +165,7 @@ macro_rules! match_all { Node::LayerNorm(node) => $func(node), Node::Linear(node) => $func(node), Node::Matmul(node) => $func(node), + Node::MatmulInteger(node) => $func(node), Node::MaxPool1d(node) => $func(node), Node::MaxPool2d(node) => $func(node), Node::Mean(node) => $func(node), diff --git a/crates/burn-import/src/burn/node/matmul_integer.rs b/crates/burn-import/src/burn/node/matmul_integer.rs new file mode 100644 index 0000000000..8b61ca318d --- /dev/null +++ b/crates/burn-import/src/burn/node/matmul_integer.rs @@ -0,0 +1,196 @@ +use core::cmp::Ordering; + +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorKind, TensorType, ToTokens, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone)] +pub struct MatMulIntegerNode { + pub lhs: TensorType, + pub rhs: TensorType, + pub output: TensorType, + pub a_zero_point: Option, + pub b_zero_point: Option, +} + +impl MatMulIntegerNode { + pub fn new(lhs: TensorType, rhs: TensorType, output: TensorType, a_zero_point: Option, b_zero_point: Option) -> Self { + // Validate tensor types - using Int for quantized tensors + if lhs.kind != TensorKind::Int || rhs.kind != TensorKind::Int { + panic!("MatMulInteger is only implemented for integer tensors"); + } + + // Output is typically an Int32 tensor in ONNX + if output.kind != TensorKind::Int { + panic!("MatMulInteger output must be an integer tensor"); + } + + // Validate zero points if provided + if let Some(a_zero) = &a_zero_point { + if a_zero.kind != TensorKind::Int { + panic!("A zero point must be an integer tensor"); + } + } + + if let Some(b_zero) = &b_zero_point { + if b_zero.kind != TensorKind::Int { + panic!("B zero point must be an integer tensor"); + } + } + + Self { lhs, rhs, output, a_zero_point, b_zero_point } + } +} + +impl NodeCodegen for MatMulIntegerNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + let mut input_types = vec![ + Type::Tensor(self.lhs.clone()), + Type::Tensor(self.rhs.clone()), + ]; + if let Some(a_zero_point) = &self.a_zero_point { + input_types.push(Type::Tensor(a_zero_point.clone())); + } + if let Some(b_zero_point) = &self.b_zero_point { + input_types.push(Type::Tensor(b_zero_point.clone())); + } + input_types + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let lhs = scope.tensor_use_owned(&self.lhs, node_position); + let rhs = scope.tensor_use_owned(&self.rhs, node_position); + let output = &self.output.name; + + let a_zero_point = if let Some(a_zero_point) = &self.a_zero_point { + scope.tensor_use_owned(a_zero_point, node_position) + } else { + quote! { 0 } + }; + + let b_zero_point = if let Some(b_zero_point) = &self.b_zero_point { + scope.tensor_use_owned(b_zero_point, node_position) + } else { + quote! { 0 } + }; + + let lhs_dim = self.lhs.dim; + let rhs_dim = self.rhs.dim; + + // Support broadcasting for missing dimensions + match lhs_dim.cmp(&rhs_dim) { + Ordering::Greater => { + let axes = (0..lhs_dim - rhs_dim) + .map(|i| if i % 2 == 0 { 0 } else { -1 }) + .collect::>(); + let axes = axes.to_tokens(); + + if rhs_dim == 1 { + let squeeze_dim = lhs_dim - 1; + quote! { + let #output = (#lhs - #a_zero_point).matmul((#rhs.unsqueeze_dims(&#axes) - #b_zero_point)).squeeze(#squeeze_dim); + } + } else { + quote! { + let #output = (#lhs - #a_zero_point).matmul((#rhs.unsqueeze_dims(&#axes) - #b_zero_point)); + } + } + } + Ordering::Less => { + let axes = [0i64].repeat(rhs_dim - lhs_dim).to_tokens(); + + if lhs_dim == 1 { + let squeeze_dim = rhs_dim - 2; + quote! { + let #output = (#lhs.unsqueeze_dims(&#axes) - #a_zero_point).matmul((#rhs - #b_zero_point)).squeeze(#squeeze_dim); + } + } else { + quote! { + let #output = (#lhs.unsqueeze_dims(&#axes) - #a_zero_point).matmul((#rhs - #b_zero_point)); + } + } + } + Ordering::Equal => quote! { + let #output = (#lhs - #a_zero_point).matmul((#rhs - #b_zero_point)); + }, + } + } + + fn into_node(self) -> Node { + Node::MatmulInteger(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{matmul_integer::MatMulIntegerNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_matmul_integer() { + let mut graph = BurnGraph::::default(); + + graph.register(MatMulIntegerNode::new( + TensorType::new_int("tensor1", 4), + TensorType::new_int("tensor2", 4), + TensorType::new_int("tensor3", 4), + Some(TensorType::new_int("a_zero_point", 1)), + Some(TensorType::new_int("b_zero_point", 1)), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string(), "a_zero_point".to_string(), "b_zero_point".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor, + a_zero_point: Tensor, + b_zero_point: Tensor, + ) -> Tensor { + let tensor3 = (tensor1 - a_zero_point).matmul((tensor2 - b_zero_point)); + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} \ No newline at end of file diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 39154ed979..b0da4bb993 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -25,6 +25,7 @@ pub(crate) mod layer_norm; pub(crate) mod linear; pub(crate) mod mask_where; pub(crate) mod matmul; +pub(crate) mod matmul_integer; pub(crate) mod max_pool1d; pub(crate) mod max_pool2d; pub(crate) mod mean;