From 1f7b6bee00f5c17c207783f5a8052c13106d1de0 Mon Sep 17 00:00:00 2001
From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com>
Date: Mon, 24 Feb 2025 16:25:37 -0800
Subject: [PATCH] Add example with Partially Bayesian Transformer
---
.../PartialBayesianTransformer_esol.ipynb | 943 ++++++++++++++++++
1 file changed, 943 insertions(+)
create mode 100644 examples/PartialBayesianTransformer_esol.ipynb
diff --git a/examples/PartialBayesianTransformer_esol.ipynb b/examples/PartialBayesianTransformer_esol.ipynb
new file mode 100644
index 0000000..b92f9a8
--- /dev/null
+++ b/examples/PartialBayesianTransformer_esol.ipynb
@@ -0,0 +1,943 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "machine_shape": "hm",
+ "gpuType": "A100",
+ "authorship_tag": "ABX9TyOQTGOXN/O7bazRtzBSfsXD",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Partially Bayesian Transformer - Molecular Property Prediction from SMILES\n",
+ "\n",
+ "*Prepared by Maxim Ziatdinov (February 2025)*\n",
+ "\n",
+ "This notebook demonstrates the application of Partially Bayesian Transformers to predict molecular solubility from SMILES representations using the NeuroBayes framework. It starts with a deterministic transformer model for the ESOL dataset, then converts it into partially Bayesian models by making different components probabilistic - specifically the token embedding layer, final dense layer, and attention layer."
+ ],
+ "metadata": {
+ "id": "UIsh9MppMDBx"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Install NeuroBayes:"
+ ],
+ "metadata": {
+ "id": "4o119a2WNZ37"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "e-uTgq6lvyRJ",
+ "outputId": "102b4fa2-3351-4764-be46-18a5c9b5f7de"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting git+https://github.com/ziatdinovmax/NeuroBayes\n",
+ " Cloning https://github.com/ziatdinovmax/NeuroBayes to /tmp/pip-req-build-350_dh23\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/ziatdinovmax/NeuroBayes /tmp/pip-req-build-350_dh23\n",
+ " Resolved https://github.com/ziatdinovmax/NeuroBayes to commit 668d10d3130362e74dba9221bcdc887b76b1951d\n",
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Collecting jax<=0.4.31,>=0.4.0 (from NeuroBayes==0.0.12)\n",
+ " Downloading jax-0.4.31-py3-none-any.whl.metadata (22 kB)\n",
+ "Collecting jaxlib<=0.4.31,>=0.4.0 (from NeuroBayes==0.0.12)\n",
+ " Downloading jaxlib-0.4.31-cp311-cp311-manylinux2014_x86_64.whl.metadata (983 bytes)\n",
+ "Collecting numpyro>=0.13.0 (from NeuroBayes==0.0.12)\n",
+ " Downloading numpyro-0.17.0-py3-none-any.whl.metadata (37 kB)\n",
+ "Requirement already satisfied: flax>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from NeuroBayes==0.0.12) (0.10.3)\n",
+ "Requirement already satisfied: numpy>=1.23.2 in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (1.26.4)\n",
+ "Requirement already satisfied: msgpack in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (1.1.0)\n",
+ "Requirement already satisfied: optax in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (0.2.4)\n",
+ "Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (0.6.4)\n",
+ "Requirement already satisfied: tensorstore in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (0.1.71)\n",
+ "Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (13.9.4)\n",
+ "Requirement already satisfied: typing_extensions>=4.2 in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (4.12.2)\n",
+ "Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (6.0.2)\n",
+ "Requirement already satisfied: treescope>=0.1.7 in /usr/local/lib/python3.11/dist-packages (from flax>=0.8.4->NeuroBayes==0.0.12) (0.1.9)\n",
+ "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from jax<=0.4.31,>=0.4.0->NeuroBayes==0.0.12) (0.4.1)\n",
+ "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.11/dist-packages (from jax<=0.4.31,>=0.4.0->NeuroBayes==0.0.12) (3.4.0)\n",
+ "Requirement already satisfied: scipy>=1.10 in /usr/local/lib/python3.11/dist-packages (from jax<=0.4.31,>=0.4.0->NeuroBayes==0.0.12) (1.13.1)\n",
+ "Requirement already satisfied: multipledispatch in /usr/local/lib/python3.11/dist-packages (from numpyro>=0.13.0->NeuroBayes==0.0.12) (1.0.0)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from numpyro>=0.13.0->NeuroBayes==0.0.12) (4.67.1)\n",
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=11.1->flax>=0.8.4->NeuroBayes==0.0.12) (3.0.0)\n",
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=11.1->flax>=0.8.4->NeuroBayes==0.0.12) (2.18.0)\n",
+ "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from optax->flax>=0.8.4->NeuroBayes==0.0.12) (1.4.0)\n",
+ "Requirement already satisfied: chex>=0.1.87 in /usr/local/lib/python3.11/dist-packages (from optax->flax>=0.8.4->NeuroBayes==0.0.12) (0.1.88)\n",
+ "Requirement already satisfied: etils[epy] in /usr/local/lib/python3.11/dist-packages (from optax->flax>=0.8.4->NeuroBayes==0.0.12) (1.12.0)\n",
+ "Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax>=0.8.4->NeuroBayes==0.0.12) (1.6.0)\n",
+ "Requirement already satisfied: protobuf in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax>=0.8.4->NeuroBayes==0.0.12) (4.25.6)\n",
+ "Requirement already satisfied: humanize in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax>=0.8.4->NeuroBayes==0.0.12) (4.11.0)\n",
+ "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from chex>=0.1.87->optax->flax>=0.8.4->NeuroBayes==0.0.12) (0.12.1)\n",
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.8.4->NeuroBayes==0.0.12) (0.1.2)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from etils[epath,epy]->orbax-checkpoint->flax>=0.8.4->NeuroBayes==0.0.12) (2024.10.0)\n",
+ "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.11/dist-packages (from etils[epath,epy]->orbax-checkpoint->flax>=0.8.4->NeuroBayes==0.0.12) (6.5.2)\n",
+ "Requirement already satisfied: zipp in /usr/local/lib/python3.11/dist-packages (from etils[epath,epy]->orbax-checkpoint->flax>=0.8.4->NeuroBayes==0.0.12) (3.21.0)\n",
+ "Downloading jax-0.4.31-py3-none-any.whl (2.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m30.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading jaxlib-0.4.31-cp311-cp311-manylinux2014_x86_64.whl (88.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.1/88.1 MB\u001b[0m \u001b[31m24.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading numpyro-0.17.0-py3-none-any.whl (360 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m360.8/360.8 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hBuilding wheels for collected packages: NeuroBayes\n",
+ " Building wheel for NeuroBayes (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for NeuroBayes: filename=NeuroBayes-0.0.12-py3-none-any.whl size=70460 sha256=cf31aa2fe9178eba0000e6d45ad05fa218692d940c8e21e90a5a1aae49a1381c\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-xdg9khda/wheels/52/f4/85/0e45c139ca1e184850945237cf19f7edb55c1ff7afbd82a351\n",
+ "Successfully built NeuroBayes\n",
+ "Installing collected packages: jaxlib, jax, numpyro, NeuroBayes\n",
+ " Attempting uninstall: jaxlib\n",
+ " Found existing installation: jaxlib 0.4.33\n",
+ " Uninstalling jaxlib-0.4.33:\n",
+ " Successfully uninstalled jaxlib-0.4.33\n",
+ " Attempting uninstall: jax\n",
+ " Found existing installation: jax 0.4.33\n",
+ " Uninstalling jax-0.4.33:\n",
+ " Successfully uninstalled jax-0.4.33\n",
+ "Successfully installed NeuroBayes-0.0.12 jax-0.4.31 jaxlib-0.4.31 numpyro-0.17.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install git+https://github.com/ziatdinovmax/NeuroBayes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Download dataset:"
+ ],
+ "metadata": {
+ "id": "Jxh7l0oeMgbc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!wget https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/delaney-processed.csv"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "scfe_zjFwl88",
+ "outputId": "928289dd-0dff-4060-d04f-93db70abeb2d"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--2025-02-24 18:29:02-- https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/delaney-processed.csv\n",
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...\n",
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 96699 (94K) [text/plain]\n",
+ "Saving to: ‘delaney-processed.csv’\n",
+ "\n",
+ "\rdelaney-processed.c 0%[ ] 0 --.-KB/s \rdelaney-processed.c 100%[===================>] 94.43K --.-KB/s in 0.02s \n",
+ "\n",
+ "2025-02-24 18:29:02 (4.06 MB/s) - ‘delaney-processed.csv’ saved [96699/96699]\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Check that GPU is available:"
+ ],
+ "metadata": {
+ "id": "9TKlZJrxH5XF"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!nvidia-smi"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1CoIPWp3H4Vm",
+ "outputId": "92c49564-13f4-4972-83d2-e23d35f58756"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Mon Feb 24 18:29:02 2025 \n",
+ "+-----------------------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |\n",
+ "|-----------------------------------------+------------------------+----------------------+\n",
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|=========================================+========================+======================|\n",
+ "| 0 NVIDIA A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 |\n",
+ "| N/A 29C P0 45W / 400W | 0MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-----------------------------------------+------------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=========================================================================================|\n",
+ "| No running processes found |\n",
+ "+-----------------------------------------------------------------------------------------+\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Import neccessary libraries:"
+ ],
+ "metadata": {
+ "id": "3oQdV6SUMjHq"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import neurobayes as nb\n",
+ "\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.metrics import r2_score"
+ ],
+ "metadata": {
+ "id": "7oxHwcB3x5Ym"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Preprocess data:"
+ ],
+ "metadata": {
+ "id": "KZU2yj8vwjjA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class SMILESTokenizer:\n",
+ " def __init__(self):\n",
+ " self.chars = set(' ()[]{}=-#@+/\\\\.%$NH?OSCFIBrclnop0123456789')\n",
+ " self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)}\n",
+ " self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}\n",
+ " self.vocab_size = len(self.chars)\n",
+ "\n",
+ " def encode(self, smiles):\n",
+ " return [self.char_to_idx.get(c, self.char_to_idx['?']) for c in smiles]\n",
+ "\n",
+ " def decode(self, tokens):\n",
+ " return ''.join([self.idx_to_char[t] for t in tokens])\n",
+ "\n",
+ "\n",
+ "def load_esol_data(file_path):\n",
+ " \"\"\"Load ESOL dataset from CSV\"\"\"\n",
+ " df = pd.read_csv(file_path)\n",
+ "\n",
+ " smiles_col = [col for col in df.columns if 'SMILES' in col.upper()][0]\n",
+ " solubility_col = [col for col in df.columns if 'SOLUBILITY' in col.upper()][0]\n",
+ "\n",
+ " return df[smiles_col].values, df[solubility_col].values\n",
+ "\n",
+ "def prepare_dataset(smiles_list, solubility_values, max_length=128):\n",
+ " \"\"\"Prepare entire dataset\"\"\"\n",
+ " tokenizer = SMILESTokenizer()\n",
+ " encoded = [tokenizer.encode(s) for s in smiles_list]\n",
+ " padded = [seq + [0] * (max_length - len(seq)) if len(seq) < max_length\n",
+ " else seq[:max_length] for seq in encoded]\n",
+ " return {\n",
+ " 'input_ids': np.array(padded, dtype=np.int32),\n",
+ " 'solubility': np.array(solubility_values, dtype=np.float32)\n",
+ " }\n",
+ "\n",
+ "smiles_list, solubility_values = load_esol_data('delaney-processed.csv')\n",
+ "\n",
+ "X_train, X_test, y_train, y_test = train_test_split(\n",
+ " smiles_list, solubility_values, test_size=0.5, random_state=1\n",
+ ")\n",
+ "\n",
+ "train_data = prepare_dataset(X_train.tolist(), y_train.tolist())\n",
+ "test_data = prepare_dataset(X_test.tolist(), y_test.tolist())"
+ ],
+ "metadata": {
+ "id": "hKkBrNTcwnFn"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Initialize and train a deterministic transformer model:"
+ ],
+ "metadata": {
+ "id": "RHvxayUPzXDg"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "tokenizer = SMILESTokenizer()\n",
+ "model = nb.flax_nets.FlaxTransformer(\n",
+ " vocab_size=tokenizer.vocab_size,\n",
+ " d_model=16,\n",
+ " nhead=4,\n",
+ " num_layers=2,\n",
+ " dim_feedforward=64)\n",
+ "\n",
+ "nn_model = nb.DeterministicNN(\n",
+ " architecture=model,\n",
+ " input_shape=128,\n",
+ " learning_rate=1e-3,\n",
+ " map=False,\n",
+ " swa_config={'schedule': 'constant', 'start_pct': 0.90},\n",
+ " collect_gradients=False\n",
+ ")\n",
+ "\n",
+ "\n",
+ "nn_model.train(\n",
+ " train_data['input_ids'],\n",
+ " train_data['solubility'],\n",
+ " epochs=50,\n",
+ " batch_size=32\n",
+ ")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "l38QDaHXv2NL",
+ "outputId": "a180fe00-096e-4eb2-e8ec-3ab6323e8aed"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Training Progress: 100%|██████████| 50/50 [00:17<00:00, 2.80it/s, Epoch 50/50, LR: 0.001000, Loss: 0.2789 ]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Make a prediction with the trained deterministic transformer:"
+ ],
+ "metadata": {
+ "id": "LvibYUtXWpiL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Get predictions on test set\n",
+ "predictions = nn_model.predict(test_data['input_ids']).squeeze()\n",
+ "\n",
+ "# Compute RMSE and R^2\n",
+ "test_rmse = np.sqrt(np.mean((test_data['solubility'] - predictions) ** 2))\n",
+ "test_r2 = r2_score(test_data['solubility'], predictions)\n",
+ "\n",
+ "print(\"\\nModel Performance:\")\n",
+ "print(f\"Test RMSE: {test_rmse:.4f}\")\n",
+ "print(f\"Test R²: {test_r2:.4f}\")\n",
+ "\n",
+ "# Show example predictions\n",
+ "print(\"\\nExample Predictions:\")\n",
+ "for i in range(min(5, len(X_test))):\n",
+ " pred = float(nn_model.predict(test_data['input_ids'][i:i+1])[0, 0]) # Access the scalar value\n",
+ " print(f\"SMILES: {X_test[i]}\")\n",
+ " print(f\"True Solubility: {y_test[i]:.2f}\")\n",
+ " print(f\"Predicted Solubility: {pred:.2f}\\n\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dVgPGKO00MaI",
+ "outputId": "280ed6e2-1bb1-4c80-9c5a-3c592f0e9217"
+ },
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "Model Performance:\n",
+ "Test RMSE: 0.5376\n",
+ "Test R²: 0.8925\n",
+ "\n",
+ "Example Predictions:\n",
+ "SMILES: CCCCCCCCCC(=O)OC\n",
+ "True Solubility: -3.32\n",
+ "Predicted Solubility: -2.62\n",
+ "\n",
+ "SMILES: Cc1cc(ccc1NS(=O)(=O)C(F)(F)F)S(=O)(=O)c2ccccc2\n",
+ "True Solubility: -4.95\n",
+ "Predicted Solubility: -4.49\n",
+ "\n",
+ "SMILES: Nc1nc[nH]n1\n",
+ "True Solubility: -0.67\n",
+ "Predicted Solubility: -0.31\n",
+ "\n",
+ "SMILES: Clc1ccc(c(Cl)c1)c2cc(Cl)c(Cl)c(Cl)c2Cl \n",
+ "True Solubility: -7.34\n",
+ "Predicted Solubility: -7.07\n",
+ "\n",
+ "SMILES: CCOC(=O)c1ccccc1\n",
+ "True Solubility: -2.77\n",
+ "Predicted Solubility: -2.88\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next, we convert our deterministic Transformer model into a partially Bayesian Transformer. We will test three types of layers: token embedding, attention, and final (fully connected) layer. Since making even a single layer probabilistic introduces substantial computational overhead, we will select only a subset of weights from that layer for Bayesian treatment. There are multiple ways to select this subset of probabilistic weights - based on magnitude, variance, and changes in gradient. At this point, it's not entirely clear to me which method should be used in which situation. For this tutorial notebook, we'll stick to 'variance', but feel free to explore other options, including completely random selection.\n",
+ "\n",
+ "If you see the sampling seriously struggling (very small step size and acceptance probability close to zero) after more than 20% of warmup steps, I recommend stopping it and changing the probabilistic weights selection strategy."
+ ],
+ "metadata": {
+ "id": "UU4PBhBoWxev"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "NeuroBayes has a utility function for printing each layer name and specs, which makes the selection and assignment process easier:"
+ ],
+ "metadata": {
+ "id": "lReOsn-jjua2"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "nb.print_layer_configs(nn_model.model)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "RUXStModj3iY",
+ "outputId": "0e797533-bfe8-4d8b-8558-c1078c6e1a10"
+ },
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "================================================================================\n",
+ "Model Architecture: FlaxTransformer\n",
+ "================================================================================\n",
+ "\n",
+ "layer_name layer_type features num_embeddings\n",
+ "-------------------------------------------------------\n",
+ "TokenEmbed embedding 16 42 \n",
+ "PosEmbed embedding 16 1024 \n",
+ "Block0_Attention attention 16 \n",
+ "Block0_LayerNorm1 layernorm - \n",
+ "Block0_MLP_dense1 fc 64 \n",
+ "Block0_MLP_dense2 fc 16 \n",
+ "Block0_LayerNorm2 layernorm - \n",
+ "Block1_Attention attention 16 \n",
+ "Block1_LayerNorm1 layernorm - \n",
+ "Block1_MLP_dense1 fc 64 \n",
+ "Block1_MLP_dense2 fc 16 \n",
+ "Block1_LayerNorm2 layernorm - \n",
+ "FinalDense1 fc 64 \n",
+ "FinalDense2 fc 1 \n",
+ "\n",
+ "================================================================================\n",
+ "Total layers: 14\n",
+ "================================================================================\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Select probabilistic layer(s), a corresponding subset of probabilistic weights, and train a partially Bayesian Transformer:"
+ ],
+ "metadata": {
+ "id": "D4NDkdZVj0YK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Specify probabilistic layers\n",
+ "probabilistic_layers = ['TokenEmbed']\n",
+ "\n",
+ "# Select a subset of weights for the selected probabilistic layers\n",
+ "prob_indices = nb.select_bayesian_components(\n",
+ " nn_model,\n",
+ " layer_names=probabilistic_layers,\n",
+ " method='variance',\n",
+ " num_pairs_per_layer=100\n",
+ ")\n",
+ "\n",
+ "# Initialize a partially Bayesian Transformer\n",
+ "pbnn_model = nb.PartialBayesianTransformer(\n",
+ " nn_model.model,\n",
+ " deterministic_weights=nn_model.get_params(),\n",
+ " probabilistic_layer_names=probabilistic_layers,\n",
+ " probabilistic_neurons=prob_indices\n",
+ ")\n",
+ "\n",
+ "# Train\n",
+ "pbnn_model.fit(\n",
+ " train_data[\"input_ids\"],\n",
+ " train_data['solubility'],\n",
+ " num_warmup=1000,\n",
+ " num_samples=1000,\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "5CF0IxasA-5A",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "4818ca8b-5f09-47e9-e0ff-ee801349dd1a"
+ },
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "sample: 100%|██████████| 2000/2000 [10:14<00:00, 3.25it/s, 63 steps of size 8.25e-02. acc. prob=0.91]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Compute prediction accuracy and uncertaintyu metrics:"
+ ],
+ "metadata": {
+ "id": "fGZMz7XrOCyI"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Get predictions on test set\n",
+ "predictions, variance = pbnn_model.predict(test_data['input_ids'])\n",
+ "predictions = predictions.squeeze()\n",
+ "variance = variance.squeeze()\n",
+ "\n",
+ "# Compute RMSE and R^2\n",
+ "test_rmse_emb = np.sqrt(np.mean((test_data['solubility'] - predictions) ** 2))\n",
+ "test_r2_emb = r2_score(test_data['solubility'], predictions)\n",
+ "\n",
+ "print(\"Model Performance:\")\n",
+ "print(f\"Test RMSE: {test_rmse_emb:.4f}\")\n",
+ "print(f\"Test R²: {test_r2_emb:.4f}\")\n",
+ "\n",
+ "# Compute negative log predictive density and coverage probability\n",
+ "test_nlpd_emb = nb.utils.nlpd(test_data['solubility'], predictions, variance)\n",
+ "test_coverage_emb = nb.utils.coverage(test_data['solubility'], predictions, variance)\n",
+ "\n",
+ "print(f\"NLPD: {test_nlpd_emb:.4f}\")\n",
+ "print(f\"Coverage probability: {test_coverage_emb:.4f}\")"
+ ],
+ "metadata": {
+ "id": "pjKelSZG35xi",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "3a40d473-5363-416f-dc12-97922b0c9baf"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Model Performance:\n",
+ "Test RMSE: 0.5399\n",
+ "Test R²: 0.8915\n",
+ "NLPD: 0.8274\n",
+ "Coverage probability: 0.9096\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Now let's do the same for the final learnable layer in the architecture:"
+ ],
+ "metadata": {
+ "id": "OQ3lWS7WVtzZ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "probabilistic_layers = ['FinalDense1']\n",
+ "\n",
+ "prob_indices = nb.select_bayesian_components(\n",
+ " nn_model,\n",
+ " layer_names=probabilistic_layers,\n",
+ " method='variance',\n",
+ " num_pairs_per_layer=100\n",
+ ")\n",
+ "\n",
+ "pbnn_model = nb.PartialBayesianTransformer(\n",
+ " nn_model.model,\n",
+ " deterministic_weights=nn_model.get_params(),\n",
+ " probabilistic_layer_names=probabilistic_layers,\n",
+ " probabilistic_neurons=prob_indices\n",
+ ")\n",
+ "\n",
+ "pbnn_model.fit(\n",
+ " train_data[\"input_ids\"],\n",
+ " train_data['solubility'],\n",
+ " num_warmup=1000,\n",
+ " num_samples=1000,\n",
+ ")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ZOaFYIy9-Hz6",
+ "outputId": "94cab34e-e94c-466a-aded-996508e13471"
+ },
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "sample: 100%|██████████| 2000/2000 [09:50<00:00, 3.39it/s, 127 steps of size 2.92e-02. acc. prob=0.89]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Get predictions on test set\n",
+ "predictions, variance = pbnn_model.predict(test_data['input_ids'])\n",
+ "predictions = predictions.squeeze()\n",
+ "variance = variance.squeeze()\n",
+ "\n",
+ "# Compute RMSE and R^2\n",
+ "test_rmse_fin = np.sqrt(np.mean((test_data['solubility'] - predictions) ** 2))\n",
+ "test_r2_fin = r2_score(test_data['solubility'], predictions)\n",
+ "print(\"\\nModel Performance:\")\n",
+ "print(f\"Test RMSE: {test_rmse_fin:.4f}\")\n",
+ "print(f\"Test R²: {test_r2_fin:.4f}\")\n",
+ "\n",
+ "\n",
+ "# Compute negative log predictive density and coverage probability\n",
+ "test_nlpd_fin = nb.utils.nlpd(test_data['solubility'], predictions, variance)\n",
+ "test_coverage_fin = nb.utils.coverage(test_data['solubility'], predictions, variance)\n",
+ "print(f\"NLPD: {test_nlpd_fin:.4f}\")\n",
+ "print(f\"Coverage probability: {test_coverage_fin:.4f}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "U__IKNPsf0JB",
+ "outputId": "217198be-1650-409f-e675-c87a889a6cf1"
+ },
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "Model Performance:\n",
+ "Test RMSE: 0.5389\n",
+ "Test R²: 0.8920\n",
+ "NLPD: 0.8323\n",
+ "Coverage probability: 0.9149\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Now let's make one of the attention layers partially probabilisitc. We will also do a 'data thinning', i.e. select only a subset of traning data points, to speed up the computations:"
+ ],
+ "metadata": {
+ "id": "scS4sMXzd4Ek"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "probabilistic_layers = ['Block0_Attention']\n",
+ "\n",
+ "prob_indices = nb.select_bayesian_components(\n",
+ " nn_model,\n",
+ " probabilistic_layers,\n",
+ " method='variance',\n",
+ " num_pairs_per_layer=32)\n",
+ "\n",
+ "pbnn_model = nb.PartialBayesianTransformer(\n",
+ " nn_model.model,\n",
+ " deterministic_weights=nn_model.get_params(),\n",
+ " probabilistic_layer_names=probabilistic_layers,\n",
+ " probabilistic_neurons=prob_indices\n",
+ ")\n",
+ "\n",
+ "pbnn_model.fit(\n",
+ " train_data[\"input_ids\"][::10], # data thinning\n",
+ " train_data['solubility'][::10],\n",
+ " num_warmup=1000,\n",
+ " num_samples=1000,\n",
+ ")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "M-_r3wPld9wA",
+ "outputId": "4fa209fb-545c-4cc0-9d38-32868c5af755"
+ },
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "sample: 100%|██████████| 2000/2000 [05:10<00:00, 6.43it/s, 127 steps of size 3.00e-02. acc. prob=0.94]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Get predictive metrics:"
+ ],
+ "metadata": {
+ "id": "anHmlcSIet6Z"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Get predictions on test set\n",
+ "predictions, variance = pbnn_model.predict(test_data['input_ids'])\n",
+ "predictions = predictions.squeeze()\n",
+ "variance = variance.squeeze()\n",
+ "\n",
+ "# Compute RMSE and R^2\n",
+ "test_rmse_att = np.sqrt(np.mean((test_data['solubility'] - predictions) ** 2))\n",
+ "test_r2_att = r2_score(test_data['solubility'], predictions)\n",
+ "print(\"\\nModel Performance:\")\n",
+ "print(f\"Test RMSE: {test_rmse_att:.4f}\")\n",
+ "print(f\"Test R²: {test_r2_att:.4f}\")\n",
+ "\n",
+ "\n",
+ "# Compute negative log predictive density and coverage probability\n",
+ "test_nlpd_att = nb.utils.nlpd(test_data['solubility'], predictions, variance)\n",
+ "test_coverage_att = nb.utils.coverage(test_data['solubility'], predictions, variance)\n",
+ "print(f\"NLPD: {test_nlpd_att:.4f}\")\n",
+ "print(f\"Coverage probability: {test_coverage_att:.4f}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FFrNKZzqd9yf",
+ "outputId": "a2d39157-75b6-4a85-a0d0-78115a3a21fc"
+ },
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "Model Performance:\n",
+ "Test RMSE: 0.5456\n",
+ "Test R²: 0.8892\n",
+ "NLPD: 0.8067\n",
+ "Coverage probability: 0.9291\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Finally, let's comppare predictive metrics for the three cases, as well as against Gaussian process baseline (computed in this separate [notebook](https://github.com/ziatdinovmax/NeuroBayes/blob/main/examples/GPyTorch_esol.ipynb)):"
+ ],
+ "metadata": {
+ "id": "gUHLNSV0d90q"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# GP results\n",
+ "test_rmse_gp = 0.6887\n",
+ "test_nlpd_gp = 0.9575\n",
+ "test_coverage_gp = 0.9291\n",
+ "\n",
+ "# Labels for the bars\n",
+ "labels = [\"Partially Bayesian\\nToken Embedding\",\n",
+ " \"Partially Bayesian\\nFinal Layer\",\n",
+ " \"Partially Bayesian\\nAttention\",\n",
+ " \"Gaussian\\nProcess\"]\n",
+ "\n",
+ "# RMSE, NLPD, and Coverage values\n",
+ "rmse_values = [test_rmse_emb, test_rmse_fin, test_rmse_att, test_rmse_gp]\n",
+ "nlpd_values = [test_nlpd_emb, test_nlpd_fin, test_nlpd_att, test_nlpd_gp]\n",
+ "coverage_values = [test_coverage_emb, test_coverage_fin, test_coverage_att, test_coverage_gp]\n",
+ "\n",
+ "# Colors for each model\n",
+ "colors = ['green', 'red', 'blue', 'purple']\n",
+ "\n",
+ "# Create figure and axes with more width for the labels\n",
+ "plt.figure(figsize=(20, 6))\n",
+ "_, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 6))\n",
+ "\n",
+ "# Function to create bars with error bars\n",
+ "def create_bar_plot(ax, values, title, ylabel, colors, labels):\n",
+ " bars = ax.bar(np.arange(len(values)), values, color=colors)\n",
+ " ax.set_xticks(np.arange(len(values)))\n",
+ " ax.set_xticklabels(labels, rotation=45, ha='right')\n",
+ " ax.set_ylabel(ylabel)\n",
+ " ax.set_title(title)\n",
+ " # Add value labels on top of bars\n",
+ " for bar in bars:\n",
+ " height = bar.get_height()\n",
+ " ax.text(bar.get_x() + bar.get_width()/2., height,\n",
+ " f'{height:.3f}',\n",
+ " ha='center', va='bottom')\n",
+ " return bars\n",
+ "\n",
+ "# Create the three plots\n",
+ "create_bar_plot(ax0, rmse_values,\n",
+ " \"Root Mean Squared Error\\n(lower is better)\",\n",
+ " \"Root Mean Squared Error\",\n",
+ " colors, labels)\n",
+ "\n",
+ "create_bar_plot(ax1, nlpd_values,\n",
+ " \"Negative Log Predictive Density\\n(lower is better)\",\n",
+ " \"Negative Log Predictive Density\",\n",
+ " colors, labels)\n",
+ "\n",
+ "create_bar_plot(ax2, coverage_values,\n",
+ " \"Coverage Probability\\n(higher is better)\",\n",
+ " \"Coverage Probability\",\n",
+ " colors, labels)\n",
+ "\n",
+ "# Set y-axis limits for NLPD and Coverage as in original\n",
+ "ax1.set_ylim(0.7, 1.1)\n",
+ "ax2.set_ylim(0.7, 1.0)\n",
+ "\n",
+ "# Adjust layout to prevent label overlap\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 464
+ },
+ "id": "PPfJvSQymyBL",
+ "outputId": "fe6bd733-f99d-4613-bd17-dc1543c8cfb7"
+ },
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "image/png": "\n"
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "oV6FH65VvPYC"
+ },
+ "execution_count": 12,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file