Skip to content

Commit

Permalink
Fix simple regression batch targets (#2379)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjiawen2013 authored Oct 17, 2024
1 parent 359ad08 commit 296c526
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/simple-regression/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<B: Backend> RegressionModel<B> {
}

pub fn forward_step(&self, item: DiabetesBatch<B>) -> RegressionOutput<B> {
let targets: Tensor<B, 2> = item.targets.unsqueeze();
let targets: Tensor<B, 2> = item.targets.unsqueeze_dim(1);
let output: Tensor<B, 2> = self.forward(item.inputs);

let loss = MseLoss::new().forward(output.clone(), targets.clone(), Mean);
Expand Down

0 comments on commit 296c526

Please sign in to comment.