Skip to content

Commit d718968

Browse files
Merge pull request #28 from ColibrITD-SAS/feat-capture-braket-hardprint
capture and handle hard print warnings from Braket
2 parents d2f0158 + a28f1d0 commit d718968

File tree

6 files changed

+116
-47
lines changed

6 files changed

+116
-47
lines changed

mpqp/execution/result.py

-16
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,3 @@
1-
######################################
2-
# Copyright(C) 2021 - 2023 ColibrITD
3-
#
4-
# Developers :
5-
# - Hamza JAFFALI < hamza.jaffali@colibritd.com >
6-
# - Karla BAUMANN < karla.baumann@colibritd.com >
7-
# - Henri de BOUTRAY < henri.de.boutray@colibritd.com >
8-
#
9-
# Version : 0.1
10-
#
11-
# This file is part of QUICK.
12-
#
13-
# QUICK can not be copied and / or distributed without the express
14-
# permission of ColibrITD
15-
#
16-
######################################
171
from __future__ import annotations
182

193
import math

mpqp/qasm/qasm_to_braket.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""File regrouping all features for translating QASM code to Amazon Braket objects."""
22

3+
import io
4+
import warnings
5+
from logging import StreamHandler, getLogger
6+
37
from braket.ir.openqasm import Program
48
from braket.circuits import Circuit
59
from typeguard import typechecked
@@ -30,16 +34,14 @@ def qasm3_to_braket_Program(qasm3_str: str) -> Program:
3034

3135
@typechecked
3236
def qasm3_to_braket_Circuit(qasm3_str: str) -> Circuit:
33-
"""
34-
Converting a OpenQASM 3.0 code into a Braket Circuit
37+
"""Converting a OpenQASM 3.0 code into a Braket Circuit.
3538
3639
Args:
3740
qasm3_str: A string representing the OpenQASM 3.0 code.
3841
3942
Returns:
4043
A Circuit equivalent to the QASM code in parameter.
4144
"""
42-
4345
# PROBLEM: import and standard gates are not supported by Braket
4446
# NOTE: however custom OpenQASM 3 gates declaration is supported by Braket,
4547
# SOLUTION: the idea is then to hard import the standard lib and other files into the qasm string's header before
@@ -49,12 +51,30 @@ def qasm3_to_braket_Circuit(qasm3_str: str) -> Circuit:
4951
# SOLUTION: import a specific qasm file with U and gphase redefined with the supported Braket SDK gates, and by
5052
# removing from this import file the already handled gates
5153

52-
# we remove any include of stdgates.inc and replace it with custom include
5354
qasm3_str = qasm3_str.replace("stdgates.inc", "braket_custom_include.inc")
5455

5556
after_stdgates_included = open_qasm_hard_includes(qasm3_str, set())
56-
# NOTE : gphase is a already used in Braket and thus cannot be redefined as a native gate in OpenQASM.
57-
# We used ggphase instead
57+
58+
braket_warning_message = (
59+
"This program uses OpenQASM language features that may not be supported"
60+
" on QPUs or on-demand simulators."
61+
)
62+
63+
braket_logger = getLogger()
64+
logger_output_stream = io.StringIO()
65+
stream_handler = StreamHandler(logger_output_stream)
66+
braket_logger.addHandler(stream_handler)
5867

5968
circuit = Circuit.from_ir(after_stdgates_included)
69+
70+
braket_logger.removeHandler(stream_handler)
71+
log_lines = logger_output_stream.getvalue().split("\n")
72+
for message in log_lines:
73+
if message == braket_warning_message:
74+
warnings.warn(
75+
"\n" + braket_warning_message, UnsupportedBraketFeaturesWarning
76+
)
77+
else:
78+
braket_logger.warning(message)
79+
6080
return circuit

pyproject.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,11 @@ build-backend = "setuptools.build_meta"
55

66
[tool.setuptools_scm]
77
version_scheme = "post-release"
8-
local_scheme = "no-local-version"
8+
local_scheme = "no-local-version"
9+
10+
[tool.isort]
11+
multi_line_output = 3
12+
include_trailing_comma = true
13+
force_grid_wrap = 0
14+
line_length = 88
15+
profile = "black"

tests/example/test_demonstrations.py

+42-15
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
1+
from typing import Any, Callable
2+
3+
import numpy as np
14
import pytest
5+
from braket.devices import LocalSimulator
6+
27
from mpqp import QCircuit
3-
from mpqp.core.instruction.measurement import Observable, ExpectationMeasure
4-
from mpqp.execution.devices import AWSDevice, ATOSDevice, IBMDevice
8+
from mpqp.core.instruction.measurement import ExpectationMeasure, Observable
9+
from mpqp.execution import run
10+
from mpqp.execution.devices import ATOSDevice, AvailableDevice, AWSDevice, IBMDevice
511
from mpqp.gates import *
612
from mpqp.measures import BasisMeasure
7-
from mpqp.execution import run
8-
from braket.devices import LocalSimulator
9-
import numpy as np
1013
from mpqp.qasm.qasm_to_braket import qasm3_to_braket_Circuit
14+
from mpqp.tools.errors import UnsupportedBraketFeaturesWarning
15+
16+
17+
def warn_guard(device: AvailableDevice, run: Callable[[], Any]):
18+
if isinstance(device, AWSDevice):
19+
with pytest.warns(UnsupportedBraketFeaturesWarning):
20+
return run()
21+
else:
22+
return run()
1123

1224

1325
def test_sample_demo():
@@ -34,7 +46,7 @@ def test_sample_demo():
3446
circuit.add(BasisMeasure([0, 1, 2, 3], shots=2000))
3547

3648
# Run the circuit on a selected device
37-
run(
49+
runner = lambda: run(
3850
circuit,
3951
[
4052
IBMDevice.AER_SIMULATOR,
@@ -44,6 +56,8 @@ def test_sample_demo():
4456
],
4557
)
4658

59+
warn_guard(AWSDevice.BRAKET_LOCAL_SIMULATOR, runner)
60+
4761
assert True
4862

4963

@@ -68,7 +82,7 @@ def test_statevector_demo():
6882
circuit.add(Rz(3.14, 0))
6983

7084
# when no measure in the circuit, must run in statevector mode
71-
run(
85+
runner = lambda: run(
7286
circuit,
7387
[
7488
IBMDevice.AER_SIMULATOR_STATEVECTOR,
@@ -78,11 +92,13 @@ def test_statevector_demo():
7892
],
7993
)
8094

95+
warn_guard(AWSDevice.BRAKET_LOCAL_SIMULATOR, runner)
96+
8197
# same when we add a BasisMeasure with 0 shots
8298
circuit.add(BasisMeasure([0, 1, 2, 3], shots=0))
8399

84100
# Run the circuit on a selected device
85-
run(
101+
runner = lambda: run(
86102
circuit,
87103
[
88104
IBMDevice.AER_SIMULATOR_STATEVECTOR,
@@ -92,6 +108,8 @@ def test_statevector_demo():
92108
],
93109
)
94110

111+
warn_guard(AWSDevice.BRAKET_LOCAL_SIMULATOR, runner)
112+
95113
assert True
96114

97115

@@ -121,7 +139,7 @@ def test_observable_demo():
121139
assert True
122140

123141

124-
def test_aws_executions():
142+
def test_aws_qasm_executions():
125143
device = LocalSimulator()
126144

127145
qasm_str = """OPENQASM 3.0;
@@ -133,11 +151,12 @@ def test_aws_executions():
133151
c[0] = measure q[0];
134152
c[1] = measure q[1];"""
135153

136-
circuit = qasm3_to_braket_Circuit(qasm_str)
137-
154+
runner = lambda: qasm3_to_braket_Circuit(qasm_str)
155+
circuit = warn_guard(AWSDevice.BRAKET_LOCAL_SIMULATOR, runner)
138156
device.run(circuit, shots=100).result()
139157

140-
#####################################################
158+
159+
def test_aws_mpqp_executions():
141160

142161
# Declaration of the circuit with the right size
143162
circuit = QCircuit(4)
@@ -161,7 +180,9 @@ def test_aws_executions():
161180
# Add measurement
162181
circuit.add(BasisMeasure([0, 1, 2, 3], shots=2000))
163182

164-
run(circuit, AWSDevice.BRAKET_LOCAL_SIMULATOR)
183+
runner = lambda: run(circuit, AWSDevice.BRAKET_LOCAL_SIMULATOR)
184+
185+
warn_guard(AWSDevice.BRAKET_LOCAL_SIMULATOR, runner)
165186

166187
#####################################################
167188

@@ -185,7 +206,10 @@ def test_aws_executions():
185206
circuit.add(ExpectationMeasure([0, 1], observable=obs, shots=0))
186207

187208
# Running the computation on myQLM and on Braket simulator, then retrieving the results
188-
run(circuit, [AWSDevice.BRAKET_LOCAL_SIMULATOR, ATOSDevice.MYQLM_PYLINALG])
209+
runner = lambda: run(
210+
circuit, [AWSDevice.BRAKET_LOCAL_SIMULATOR, ATOSDevice.MYQLM_PYLINALG]
211+
)
212+
warn_guard(AWSDevice.BRAKET_LOCAL_SIMULATOR, runner)
189213

190214
#####################################################
191215

@@ -195,7 +219,10 @@ def test_aws_executions():
195219
)
196220

197221
# Running the computation on myQLM and on Aer simulator, then retrieving the results
198-
run(circuit, [AWSDevice.BRAKET_LOCAL_SIMULATOR, ATOSDevice.MYQLM_PYLINALG])
222+
runner = lambda: run(
223+
circuit, [AWSDevice.BRAKET_LOCAL_SIMULATOR, ATOSDevice.MYQLM_PYLINALG]
224+
)
225+
warn_guard(AWSDevice.BRAKET_LOCAL_SIMULATOR, runner)
199226

200227

201228
def test_all_native_gates():

tests/execution/test_vqa.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any
2+
23
import numpy as np
34
import pytest
45
from sympy import Expr
@@ -8,11 +9,12 @@
89
ExpectationMeasure,
910
Observable,
1011
)
11-
from mpqp.gates import *
12-
from mpqp.execution.devices import AWSDevice, AvailableDevice, IBMDevice, ATOSDevice
13-
from mpqp.execution.vqa import minimize, Optimizer
14-
from mpqp.execution.vqa.vqa import OptimizableFunc
12+
from mpqp.execution.devices import ATOSDevice, AvailableDevice, AWSDevice, IBMDevice
1513
from mpqp.execution.runner import _run_single # pyright: ignore[reportPrivateUsage]
14+
from mpqp.execution.vqa import Optimizer, minimize
15+
from mpqp.execution.vqa.vqa import OptimizableFunc
16+
from mpqp.gates import *
17+
from mpqp.tools.errors import UnsupportedBraketFeaturesWarning
1618

1719
# the symbols function is a bit wacky, so some manual type definition is needed here
1820
theta: Expr = symbols("θ") # type: ignore
@@ -41,8 +43,15 @@ def with_local_devices(args: tuple[Any, ...]):
4143
),
4244
)
4345
def test_optimizer_circuit(circ: QCircuit, minimum: float, device: AvailableDevice):
44-
try:
46+
def run():
4547
assert minimize(circ, Optimizer.BFGS, device)[0] - minimum < 0.05
48+
49+
try:
50+
if isinstance(device, AWSDevice):
51+
with pytest.warns(UnsupportedBraketFeaturesWarning):
52+
run()
53+
else:
54+
run()
4655
except (ValueError, NotImplementedError) as err:
4756
if "not handled" not in str(err):
4857
raise

tests/qasm/test_qasm_to_braket.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from braket.circuits import Operator, Circuit
2-
from braket.circuits.gates import H, CNot
31
import pytest
2+
from braket.circuits import Circuit, Operator
3+
from braket.circuits.gates import CNot, H
44

55
from mpqp.qasm.qasm_to_braket import qasm3_to_braket_Circuit
6+
from mpqp.tools.errors import UnsupportedBraketFeaturesWarning
67

78

89
@pytest.mark.parametrize(
@@ -12,6 +13,19 @@
1213
"""OPENQASM 3.0;""",
1314
[],
1415
),
16+
],
17+
)
18+
def test_qasm3_to_braket_Circuit(qasm_code: str, braket_operators: list[Operator]):
19+
circ = qasm3_to_braket_Circuit(qasm_code)
20+
21+
assert isinstance(circ, Circuit)
22+
for circ_instr, expected_operator in zip(circ.instructions, braket_operators):
23+
assert circ_instr.operator == expected_operator
24+
25+
26+
@pytest.mark.parametrize(
27+
"qasm_code, braket_operators",
28+
[
1529
(
1630
"""OPENQASM 3.0;
1731
include 'stdgates.inc';
@@ -27,8 +41,16 @@
2741
),
2842
],
2943
)
30-
def test_qasm3_to_braket_Circuit(qasm_code: str, braket_operators: list[Operator]):
31-
circ = qasm3_to_braket_Circuit(qasm_code)
44+
def test_qasm3_to_braket_Circuit_warning(
45+
qasm_code: str, braket_operators: list[Operator]
46+
):
47+
warning = (
48+
"This program uses OpenQASM language features that may not be supported"
49+
" on QPUs or on-demand simulators."
50+
)
51+
with pytest.warns(UnsupportedBraketFeaturesWarning, match=warning):
52+
circ = qasm3_to_braket_Circuit(qasm_code)
53+
3254
assert isinstance(circ, Circuit)
3355
for circ_instr, expected_operator in zip(circ.instructions, braket_operators):
3456
assert circ_instr.operator == expected_operator

0 commit comments

Comments
 (0)