Skip to content

Commit 865b836

Browse files
author
Zhuoyang Ye
committed
[Rename] Rename.
1 parent ad88a79 commit 865b836

6 files changed

+91
-23
lines changed

examples/mnist/mnist_noise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self):
8484
self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["4x4_u3_h_rx"])
8585

8686
self.q_layer = self.QLayer()
87-
self.measure = tq.MeasureAll_Density(tq.PauliZ)
87+
self.measure = tq.MeasureAll_density(tq.PauliZ)
8888

8989
def forward(self, x, use_qiskit=False):
9090
qdev = tq.NoiseDevice(

test/density/test_density_measure.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import numpy as np
3030

3131

32-
def test_measure():
32+
def test_measure_density():
3333
n_shots = 10000
3434
qdev = tq.NoiseDevice(n_wires=3, bsz=1, record_op=True)
3535
qdev.x(wires=2) # type: ignore

test/density/test_eval_observable_density.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
}
4444

4545

46-
def test_expval_observable():
46+
def test_expval_observable_density():
4747
# seed = 0
4848
# random.seed(seed)
4949
# np.random.seed(seed)

test/density/test_expval_joint_sampling_grouping_density.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import random
3333

3434

35-
def test_expval_joint_sampling_grouping():
35+
def test_expval_joint_sampling_grouping_density():
3636
n_obs = 20
3737
n_wires = 4
3838
obs_all = []
@@ -59,4 +59,4 @@ def test_expval_joint_sampling_grouping():
5959

6060

6161
if __name__ == "__main__":
62-
test_expval_joint_sampling_grouping()
62+
test_expval_joint_sampling_grouping_density()

torchquantum/device/noisedevices.py

+13
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,19 @@ def get_2d_matrix(self, index):
124124
_matrix = torch.reshape(self.densities[index], [2 ** self.n_wires] * 2)
125125
return _matrix
126126

127+
128+
def get_densities_2d(self):
129+
"""Return the states in a 1d tensor."""
130+
bsz = self.densities.shape[0]
131+
return torch.reshape(self.densities, [bsz, 2**self.n_wires, 2**self.n_wires])
132+
133+
def get_density_2d(self):
134+
"""Return the state in a 1d tensor."""
135+
return torch.reshape(self.density, [2**self.n_wires,2**self.n_wires])
136+
137+
138+
139+
127140
def calc_trace(self, index):
128141
_matrix = torch.reshape(self.densities[index], [2 ** self.n_wires] * 2)
129142
return torch.trace(_matrix)

torchquantum/measurement/density_measurements.py

+73-18
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
from .measurements import find_observable_groups
1919

2020
__all__ = [
21-
"expval_joint_sampling_grouping",
22-
"expval_joint_analytical",
23-
"expval_joint_sampling",
24-
"expval",
25-
"measure",
26-
"MeasureAll_Density"
21+
"expval_joint_sampling_grouping_density",
22+
"expval_joint_sampling_density",
23+
"expval_joint_analytical_density",
24+
"expval_density",
25+
"measure_density",
26+
"MeasureAll_density"
2727
]
2828

2929

30-
def measure(noisedev: tq.NoiseDevice, n_shots=1024, draw_id=None):
30+
def measure_density(noisedev: tq.NoiseDevice, n_shots=1024, draw_id=None):
3131
"""Measure the target density matrix and obtain classical bitstream distribution
3232
Args:
3333
noisedev: input tq.NoiseDevice
@@ -62,7 +62,7 @@ def measure(noisedev: tq.NoiseDevice, n_shots=1024, draw_id=None):
6262
return distri_all
6363

6464

65-
def expval_joint_sampling_grouping(
65+
def expval_joint_sampling_grouping_density(
6666
noisedev: tq.NoiseDevice,
6767
observables: List[str],
6868
n_shots_per_group=1024,
@@ -85,7 +85,7 @@ def expval_joint_sampling_grouping(
8585

8686
expval_all_obs = {}
8787
for obs_group, obs_elements in groups.items():
88-
# for each group need to clone a new qdev and its states
88+
# for each group need to clone a new qdev and its densities
8989
noisedev_clone = tq.NoiseDevice(n_wires=noisedev.n_wires, bsz=noisedev.bsz, device=noisedev.device)
9090
noisedev_clone.clone_densities(noisedev.densities)
9191

@@ -94,7 +94,7 @@ def expval_joint_sampling_grouping(
9494
rotation(noisedev_clone, wires=wire)
9595

9696
# measure
97-
distributions = measure(noisedev_clone, n_shots=n_shots_per_group)
97+
distributions = measure_density(noisedev_clone, n_shots=n_shots_per_group)
9898
# interpret the distribution for different observable elements
9999
for obs_element in obs_elements:
100100
expval_all = []
@@ -118,15 +118,70 @@ def expval_joint_sampling_grouping(
118118
return expval_all_obs
119119

120120

121-
def expval_joint_sampling(
121+
def expval_joint_sampling_density(
122122
qdev: tq.NoiseDevice,
123123
observable: str,
124124
n_shots=1024,
125125
):
126-
return
126+
"""
127+
Compute the expectation value of a joint observable from sampling
128+
the measurement bistring
129+
Args:
130+
qdev: the noise device
131+
observable: the joint observable, on the qubit 0, 1, 2, 3, etc in this order
132+
Returns:
133+
the expectation value
134+
Examples:
135+
>>> import torchquantum as tq
136+
>>> import torchquantum.functional as tqf
137+
>>> x = tq.QuantumDevice(n_wires=2)
138+
>>> tqf.hadamard(x, wires=0)
139+
>>> tqf.x(x, wires=1)
140+
>>> tqf.cnot(x, wires=[0, 1])
141+
>>> print(expval_joint_sampling(x, 'II', n_shots=8192))
142+
tensor([[0.9997]])
143+
>>> print(expval_joint_sampling(x, 'XX', n_shots=8192))
144+
tensor([[0.9991]])
145+
>>> print(expval_joint_sampling(x, 'ZZ', n_shots=8192))
146+
tensor([[-0.9980]])
147+
"""
148+
# rotation to the desired basis
149+
n_wires = qdev.n_wires
150+
paulix = op.op_name_dict["paulix"]
151+
pauliy = op.op_name_dict["pauliy"]
152+
pauliz = op.op_name_dict["pauliz"]
153+
iden = op.op_name_dict["i"]
154+
pauli_dict = {"X": paulix, "Y": pauliy, "Z": pauliz, "I": iden}
155+
156+
qdev_clone = tq.NoiseDevice(n_wires=qdev.n_wires, bsz=qdev.bsz, device=qdev.device)
157+
qdev_clone.clone_densities(qdev.densities)
158+
159+
observable = observable.upper()
160+
for wire in range(n_wires):
161+
for rotation in pauli_dict[observable[wire]]().diagonalizing_gates():
162+
rotation(qdev_clone, wires=wire)
127163

164+
mask = np.ones(len(observable), dtype=bool)
165+
mask[np.array([*observable]) == "I"] = False
166+
167+
expval_all = []
168+
# measure
169+
distributions = measure_density(qdev_clone, n_shots=n_shots)
170+
for distri in distributions:
171+
n_eigen_one = 0
172+
n_eigen_minus_one = 0
173+
for bitstring, n_count in distri.items():
174+
if np.dot(list(map(lambda x: eval(x), [*bitstring])), mask).sum() % 2 == 0:
175+
n_eigen_one += n_count
176+
else:
177+
n_eigen_minus_one += n_count
178+
179+
expval = n_eigen_one / n_shots + (-1) * n_eigen_minus_one / n_shots
180+
expval_all.append(expval)
181+
182+
return torch.tensor(expval_all, dtype=F_DTYPE)
128183

129-
def expval_joint_analytical(
184+
def expval_joint_analytical_density(
130185
noisedev: tq.NoiseDevice,
131186
observable: str,
132187
n_shots=1024
@@ -174,7 +229,7 @@ def expval_joint_analytical(
174229

175230
expval_all = []
176231
# measure
177-
distributions = measure(noisedev_clone, n_shots=n_shots)
232+
distributions = measure_density(noisedev_clone, n_shots=n_shots)
178233
for distri in distributions:
179234
n_eigen_one = 0
180235
n_eigen_minus_one = 0
@@ -190,7 +245,7 @@ def expval_joint_analytical(
190245
return torch.tensor(expval_all, dtype=F_DTYPE)
191246

192247

193-
def expval(
248+
def expval_density(
194249
noisedev: tq.NoiseDevice,
195250
wires: Union[int, List[int]],
196251
observables: Union[op.Observable, List[op.Observable]],
@@ -223,7 +278,7 @@ def expval(
223278
return torch.stack(expectations, dim=-1)
224279

225280

226-
class MeasureAll_Density(tq.QuantumModule):
281+
class MeasureAll_density(tq.QuantumModule):
227282
"""Obtain the expectation value of all the qubits."""
228283

229284
def __init__(self, obs, v_c_reg_mapping=None):
@@ -265,9 +320,9 @@ def set_v_c_reg_mapping(self, mapping):
265320
op(qdev, wires=0)
266321

267322
# measure the state on z basis
268-
print(tq.measure(qdev, n_shots=1024))
323+
print(tq.measure_density(qdev, n_shots=1024))
269324

270325
# obtain the expval on a observable
271-
expval = expval_joint_sampling(qdev, 'II', 100000)
326+
expval = expval_joint_sampling_density(qdev, 'II', 100000)
272327
# expval_ana = expval_joint_analytical(qdev, 'II')
273328
print(expval)

0 commit comments

Comments
 (0)