1
+ from typing import Any , Callable
2
+
3
+ import numpy as np
1
4
import pytest
5
+ from braket .devices import LocalSimulator
6
+
2
7
from mpqp import QCircuit
3
- from mpqp .core .instruction .measurement import Observable , ExpectationMeasure
4
- from mpqp .execution .devices import AWSDevice , ATOSDevice , IBMDevice
8
+ from mpqp .core .instruction .measurement import ExpectationMeasure , Observable
9
+ from mpqp .execution import run
10
+ from mpqp .execution .devices import ATOSDevice , AvailableDevice , AWSDevice , IBMDevice
5
11
from mpqp .gates import *
6
12
from mpqp .measures import BasisMeasure
7
- from mpqp .execution import run
8
- from braket .devices import LocalSimulator
9
- import numpy as np
10
13
from mpqp .qasm .qasm_to_braket import qasm3_to_braket_Circuit
14
+ from mpqp .tools .errors import UnsupportedBraketFeaturesWarning
15
+
16
+
17
+ def warn_guard (device : AvailableDevice , run : Callable [[], Any ]):
18
+ if isinstance (device , AWSDevice ):
19
+ with pytest .warns (UnsupportedBraketFeaturesWarning ):
20
+ return run ()
21
+ else :
22
+ return run ()
11
23
12
24
13
25
def test_sample_demo ():
@@ -34,7 +46,7 @@ def test_sample_demo():
34
46
circuit .add (BasisMeasure ([0 , 1 , 2 , 3 ], shots = 2000 ))
35
47
36
48
# Run the circuit on a selected device
37
- run (
49
+ runner = lambda : run (
38
50
circuit ,
39
51
[
40
52
IBMDevice .AER_SIMULATOR ,
@@ -44,6 +56,8 @@ def test_sample_demo():
44
56
],
45
57
)
46
58
59
+ warn_guard (AWSDevice .BRAKET_LOCAL_SIMULATOR , runner )
60
+
47
61
assert True
48
62
49
63
@@ -68,7 +82,7 @@ def test_statevector_demo():
68
82
circuit .add (Rz (3.14 , 0 ))
69
83
70
84
# when no measure in the circuit, must run in statevector mode
71
- run (
85
+ runner = lambda : run (
72
86
circuit ,
73
87
[
74
88
IBMDevice .AER_SIMULATOR_STATEVECTOR ,
@@ -78,11 +92,13 @@ def test_statevector_demo():
78
92
],
79
93
)
80
94
95
+ warn_guard (AWSDevice .BRAKET_LOCAL_SIMULATOR , runner )
96
+
81
97
# same when we add a BasisMeasure with 0 shots
82
98
circuit .add (BasisMeasure ([0 , 1 , 2 , 3 ], shots = 0 ))
83
99
84
100
# Run the circuit on a selected device
85
- run (
101
+ runner = lambda : run (
86
102
circuit ,
87
103
[
88
104
IBMDevice .AER_SIMULATOR_STATEVECTOR ,
@@ -92,6 +108,8 @@ def test_statevector_demo():
92
108
],
93
109
)
94
110
111
+ warn_guard (AWSDevice .BRAKET_LOCAL_SIMULATOR , runner )
112
+
95
113
assert True
96
114
97
115
@@ -121,7 +139,7 @@ def test_observable_demo():
121
139
assert True
122
140
123
141
124
- def test_aws_executions ():
142
+ def test_aws_qasm_executions ():
125
143
device = LocalSimulator ()
126
144
127
145
qasm_str = """OPENQASM 3.0;
@@ -133,11 +151,12 @@ def test_aws_executions():
133
151
c[0] = measure q[0];
134
152
c[1] = measure q[1];"""
135
153
136
- circuit = qasm3_to_braket_Circuit (qasm_str )
137
-
154
+ runner = lambda : qasm3_to_braket_Circuit (qasm_str )
155
+ circuit = warn_guard ( AWSDevice . BRAKET_LOCAL_SIMULATOR , runner )
138
156
device .run (circuit , shots = 100 ).result ()
139
157
140
- #####################################################
158
+
159
+ def test_aws_mpqp_executions ():
141
160
142
161
# Declaration of the circuit with the right size
143
162
circuit = QCircuit (4 )
@@ -161,7 +180,9 @@ def test_aws_executions():
161
180
# Add measurement
162
181
circuit .add (BasisMeasure ([0 , 1 , 2 , 3 ], shots = 2000 ))
163
182
164
- run (circuit , AWSDevice .BRAKET_LOCAL_SIMULATOR )
183
+ runner = lambda : run (circuit , AWSDevice .BRAKET_LOCAL_SIMULATOR )
184
+
185
+ warn_guard (AWSDevice .BRAKET_LOCAL_SIMULATOR , runner )
165
186
166
187
#####################################################
167
188
@@ -185,7 +206,10 @@ def test_aws_executions():
185
206
circuit .add (ExpectationMeasure ([0 , 1 ], observable = obs , shots = 0 ))
186
207
187
208
# Running the computation on myQLM and on Braket simulator, then retrieving the results
188
- run (circuit , [AWSDevice .BRAKET_LOCAL_SIMULATOR , ATOSDevice .MYQLM_PYLINALG ])
209
+ runner = lambda : run (
210
+ circuit , [AWSDevice .BRAKET_LOCAL_SIMULATOR , ATOSDevice .MYQLM_PYLINALG ]
211
+ )
212
+ warn_guard (AWSDevice .BRAKET_LOCAL_SIMULATOR , runner )
189
213
190
214
#####################################################
191
215
@@ -195,7 +219,10 @@ def test_aws_executions():
195
219
)
196
220
197
221
# Running the computation on myQLM and on Aer simulator, then retrieving the results
198
- run (circuit , [AWSDevice .BRAKET_LOCAL_SIMULATOR , ATOSDevice .MYQLM_PYLINALG ])
222
+ runner = lambda : run (
223
+ circuit , [AWSDevice .BRAKET_LOCAL_SIMULATOR , ATOSDevice .MYQLM_PYLINALG ]
224
+ )
225
+ warn_guard (AWSDevice .BRAKET_LOCAL_SIMULATOR , runner )
199
226
200
227
201
228
def test_all_native_gates ():
0 commit comments