Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
NewBornRustacean committed Feb 26, 2025
1 parent 2544b3e commit 5cb4e2a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
14 changes: 7 additions & 7 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ use super::{
dropout::DropoutNode, expand::ExpandNode, floor::FloorNode, gather::GatherNode,
gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
matmul_integer::MatMulIntegerNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode,
pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode,
random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode,
random_uniform_like::RandomUniformLikeNode, range::RangeNode, reshape::ReshapeNode,
resize::ResizeNode, slice::SliceNode, split::SplitNode, squeeze::SqueezeNode, sum::SumNode,
tile::TileNode, top_k::TopKNode, trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
matmul_integer::MatMulIntegerNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode,
mean::MeanNode, one_hot::OneHotNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode,
random_uniform::RandomUniformNode, random_uniform_like::RandomUniformLikeNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, split::SplitNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, trilu::TriluNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::record::PrecisionSettings;
Expand Down
25 changes: 21 additions & 4 deletions crates/burn-import/src/burn/node/matmul_integer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ pub struct MatMulIntegerNode {
}

impl MatMulIntegerNode {
pub fn new(lhs: TensorType, rhs: TensorType, output: TensorType, a_zero_point: Option<TensorType>, b_zero_point: Option<TensorType>) -> Self {
pub fn new(
lhs: TensorType,
rhs: TensorType,
output: TensorType,
a_zero_point: Option<TensorType>,
b_zero_point: Option<TensorType>,
) -> Self {
// Validate tensor types - using Int for quantized tensors
if lhs.kind != TensorKind::Int || rhs.kind != TensorKind::Int {
panic!("MatMulInteger is only implemented for integer tensors");
Expand All @@ -40,7 +46,13 @@ impl MatMulIntegerNode {
}
}

Self { lhs, rhs, output, a_zero_point, b_zero_point }
Self {
lhs,
rhs,
output,
a_zero_point,
b_zero_point,
}
}
}

Expand Down Expand Up @@ -151,7 +163,12 @@ mod tests {
));

graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string(), "a_zero_point".to_string(), "b_zero_point".to_string()],
vec![
"tensor1".to_string(),
"tensor2".to_string(),
"a_zero_point".to_string(),
"b_zero_point".to_string(),
],
vec!["tensor3".to_string()],
);

Expand Down Expand Up @@ -193,4 +210,4 @@ mod tests {

assert_tokens(graph.codegen(), expected);
}
}
}

0 comments on commit 5cb4e2a

Please sign in to comment.