From 06bfb81c5d2613d86803e4fa62ca9455cc6dc631 Mon Sep 17 00:00:00 2001 From: Hauke Schulz <43613877+observingClouds@users.noreply.github.com> Date: Sun, 26 Jan 2025 21:52:16 +0100 Subject: [PATCH] fix incorrect input type --- neural_lam/models/ar_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index d19805f1..81c2f720 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -181,7 +181,7 @@ def _create_dataarray_from_tensor( weather_dataset = WeatherDataset(datastore=self._datastore, split=split) time = np.array(time.cpu(), dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor.cpu().numpy(), time=time, category=category + tensor=tensor, time=time, category=category ) return da