Skip to content

Commit 7e05eea

Browse files
[minor] preliminary tests for inverse
1 parent d9b8d70 commit 7e05eea

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

test/module/inverse.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torchquantum as tq
2+
from torchquantum.plugin import op_history2qiskit, qiskit2tq_op_history
3+
from torchquantum.measurement import expval_joint_analytical
4+
from qiskit import QuantumCircuit, Aer, execute
5+
from qiskit.quantum_info import Pauli
6+
import numpy as np
7+
8+
"""
9+
Testing strategy:
10+
partition on Operation: iterate through all the possible operations
11+
partition on number of gates in module: 1, >1
12+
"""
13+
14+
def compare(ops, n_wires):
15+
# construct a normal tq circuit
16+
qmod = tq.QuantumModule.from_op_history(ops)
17+
qdev = tq.QuantumDevice(n_wires=n_wires, record_op=True)
18+
qmod(qdev)
19+
20+
# turn into qiskit and inverse
21+
qiskit_circuit = op_history2qiskit(n_wires, qdev.op_history)
22+
qiskit_circuit = qiskit_circuit.inverse()
23+
24+
# inverse the tq circuit
25+
qmod = tq.QuantumModule.from_op_history(ops)
26+
qdev = tq.QuantumDevice(n_wires=n_wires, record_op=True)
27+
qmod.inverse_module()
28+
qmod(qdev)
29+
30+
qdev_ops = qiskit2tq_op_history(qiskit_circuit)
31+
32+
for tq_op, qiskit_op in zip(qdev.op_history, qdev_ops):
33+
# TODO: name-wise (but currently need to ensure, e.g., cx == cnot)
34+
if tq_op["params"] is not None and qiskit_op["params"] is not None:
35+
assert np.allclose(np.array(tq_op["params"]), np.array(qiskit_op["params"]))
36+
37+
def get_random_rotations(num_params):
38+
return 4*np.pi*np.random.rand(num_params) - 2*np.pi
39+
40+
def test_inverse():
41+
ops = [
42+
{'name': 'u3', 'wires': 0, 'trainable': True, 'params': get_random_rotations(3)},
43+
{'name': 'u3', 'wires': 1, 'trainable': True, 'params': get_random_rotations(3)},
44+
{'name': 'cx', 'wires': [0, 1]},
45+
{'name': 'cx', 'wires': [1, 0]},
46+
{'name': 'u3', 'wires': 0, 'trainable': True, 'params': get_random_rotations(3)},
47+
{'name': 'u3', 'wires': 1, 'trainable': True, 'params': get_random_rotations(3)},
48+
{'name': 'cx', 'wires': [0, 1]},
49+
{'name': 'cx', 'wires': [1, 0]},
50+
]
51+
compare(ops, 2)
52+
53+
# test_inverse()

0 commit comments

Comments
 (0)