Skip to content

Commit 964ec12

Browse files
fix: fixing the multi-observable local run
1 parent 1c95da1 commit 964ec12

File tree

2 files changed

+53
-76
lines changed

2 files changed

+53
-76
lines changed

mpqp/execution/providers/ibm.py

+52-75
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from typing import TYPE_CHECKING, Optional
77

88
import numpy as np
9-
from typeguard import typechecked
10-
119
from mpqp.core.circuit import QCircuit
1210
from mpqp.core.instruction.gates import TOF, CRk, Gate, Id, P, Rk, Rx, Ry, Rz, T, U
1311
from mpqp.core.instruction.gates.native_gates import NativeGate
@@ -26,6 +24,7 @@
2624
from mpqp.execution.result import BatchResult, Result, Sample, StateVector
2725
from mpqp.noise import DimensionalNoiseModel
2826
from mpqp.tools.errors import DeviceJobIncompatibleError, IBMRemoteExecutionError
27+
from typeguard import typechecked
2928

3029
if TYPE_CHECKING:
3130
from qiskit import QuantumCircuit
@@ -140,11 +139,6 @@ def compute_expectation_value(
140139
}
141140
estimator = Estimator(options=options)
142141

143-
# TODO: implement the possibility to compute several expectation values at
144-
# the same time when the circuit is the same apparently the estimator.run()
145-
# can take several circuits and observables at the same time, because
146-
# putting them all together will increase the performance
147-
148142
job.status = JobStatus.RUNNING
149143
circuits_and_observables = [(ibm_circuit, obs) for obs in qiskit_observables]
150144
job_expectation = estimator.run(circuits_and_observables)
@@ -153,15 +147,7 @@ def compute_expectation_value(
153147
if TYPE_CHECKING:
154148
assert isinstance(job.device, (IBMDevice, IBMSimulatedDevice))
155149

156-
if len(qiskit_observables) > 1:
157-
return BatchResult(
158-
[
159-
extract_result(estimator_result, job, job.device)
160-
for _ in range(len(qiskit_observables))
161-
]
162-
)
163-
else:
164-
return extract_result(estimator_result, job, job.device)
150+
return extract_result(estimator_result, job, job.device)
165151

166152

167153
@typechecked
@@ -597,10 +583,7 @@ def submit_remote_ibm(job: Job) -> tuple[str, "RuntimeJobV2"]:
597583

598584
job.id = ibm_job.job_id()
599585

600-
return (
601-
job.id,
602-
ibm_job,
603-
)
586+
return job.id, ibm_job
604587

605588

606589
@typechecked
@@ -622,17 +605,16 @@ def run_remote_ibm(job: Job) -> BatchResult | Result:
622605
ibm_result = remote_job.result()
623606
if TYPE_CHECKING:
624607
assert isinstance(job.device, IBMDevice)
625-
return extract_result(
626-
ibm_result, job, job.device
627-
) # TODO: update this to take into account the case when we have list of Observables
608+
return extract_result(ibm_result, job, job.device)
609+
# TODO: update this to take into account the case when we have list of Observables
628610

629611

630612
@typechecked
631613
def extract_result(
632614
result: "QiskitResult | EstimatorResult | PrimitiveResult[PubResult | SamplerPubResult]",
633615
job: Optional[Job],
634616
device: "IBMDevice | IBMSimulatedDevice | AZUREDevice",
635-
) -> BatchResult | Result: # TODO: [multi-obs] return BatchResult for multi observable
617+
) -> BatchResult | Result:
636618
"""Parses a result from ``IBM`` execution (remote or local) in a ``MPQP``
637619
:class:`~mpqp.execution.result.Result`.
638620
@@ -651,20 +633,18 @@ def extract_result(
651633

652634
# If this is a PubResult from primitives V2
653635
if isinstance(result, PrimitiveResult):
654-
res_data = result[0].data
655-
# res_data is a DataBin, which means all typechecking is out of the
656-
# windows for this specific object
657636
all_results = []
658-
res_data = result
659637

660-
results = res_data.evs if isinstance(res_data.evs, list) else [res_data.evs]
661-
for single_result in results:
638+
for res in result:
639+
res_data = res.data
640+
# res_data is a DataBin, which means all typechecking is out of the
641+
# windows for this specific object
662642
if hasattr(res_data, "evs"): #
663643
if job is None:
664644
job = Job(JobType.OBSERVABLE, QCircuit(0), device)
665645

666646
mean = float(
667-
single_result
647+
res_data.evs
668648
) # pyright: ignore[reportAttributeAccessIssue]
669649
error = float(
670650
res_data.stds
@@ -676,50 +656,38 @@ def extract_result(
676656
)
677657
all_results.append(Result(job, mean, error, shots))
678658

679-
if len(all_results) > 1:
680-
return BatchResult(all_results)
681-
682-
return all_results[0]
683-
# if hasattr(res_data, "evs"):
684-
# if job is None:
685-
# job = Job(JobType.OBSERVABLE, QCircuit(0), device)
686-
687-
# mean = float(res_data.evs) # pyright: ignore[reportAttributeAccessIssue]
688-
# error = float(res_data.stds) # pyright: ignore[reportAttributeAccessIssue]
689-
# shots = (
690-
# job.measure.shots
691-
# if job.device.is_simulator() and job.measure is not None
692-
# else result[0].metadata["shots"]
693-
# )
694-
# return Result(job, mean, error, shots)
695-
# If we are in sample mode
696-
else:
697-
if job is None:
698-
shots = (
699-
res_data.c.num_shots # pyright: ignore[reportAttributeAccessIssue]
700-
)
701-
nb_qubits = (
702-
res_data.c.num_bits # pyright: ignore[reportAttributeAccessIssue]
703-
)
704-
job = Job(
705-
JobType.SAMPLE,
706-
QCircuit(nb_qubits),
707-
device,
708-
BasisMeasure(list(range(nb_qubits)), shots=shots),
709-
)
710-
if TYPE_CHECKING:
711-
assert job.measure is not None
659+
else:
660+
if job is None:
661+
shots = (
662+
res_data.c.num_shots # pyright: ignore[reportAttributeAccessIssue]
663+
)
664+
nb_qubits = (
665+
res_data.c.num_bits # pyright: ignore[reportAttributeAccessIssue]
666+
)
667+
job = Job(
668+
JobType.SAMPLE,
669+
QCircuit(nb_qubits),
670+
device,
671+
BasisMeasure(list(range(nb_qubits)), shots=shots),
672+
)
673+
if TYPE_CHECKING:
674+
assert job.measure is not None
712675

713-
counts = (
714-
res_data.c.get_counts() # pyright: ignore[reportAttributeAccessIssue]
715-
)
716-
data = [
717-
Sample(
718-
bin_str=item, count=counts[item], nb_qubits=job.circuit.nb_qubits
676+
counts = (
677+
res_data.c.get_counts() # pyright: ignore[reportAttributeAccessIssue]
719678
)
720-
for item in counts
721-
]
722-
return Result(job, data, None, job.measure.shots)
679+
data = [
680+
Sample(
681+
bin_str=item,
682+
count=counts[item],
683+
nb_qubits=job.circuit.nb_qubits,
684+
)
685+
for item in counts
686+
]
687+
# Since we don't handle multiple sampling jobs, we know the first result is the only one
688+
return Result(job, data, None, job.measure.shots)
689+
690+
return BatchResult(all_results) if len(all_results) > 1 else all_results[0]
723691

724692
else:
725693

@@ -799,7 +767,7 @@ def extract_result(
799767

800768

801769
@typechecked
802-
def get_result_from_ibm_job_id(job_id: str) -> Result:
770+
def get_result_from_ibm_job_id(job_id: str) -> Result | BatchResult:
803771
"""Retrieves from IBM remote platform and parse the result of the job_id
804772
given in parameter. If the job is still running, we wait (blocking) until it
805773
is ``DONE``.
@@ -808,7 +776,7 @@ def get_result_from_ibm_job_id(job_id: str) -> Result:
808776
job_id: Id of the remote IBM job.
809777
810778
Returns:
811-
The result converted to our format.
779+
The result (or batch of result) converted to our format.
812780
"""
813781
from qiskit.providers import BackendV1, BackendV2
814782

@@ -841,6 +809,15 @@ def get_result_from_ibm_job_id(job_id: str) -> Result:
841809

842810

843811
def extract_samples(job: Job, result: QiskitResult) -> list[Sample]:
812+
"""
813+
TODO comment
814+
Args:
815+
job:
816+
result:
817+
818+
Returns:
819+
820+
"""
844821
counts = result.get_counts(0)
845822
job_data = result.data()
846823
return [

mpqp/execution/runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _run_single(
158158
device: AvailableDevice,
159159
values: dict[Expr | str, Complex],
160160
display_breakpoints: bool = True,
161-
) -> Result:
161+
) -> Result | BatchResult:
162162
"""Runs the circuit on the ``backend``. If the circuit depends on variables,
163163
the ``values`` given in parameters are used to do the substitution.
164164

0 commit comments

Comments
 (0)