diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 78a304a659..ac6bc02729 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -195,13 +195,18 @@ fn concat_update_outputs(node: &mut Node) { node.outputs[0].ty = ArgType::Tensor(tensor.clone()); } + fn reshape_update_outputs(node: &mut Node) { - let shape = match node.inputs.get(1) { - Some(input) => match &input.value { - Some(Data::Int64s(shape)) => Some(shape.clone()), - _ => panic!("Reshape: invalid input types"), - }, - None => node.attrs.get("shape").cloned().map(|v| v.into_i64s()), + let shape = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match value { + Data::Int64s(shape) => Some(shape.clone()), + _ => panic!("Reshape: invalid input types"), + }, + None => None, + } + } else { + node.attrs.get("shape").cloned().map(|v| v.into_i64s()) }; let output = match &node.outputs[0].ty { @@ -252,24 +257,34 @@ fn reduce_mean_update_outputs(node: &mut Node) { /// Update the output tensor dimension based on the "axes" attribute or the second input fn unsqueeze_update_output(node: &mut Node) { - let axes = match node.inputs.get(1) { - Some(input) => match &input.value { - Some(Data::Int64s(axes)) => Some(axes.clone()), - _ => panic!("Unsqueeze: invalid input types"), - }, - None => node.attrs.get("axes").cloned().map(|v| v.into_i64s()), + let axes = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match value { + Data::Int64s(axes) => Some(axes.clone()), + _ => panic!("Unsqueeze: invalid input types"), + }, + None => None, + } + } else { + node.attrs.get("axes").cloned().map(|v| v.into_i64s()) }; + // need output way up here to avoid borrowing issues let input = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor.clone(), - ty => panic!("Unsqueeze: invalid output type ({ty:?})"), + _ => panic!("Unsqueeze: invalid output types"), + }; + + let output = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Unsqueeze: invalid output types"), }; if let Some(axes) = axes { node.outputs[0].ty = ArgType::Tensor(TensorType { dim: input.dim + axes.len(), shape: None, // shape is calculated at runtime - ..input + ..output }); } }