|
142 | 142 | "Following the convention of PennyLane, we have used megahertz (MHz) as the unit for the amplitude and detuning. Also, we have limited the maximum duration of the program, `time_max`, to be 4$\\mu s$, and introduced the Rydberg interaction constant $C_6$. These constants are specific to QuEra's Aquila, and they can be queried from Braket SDK, as shown in [this AHS example notebook](https://github.com/aws/amazon-braket-examples/blob/main/examples/analog_hamiltonian_simulation/01_Introduction_to_Aquila.ipynb). With these constants, the differentiable driving fields for the Rydberg atom device can be defined via [PennyLane's JAX interface](https://pennylane.ai/qml/demos/tutorial_jax_transformations)."
|
143 | 143 | ]
|
144 | 144 | },
|
145 |
| - { |
| 145 | + { |
146 | 146 | "cell_type": "code",
|
147 | 147 | "execution_count": 4,
|
148 | 148 | "id": "7ec45e20-ed22-4f1d-9540-db456794e5ed",
|
|
153 | 153 | "source": [
|
154 | 154 | "import jax\n",
|
155 | 155 | "import jax.numpy as jnp\n",
|
| 156 | + "from functools import partial\n", |
156 | 157 | "\n",
|
157 | 158 | "# Set to float64 precision and remove jax CPU/GPU warning\n",
|
158 | 159 | "jax.config.update(\"jax_enable_x64\", True)\n",
|
159 | 160 | "jax.config.update(\"jax_platform_name\", \"cpu\")\n",
|
160 | 161 | "jax.config.update(\"jax_debug_nans\", True)\n",
|
161 | 162 | "\n",
|
162 |
| - "\n", |
| 163 | + "@jax.jit\n", |
163 | 164 | "def ryd_amplitude(p, t):\n",
|
164 | 165 | " \"\"\"Parametrized function for amplitude\"\"\"\n",
|
165 | 166 | " return p[0] * (1 - (1 - jnp.sin(jnp.pi * t / time_max) ** 2) ** (p[1] / 2))\n",
|
166 | 167 | "\n",
|
167 |
| - "\n", |
| 168 | + "@jax.jit\n", |
168 | 169 | "def ryd_detuning(p, t):\n",
|
169 | 170 | " \"\"\"Parametrized function for detuning\"\"\"\n",
|
170 | 171 | " return p[0] * jnp.arctan(p[1] * (t - time_max / 2)) / (jnp.pi / 2)"
|
|
416 | 417 | "With the cost Hamiltonian defined, we can create a variational AHS program, which evolves the Rydberg system followed by measuring the cost function $H_\\text{cost}$.\n"
|
417 | 418 | ]
|
418 | 419 | },
|
419 |
| - { |
| 420 | + { |
420 | 421 | "cell_type": "code",
|
421 | 422 | "execution_count": 12,
|
422 | 423 | "id": "94cb8760-1ee5-4ed3-8f3d-cd64f6103e12",
|
|
429 | 430 | "dev = qml.device(\"default.qubit.jax\", wires=range(len(coords)))\n",
|
430 | 431 | "\n",
|
431 | 432 | "\n",
|
432 |
| - "# Define the qnode that evolves the Rydberg system, followed by calculating the cost function\n", |
| 433 | + "@partial(jax.jit, static_argnums=(1,))\n", |
433 | 434 | "@qml.qnode(dev, interface=\"jax\")\n",
|
434 |
| - "def program_cost(detuning_param, amplitude_param=amplitude_param, ts=ts):\n", |
| 435 | + "def program_cost(detuning_param, amplitude_param=amplitude_param):\n", |
435 | 436 | " qml.evolve(H_ryd)([amplitude_param, detuning_param], ts)\n",
|
436 | 437 | " return qml.expval(H_cost)"
|
437 | 438 | ]
|
|
511 | 512 | "import optax\n",
|
512 | 513 | "\n",
|
513 | 514 | "n_epochs = 10\n",
|
514 |
| - "\n", |
515 |
| - "# The following block creates a constant schedule of the learning rate\n", |
516 |
| - "# that increases from 0.1 to 0.5 after 10 epochs\n", |
517 |
| - "schedule0 = optax.constant_schedule(1e-1)\n", |
518 |
| - "schedule1 = optax.constant_schedule(5e-1)\n", |
519 |
| - "schedule = optax.join_schedules([schedule0, schedule1], [10])\n", |
520 |
| - "optimizer = optax.nadam(learning_rate=schedule)\n", |
521 |
| - "opt_state = optimizer.init(theta)" |
| 515 | + "schedule = optax.join_schedules(\n", |
| 516 | + " [optax.constant_schedule(1e-1), optax.constant_schedule(5e-1)], [10]\n", |
| 517 | + ")\n", |
| 518 | + "optimizer = optax.nadam(learning_rate=schedule)" |
522 | 519 | ]
|
523 | 520 | },
|
524 | 521 | {
|
|
542 | 539 | "output_type": "stream",
|
543 | 540 | "text": [
|
544 | 541 | "epoch expectation\n",
|
545 |
| - "0 -1.4220786805245293\n", |
546 |
| - "1 -1.4440255545399934\n", |
547 |
| - "2 -1.4682669156813517\n", |
548 |
| - "3 -1.494929936649468\n", |
549 |
| - "4 -1.52414541861693\n", |
550 |
| - "5 -1.5561124234490418\n", |
551 |
| - "6 -1.590274680470023\n", |
552 |
| - "7 -1.6261105967628833\n", |
553 |
| - "8 -1.6618436930215126\n", |
554 |
| - "9 -1.695197967500683\n", |
555 |
| - "The final parameter for detuning = [12.268508609205842, 0.4353973115231312]\n" |
| 542 | + "0 -1.422078680524523\n", |
| 543 | + "1 -1.4552870814551753\n", |
| 544 | + "2 -1.48565657822032\n", |
| 545 | + "3 -1.5177182172197772\n", |
| 546 | + "4 -1.551923494981983\n", |
| 547 | + "5 -1.5883691608696957\n", |
| 548 | + "6 -1.626260941139706\n", |
| 549 | + "7 -1.6636605058546716\n", |
| 550 | + "8 -1.6978874284223098\n", |
| 551 | + "9 -1.7255362493082869\n", |
| 552 | + "The final parameter for detuning = [12.208998163228085, 0.32605421934464207]\n" |
556 | 553 | ]
|
557 | 554 | }
|
558 | 555 | ],
|
559 | 556 | "source": [
|
560 |
| - "energy = np.zeros(n_epochs + 1)\n", |
561 |
| - "energy[0] = program_cost(theta)\n", |
562 |
| - "gradients = np.zeros(n_epochs)\n", |
| 557 | + "@jax.jit\n", |
| 558 | + "def optimization_step(theta, opt_state):\n", |
| 559 | + " val, grad = jax.value_and_grad(program_cost)(theta)\n", |
| 560 | + " updates, opt_state = optimizer.update(grad, opt_state)\n", |
| 561 | + " theta = optax.apply_updates(theta, updates)\n", |
| 562 | + " return val, grad, theta, opt_state\n", |
| 563 | + "\n", |
| 564 | + "theta = detuning_param\n", |
| 565 | + "opt_state = optimizer.init(theta)\n", |
563 | 566 | "\n",
|
564 | 567 | "## Optimization loop\n",
|
565 | 568 | "print(\"epoch expectation\")\n",
|
566 | 569 | "for n in range(n_epochs):\n",
|
567 |
| - " val, grad_program = value_and_grad(theta)\n", |
568 |
| - " updates, opt_state = optimizer.update(grad_program, opt_state)\n", |
569 |
| - " theta = optax.apply_updates(theta, updates)\n", |
570 |
| - "\n", |
571 |
| - " energy[n + 1] = val\n", |
572 |
| - " gradients[n] = np.mean(np.abs(grad_program))\n", |
573 |
| - "\n", |
574 |
| - " print(n, \" \", val)\n", |
| 570 | + " val, grad, theta, opt_state = optimization_step(theta, opt_state)\n", |
| 571 | + " print(f\"{n} {val}\")\n", |
575 | 572 | "\n",
|
576 | 573 | "print(f\"The final parameter for detuning = {[float(i) for i in theta]}\")"
|
577 | 574 | ]
|
|
0 commit comments