Skip to content

Commit

Permalink
Update TransformerTutorial.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMGeo committed Feb 14, 2025
1 parent 95c703f commit 7df95a2
Showing 1 changed file with 0 additions and 107 deletions.
107 changes: 0 additions & 107 deletions TransformerTutorial/TransformerTutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1136,113 +1136,6 @@
}
]
},
{
"cell_type": "code",
"source": [
"def visualize_prediction_sequence(model, dataset, dataloader, device, sample_idx=0):\n",
" \"\"\"\n",
" Creates a visualization comparing model predictions with actual values.\n",
"\n",
" Args:\n",
" model: Trained WeatherTransformer model\n",
" dataset: WeatherForecastDataset instance\n",
" dataloader: DataLoader instance\n",
" device: torch device\n",
" sample_idx: Index of the sample to visualize\n",
" \"\"\"\n",
" model.eval()\n",
"\n",
" # Get predictions\n",
" inputs, targets = dataset[sample_idx]\n",
" with torch.no_grad():\n",
" predictions = model(inputs.unsqueeze(0).to(device))\n",
"\n",
" # Denormalize data\n",
" inputs = inputs.cpu().numpy() * dataset.std + dataset.mean\n",
" targets = targets.cpu().numpy() * dataset.std + dataset.mean\n",
" predictions = predictions.cpu().numpy()[0] * dataset.std + dataset.mean\n",
"\n",
" # Setup visualization\n",
" fig = plt.figure(figsize=(20, 12))\n",
" n_input, n_output = inputs.shape[-1], targets.shape[-1]\n",
" total_steps = n_input + 2 + n_output\n",
"\n",
" # Configure map and temperature display\n",
" vmin, vmax = min(inputs.min(), targets.min(), predictions.min()), max(inputs.max(), targets.max(), predictions.max())\n",
" # Balanced number of contour levels for detail while maintaining clarity\n",
" levels = np.linspace(vmin, vmax, 10)\n",
"\n",
" for t in range(total_steps):\n",
" ax = fig.add_subplot(4, 6, t + 1, projection=ccrs.PlateCarree())\n",
" ax.coastlines(resolution='50m', color='black', linewidth=0.5)\n",
" ax.add_feature(cfeature.BORDERS, linewidth=0.5)\n",
" ax.set_extent([\n",
" float(dataset.data.lon.min()), float(dataset.data.lon.max()),\n",
" float(dataset.data.lat.min()), float(dataset.data.lat.max())\n",
" ])\n",
"\n",
" if t < n_input: # Historical data\n",
" data = inputs[:, :, t]\n",
" hours_ago = (n_input - t) * 6\n",
" title = f'Training Sequence\\nt-{hours_ago}h\\n{data.mean():.1f}K'\n",
" # Add contours to historical data\n",
" ax.contour(dataset.data.lon, dataset.data.lat, data,\n",
" levels=levels, colors='black', alpha=0.3,\n",
" transform=ccrs.PlateCarree())\n",
"\n",
" elif t < n_input + 2: # GAP frames\n",
" weight = (t - n_input) / 2\n",
" data = (1 - weight) * inputs[:, :, -1] + weight * predictions[:, :, 0]\n",
" title = f'GAP\\nt+{(t-n_input)*6}h'\n",
" # Add interpolated contours in gap\n",
" ax.contour(dataset.data.lon, dataset.data.lat, data,\n",
" levels=levels, colors='black', alpha=0.3,\n",
" transform=ccrs.PlateCarree())\n",
"\n",
" else: # Predictions vs truth\n",
" pred_idx = t - (n_input + 2)\n",
" pred_data = predictions[:, :, pred_idx]\n",
" true_data = targets[:, :, pred_idx]\n",
" data = pred_data\n",
"\n",
" # Calculate metrics\n",
" rmse = np.sqrt(np.mean((pred_data - true_data) ** 2))\n",
" corr = np.corrcoef(pred_data.flatten(), true_data.flatten())[0, 1]\n",
"\n",
" title = (f'Predictions\\nt+{(pred_idx+2)*6}h\\n'\n",
" f'True: {true_data.mean():.1f}K\\n'\n",
" f'Pred: {pred_data.mean():.1f}K\\n'\n",
" f'RMSE: {rmse:.2f}K')\n",
"\n",
" # Add truth contours\n",
" ax.contour(dataset.data.lon, dataset.data.lat, true_data,\n",
" levels=levels, colors='black', alpha=0.3,\n",
" transform=ccrs.PlateCarree())\n",
"\n",
" # Plot data\n",
" mesh = ax.pcolormesh(dataset.data.lon, dataset.data.lat, data,\n",
" transform=ccrs.PlateCarree(),\n",
" cmap='YlGnBu_r', vmin=vmin, vmax=vmax)\n",
" ax.set_title(title, fontsize=8, pad=4)\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
"\n",
" # Add colorbar and title\n",
" cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])\n",
" fig.colorbar(mesh, cax=cbar_ax, label='Temperature (K)')\n",
" plt.suptitle(f'Temperature Prediction Sequence - Sample {sample_idx}',\n",
" y=0.95, fontsize=16)\n",
"\n",
" plt.subplots_adjust(top=0.9, bottom=0.1, left=0.05, right=0.9,\n",
" wspace=0.1, hspace=0.2)\n",
" plt.show()"
],
"metadata": {
"id": "Z2rySkWQTvlC"
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
Expand Down

0 comments on commit 7df95a2

Please sign in to comment.