diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index d19805f1..81c2f720 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -181,7 +181,7 @@ def _create_dataarray_from_tensor( weather_dataset = WeatherDataset(datastore=self._datastore, split=split) time = np.array(time.cpu(), dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor.cpu().numpy(), time=time, category=category + tensor=tensor, time=time, category=category ) return da