Skip to content

Commit 74ecd0d

Browse files
authored
fix: use jit for analog hamiltonian example
1 parent 5d15aac commit 74ecd0d

File tree

1 file changed

+33
-36
lines changed

1 file changed

+33
-36
lines changed

examples/analog_hamiltonian_simulation/06_Analog_Hamiltonian_simulation_with_PennyLane.ipynb

+33-36
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@
142142
"Following the convention of PennyLane, we have used megahertz (MHz) as the unit for the amplitude and detuning. Also, we have limited the maximum duration of the program, `time_max`, to be 4$\\mu s$, and introduced the Rydberg interaction constant $C_6$. These constants are specific to QuEra's Aquila, and they can be queried from Braket SDK, as shown in [this AHS example notebook](https://github.com/aws/amazon-braket-examples/blob/main/examples/analog_hamiltonian_simulation/01_Introduction_to_Aquila.ipynb). With these constants, the differentiable driving fields for the Rydberg atom device can be defined via [PennyLane's JAX interface](https://pennylane.ai/qml/demos/tutorial_jax_transformations)."
143143
]
144144
},
145-
{
145+
{
146146
"cell_type": "code",
147147
"execution_count": 4,
148148
"id": "7ec45e20-ed22-4f1d-9540-db456794e5ed",
@@ -153,18 +153,19 @@
153153
"source": [
154154
"import jax\n",
155155
"import jax.numpy as jnp\n",
156+
"from functools import partial\n",
156157
"\n",
157158
"# Set to float64 precision and remove jax CPU/GPU warning\n",
158159
"jax.config.update(\"jax_enable_x64\", True)\n",
159160
"jax.config.update(\"jax_platform_name\", \"cpu\")\n",
160161
"jax.config.update(\"jax_debug_nans\", True)\n",
161162
"\n",
162-
"\n",
163+
"@jax.jit\n",
163164
"def ryd_amplitude(p, t):\n",
164165
" \"\"\"Parametrized function for amplitude\"\"\"\n",
165166
" return p[0] * (1 - (1 - jnp.sin(jnp.pi * t / time_max) ** 2) ** (p[1] / 2))\n",
166167
"\n",
167-
"\n",
168+
"@jax.jit\n",
168169
"def ryd_detuning(p, t):\n",
169170
" \"\"\"Parametrized function for detuning\"\"\"\n",
170171
" return p[0] * jnp.arctan(p[1] * (t - time_max / 2)) / (jnp.pi / 2)"
@@ -416,7 +417,7 @@
416417
"With the cost Hamiltonian defined, we can create a variational AHS program, which evolves the Rydberg system followed by measuring the cost function $H_\\text{cost}$.\n"
417418
]
418419
},
419-
{
420+
{
420421
"cell_type": "code",
421422
"execution_count": 12,
422423
"id": "94cb8760-1ee5-4ed3-8f3d-cd64f6103e12",
@@ -429,9 +430,9 @@
429430
"dev = qml.device(\"default.qubit.jax\", wires=range(len(coords)))\n",
430431
"\n",
431432
"\n",
432-
"# Define the qnode that evolves the Rydberg system, followed by calculating the cost function\n",
433+
"@partial(jax.jit, static_argnums=(1,))\n",
433434
"@qml.qnode(dev, interface=\"jax\")\n",
434-
"def program_cost(detuning_param, amplitude_param=amplitude_param, ts=ts):\n",
435+
"def program_cost(detuning_param, amplitude_param=amplitude_param):\n",
435436
" qml.evolve(H_ryd)([amplitude_param, detuning_param], ts)\n",
436437
" return qml.expval(H_cost)"
437438
]
@@ -511,14 +512,10 @@
511512
"import optax\n",
512513
"\n",
513514
"n_epochs = 10\n",
514-
"\n",
515-
"# The following block creates a constant schedule of the learning rate\n",
516-
"# that increases from 0.1 to 0.5 after 10 epochs\n",
517-
"schedule0 = optax.constant_schedule(1e-1)\n",
518-
"schedule1 = optax.constant_schedule(5e-1)\n",
519-
"schedule = optax.join_schedules([schedule0, schedule1], [10])\n",
520-
"optimizer = optax.nadam(learning_rate=schedule)\n",
521-
"opt_state = optimizer.init(theta)"
515+
"schedule = optax.join_schedules(\n",
516+
" [optax.constant_schedule(1e-1), optax.constant_schedule(5e-1)], [10]\n",
517+
")\n",
518+
"optimizer = optax.nadam(learning_rate=schedule)"
522519
]
523520
},
524521
{
@@ -542,36 +539,36 @@
542539
"output_type": "stream",
543540
"text": [
544541
"epoch expectation\n",
545-
"0 -1.4220786805245293\n",
546-
"1 -1.4440255545399934\n",
547-
"2 -1.4682669156813517\n",
548-
"3 -1.494929936649468\n",
549-
"4 -1.52414541861693\n",
550-
"5 -1.5561124234490418\n",
551-
"6 -1.590274680470023\n",
552-
"7 -1.6261105967628833\n",
553-
"8 -1.6618436930215126\n",
554-
"9 -1.695197967500683\n",
555-
"The final parameter for detuning = [12.268508609205842, 0.4353973115231312]\n"
542+
"0 -1.422078680524523\n",
543+
"1 -1.4552870814551753\n",
544+
"2 -1.48565657822032\n",
545+
"3 -1.5177182172197772\n",
546+
"4 -1.551923494981983\n",
547+
"5 -1.5883691608696957\n",
548+
"6 -1.626260941139706\n",
549+
"7 -1.6636605058546716\n",
550+
"8 -1.6978874284223098\n",
551+
"9 -1.7255362493082869\n",
552+
"The final parameter for detuning = [12.208998163228085, 0.32605421934464207]\n"
556553
]
557554
}
558555
],
559556
"source": [
560-
"energy = np.zeros(n_epochs + 1)\n",
561-
"energy[0] = program_cost(theta)\n",
562-
"gradients = np.zeros(n_epochs)\n",
557+
"@jax.jit\n",
558+
"def optimization_step(theta, opt_state):\n",
559+
" val, grad = jax.value_and_grad(program_cost)(theta)\n",
560+
" updates, opt_state = optimizer.update(grad, opt_state)\n",
561+
" theta = optax.apply_updates(theta, updates)\n",
562+
" return val, grad, theta, opt_state\n",
563+
"\n",
564+
"theta = detuning_param\n",
565+
"opt_state = optimizer.init(theta)\n",
563566
"\n",
564567
"## Optimization loop\n",
565568
"print(\"epoch expectation\")\n",
566569
"for n in range(n_epochs):\n",
567-
" val, grad_program = value_and_grad(theta)\n",
568-
" updates, opt_state = optimizer.update(grad_program, opt_state)\n",
569-
" theta = optax.apply_updates(theta, updates)\n",
570-
"\n",
571-
" energy[n + 1] = val\n",
572-
" gradients[n] = np.mean(np.abs(grad_program))\n",
573-
"\n",
574-
" print(n, \" \", val)\n",
570+
" val, grad, theta, opt_state = optimization_step(theta, opt_state)\n",
571+
" print(f\"{n} {val}\")\n",
575572
"\n",
576573
"print(f\"The final parameter for detuning = {[float(i) for i in theta]}\")"
577574
]

0 commit comments

Comments
 (0)