diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 166c0dbd42..95cc681b94 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -9,13 +9,13 @@ use super::{ dropout::DropoutNode, expand::ExpandNode, floor::FloorNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - matmul_integer::MatMulIntegerNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode, - pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, - random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode, - random_uniform_like::RandomUniformLikeNode, range::RangeNode, reshape::ReshapeNode, - resize::ResizeNode, slice::SliceNode, split::SplitNode, squeeze::SqueezeNode, sum::SumNode, - tile::TileNode, top_k::TopKNode, trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + matmul_integer::MatMulIntegerNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, + mean::MeanNode, one_hot::OneHotNode, pad::PadNode, prelu::PReluNode, + random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode, + random_uniform::RandomUniformNode, random_uniform_like::RandomUniformLikeNode, + range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, split::SplitNode, + squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, trilu::TriluNode, + unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::record::PrecisionSettings; diff --git a/crates/burn-import/src/burn/node/matmul_integer.rs b/crates/burn-import/src/burn/node/matmul_integer.rs index 8b61ca318d..d87e9cb79e 100644 --- a/crates/burn-import/src/burn/node/matmul_integer.rs +++ b/crates/burn-import/src/burn/node/matmul_integer.rs @@ -16,7 +16,13 @@ pub struct MatMulIntegerNode { } impl MatMulIntegerNode { - pub fn new(lhs: TensorType, rhs: TensorType, output: TensorType, a_zero_point: Option, b_zero_point: Option) -> Self { + pub fn new( + lhs: TensorType, + rhs: TensorType, + output: TensorType, + a_zero_point: Option, + b_zero_point: Option, + ) -> Self { // Validate tensor types - using Int for quantized tensors if lhs.kind != TensorKind::Int || rhs.kind != TensorKind::Int { panic!("MatMulInteger is only implemented for integer tensors"); @@ -40,7 +46,13 @@ impl MatMulIntegerNode { } } - Self { lhs, rhs, output, a_zero_point, b_zero_point } + Self { + lhs, + rhs, + output, + a_zero_point, + b_zero_point, + } } } @@ -151,7 +163,12 @@ mod tests { )); graph.register_input_output( - vec!["tensor1".to_string(), "tensor2".to_string(), "a_zero_point".to_string(), "b_zero_point".to_string()], + vec![ + "tensor1".to_string(), + "tensor2".to_string(), + "a_zero_point".to_string(), + "b_zero_point".to_string(), + ], vec!["tensor3".to_string()], ); @@ -193,4 +210,4 @@ mod tests { assert_tokens(graph.codegen(), expected); } -} \ No newline at end of file +}