Skip to content

Commit

Permalink
Fix ONNX Gather codegen for Shape input (#2148)
Browse files Browse the repository at this point in the history
* Fix ONNX Gather codegen for Shape input

* Remove unneccessary cast, switch to slice for ownership
  • Loading branch information
hexd0t authored Aug 15, 2024
1 parent 0435721 commit 16239db
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 14 deletions.
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ fn main() {
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx")
.input("tests/gather/gather_scalar.onnx")
.input("tests/gather/gather_shape.onnx")
.input("tests/gather_elements/gather_elements.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
Expand Down
Binary file not shown.
69 changes: 69 additions & 0 deletions crates/burn-import/onnx-tests/tests/gather/gather_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

# used to generate model: gather_shape.onnx

# torch doesn't easily generate Shape into Gather operations in ONNX
# (tensor.size and .shape just return a tuple, no tensor)
# Hence this model is exported using onnx directly

import onnx
import onnx.helper


def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"Shape",
inputs=["input1"],
outputs=["shape1"],
name="/Shape"
),
onnx.helper.make_node(
"Gather",
inputs=["shape1", "input2"],
outputs=["output1"],
name="/Gather",
axis=0
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2,3]
),
),
onnx.helper.make_value_info(
name="input2",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[1]
),
),

],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[1]
),
)
]),
)


def main():
onnx_model = build_model()
file_name = "gather_shape.onnx"

# Ensure valid ONNX:
onnx.checker.check_model(onnx_model)

onnx.save(onnx_model, file_name)


if __name__ == "__main__":
main()
16 changes: 16 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include_models!(
flatten,
gather,
gather_scalar,
gather_shape,
gather_elements,
gelu,
global_avr_pool,
Expand Down Expand Up @@ -463,6 +464,21 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn gather_shape() {
let model: gather_shape::Model<Backend> = gather_shape::Model::default();

let device = Default::default();

let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
// shape(input) = [2, 3]
let index = Tensor::<Backend, 1, Int>::from_ints([0], &device);
let output = model.forward(input, index);
let expected = TensorData::from([2i64]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn gather_scalar() {
let model: gather_scalar::Model<Backend> = gather_scalar::Model::default();
Expand Down
80 changes: 72 additions & 8 deletions crates/burn-import/src/burn/node/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use quote::quote;

#[derive(Debug, Clone, new)]
pub struct GatherNode {
pub input: TensorType,
pub input: Type,
pub index: Type,
pub output: TensorType,
pub dim: usize,
Expand All @@ -18,7 +18,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![Type::Tensor(self.input.clone()), self.index.clone()]
vec![self.input.clone(), self.index.clone()]
}

fn forward(
Expand All @@ -27,7 +27,17 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
node_position: usize,
) -> proc_macro2::TokenStream {
let dim = self.dim.to_tokens();
let input = scope.tensor_use_owned(&self.input, node_position);
let input = match &self.input {
Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position),
Type::Shape(in_shape) => {
let in_shape_name = &in_shape.name;
// To copy just the values from the shape value without moving it
// (which could lead to ownership problems if the same Shape is used multiple times)
// borrow the array as a slice and use that to create the Tensor:
quote! { Tensor::from_data(&#in_shape_name as &[_], &*self.device) }
}
_ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input),
};
let output = &self.output.name;

match &self.index {
Expand All @@ -46,7 +56,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
let #output = #input.select(#dim, #index);
}
}
_ => panic!("Gather needs Scalar or Tensor index!"),
_ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index),
}
}

Expand All @@ -64,15 +74,15 @@ mod tests {
use crate::burn::{
graph::BurnGraph,
node::{gather::GatherNode, test::assert_tokens},
ScalarKind, ScalarType, TensorType,
ScalarKind, ScalarType, ShapeType, TensorType,
};

#[test]
fn test_codegen_gather() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(GatherNode::new(
TensorType::new_float("tensor1", 2),
Type::Tensor(TensorType::new_float("tensor1", 2)),
Type::Tensor(TensorType::new_int("tensor2", 1)),
TensorType::new_float("tensor3", 2),
0,
Expand Down Expand Up @@ -122,11 +132,65 @@ mod tests {
}

#[test]
fn test_codegen_gather_scalar() {
fn test_codegen_gather_shape_input() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(GatherNode::new(
Type::Shape(ShapeType::new("shape1", 3)),
Type::Tensor(TensorType::new_int("tensor1", 1)),
TensorType::new_float("tensor2", 2),
0,
));

graph.register_input_output(
vec!["shape1".to_string(), "tensor1".to_string()],
vec!["tensor2".to_string()],
);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
shape1: [usize; 3],
tensor1: Tensor<B, 1, Int>
) -> Tensor<B, 2> {
let tensor2 = Tensor::from_data(&shape1 as &[_], &*self.device).select(0, tensor1);

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}

#[test]
fn test_codegen_gather_scalar_idx() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(GatherNode::new(
TensorType::new_float("tensor1", 2),
Type::Tensor(TensorType::new_float("tensor1", 2)),
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
TensorType::new_float("tensor2", 2),
0,
Expand Down
9 changes: 5 additions & 4 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,10 @@ pub fn gather_config(curr: &Node) -> usize {
}

// extract the shape of the input tensor
let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
let input_dim = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor.dim as i64,
ArgType::Shape(_shape) => 1, //Shape is always 1-D
other => panic!("Only tensor or shape input is valid, got {:?}", other),
};

// extract the attributes
Expand All @@ -469,7 +470,7 @@ pub fn gather_config(curr: &Node) -> usize {

// if dim is negative, it is counted from the end
if dim < 0 {
dim += tensor.dim as i64;
dim += input_dim;
}

dim as usize
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ impl ParsedOnnxGraph {
}

fn gather_conversion(node: Node) -> GatherNode {
let input = TensorType::from(node.inputs.first().unwrap());
let input = Type::from(node.inputs.first().unwrap());
let index = Type::from(node.inputs.get(1).unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let dim = gather_config(&node);
Expand Down Expand Up @@ -1242,7 +1242,7 @@ impl From<&OnnxArgument> for TensorType {
dim,
..
}) => TensorType::new_bool(arg.name.clone(), *dim),
_ => panic!("Can't transform scalar to tensor."),
_ => panic!("Can't transform {:?} to tensor.", arg.ty),
}
}
}
Expand Down

0 comments on commit 16239db

Please sign in to comment.