Skip to content

Commit

Permalink
last updates/fixes before paper
Browse files Browse the repository at this point in the history
  • Loading branch information
kmheckel committed Feb 23, 2024
1 parent 97ff1f9 commit c0d43df
Show file tree
Hide file tree
Showing 14 changed files with 3,615 additions and 9,165 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
project = 'Spyx'
copyright = '2023, Kade Heckel'
author = 'Kade Heckel'
release = 'v0.1.18'
release = 'v0.1.19'

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
22 changes: 14 additions & 8 deletions docs/examples/surrogate_gradient/SurrogateGradientTutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
},
"outputs": [],
"source": [
"shd_dl = spyx.data.SHD_loader(256,128,128)"
"shd_dl = spyx.loaders.SHD_loader(256,128,128)"
]
},
{
Expand Down Expand Up @@ -259,9 +259,9 @@
"\n",
" # Haiku has the ability to stack multiple layers/recurrent modules as one entity\n",
" core = hk.DeepRNN([\n",
" snn.LIF((64,), activation=spyx.axn.Axon(spyx.axn.triangular())), #LIF neuron layer with triangular activation\n",
" snn.LIF((64,), activation=spyx.axn.triangular()), #LIF neuron layer with triangular activation\n",
" hk.Linear(64, with_bias=False),\n",
" snn.LIF((64,), activation=spyx.axn.Axon(spyx.axn.triangular())),\n",
" snn.LIF((64,), activation=spyx.axn.triangular()),\n",
" hk.Linear(20, with_bias=False),\n",
" snn.LI((20,)) # Non-spiking final layer\n",
" ])\n",
Expand Down Expand Up @@ -312,6 +312,9 @@
" # We use optax for our optimizer.\n",
" opt = optax.lion(learning_rate=schedule)\n",
"\n",
" Loss = spyx.fn.integral_crossentropy()\n",
" Acc = spyx.fn.integral_accuracy()\n",
"\n",
" # create and initialize the optimizer\n",
" opt_state = opt.init(params)\n",
" grad_params = params\n",
Expand All @@ -321,7 +324,7 @@
" def net_eval(weights, events, targets):\n",
" readout = SNN.apply(weights, events)\n",
" traces, V_f = readout\n",
" return spyx.fn.integral_crossentropy(traces, targets)\n",
" return Loss(traces, targets)\n",
"\n",
" # Use JAX to create a function that calculates the loss and the gradient!\n",
" surrogate_grad = jax.value_and_grad(net_eval)\n",
Expand Down Expand Up @@ -357,8 +360,8 @@
" # unpack the final layer outputs and end state of each SNN layer\n",
" traces, V_f = readout\n",
" # compute accuracy, predictions, and loss\n",
" acc, pred = spyx.fn.integral_accuracy(traces, targets)\n",
" loss = spyx.fn.integral_crossentropy(traces, targets)\n",
" acc, pred = Acc(traces, targets)\n",
" loss = Loss(traces, targets)\n",
" # we return the parameters here because of how jax.scan is structured.\n",
" return grad_params, jnp.array([acc, loss])\n",
"\n",
Expand Down Expand Up @@ -421,14 +424,17 @@
"source": [
"def test_gd(SNN, params, dl):\n",
"\n",
" Loss = spyx.fn.integral_crossentropy()\n",
" Acc = spyx.fn.integral_accuracy()\n",
"\n",
" @jax.jit\n",
" def test_step(params, data):\n",
" events, targets = data\n",
" events = jnp.unpackbits(events, axis=1)\n",
" readout = SNN.apply(params, events)\n",
" traces, V_f = readout\n",
" acc, pred = spyx.fn.integral_accuracy(traces, targets)\n",
" loss = spyx.fn.integral_crossentropy(traces, targets)\n",
" acc, pred = Acc(traces, targets)\n",
" loss = Loss(traces, targets)\n",
" return params, [acc, loss, pred, targets]\n",
"\n",
" test_data = dl.test_epoch()\n",
Expand Down
6,448 changes: 312 additions & 6,136 deletions docs/examples/surrogate_gradient/shd_sg_neuron_model_comparison.ipynb

Large diffs are not rendered by default.

2,367 changes: 1,996 additions & 371 deletions docs/examples/surrogate_gradient/shd_sg_surrogate_comparison.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit c0d43df

Please sign in to comment.