Skip to content

Commit

Permalink
Reshape bug fix (#1684)
Browse files Browse the repository at this point in the history
* Revert 1c639c8

1c639c8?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze
  • Loading branch information
antimora authored Apr 25, 2024
1 parent 886a1de commit a1bd14c
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,18 @@ fn concat_update_outputs(node: &mut Node) {

node.outputs[0].ty = ArgType::Tensor(tensor.clone());
}

fn reshape_update_outputs(node: &mut Node) {
let shape = match node.inputs.get(1) {
Some(input) => match &input.value {
Some(Data::Int64s(shape)) => Some(shape.clone()),
_ => panic!("Reshape: invalid input types"),
},
None => node.attrs.get("shape").cloned().map(|v| v.into_i64s()),
let shape = if node.inputs.len() == 2 {
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(shape) => Some(shape.clone()),
_ => panic!("Reshape: invalid input types"),
},
None => None,
}
} else {
node.attrs.get("shape").cloned().map(|v| v.into_i64s())
};

let output = match &node.outputs[0].ty {
Expand Down Expand Up @@ -252,24 +257,34 @@ fn reduce_mean_update_outputs(node: &mut Node) {

/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = match node.inputs.get(1) {
Some(input) => match &input.value {
Some(Data::Int64s(axes)) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => node.attrs.get("axes").cloned().map(|v| v.into_i64s()),
let axes = if node.inputs.len() == 2 {
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(axes) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => None,
}
} else {
node.attrs.get("axes").cloned().map(|v| v.into_i64s())
};

// need output way up here to avoid borrowing issues
let input = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
ty => panic!("Unsqueeze: invalid output type ({ty:?})"),
_ => panic!("Unsqueeze: invalid output types"),
};

let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
};

if let Some(axes) = axes {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input.dim + axes.len(),
shape: None, // shape is calculated at runtime
..input
..output
});
}
}
Expand Down

0 comments on commit a1bd14c

Please sign in to comment.