Skip to content

Commit ce6df23

Browse files
feat: prep the variance computation on cirq
1 parent 6f8ade3 commit ce6df23

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

mpqp/execution/providers/google.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mpqp.execution.devices import GOOGLEDevice
1717
from mpqp.execution.job import Job, JobType
1818
from mpqp.execution.result import Result, Sample, StateVector
19+
from mpqp.tools.generics import flatten
1920

2021

2122
@typechecked
@@ -114,7 +115,8 @@ def run_local(job: Job) -> Result:
114115
ValueError: If the job device is not GOOGLEDevice.
115116
"""
116117
from cirq.circuits.circuit import Circuit as CirqCircuit
117-
from cirq.ops.linear_combinations import PauliSum as Cirq_PauliSum
118+
from cirq.ops.linear_combinations import PauliSum as CirqPauliSum
119+
from cirq.ops.pauli_string import PauliString as CirqPauliString
118120
from cirq.sim.sparse_simulator import Simulator
119121
from cirq.work.observable_measurement import (
120122
RepetitionsStoppingCriteria,
@@ -150,7 +152,9 @@ def run_local(job: Job) -> Result:
150152
cirq_obs = job.measure.observable.to_other_language(
151153
language=Language.CIRQ, circuit=cirq_circuit
152154
)
153-
assert type(cirq_obs) == Cirq_PauliSum
155+
assert isinstance(cirq_obs, CirqPauliSum) or isinstance(
156+
cirq_obs, CirqPauliString
157+
)
154158

155159
if job.measure.shots == 0:
156160
result_sim = simulator.simulate_expectation_values(
@@ -159,11 +163,11 @@ def run_local(job: Job) -> Result:
159163
else:
160164
result_sim = measure_observables(
161165
cirq_circuit,
162-
observables=cirq_obs, # type: ignore[reportArgumentType]
166+
observables=flatten(cirq_obs),
163167
sampler=simulator,
164168
stopping_criteria=RepetitionsStoppingCriteria(job.measure.shots),
165169
)
166-
170+
print(result_sim)
167171
return extract_result_OBSERVABLE(result_sim, job)
168172
else:
169173
raise ValueError(f"Job type {job.job_type} not handled")
@@ -313,7 +317,7 @@ def extract_result_OBSERVABLE(
313317
raise NotImplementedError("job.measure is None")
314318
for result in results:
315319
if isinstance(result, float):
316-
mean += abs(result)
320+
mean += result
317321
if isinstance(result, ObservableMeasuredResult):
318322
mean += result.mean
319323
# TODO variance not supported variance += result1.variance

0 commit comments

Comments
 (0)