Skip to content

Commit e077d2f

Browse files
author
Zhuoyang Ye
committed
[Fix] Fix the dimension bug of expeval of density matrix.
1 parent 00e6dab commit e077d2f

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

examples/mnist/mnist_noise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self):
8282
self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["4x4_u3_h_rx"])
8383

8484
self.q_layer = self.QLayer()
85-
self.measure = tq.MeasureAll(tq.PauliZ)
85+
self.measure = tq.MeasureAll_Density(tq.PauliZ)
8686

8787
def forward(self, x, use_qiskit=False):
8888
qdev = tq.NoiseDevice(

torchquantum/device/noisedevices.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,10 @@ def __init__(
7373
self.record_op = record_op
7474
self.op_history = []
7575

76-
7776
def reset_op_history(self):
7877
"""Resets the all Operation of the quantum device"""
7978
self.op_history = []
8079

81-
8280
def print_2d(self, index):
8381
"""Print the matrix value at the given index.
8482
@@ -125,12 +123,16 @@ def get_probs_1d(self):
125123
"""Return the states in a 1d tensor."""
126124
bsz = self.densities.shape[0]
127125
densities2d = torch.reshape(self.densities, [bsz, 2 ** self.n_wires, 2 ** self.n_wires])
128-
return torch.diagonal(densities2d, offset=0, dim1=1, dim2=2)
126+
return torch.abs(torch.diagonal(densities2d, offset=0, dim1=1, dim2=2))
129127

130128
def get_prob_1d(self):
131129
"""Return the state in a 1d tensor."""
132130
density2d = torch.reshape(self.density, [2 ** self.n_wires, 2 ** self.n_wires])
133-
return torch.diagonal(density2d, offset=0, dim1=0, dim2=1)
131+
return torch.abs(torch.diagonal(density2d, offset=0, dim1=0, dim2=1))
132+
133+
def clone_densities(self, existing_densities: torch.Tensor):
134+
"""Clone the densities of the other quantum device."""
135+
self.densities = existing_densities.clone()
134136

135137

136138
for func_name, func in func_name_dict.items():

torchquantum/measurement/density_measurements.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
import torchquantum.operator as op
1515
from copy import deepcopy
1616
import matplotlib.pyplot as plt
17-
from measurements import gen_bitstrings
18-
from measurements import find_observable_groups
17+
from .measurements import gen_bitstrings
18+
from .measurements import find_observable_groups
1919

2020
__all__ = [
2121
"expval_joint_sampling_grouping",
2222
"expval_joint_analytical",
2323
"expval_joint_sampling",
2424
"expval",
2525
"measure",
26+
"MeasureAll_Density"
2627
]
2728

2829

@@ -194,7 +195,7 @@ def expval(
194195
wires: Union[int, List[int]],
195196
observables: Union[op.Observable, List[op.Observable]],
196197
):
197-
all_dims = np.arange(noisedev.densities.dim())
198+
all_dims = np.arange(noisedev.n_wires+1)
198199
if isinstance(wires, int):
199200
wires = [wires]
200201
observables = [observables]
@@ -206,7 +207,8 @@ def expval(
206207

207208
# compute magnitude
208209
state_mag = noisedev.get_probs_1d()
209-
210+
bsz = state_mag.shape[0]
211+
state_mag = torch.reshape(state_mag, [bsz] + [2] * noisedev.n_wires)
210212
expectations = []
211213
for wire, observable in zip(wires, observables):
212214
# compute marginal magnitude
@@ -221,7 +223,7 @@ def expval(
221223
return torch.stack(expectations, dim=-1)
222224

223225

224-
class MeasureAll(tq.QuantumModule):
226+
class MeasureAll_Density(tq.QuantumModule):
225227
"""Obtain the expectation value of all the qubits."""
226228

227229
def __init__(self, obs, v_c_reg_mapping=None):
@@ -265,11 +267,7 @@ def set_v_c_reg_mapping(self, mapping):
265267
# measure the state on z basis
266268
print(tq.measure(qdev, n_shots=1024))
267269

268-
269-
270-
'''
271270
# obtain the expval on a observable
272271
expval = expval_joint_sampling(qdev, 'II', 100000)
273-
expval_ana = expval_joint_analytical(qdev, 'II')
274-
print(expval, expval_ana)
275-
'''
272+
# expval_ana = expval_joint_analytical(qdev, 'II')
273+
print(expval)

0 commit comments

Comments
 (0)