14
14
import torchquantum .operator as op
15
15
from copy import deepcopy
16
16
import matplotlib .pyplot as plt
17
- from measurements import gen_bitstrings
18
- from measurements import find_observable_groups
17
+ from . measurements import gen_bitstrings
18
+ from . measurements import find_observable_groups
19
19
20
20
__all__ = [
21
21
"expval_joint_sampling_grouping" ,
22
22
"expval_joint_analytical" ,
23
23
"expval_joint_sampling" ,
24
24
"expval" ,
25
25
"measure" ,
26
+ "MeasureAll_Density"
26
27
]
27
28
28
29
@@ -194,7 +195,7 @@ def expval(
194
195
wires : Union [int , List [int ]],
195
196
observables : Union [op .Observable , List [op .Observable ]],
196
197
):
197
- all_dims = np .arange (noisedev .densities . dim () )
198
+ all_dims = np .arange (noisedev .n_wires + 1 )
198
199
if isinstance (wires , int ):
199
200
wires = [wires ]
200
201
observables = [observables ]
@@ -206,7 +207,8 @@ def expval(
206
207
207
208
# compute magnitude
208
209
state_mag = noisedev .get_probs_1d ()
209
-
210
+ bsz = state_mag .shape [0 ]
211
+ state_mag = torch .reshape (state_mag , [bsz ] + [2 ] * noisedev .n_wires )
210
212
expectations = []
211
213
for wire , observable in zip (wires , observables ):
212
214
# compute marginal magnitude
@@ -221,7 +223,7 @@ def expval(
221
223
return torch .stack (expectations , dim = - 1 )
222
224
223
225
224
- class MeasureAll (tq .QuantumModule ):
226
+ class MeasureAll_Density (tq .QuantumModule ):
225
227
"""Obtain the expectation value of all the qubits."""
226
228
227
229
def __init__ (self , obs , v_c_reg_mapping = None ):
@@ -265,11 +267,7 @@ def set_v_c_reg_mapping(self, mapping):
265
267
# measure the state on z basis
266
268
print (tq .measure (qdev , n_shots = 1024 ))
267
269
268
-
269
-
270
- '''
271
270
# obtain the expval on a observable
272
271
expval = expval_joint_sampling (qdev , 'II' , 100000 )
273
- expval_ana = expval_joint_analytical(qdev, 'II')
274
- print(expval, expval_ana)
275
- '''
272
+ # expval_ana = expval_joint_analytical(qdev, 'II')
273
+ print (expval )
0 commit comments