From 2454eb75d1627076b7bb29fd33e8d959a9784252 Mon Sep 17 00:00:00 2001 From: "baptiste.bellot" Date: Tue, 4 Feb 2025 18:09:32 +0100 Subject: [PATCH] feat: Sparse State Preparation Signed-off-by: baptiste.bellot --- .../sparse_state_preparation.ipynb | 566 ++++++++++++++++++ 1 file changed, 566 insertions(+) create mode 100644 community/paper_implementation/sparse_state_preparation/sparse_state_preparation.ipynb diff --git a/community/paper_implementation/sparse_state_preparation/sparse_state_preparation.ipynb b/community/paper_implementation/sparse_state_preparation/sparse_state_preparation.ipynb new file mode 100644 index 00000000..cb0b7742 --- /dev/null +++ b/community/paper_implementation/sparse_state_preparation/sparse_state_preparation.ipynb @@ -0,0 +1,566 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# An Efficient Algorithm for Sparse Quantum State Preparation\n", + "\n", + "Implementation of the [paper from Niels Gleinig & Torsten Hoefler](https://htor.inf.ethz.ch/publications/img/quantum_dac.pdf) using Classiq's python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Classiq related (algorithmic part)\n", + "from classiq import *\n", + "import numpy as np\n", + "from typing import List, Tuple\n", + "from classiq.qmod import control, unitary\n", + "from classiq.execution import ClassiqBackendPreferences, ExecutionPreferences\n", + "from classiq.synthesis import set_execution_preferences, SerializedQuantumProgram\n", + "\n", + "\n", + "# Qiskit related (end of computation)\n", + "from qiskit import QuantumCircuit\n", + "from qiskit.circuit.controlledgate import ControlledGate\n", + "from qiskit.quantum_info import Operator\n", + "from qiskit.circuit.library import *\n", + "from math import log2, ceil\n", + "from qiskit.qasm2 import dumps\n", + "import qiskit_aer\n", + "import qiskit" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def gate_matrix(alpha: float, beta: float) -> List[float]:\n", + " \"\"\" Implementation of the gate G as described in the paper (II.A)\n", + "\n", + " Args:\n", + " alpha (float): Probability of ket 0\n", + " beta (float): Probability of ket 1\n", + "\n", + " Returns:\n", + " List[float]: List representing the G matrix\n", + " \"\"\"\n", + " alpha = np.sqrt(alpha)\n", + " beta = np.sqrt(beta)\n", + " a11 = np.sin(beta)\n", + " a12 = np.exp(1j * alpha) * np.cos(beta)\n", + " a21 = np.exp(- 1j * alpha) * np.cos(beta)\n", + " a22 = - np.sin(beta)\n", + "\n", + " return np.array([[a11, a12], [a21, a22]]).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def state_to_bitwise(state: List[float]) -> List[List[int]]:\n", + " \"\"\" Converts an array of coefficients into a set of bitwise indexes, removing empty ones\n", + "\n", + " Args:\n", + " state (List[float]): State to convert\n", + "\n", + " Returns:\n", + " List[List[int]]: List of bitwise indexes\n", + " \"\"\"\n", + " size = 1\n", + " length = 2\n", + " state_len = len(state)\n", + " while state_len > length:\n", + " size += 1\n", + " length *= 2\n", + " result = []\n", + " for i in range(state_len):\n", + " if state[i] != 0:\n", + " bitwise = []\n", + " while i > 0:\n", + " bitwise.append(i%2)\n", + " i = i//2\n", + " while len(bitwise) < size:\n", + " bitwise.append(0)\n", + " bitwise.reverse()\n", + " result.append(bitwise)\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "<>:2: SyntaxWarning: invalid escape sequence '\\i'\n", + "<>:2: SyntaxWarning: invalid escape sequence '\\i'\n", + "/tmp/ipykernel_176095/2247158695.py:2: SyntaxWarning: invalid escape sequence '\\i'\n", + " \"\"\" Finds a bit such that it splits T into two sets as unequal as possible but neither are empty and:\n" + ] + } + ], + "source": [ + "def optimal_split(T: List[str]) -> Tuple[int, List[str], List[str]]:\n", + " \"\"\" Finds a bit such that it splits T into two sets as unequal as possible but neither are empty and:\n", + " .. math::\n", + " b \\in \\{1, 2, ..., n\\}\n", + " T_0 := \\{x \\in T | x[b] == 0\\}\n", + " T_1 := \\{x \\in T | x[b] == 1\\}\n", + " \n", + " Args:\n", + " T (List[str]): Set T\n", + "\n", + " Raises:\n", + " IndexError: State should not be empy\n", + " RuntimeError: No split for the given state\n", + "\n", + " Returns:\n", + " Tuple[int, List[str], List[str]]: b, T0, T1 as desribed above\n", + " \"\"\"\n", + " if T == []:\n", + " raise IndexError(\"State should not be empy\")\n", + " append_T0 = False\n", + " append_T1 = False\n", + " for bit_nb in range(len(T[0])):\n", + " T0, T1 = [], []\n", + " for state in T:\n", + " if state[bit_nb] == 0:\n", + " T0.append(state)\n", + " append_T0 = True\n", + " else:\n", + " T1.append(state)\n", + " append_T1 = True\n", + " if append_T0 and append_T1:\n", + " return bit_nb, T0, T1\n", + " raise RuntimeError(\"No split for the given state\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def bitwise_to_int(bitwise : List[int]) -> int:\n", + " \"\"\" Converts a list of bits to its decimal representation\n", + "\n", + " Args:\n", + " bitwise (List[int]): List of bits\n", + "\n", + " Returns:\n", + " int: Decimal representation of bitwise\n", + " \"\"\"\n", + " res = 0\n", + " for i in bitwise:\n", + " res *= 2\n", + " res += i\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def build_T_prime(T: List[str], dif_qubits: List[int], dif_values: List[int]) -> List[str]:\n", + " \"\"\" Builds T' according to the paper:\n", + " Let T' ⊂ S denote the set of strings that have the values in dif_values on the bits dif_qubits;\n", + "\n", + " Args:\n", + " T (List[str]): State s represented in bits\n", + " dif_qubits (List[int]): Stack of bits b ∈ {1, 2, . . . , n} that will hold in the end the bits that we use as control for the “merging” step\n", + " dif_values (List[int]): Stack of boolean values\n", + "\n", + " Returns:\n", + " List[str]: List of bits T' as described above\n", + " \"\"\"\n", + " T_prime = []\n", + " for state in T:\n", + " matches = True\n", + " for qubit, value in zip(dif_qubits, dif_values):\n", + " matches = int(state[qubit]) == value\n", + " if not matches:\n", + " break\n", + " if matches:\n", + " T_prime.append(state)\n", + " return T_prime" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def translate_circuit(full_circuit: list[QuantumProgram], state: List[int]) -> QuantumCircuit:\n", + " \"\"\" Translates the circuit from a list of QuantumProgram of Classiq to a QuantumCircuit in Qiskit.\n", + " It reverses the order of gates while inversing them and adding the required X gates (from Algorithm 2)\n", + "\n", + " Args:\n", + " full_circuit (list[QuantumProgram]): List of Classiq QuantumPrograms to concatenate\n", + " state (List[int]): List of amplitudes required for the state\n", + "\n", + " Returns:\n", + " QuantumCircuit: Final Qiskit circuit for sparse state preparation\n", + " \"\"\"\n", + "\n", + " nb_qubits : int = ceil(log2(len(state)))\n", + " gates = []\n", + "\n", + " for circ in full_circuit:\n", + "\n", + " for gate in circ.debug_info:\n", + " if \"XGate\" in gate.name:\n", + " target = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in gate.registers if reg.name == \"TARGET\").qubit_indexes_absolute))\n", + " gates.append((XGate(), target))\n", + "\n", + " for child in gate.children:\n", + " if child.generated_function.name == \"control\":\n", + " ctrl = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in child.registers if reg.name == \"control_group\").qubit_indexes_absolute))\n", + " target = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in child.registers if reg.name == \"TARGET\").qubit_indexes_absolute))\n", + " strip_and_replace = lambda x: complex(x.replace(\" \", \"\").replace(\"*I\", \"j\").strip(']['))\n", + " i = np.array(list(map(strip_and_replace, child.parameters[0].value.strip('][').split(',')))).reshape(2, 2).tolist()\n", + "\n", + " UC = UnitaryGate(i).control(1)\n", + " gates.append((UC, [ctrl, target]))\n", + "\n", + " elif \"CXGate\" in child.generated_function.name:\n", + " ctrl = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in child.registers if reg.name == \"CTRL\").qubit_indexes_absolute))\n", + " target = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in child.registers if reg.name == \"TARGET\").qubit_indexes_absolute))\n", + "\n", + " gates.append((CXGate(), [ctrl, target]))\n", + "\n", + " T = state_to_bitwise(state)\n", + " for i in range(len(T[0])):\n", + " if T[0][i] == 1:\n", + " gates.append((XGate(), [i]))\n", + "\n", + " # reverse QC and inverse gates\n", + " rev_qc = QuantumCircuit(nb_qubits)\n", + "\n", + " for i in range(len(gates) - 1, -1, -1):\n", + " gate, trgts = gates[i]\n", + " if isinstance(gate, ControlledGate) and not isinstance(gate, CXGate):\n", + " rev_qc.append(Operator(gate.inverse()), trgts)\n", + " else:\n", + " rev_qc.append(gate.inverse(), trgts)\n", + "\n", + " return rev_qc" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def algo1_classic_part(state : list[float]) -> Tuple[int, list[int], list[int], list[int]]:\n", + " \"\"\" Classical part of the algorithm (not requiring any Classiq) for easier understanding\n", + " Basically lines 1 to 28 of Algorithm 1\n", + " Called by main\n", + "\n", + " Args:\n", + " state (list[float]): State we want to prepare\n", + "\n", + " Returns:\n", + " Tuple[int, list[int], list[int], list[int]]: dif, x1, x2, dif_qubits as they are needed for the quantum part of the algorithm\n", + " \"\"\"\n", + " dif_qubits = []\n", + " dif_values = []\n", + " T = state_to_bitwise(state)\n", + " T_copy = T.copy()\n", + " while len(T) > 1:\n", + " bit_nb, T0, T1 = optimal_split(T)\n", + " dif_qubits.append(bit_nb)\n", + " if (len(T0) < len(T1)):\n", + " T = T0\n", + " dif_values.append(0)\n", + " else:\n", + " T = T1\n", + " dif_values.append(1)\n", + " \n", + " dif = dif_qubits.pop()\n", + " dif_values.pop()\n", + " x1 = T[0]\n", + " T_prime = build_T_prime(T_copy, dif_qubits, dif_values)\n", + " T_prime.remove(x1)\n", + "\n", + " while len(T_prime) > 1:\n", + " bit_nb, T0, T1 = optimal_split(T_prime)\n", + " dif_qubits.append(bit_nb)\n", + " if (len(T0) < len(T1)):\n", + " T_prime = T0\n", + " dif_values.append(0)\n", + " else:\n", + " T_prime = T1\n", + " dif_values.append(1)\n", + " \n", + " x2 = T_prime[0]\n", + " return (dif, x1, x2, dif_qubits)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "<>:3: SyntaxWarning: invalid escape sequence '\\ '\n", + "<>:3: SyntaxWarning: invalid escape sequence '\\ '\n", + "/tmp/ipykernel_176095/3025597175.py:3: SyntaxWarning: invalid escape sequence '\\ '\n", + " \"\"\" Quantum part of Algorithm 1\n" + ] + } + ], + "source": [ + "@qfunc(generative=True)\n", + "def algo1(quantum_circuit : QArray[QBit], dif : int, x1 : list[int], x2 : list[int], dif_qubits : list[int], state : list[float]):\n", + " \"\"\" Quantum part of Algorithm 1\n", + " Called by main\n", + "\n", + " Args:\n", + " quantum_circuit (QArray[QBit]): Quantum circuit as required by Classiq\n", + " dif (int): last value appended to dif_qubit\n", + " x1 (list[int]): Single element of T\n", + " x2 (list[int]): Single element of T'\n", + " dif_qubits (list[int]): Stack of bits\n", + " state (list[float]): State to prepare /!\\ Global variable\n", + " \"\"\"\n", + "\n", + " qdif=quantum_circuit[dif]\n", + " size_n = len(x1)\n", + "\n", + " if x1[dif] != 1:\n", + " X(qdif)\n", + "\n", + " for b in range(size_n):\n", + " if b != dif and x1[b] != x2[b]:\n", + " CX(ctrl=qdif, target=quantum_circuit[b])\n", + "\n", + " for b in dif_qubits:\n", + " if x2[b] != 1:\n", + " X(quantum_circuit[b])\n", + "\n", + " if dif == 0:\n", + " target_qbit : QArray[QBit] = QArray(\"traget_qbit\", length=1)\n", + " control_group : QArray[QBit] = QArray(\"control_group\",length=size_n-1)\n", + " bind(quantum_circuit, [target_qbit, control_group])\n", + "\n", + " control(ctrl=control_group, stmt_block=lambda: unitary(gate_matrix(state[bitwise_to_int(x1)], state[bitwise_to_int(x2)]), target=target_qbit))\n", + "\n", + " bind([target_qbit, control_group], quantum_circuit)\n", + " elif dif == size_n - 1:\n", + " target_qbit : QArray[QBit] = QArray(\"traget_qbit\", length=1)\n", + " control_group : QArray[QBit] = QArray(\"control_group\",length=size_n-1)\n", + " bind(quantum_circuit, [control_group, target_qbit])\n", + "\n", + " control(ctrl=control_group, stmt_block=lambda: unitary(gate_matrix(state[bitwise_to_int(x1)], state[bitwise_to_int(x2)]), target=target_qbit))\n", + "\n", + " bind([control_group, target_qbit], quantum_circuit)\n", + " else:\n", + " target_qbit : QArray[QBit] = QArray(\"traget_qbit\", length=1)\n", + " before_target : QArray[QBit] = QArray(\"before\",length=max(0,dif-1))\n", + " after_target : QArray[QBit] = QArray(\"after\", length=size_n-dif-1)\n", + " bind(quantum_circuit, [before_target, target_qbit, after_target])\n", + " control_group : QArray[QBit] = QArray(\"control_group\",length=size_n-1)\n", + " bind([before_target,after_target], control_group)\n", + "\n", + " control(ctrl=control_group, stmt_block=lambda: unitary(gate_matrix(state[bitwise_to_int(x1)], state[bitwise_to_int(x2)]), target=target_qbit))\n", + "\n", + " bind(control_group, [before_target, after_target])\n", + " bind([before_target, target_qbit, after_target], quantum_circuit)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "@qfunc\n", + "def main(quantum_circuit: Output[QArray[QBit]]) -> None:\n", + " \"\"\" Basically Algorithm 2 but it **needs** to be called main\n", + " The while loop will call this every iteration\n", + "\n", + " Args:\n", + " quantum_circuit (Output[QArray[QBit]]): Quantum circuit on which the gates will be added\n", + " \"\"\"\n", + " T = state_to_bitwise(state)\n", + " size_n = len(T[0])\n", + " allocate(size_n, quantum_circuit)\n", + " if len(T)>1:\n", + " dif, x1, x2, dif_qubits = algo1_classic_part(state)\n", + " algo1(quantum_circuit, dif, x1, x2, dif_qubits, state)\n", + " kept_state = bitwise_to_int(x1)\n", + " merged_state = bitwise_to_int(x2)\n", + " state[kept_state] += state[merged_state]\n", + " state[merged_state] = 0." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "def algo1_iter() -> SerializedQuantumProgram :\n", + " \"\"\" Driver code for an iteration of algo1 to leverage Classiq\n", + "\n", + " Returns:\n", + " SerializedQuantumProgram: Synthetized program from Classiq\n", + " \"\"\"\n", + " quantum_program = create_model(main)\n", + " backend_preferences = ClassiqBackendPreferences(backend_name=\"simulator_statevector\")\n", + " qmod_b_load = set_execution_preferences(\n", + " quantum_program,\n", + " execution_preferences=ExecutionPreferences(\n", + " num_shots=1, backend_preferences=backend_preferences\n", + " ),\n", + " )\n", + " qprog_b_load = synthesize(qmod_b_load)\n", + " return qprog_b_load" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# State to prepare" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# BELL STATE\n", + "# state = [0.5, 0, 0, 0.5]\n", + "# GHZ STATE 4 Qbits\n", + "# state : List[float] = [0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5]\n", + "state = [0 for _ in range(8)]\n", + "state[1] = 2 / np.sqrt(168)\n", + "state[4] = 8 / np.sqrt(168)\n", + "state[7] = 10 / np.sqrt(168)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "T = state_to_bitwise(state)\n", + "full_circuit : List[QuantumProgram] = []\n", + "while len(T) > 1:\n", + " qprog_b_load = algo1_iter()\n", + " full_circuit.insert(0, QuantumProgram.from_qprog(qprog_b_load))\n", + " T = state_to_bitwise(state)\n", + "qprog_b_load = algo1_iter()\n", + "full_circuit.append(QuantumProgram.from_qprog(qprog_b_load))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'100': 1899, '110': 4008, '001': 4093}\n" + ] + }, + { + "data": { + "text/html": [ + "
           ┌───┐    ┌──────────┐                 ┌──────────┐┌───┐ ░ ┌─┐      \n",
+       "   q_0: ───┤ X ├────┤0         ├─────────────────┤0         ├┤ X ├─░─┤M├──────\n",
+       "        ┌──┴───┴───┐│          │┌───┐┌──────────┐│          │└─┬─┘ ░ └╥┘┌─┐   \n",
+       "   q_1: ┤0         ├┤  Unitary ├┤ X ├┤0         ├┤  Unitary ├──┼───░──╫─┤M├───\n",
+       "        │  Unitary ││          │└─┬─┘│  Unitary ││          │  │   ░  ║ └╥┘┌─┐\n",
+       "   q_2: ┤1         ├┤1         ├──■──┤1         ├┤1         ├──■───░──╫──╫─┤M├\n",
+       "        └──────────┘└──────────┘     └──────────┘└──────────┘      ░  ║  ║ └╥┘\n",
+       "meas: 3/══════════════════════════════════════════════════════════════╩══╩══╩═\n",
+       "                                                                      0  1  2 
" + ], + "text/plain": [ + " ┌───┐ ┌──────────┐ ┌──────────┐┌───┐ ░ ┌─┐ \n", + " q_0: ───┤ X ├────┤0 ├─────────────────┤0 ├┤ X ├─░─┤M├──────\n", + " ┌──┴───┴───┐│ │┌───┐┌──────────┐│ │└─┬─┘ ░ └╥┘┌─┐ \n", + " q_1: ┤0 ├┤ Unitary ├┤ X ├┤0 ├┤ Unitary ├──┼───░──╫─┤M├───\n", + " │ Unitary ││ │└─┬─┘│ Unitary ││ │ │ ░ ║ └╥┘┌─┐\n", + " q_2: ┤1 ├┤1 ├──■──┤1 ├┤1 ├──■───░──╫──╫─┤M├\n", + " └──────────┘└──────────┘ └──────────┘└──────────┘ ░ ║ ║ └╥┘\n", + "meas: 3/══════════════════════════════════════════════════════════════╩══╩══╩═\n", + " 0 1 2 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rev_qc = translate_circuit(full_circuit, state)\n", + "rev_qc.measure_all()\n", + "bc = qiskit_aer.Aer.get_backend(\"aer_simulator\")\n", + "tqc = qiskit.transpile(rev_qc, bc)\n", + "job = bc.run(tqc, shots=10000)\n", + "print(job.result().get_counts())\n", + "rev_qc.draw()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "qasm_str = dumps(rev_qc)\n", + "with open(\"circuit.qasm\", \"w\") as text_file:\n", + " text_file.write(qasm_str)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv_classiq", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}