From 296c526551e080e17eb1931c4cd7e880f9cbfb27 Mon Sep 17 00:00:00 2001 From: jiawen wang Date: Thu, 17 Oct 2024 22:51:08 +0800 Subject: [PATCH] Fix simple regression batch targets (#2379) --- examples/simple-regression/src/model.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/simple-regression/src/model.rs b/examples/simple-regression/src/model.rs index 318e1e5da0..bcef210db9 100644 --- a/examples/simple-regression/src/model.rs +++ b/examples/simple-regression/src/model.rs @@ -50,7 +50,7 @@ impl RegressionModel { } pub fn forward_step(&self, item: DiabetesBatch) -> RegressionOutput { - let targets: Tensor = item.targets.unsqueeze(); + let targets: Tensor = item.targets.unsqueeze_dim(1); let output: Tensor = self.forward(item.inputs); let loss = MseLoss::new().forward(output.clone(), targets.clone(), Mean);