diff --git a/examples/simple-regression/src/model.rs b/examples/simple-regression/src/model.rs index 318e1e5da0..bcef210db9 100644 --- a/examples/simple-regression/src/model.rs +++ b/examples/simple-regression/src/model.rs @@ -50,7 +50,7 @@ impl RegressionModel { } pub fn forward_step(&self, item: DiabetesBatch) -> RegressionOutput { - let targets: Tensor = item.targets.unsqueeze(); + let targets: Tensor = item.targets.unsqueeze_dim(1); let output: Tensor = self.forward(item.inputs); let loss = MseLoss::new().forward(output.clone(), targets.clone(), Mean);