Skip to content

Commit 3d5fc06

Browse files
committed
chore: EstimatorResult supports multi observables
1 parent 67f2011 commit 3d5fc06

File tree

1 file changed

+30
-11
lines changed
  • mpqp/execution/providers

1 file changed

+30
-11
lines changed

mpqp/execution/providers/ibm.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import TYPE_CHECKING, Optional
77

88
import numpy as np
9+
from typeguard import typechecked
10+
911
from mpqp.core.circuit import QCircuit
1012
from mpqp.core.instruction.gates import TOF, CRk, Gate, Id, P, Rk, Rx, Ry, Rz, T, U
1113
from mpqp.core.instruction.gates.native_gates import NativeGate
@@ -24,7 +26,6 @@
2426
from mpqp.execution.result import BatchResult, Result, Sample, StateVector
2527
from mpqp.noise import DimensionalNoiseModel
2628
from mpqp.tools.errors import DeviceJobIncompatibleError, IBMRemoteExecutionError
27-
from typeguard import typechecked
2829

2930
if TYPE_CHECKING:
3031
from qiskit import QuantumCircuit
@@ -594,6 +595,7 @@ def run_remote_ibm(job: Job) -> BatchResult | Result:
594595
ibm_result = remote_job.result()
595596
if TYPE_CHECKING:
596597
assert isinstance(job.device, IBMDevice)
598+
597599
return extract_result(ibm_result, job, job.device)
598600
# TODO: update this to take into account the case when we have list of Observables
599601

@@ -689,16 +691,33 @@ def extract_result(
689691
)
690692

691693
if isinstance(result, EstimatorResult):
692-
# TODO: do the same for multi observable, for loop over the result.values ?
693-
if job is None:
694-
job = Job(JobType.OBSERVABLE, QCircuit(0), device)
695-
shots = result.metadata[0]["shots"] if "shots" in result.metadata[0] else 0
696-
variance = (
697-
result.metadata[0]["variance"]
698-
if "variance" in result.metadata[0]
699-
else None
700-
)
701-
return Result(job, result.values[0], variance, shots)
694+
all_results = []
695+
696+
for res in result:
697+
res_data = res.data
698+
if hasattr(res_data, "evs"):
699+
if job is None:
700+
job = Job(JobType.OBSERVABLE, QCircuit(0), device)
701+
702+
mean = float(
703+
res_data.evs
704+
) # pyright: ignore[reportAttributeAccessIssue]
705+
error = float(
706+
res_data.stds
707+
) # pyright: ignore[reportAttributeAccessIssue]
708+
shots = (
709+
job.measure.shots
710+
if job.device.is_noisy_simulator() and job.measure is not None
711+
else result[0].metadata["shots"]
712+
)
713+
variance = (
714+
result.metadata[0]["variance"]
715+
if "variance" in result.metadata[0]
716+
else None
717+
)
718+
all_results.append(Result(job, mean, error, shots))
719+
720+
return BatchResult(all_results) if len(all_results) > 1 else all_results[0]
702721

703722
elif isinstance(
704723
result, QiskitResult

0 commit comments

Comments
 (0)