Skip to content

Commit 329c20b

Browse files
feat: add callback support to VQA optimization functions
1 parent 8d17a29 commit 329c20b

File tree

1 file changed

+40
-5
lines changed

1 file changed

+40
-5
lines changed

mpqp/execution/vqa/vqa.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323
[OptimizableFunc, Optional[OptimizerInput], Optional[OptimizerOptions]],
2424
tuple[float, OptimizerInput],
2525
]
26+
OptimizerCallback = Optional[
27+
Union[
28+
Callable[[OptimizeResult], None],
29+
Callable[
30+
[Union[list[float], npt.NDArray[np.float32], tuple[float, ...]]], None
31+
],
32+
]
33+
]
34+
2635

2736
# TODO: all those functions with almost or exactly the same signature look like
2837
# a code smell to me.
@@ -47,6 +56,7 @@ def minimize(
4756
init_params: Optional[OptimizerInput] = None,
4857
nb_params: Optional[int] = None,
4958
optimizer_options: Optional[dict[str, Any]] = None,
59+
callback: Optional[OptimizerCallback] = None,
5060
) -> tuple[float, OptimizerInput]:
5161
"""This function runs an optimization on the parameters of the circuit, in order to
5262
minimize the measured expectation value of observables associated with the given circuit.
@@ -71,6 +81,7 @@ def minimize(
7181
optimizer_options: Options used to configure the VQA optimizer (maximum
7282
iterations, convergence threshold, etc...). These options are passed
7383
as is to the minimizer.
84+
callback: A callable called after each iteration.
7485
7586
Returns:
7687
The optimal value reached and the parameters corresponding to this value.
@@ -117,11 +128,25 @@ def minimize(
117128
if device is None:
118129
raise ValueError("A device is needed to optimize a circuit")
119130
optimizer = _minimize_remote if device.is_remote() else _minimize_local
120-
return optimizer(optimizable, method, device, init_params, nb_params)
131+
return optimizer(
132+
optimizable,
133+
method,
134+
device,
135+
init_params,
136+
nb_params,
137+
optimizer_options,
138+
callback,
139+
)
121140
else:
122141
# TODO: find a way to know if the job is remote or local from the function
123142
return _minimize_local(
124-
optimizable, method, device, init_params, nb_params, optimizer_options
143+
optimizable,
144+
method,
145+
device,
146+
init_params,
147+
nb_params,
148+
optimizer_options,
149+
callback,
125150
)
126151

127152

@@ -133,6 +158,7 @@ def _minimize_remote(
133158
init_params: Optional[OptimizerInput] = None,
134159
nb_params: Optional[int] = None,
135160
optimizer_options: Optional[dict[str, Any]] = None,
161+
callback: Optional[OptimizerCallback] = None,
136162
) -> tuple[float, OptimizerInput]:
137163
"""This function runs an optimization on the parameters of the circuit, to
138164
minimize the expectation value of the measure of the circuit by it's
@@ -158,6 +184,7 @@ def _minimize_remote(
158184
optimizer_options: Options used to configure the VQA optimizer (maximum
159185
iterations, convergence threshold, etc...). These options are passed
160186
as is to the minimizer.
187+
callback: A callable called after each iteration.
161188
162189
Returns:
163190
The optimal value reached and the parameters used to reach this value.
@@ -175,6 +202,7 @@ def _minimize_local(
175202
init_params: Optional[OptimizerInput] = None,
176203
nb_params: Optional[int] = None,
177204
optimizer_options: Optional[dict[str, Any]] = None,
205+
callback: Optional[OptimizerCallback] = None,
178206
) -> tuple[float, OptimizerInput]:
179207
"""This function runs an optimization on the parameters of the circuit, to
180208
minimize the expectation value of the measure of the circuit by it's
@@ -200,6 +228,7 @@ def _minimize_local(
200228
optimizer_options: Options used to configure the VQA optimizer (maximum
201229
iterations, convergence threshold, etc...). These options are passed
202230
as is to the minimizer.
231+
callback: A callable called after each iteration.
203232
204233
Returns:
205234
the optimal value reached and the parameters used to reach this value.
@@ -208,11 +237,11 @@ def _minimize_local(
208237
if device is None:
209238
raise ValueError("A device is needed to optimize a circuit")
210239
return _minimize_local_circ(
211-
optimizable, device, method, init_params, optimizer_options
240+
optimizable, device, method, init_params, optimizer_options, callback
212241
)
213242
else:
214243
return _minimize_local_func(
215-
optimizable, method, init_params, nb_params, optimizer_options
244+
optimizable, method, init_params, nb_params, optimizer_options, callback
216245
)
217246

218247

@@ -223,6 +252,7 @@ def _minimize_local_circ(
223252
method: Optimizer | OptimizerCallable,
224253
init_params: Optional[OptimizerInput] = None,
225254
optimizer_options: Optional[dict[str, Any]] = None,
255+
callback: Optional[OptimizerCallback] = None,
226256
) -> tuple[float, OptimizerInput]:
227257
"""This function runs an optimization on the parameters of the circuit, to
228258
minimize the expectation value of the measure of the circuit by it's
@@ -244,6 +274,7 @@ def _minimize_local_circ(
244274
optimizer_options: Options used to configure the VQA optimizer (maximum
245275
iterations, convergence threshold, etc...). These options are passed
246276
as is to the minimizer.
277+
callback: A callable called after each iteration.
247278
248279
Returns:
249280
The optimal value reached and the parameters used to reach this value.
@@ -264,7 +295,7 @@ def eval_circ(params: OptimizerInput):
264295
).expectation_value
265296

266297
return _minimize_local_func(
267-
eval_circ, method, init_params, len(variables), optimizer_options
298+
eval_circ, method, init_params, len(variables), optimizer_options, callback
268299
)
269300

270301

@@ -275,6 +306,7 @@ def _minimize_local_func(
275306
init_params: Optional[OptimizerInput] = None,
276307
nb_params: Optional[int] = None,
277308
optimizer_options: Optional[OptimizerOptions] = None,
309+
callback: Optional[OptimizerCallback] = None,
278310
) -> tuple[float, OptimizerInput]:
279311
"""This function runs an optimization on the parameters of the circuit, to
280312
minimize the expectation value of the measure of the circuit by it's
@@ -298,6 +330,8 @@ def _minimize_local_func(
298330
optimizer_options: Options used to configure the VQA optimizer (maximum
299331
iterations, convergence threshold, etc...). These options are passed
300332
as is to the minimizer.
333+
callback: A callable called after each iteration.
334+
301335
302336
Returns:
303337
The optimal value reached and the parameters used to reach this value.
@@ -317,6 +351,7 @@ def _minimize_local_func(
317351
x0=np.array(init_params),
318352
method=method.name.lower(),
319353
options=optimizer_options,
354+
callback=callback,
320355
)
321356
return res.fun, res.x
322357
else:

0 commit comments

Comments
 (0)