Skip to content

Commit 497320c

Browse files
authored
Merge pull request #236 from GenericP3rson/test_branch
Test Updates + Bug Fixes (Updated Version of #206)
2 parents fefb10b + e4233ff commit 497320c

File tree

19 files changed

+201
-95
lines changed

19 files changed

+201
-95
lines changed

.github/workflows/functional_tests.yaml

+32-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
17+
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
1818

1919
steps:
2020
- uses: actions/checkout@v3
@@ -36,3 +36,34 @@ jobs:
3636
- name: Test with pytest
3737
run: |
3838
python -m pytest -m "not skip"
39+
- name: Install TorchQuantum
40+
run: |
41+
pip install --editable .
42+
- name: Test Examples
43+
run: |
44+
python3 examples/qubit_rotation/qubit_rotation.py --epochs 1
45+
python3 examples/vqe/vqe.py --epochs 1 --steps_per_epoch 1
46+
python3 examples/train_unitary_prep/train_unitary_prep.py --epochs 1
47+
python3 examples/train_state_prep/train_state_prep.py --epochs 1
48+
python3 examples/superdense_coding/superdense_coding_torchquantum.py
49+
python3 examples/regression/run_regression.py --epochs 1
50+
python3 examples/param_shift_onchip_training/param_shift.py
51+
python3 examples/mnist/mnist_2qubit_4class.py --epochs 1
52+
python3 examples/hadamard_grad/circ.py
53+
python3 examples/encoder_examples/encoder_8x2ry.py
54+
python3 examples/converter_tq_qiskit/convert.py
55+
python3 examples/amplitude_encoding_mnist/mnist_new.py --epochs 1
56+
python3 examples/amplitude_encoding_mnist/mnist_example.py --epochs 1
57+
python3 examples/PauliSumOp/pauli_sum_op.py
58+
python3 examples/regression/new_run_regression.py --epochs 1
59+
python3 examples/quanvolution/quanvolution_trainable_quantum_layer.py --epochs 1
60+
python3 examples/grover/grover_example_sudoku.py
61+
python3 examples/param_shift_onchip_training/param_shift.py
62+
python3 examples/quanvolution/quanvolution.py --epochs 1
63+
python3 examples/quantum_lstm/qlstm.py --epochs 1
64+
python3 examples/qaoa/max_cut_backprop.py --steps 1
65+
python3 examples/optimal_control/optimal_control.py --epochs 1
66+
python3 examples/optimal_control/optimal_control_gaussian.py --epochs 1
67+
python3 examples/optimal_control/optimal_control_multi_qubit.py --epochs 1
68+
python3 examples/save_load_example/save_load.py
69+
python3 examples/mnist/mnist.py --epochs 1

examples/grover/grover_example_sudoku.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"""
2929

3030
import torchquantum as tq
31-
from torchquantum.algorithms import Grover
31+
from torchquantum.algorithm import Grover
3232

3333

3434
# To simplify the process, we can compile this set of comparisons into a list of clauses for convenience.
@@ -90,4 +90,4 @@ def XOR(input0, input1, output):
9090
print("b = ", key[1])
9191
print("c = ", key[2])
9292
print("d = ", key[3])
93-
print("")
93+
print("")

examples/mnist/mnist.py

+34-32
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def main():
179179
"--static", action="store_true", help="compute with " "static mode"
180180
)
181181
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
182+
parser.add_argument("--qiskit-simulation", action="store_true", help="run on a real quantum computer")
182183
parser.add_argument(
183184
"--wires-per-block", type=int, default=2, help="wires per block int static mode"
184185
)
@@ -243,38 +244,39 @@ def main():
243244
# test
244245
valid_test(dataflow, "test", model, device, qiskit=False)
245246

246-
# run on Qiskit simulator and real Quantum Computers
247-
try:
248-
from qiskit import IBMQ
249-
from torchquantum.plugin import QiskitProcessor
250-
251-
# firstly perform simulate
252-
print(f"\nTest with Qiskit Simulator")
253-
processor_simulation = QiskitProcessor(use_real_qc=False)
254-
model.set_qiskit_processor(processor_simulation)
255-
valid_test(dataflow, "test", model, device, qiskit=True)
256-
257-
# then try to run on REAL QC
258-
backend_name = "ibmq_lima"
259-
print(f"\nTest on Real Quantum Computer {backend_name}")
260-
# Please specify your own hub group and project if you have the
261-
# IBMQ premium plan to access more machines.
262-
processor_real_qc = QiskitProcessor(
263-
use_real_qc=True,
264-
backend_name=backend_name,
265-
hub="ibm-q",
266-
group="open",
267-
project="main",
268-
)
269-
model.set_qiskit_processor(processor_real_qc)
270-
valid_test(dataflow, "test", model, device, qiskit=True)
271-
except ImportError:
272-
print(
273-
"Please install qiskit, create an IBM Q Experience Account and "
274-
"save the account token according to the instruction at "
275-
"'https://github.com/Qiskit/qiskit-ibmq-provider', "
276-
"then try again."
277-
)
247+
if args.qiskit_simulation:
248+
# run on Qiskit simulator and real Quantum Computers
249+
try:
250+
from qiskit import IBMQ
251+
from torchquantum.plugin import QiskitProcessor
252+
253+
# firstly perform simulate
254+
print(f"\nTest with Qiskit Simulator")
255+
processor_simulation = QiskitProcessor(use_real_qc=False)
256+
model.set_qiskit_processor(processor_simulation)
257+
valid_test(dataflow, "test", model, device, qiskit=True)
258+
259+
# then try to run on REAL QC
260+
backend_name = "ibmq_lima"
261+
print(f"\nTest on Real Quantum Computer {backend_name}")
262+
# Please specify your own hub group and project if you have the
263+
# IBMQ premium plan to access more machines.
264+
processor_real_qc = QiskitProcessor(
265+
use_real_qc=True,
266+
backend_name=backend_name,
267+
hub="ibm-q",
268+
group="open",
269+
project="main",
270+
)
271+
model.set_qiskit_processor(processor_real_qc)
272+
valid_test(dataflow, "test", model, device, qiskit=True)
273+
except ImportError:
274+
print(
275+
"Please install qiskit, create an IBM Q Experience Account and "
276+
"save the account token according to the instruction at "
277+
"'https://github.com/Qiskit/qiskit-ibmq-provider', "
278+
"then try again."
279+
)
278280

279281

280282
if __name__ == "__main__":

examples/optimal_control/optimal_control.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,22 @@
2626
import torch.optim as optim
2727

2828
import torchquantum as tq
29-
import pdb
29+
import argparse
3030
import numpy as np
3131

3232
if __name__ == "__main__":
33-
pdb.set_trace()
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
35+
parser.add_argument(
36+
"--epochs", type=int, default=1000, help="number of training epochs"
37+
)
38+
39+
args = parser.parse_args()
40+
41+
if args.pdb:
42+
import pdb
43+
pdb.set_trace()
44+
3445
# target_unitary = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64)
3546
theta = 0.6
3647
target_unitary = torch.tensor(
@@ -41,11 +52,11 @@
4152
dtype=torch.complex64,
4253
)
4354

44-
pulse = tq.QuantumPulseDirect(n_steps=4, hamil=[[0, 1], [1, 0]])
55+
pulse = tq.pulse.QuantumPulseDirect(n_steps=4, hamil=[[0, 1], [1, 0]])
4556

4657
optimizer = optim.Adam(params=pulse.parameters(), lr=5e-3)
4758

48-
for k in range(1000):
59+
for k in range(args.epochs):
4960
# loss = (abs(pulse.get_unitary() - target_unitary)**2).sum()
5061
loss = (
5162
1

examples/optimal_control/optimal_control_gaussian.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,22 @@
2626
import torch.optim as optim
2727

2828
import torchquantum as tq
29-
import pdb
29+
import argparse
3030
import numpy as np
3131

3232
if __name__ == "__main__":
33-
pdb.set_trace()
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
35+
parser.add_argument(
36+
"--epochs", type=int, default=1000, help="number of training epochs"
37+
)
38+
39+
args = parser.parse_args()
40+
41+
if args.pdb:
42+
import pdb
43+
pdb.set_trace()
44+
3445
# target_unitary = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64)
3546
theta = 1.1
3647
target_unitary = torch.tensor(
@@ -41,11 +52,11 @@
4152
dtype=torch.complex64,
4253
)
4354

44-
pulse = tq.QuantumPulseGaussian(hamil=[[0, 1], [1, 0]])
55+
pulse = tq.pulse.QuantumPulseGaussian(hamil=[[0, 1], [1, 0]])
4556

4657
optimizer = optim.Adam(params=pulse.parameters(), lr=5e-3)
4758

48-
for k in range(1000):
59+
for k in range(args.epochs):
4960
# loss = (abs(pulse.get_unitary() - target_unitary)**2).sum()
5061
loss = (
5162
1

examples/optimal_control/optimal_control_multi_qubit.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,22 @@
2626
import torch.optim as optim
2727

2828
import torchquantum as tq
29-
import pdb
29+
import argparse
3030
import numpy as np
3131

3232
if __name__ == "__main__":
33-
pdb.set_trace()
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
35+
parser.add_argument(
36+
"--epochs", type=int, default=1000, help="number of training epochs"
37+
)
38+
39+
args = parser.parse_args()
40+
41+
if args.pdb:
42+
import pdb
43+
pdb.set_trace()
44+
3445
# target_unitary = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64)
3546
theta = 0.6
3647
target_unitary = torch.tensor(
@@ -43,9 +54,9 @@
4354
dtype=torch.complex64,
4455
)
4556

46-
pulse_q0 = tq.QuantumPulseDirect(n_steps=10, hamil=[[0, 1], [1, 0]])
47-
pulse_q1 = tq.QuantumPulseDirect(n_steps=10, hamil=[[0, 1], [1, 0]])
48-
pulse_q01 = tq.QuantumPulseDirect(
57+
pulse_q0 = tq.pulse.QuantumPulseDirect(n_steps=10, hamil=[[0, 1], [1, 0]])
58+
pulse_q1 = tq.pulse.QuantumPulseDirect(n_steps=10, hamil=[[0, 1], [1, 0]])
59+
pulse_q01 = tq.pulse.QuantumPulseDirect(
4960
n_steps=10,
5061
hamil=[
5162
[1, 0, 0, 0],
@@ -62,7 +73,7 @@
6273
lr=5e-3,
6374
)
6475

65-
for k in range(1000):
76+
for k in range(args.epochs):
6677
u_0 = pulse_q0.get_unitary()
6778
u_1 = pulse_q1.get_unitary()
6879
u_01 = pulse_q01.get_unitary()

examples/qaoa/max_cut_backprop.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import random
2929
import numpy as np
30+
import argparse
3031

3132
from torchquantum.functional import mat_dict
3233

@@ -172,6 +173,12 @@ def backprop_optimize(model, n_steps=100, lr=0.1):
172173

173174

174175
def main():
176+
parser = argparse.ArgumentParser()
177+
parser.add_argument(
178+
"--steps", type=int, default=300, help="number of steps"
179+
)
180+
args = parser.parse_args()
181+
175182
# create a input_graph
176183
input_graph = [(0, 1), (0, 3), (1, 2), (2, 3)]
177184
n_wires = 4
@@ -184,7 +191,7 @@ def main():
184191
# print("The circuit is", circ.draw(output="mpl"))
185192
# circ.draw(output="mpl")
186193
# use backprop
187-
backprop_optimize(model, n_steps=300, lr=0.01)
194+
backprop_optimize(model, n_steps=args.steps, lr=0.01)
188195
# use parameter shift rule
189196
# param_shift_optimize(model, n_steps=500, step_size=100000)
190197

examples/quantum_lstm/qlstm.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import torch.nn as nn
3232
import torchquantum as tq
3333
import torchquantum.functional as tqf
34+
import argparse
3435

3536

3637
class QLSTM(nn.Module):
@@ -358,6 +359,19 @@ def plot_history(history_classical, history_quantum):
358359
plt.show()
359360

360361
def main():
362+
parser = argparse.ArgumentParser()
363+
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
364+
parser.add_argument("--display", action="store_true", help="display results with matplotlib")
365+
parser.add_argument(
366+
"--epochs", type=int, default=300, help="number of training epochs"
367+
)
368+
369+
args = parser.parse_args()
370+
371+
if args.pdb:
372+
import pdb
373+
pdb.set_trace()
374+
361375
tag_to_ix = {"DET": 0, "NN": 1, "V": 2} # Assign each tag with a unique index
362376
ix_to_tag = {i:k for k,i in tag_to_ix.items()}
363377

@@ -380,7 +394,7 @@ def main():
380394

381395
embedding_dim = 8
382396
hidden_dim = 6
383-
n_epochs = 300
397+
n_epochs = args.epochs
384398

385399
model_classical = LSTMTagger(embedding_dim,
386400
hidden_dim,
@@ -404,10 +418,8 @@ def main():
404418

405419
print_result(model_quantum, training_data, word_to_ix, ix_to_tag)
406420

407-
plot_history(history_classical, history_quantum)
421+
if args.display:
422+
plot_history(history_classical, history_quantum)
408423

409424
if __name__ == "__main__":
410-
import pdb
411-
pdb.set_trace()
412-
413425
main()

0 commit comments

Comments
 (0)