Skip to content

Commit 91c80de

Browse files
Merge pull request #275 from king-p3nguin/qiskit2tq-parameterexpression
2 parents 60ace1e + cd2f3d4 commit 91c80de

File tree

3 files changed

+248
-15
lines changed

3 files changed

+248
-15
lines changed

test/plugin/test_qiskit2tq.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) 2020-present TorchQuantum Authors
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
"""
24+
25+
import random
26+
27+
import numpy as np
28+
import pytest
29+
import torch
30+
import torch.optim as optim
31+
from qiskit import QuantumCircuit
32+
from qiskit.circuit import Parameter, ParameterVector
33+
from torch.optim.lr_scheduler import CosineAnnealingLR
34+
35+
import torchquantum as tq
36+
from torchquantum.plugin import qiskit2tq
37+
38+
seed = 42
39+
random.seed(seed)
40+
np.random.seed(seed)
41+
torch.manual_seed(seed)
42+
43+
44+
class TQModel(tq.QuantumModule):
45+
def __init__(self, init_params=None):
46+
super().__init__()
47+
self.n_wires = 2
48+
self.rx = tq.RX(has_params=True, trainable=True, init_params=[init_params[0]])
49+
self.u3_0 = tq.U3(has_params=True, trainable=True, init_params=init_params[1:4])
50+
self.u3_1 = tq.U3(
51+
has_params=True,
52+
trainable=True,
53+
init_params=torch.tensor(
54+
[
55+
init_params[4] + init_params[2],
56+
init_params[5] * init_params[3],
57+
init_params[6] * init_params[1],
58+
]
59+
),
60+
)
61+
self.cu3_0 = tq.CU3(
62+
has_params=True,
63+
trainable=True,
64+
init_params=torch.tensor(
65+
[
66+
torch.sin(init_params[7]),
67+
torch.abs(torch.sin(init_params[8])),
68+
torch.abs(torch.sin(init_params[9]))
69+
* torch.exp(init_params[2] + init_params[3]),
70+
]
71+
),
72+
)
73+
74+
def forward(self, q_device: tq.QuantumDevice):
75+
q_device.reset_states(1)
76+
self.rx(q_device, wires=0)
77+
self.u3_0(q_device, wires=0)
78+
self.u3_1(q_device, wires=1)
79+
self.cu3_0(q_device, wires=[0, 1])
80+
81+
82+
def get_qiskit_ansatz():
83+
ansatz = QuantumCircuit(2)
84+
ansatz_param = Parameter("Θ") # parameter
85+
ansatz.rx(ansatz_param, 0)
86+
ansatz_param_vector = ParameterVector("φ", 9) # parameter vector
87+
ansatz.u(ansatz_param_vector[0], ansatz_param_vector[1], ansatz_param_vector[2], 0)
88+
ansatz.u(
89+
ansatz_param_vector[3] + ansatz_param_vector[1], # parameter expression
90+
ansatz_param_vector[4] * ansatz_param_vector[2],
91+
ansatz_param_vector[5] / ansatz_param_vector[0],
92+
1,
93+
)
94+
ansatz.cu(
95+
np.sin(ansatz_param_vector[6]), # numpy functions
96+
np.abs(np.sin(ansatz_param_vector[7])), # nested numpy functions
97+
# complex expression
98+
np.abs(np.sin(ansatz_param_vector[8]))
99+
* np.exp(ansatz_param_vector[1] + ansatz_param_vector[2]),
100+
0.0,
101+
0,
102+
1,
103+
)
104+
return ansatz
105+
106+
107+
def train_step(target_state, device, model, optimizer):
108+
model(device)
109+
result_state = device.get_states_1d()[0]
110+
111+
# compute the state infidelity
112+
loss = 1 - torch.dot(result_state, target_state).abs() ** 2
113+
114+
optimizer.zero_grad()
115+
loss.backward()
116+
optimizer.step()
117+
118+
infidelity = loss.item()
119+
target_state_vector = target_state.detach().cpu().numpy()
120+
result_state_vector = result_state.detach().cpu().numpy()
121+
print(
122+
f"infidelity (loss): {infidelity}, \n target state : "
123+
f"{target_state_vector}, \n "
124+
f"result state : {result_state_vector}\n"
125+
)
126+
return infidelity, target_state_vector, result_state_vector
127+
128+
129+
def train(init_params, backend):
130+
device = torch.device("cpu")
131+
132+
if backend == "qiskit":
133+
ansatz = get_qiskit_ansatz()
134+
model = qiskit2tq(ansatz, initial_parameters=init_params).to(device)
135+
elif backend == "torchquantum":
136+
model = TQModel(init_params).to(device)
137+
138+
print(f"{backend} model:", model)
139+
140+
n_epochs = 10
141+
optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=0)
142+
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)
143+
144+
q_device = tq.QuantumDevice(n_wires=2)
145+
target_state = torch.tensor([0, 1, 0, 0], dtype=torch.complex64)
146+
147+
result_list = []
148+
for epoch in range(1, n_epochs + 1):
149+
print(f"Epoch {epoch}, LR: {optimizer.param_groups[0]['lr']}")
150+
result_list.append(train_step(target_state, q_device, model, optimizer))
151+
scheduler.step()
152+
153+
return result_list
154+
155+
156+
@pytest.mark.parametrize(
157+
"init_params",
158+
[
159+
torch.nn.init.uniform_(torch.ones(10), -np.pi, np.pi),
160+
torch.nn.init.uniform_(torch.ones(10), -np.pi, np.pi),
161+
torch.nn.init.uniform_(torch.ones(10), -np.pi, np.pi),
162+
],
163+
)
164+
def test_qiskit2tq(init_params):
165+
qiskit_result = train(init_params, "qiskit")
166+
tq_result = train(init_params, "torchquantum")
167+
for qi_tensor, tq_tensor in zip(qiskit_result, tq_result):
168+
torch.testing.assert_close(qi_tensor[0], tq_tensor[0])
169+
torch.testing.assert_close(qi_tensor[1], tq_tensor[1])
170+
torch.testing.assert_close(qi_tensor[2], tq_tensor[2])
171+
172+
173+
if __name__ == "__main__":
174+
test_qiskit2tq(torch.nn.init.uniform_(torch.ones(10), -np.pi, np.pi))

torchquantum/layer/layers/module_from_ops.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@
2222
SOFTWARE.
2323
"""
2424

25+
from typing import Iterable
26+
27+
import numpy as np
2528
import torch
2629
import torch.nn as nn
30+
from torchpack.utils.logging import logger
31+
2732
import torchquantum as tq
2833
import torchquantum.functional as tqf
29-
import numpy as np
30-
31-
32-
from typing import Iterable
3334
from torchquantum.plugin.qiskit import QISKIT_INCOMPATIBLE_FUNC_NAMES
34-
from torchpack.utils.logging import logger
3535

3636
__all__ = [
3737
"QuantumModuleFromOps",
@@ -61,6 +61,6 @@ def forward(self, q_device: tq.QuantumDevice):
6161
None
6262
6363
"""
64-
self.q_device = q_device
64+
q_device.reset_states(1)
6565
for op in self.ops:
66-
op(q_device)
66+
op(q_device, wires=op.wires)

torchquantum/plugin/qiskit/qiskit_plugin.py

+67-8
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,20 @@
2121
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
SOFTWARE.
2323
"""
24+
from __future__ import annotations
2425

25-
from typing import Iterable
26+
from typing import Iterable, Optional
2627

2728
import numpy as np
2829
import qiskit
2930
import qiskit.circuit.library.standard_gates as qiskit_gate
31+
import symengine
32+
import sympy
3033
import torch
3134
from qiskit import ClassicalRegister, QuantumCircuit, transpile
32-
from qiskit.circuit import Parameter
35+
from qiskit.circuit import CircuitInstruction, Parameter
36+
from qiskit.circuit.parameter import ParameterExpression
37+
from qiskit.circuit.parametervector import ParameterVectorElement
3338
from qiskit_aer import AerSimulator
3439
from torchpack.utils.logging import logger
3540

@@ -691,7 +696,7 @@ def op_history2qiskit_expand_params(n_wires, op_history, bsz):
691696

692697
# construct a tq QuantumModule object according to the qiskit QuantumCircuit
693698
# object
694-
def qiskit2tq_Operator(circ: QuantumCircuit):
699+
def qiskit2tq_Operator(circ: QuantumCircuit, initial_parameters=None):
695700
if getattr(circ, "_layout", None) is not None:
696701
try:
697702
p2v_orig = circ._layout.final_layout.get_physical_bits().copy()
@@ -711,14 +716,23 @@ def qiskit2tq_Operator(circ: QuantumCircuit):
711716
for p in range(circ.num_qubits):
712717
p2v[p] = p
713718

719+
if initial_parameters is None:
720+
initial_parameters = torch.nn.init.uniform_(
721+
torch.ones(len(circ.parameters)), -np.pi, np.pi
722+
)
723+
724+
param_to_index = {}
725+
for i, param in enumerate(circ.parameters):
726+
param_to_index[param] = i
727+
714728
ops = []
715729
for gate in circ.data:
716730
op_name = gate[0].name
717731
wires = [circ.find_bit(qb).index for qb in gate.qubits]
718732
wires = [p2v[wire] for wire in wires]
719-
# sometimes the gate.params is ParameterExpression class
720-
init_params = (
721-
list(map(float, gate[0].params)) if len(gate[0].params) > 0 else None
733+
734+
init_params = qiskit2tq_translate_qiskit_params(
735+
gate, initial_parameters, param_to_index
722736
)
723737

724738
if op_name in [
@@ -780,8 +794,53 @@ def qiskit2tq_Operator(circ: QuantumCircuit):
780794
return ops
781795

782796

783-
def qiskit2tq(circ: QuantumCircuit):
784-
ops = qiskit2tq_Operator(circ)
797+
def qiskit2tq_translate_qiskit_params(
798+
circuit_instruction: CircuitInstruction, initial_parameters, param_to_index
799+
):
800+
parameters = []
801+
for p in circuit_instruction.operation.params:
802+
if isinstance(p, Parameter) or isinstance(p, ParameterVectorElement):
803+
parameters.append(initial_parameters[param_to_index[p]])
804+
elif isinstance(p, ParameterExpression):
805+
if len(p.parameters) == 0:
806+
parameters.append(float(p))
807+
continue
808+
809+
expr = p.sympify().simplify()
810+
if isinstance(expr, symengine.Expr): # qiskit uses symengine if available
811+
expr = expr._sympy_() # sympy.Expr
812+
813+
for free_symbol in expr.free_symbols:
814+
# replace names: theta[0] -> theta_0
815+
# ParameterVector creates symbols with brackets like theta[0]
816+
# but sympy.lambdify does not allow brackets in symbol names
817+
free_symbol.name = free_symbol.name.replace("[", "_").replace("]", "")
818+
819+
parameter_list = list(p.parameters)
820+
sympy_symbols = [param._symbol_expr for param in parameter_list]
821+
# replace names again: theta[0] -> theta_0
822+
sympy_symbols = [
823+
sympy.Symbol(str(symbol).replace("[", "_").replace("]", ""))
824+
for symbol in sympy_symbols
825+
]
826+
lam_f = sympy.lambdify(sympy_symbols, expr, modules="math")
827+
parameters.append(
828+
lam_f(
829+
*[
830+
initial_parameters[param_to_index[param]]
831+
for param in parameter_list
832+
]
833+
)
834+
)
835+
else: # non-parameterized gate
836+
parameters.append(p)
837+
return parameters
838+
839+
840+
def qiskit2tq(
841+
circ: QuantumCircuit, initial_parameters: Optional[list[torch.nn.Parameter]] = None
842+
):
843+
ops = qiskit2tq_Operator(circ, initial_parameters)
785844
return tq.QuantumModuleFromOps(ops)
786845

787846

0 commit comments

Comments
 (0)