Skip to content

Commit a1c4012

Browse files
author
Zhuoyang Ye
committed
[Fix] Encoding for density matrix. An example in mnist_new_noise.py
1 parent b48c966 commit a1c4012

21 files changed

+453
-22
lines changed

examples/PauliSumOp/pauli_sum_op_noise.py

Whitespace-only changes.

examples/amplitude_encoding_mnist/mnist_example.py

+13
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,23 @@ def forward(self, x, use_qiskit=False):
100100
bsz = x.shape[0]
101101
x = F.avg_pool2d(x, 6).view(bsz, 16)
102102

103+
104+
print("Shape 1:")
105+
print(self.q_device.states.shape)
103106
self.encoder(self.q_device, x)
104107
self.q_layer(self.q_device)
108+
109+
110+
111+
print("X shape before measurement")
112+
print(x.shape)
113+
105114
x = self.measure(self.q_device)
106115

116+
117+
print("X shape after measurement")
118+
print(x.shape)
119+
107120
x = x.reshape(bsz, 2, 2).sum(-1).squeeze()
108121
x = F.log_softmax(x, dim=1)
109122

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) 2020-present TorchQuantum Authors
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
"""
24+
25+
import torch
26+
import torch.nn.functional as F
27+
import torch.optim as optim
28+
import argparse
29+
30+
import torchquantum as tq
31+
import torchquantum.functional as tqf
32+
33+
from torchquantum.dataset import MNIST
34+
from torch.optim.lr_scheduler import CosineAnnealingLR
35+
36+
import random
37+
import numpy as np
38+
39+
40+
class QFCModel(tq.QuantumModule):
41+
class QLayer(tq.QuantumModule):
42+
def __init__(self):
43+
super().__init__()
44+
self.n_wires = 4
45+
self.random_layer = tq.RandomLayer(
46+
n_ops=50, wires=list(range(self.n_wires))
47+
)
48+
49+
# gates with trainable parameters
50+
self.rx0 = tq.RX(has_params=True, trainable=True)
51+
self.ry0 = tq.RY(has_params=True, trainable=True)
52+
self.rz0 = tq.RZ(has_params=True, trainable=True)
53+
self.crx0 = tq.CRX(has_params=True, trainable=True)
54+
55+
@tq.static_support
56+
def forward(self, q_device: tq.NoiseDevice):
57+
"""
58+
1. To convert tq QuantumModule to qiskit or run in the static
59+
model, need to:
60+
(1) add @tq.static_support before the forward
61+
(2) make sure to add
62+
static=self.static_mode and
63+
parent_graph=self.graph
64+
to all the tqf functions, such as tqf.hadamard below
65+
"""
66+
self.q_device = q_device
67+
68+
self.random_layer(self.q_device)
69+
70+
# some trainable gates (instantiated ahead of time)
71+
self.rx0(self.q_device, wires=0)
72+
self.ry0(self.q_device, wires=1)
73+
self.rz0(self.q_device, wires=3)
74+
self.crx0(self.q_device, wires=[0, 2])
75+
76+
# add some more non-parameterized gates (add on-the-fly)
77+
tqf.hadamard(
78+
self.q_device, wires=3, static=self.static_mode, parent_graph=self.graph
79+
)
80+
tqf.sx(
81+
self.q_device, wires=2, static=self.static_mode, parent_graph=self.graph
82+
)
83+
tqf.cnot(
84+
self.q_device,
85+
wires=[3, 0],
86+
static=self.static_mode,
87+
parent_graph=self.graph,
88+
)
89+
90+
def __init__(self):
91+
super().__init__()
92+
self.n_wires = 4
93+
self.q_device = tq.NoiseDevice(n_wires=self.n_wires,
94+
noise_model=tq.NoiseModel(kraus_dict={"Bitflip": 0.08, "Phaseflip": 0.08})
95+
)
96+
self.encoder = tq.AmplitudeEncoder()
97+
98+
self.q_layer = self.QLayer()
99+
self.measure = tq.MeasureAll_density(tq.PauliZ)
100+
101+
def forward(self, x, use_qiskit=False):
102+
bsz = x.shape[0]
103+
x = F.avg_pool2d(x, 6).view(bsz, 16)
104+
self.encoder(self.q_device, x)
105+
self.q_layer(self.q_device)
106+
x = self.measure(self.q_device)
107+
x = x.reshape(bsz, 2, 2).sum(-1).squeeze()
108+
x = F.log_softmax(x, dim=1)
109+
return x
110+
111+
112+
def train(dataflow, model, device, optimizer):
113+
for feed_dict in dataflow["train"]:
114+
inputs = feed_dict["image"].to(device)
115+
targets = feed_dict["digit"].to(device)
116+
117+
outputs = model(inputs)
118+
loss = F.nll_loss(outputs, targets)
119+
optimizer.zero_grad()
120+
loss.backward()
121+
optimizer.step()
122+
print(f"loss: {loss.item()}", end="\r")
123+
124+
125+
def valid_test(dataflow, split, model, device, qiskit=False):
126+
target_all = []
127+
output_all = []
128+
with torch.no_grad():
129+
for feed_dict in dataflow[split]:
130+
inputs = feed_dict["image"].to(device)
131+
targets = feed_dict["digit"].to(device)
132+
133+
outputs = model(inputs, use_qiskit=qiskit)
134+
135+
target_all.append(targets)
136+
output_all.append(outputs)
137+
target_all = torch.cat(target_all, dim=0)
138+
output_all = torch.cat(output_all, dim=0)
139+
140+
_, indices = output_all.topk(1, dim=1)
141+
masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
142+
size = target_all.shape[0]
143+
corrects = masks.sum().item()
144+
accuracy = corrects / size
145+
loss = F.nll_loss(output_all, target_all).item()
146+
147+
print(f"{split} set accuracy: {accuracy}")
148+
print(f"{split} set loss: {loss}")
149+
150+
151+
def main():
152+
parser = argparse.ArgumentParser()
153+
parser.add_argument(
154+
"--static", action="store_true", help="compute with " "static mode"
155+
)
156+
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
157+
parser.add_argument(
158+
"--wires-per-block", type=int, default=2, help="wires per block int static mode"
159+
)
160+
parser.add_argument(
161+
"--epochs", type=int, default=5, help="number of training epochs"
162+
)
163+
164+
args = parser.parse_args()
165+
166+
if args.pdb:
167+
import pdb
168+
169+
pdb.set_trace()
170+
171+
seed = 0
172+
random.seed(seed)
173+
np.random.seed(seed)
174+
torch.manual_seed(seed)
175+
176+
dataset = MNIST(
177+
root="./mnist_data",
178+
train_valid_split_ratio=[0.9, 0.1],
179+
digits_of_interest=[3, 6],
180+
n_test_samples=75,
181+
)
182+
dataflow = dict()
183+
184+
for split in dataset:
185+
sampler = torch.utils.data.RandomSampler(dataset[split])
186+
dataflow[split] = torch.utils.data.DataLoader(
187+
dataset[split],
188+
batch_size=256,
189+
sampler=sampler,
190+
num_workers=8,
191+
pin_memory=True,
192+
)
193+
194+
use_cuda = torch.cuda.is_available()
195+
device = torch.device("cuda" if use_cuda else "cpu")
196+
197+
model = QFCModel().to(device)
198+
199+
n_epochs = args.epochs
200+
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
201+
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)
202+
203+
if args.static:
204+
# optionally to switch to the static mode, which can bring speedup
205+
# on training
206+
model.q_layer.static_on(wires_per_block=args.wires_per_block)
207+
208+
for epoch in range(1, n_epochs + 1):
209+
# train
210+
print(f"Epoch {epoch}:")
211+
train(dataflow, model, device, optimizer)
212+
print(optimizer.param_groups[0]["lr"])
213+
214+
# valid
215+
valid_test(dataflow, "valid", model, device)
216+
scheduler.step()
217+
218+
# test
219+
valid_test(dataflow, "test", model, device, qiskit=False)
220+
221+
222+
if __name__ == "__main__":
223+
main()

examples/amplitude_encoding_mnist/mnist_new.py

+1
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,4 @@ def train_tq(model, device, train_dl, epochs, loss_fn, optimizer):
171171

172172
print("--Training--")
173173
train_losses = train_tq(model, device, train_dl, 1, loss_fn, optimizer)
174+

0 commit comments

Comments
 (0)