From 59f71d26bec64d4f89d0f70a757b0c272f97a809 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Fri, 8 Sep 2023 11:29:18 -0400 Subject: [PATCH 01/43] add duration type to FreeParameterExpression --- setup.py | 2 +- src/braket/pulse/ast/approximation_parser.py | 2 +- src/braket/pulse/ast/free_parameters.py | 41 +++++++------- src/braket/pulse/pulse_sequence.py | 23 ++++---- src/braket/pulse/waveforms.py | 3 +- .../braket/circuits/test_circuit.py | 34 ++++++------ .../braket/circuits/test_gate_calibration.py | 6 +-- test/unit_tests/braket/circuits/test_gates.py | 2 +- .../braket/pulse/test_pulse_sequence.py | 54 ++++++++----------- .../unit_tests/braket/pulse/test_waveforms.py | 33 +++++------- 10 files changed, 90 insertions(+), 110 deletions(-) diff --git a/setup.py b/setup.py index 25fca8252..cb4bb778e 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ install_requires=[ "amazon-braket-schemas>=1.19.1", "amazon-braket-default-simulator>=1.19.1", - "oqpy~=0.2.1", + "oqpy~=0.3.1", "setuptools", "backoff", "boltons", diff --git a/src/braket/pulse/ast/approximation_parser.py b/src/braket/pulse/ast/approximation_parser.py index 69e9fafa5..d8096bf83 100644 --- a/src/braket/pulse/ast/approximation_parser.py +++ b/src/braket/pulse/ast/approximation_parser.py @@ -56,7 +56,7 @@ def __init__(self, program: Program, frames: Dict[str, Frame]): self.amplitudes = defaultdict(TimeSeries) self.frequencies = defaultdict(TimeSeries) self.phases = defaultdict(TimeSeries) - context = _ParseState(variables=dict(), frame_data=_init_frame_data(frames)) + context = _ParseState(variables={"pi": np.pi}, frame_data=_init_frame_data(frames)) self._qubit_frames_mapping: Dict[str, List[str]] = _init_qubit_frame_mapping(frames) self.visit(program.to_ast(include_externs=False), context) diff --git a/src/braket/pulse/ast/free_parameters.py b/src/braket/pulse/ast/free_parameters.py index 6cfc36d03..6bcd719f6 100644 --- a/src/braket/pulse/ast/free_parameters.py +++ b/src/braket/pulse/ast/free_parameters.py @@ -13,8 +13,9 @@ from typing import Dict, Union from openpulse import ast -from openqasm3.ast import DurationLiteral from openqasm3.visitor import QASMTransformer +from oqpy.program import Program +from oqpy.timing import OQDurationLiteral from braket.parametric.free_parameter_expression import FreeParameterExpression @@ -22,48 +23,46 @@ class _FreeParameterExpressionIdentifier(ast.Identifier): """Dummy AST node with FreeParameterExpression instance attached""" - def __init__(self, expression: FreeParameterExpression): + def __init__( + self, expression: FreeParameterExpression, type: ast.ClassicalType = ast.FloatType() + ): super().__init__(name=f"FreeParameterExpression({expression})") self._expression = expression + self.type = type @property def expression(self) -> FreeParameterExpression: return self._expression + def to_ast(self) -> ast.Identifier: + return self + class _FreeParameterTransformer(QASMTransformer): """Walk the AST and evaluate FreeParameterExpressions.""" - def __init__(self, param_values: Dict[str, float]): + def __init__(self, param_values: Dict[str, float], program: Program): self.param_values = param_values + self.program = program super().__init__() def visit__FreeParameterExpressionIdentifier( - self, identifier: ast.Identifier + self, identifier: _FreeParameterExpressionIdentifier ) -> Union[_FreeParameterExpressionIdentifier, ast.FloatLiteral]: """Visit a FreeParameterExpressionIdentifier. Args: - identifier (Identifier): The identifier. + identifier (_FreeParameterExpressionIdentifier): The identifier. Returns: Union[_FreeParameterExpressionIdentifier, FloatLiteral]: The transformed expression. """ new_value = identifier.expression.subs(self.param_values) if isinstance(new_value, FreeParameterExpression): - return _FreeParameterExpressionIdentifier(new_value) + return _FreeParameterExpressionIdentifier(new_value, identifier.type) else: - return ast.FloatLiteral(new_value) - - def visit_DurationLiteral(self, duration_literal: DurationLiteral) -> DurationLiteral: - """Visit Duration Literal. - node.value, node.unit (node.unit.name, node.unit.value) - 1 - Args: - duration_literal (DurationLiteral): The duration literal. - Returns: - DurationLiteral: The transformed duration literal. - """ - duration = duration_literal.value - if not isinstance(duration, FreeParameterExpression): - return duration_literal - return DurationLiteral(duration.subs(self.param_values), duration_literal.unit) + if isinstance(identifier.type, ast.FloatType): + return ast.FloatLiteral(new_value) + elif isinstance(identifier.type, ast.DurationType): + return OQDurationLiteral(new_value).to_ast(self.program) + else: + raise NotImplementedError(f"{identifier.type} is not a supported type.") diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 3606bbd11..08d792681 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -46,7 +46,7 @@ class PulseSequence: def __init__(self): self._capture_v0_count = 0 - self._program = Program() + self._program = Program(simplify_constants=False) self._frames = {} self._waveforms = {} self._free_parameters = set() @@ -183,10 +183,7 @@ def delay( Returns: PulseSequence: self, with the instruction added. """ - if isinstance(duration, FreeParameterExpression): - for p in duration.expression.free_symbols: - self._free_parameters.add(FreeParameter(p.name)) - duration = OQDurationLiteral(duration) + duration = self._format_parameter_ast(duration, type_=ast.DurationType()) if not isinstance(qubits_or_frames, QubitSet): if not isinstance(qubits_or_frames, list): qubits_or_frames = [qubits_or_frames] @@ -276,9 +273,9 @@ def make_bound_pulse_sequence(self, param_values: Dict[str, float]) -> PulseSequ """ program = deepcopy(self._program) tree: ast.Program = program.to_ast(include_externs=False, ignore_needs_declaration=True) - new_tree: ast.Program = _FreeParameterTransformer(param_values).visit(tree) + new_tree: ast.Program = _FreeParameterTransformer(param_values, program).visit(tree) - new_program = Program() + new_program = Program(simplify_constants=False) new_program.declared_vars = program.declared_vars new_program.undeclared_vars = program.undeclared_vars for x in new_tree.statements: @@ -325,13 +322,19 @@ def to_ir(self) -> str: return ast_to_qasm(tree) def _format_parameter_ast( - self, parameter: Union[float, FreeParameterExpression] + self, + parameter: Union[float, FreeParameterExpression], + type_: ast.ClassicalType = ast.FloatType(), ) -> Union[float, _FreeParameterExpressionIdentifier]: if isinstance(parameter, FreeParameterExpression): for p in parameter.expression.free_symbols: self._free_parameters.add(FreeParameter(p.name)) - return _FreeParameterExpressionIdentifier(parameter) - return parameter + return _FreeParameterExpressionIdentifier(parameter, type_) + else: + if isinstance(type_, ast.FloatType): + return parameter + elif isinstance(type_, ast.DurationType): + return OQDurationLiteral(parameter) def _parse_arg_from_calibration_schema( self, argument: Dict, waveforms: Dict[Waveform], frames: Dict[Frame] diff --git a/src/braket/pulse/waveforms.py b/src/braket/pulse/waveforms.py index f298dd3e5..48da8904b 100644 --- a/src/braket/pulse/waveforms.py +++ b/src/braket/pulse/waveforms.py @@ -21,7 +21,6 @@ import numpy as np from oqpy import WaveformVar, bool_, complex128, declare_waveform_generator, duration, float64 from oqpy.base import OQPyExpression -from oqpy.timing import OQDurationLiteral from braket.parametric.free_parameter import FreeParameter from braket.parametric.free_parameter_expression import ( @@ -457,7 +456,7 @@ def _map_to_oqpy_type( ) -> Union[_FreeParameterExpressionIdentifier, OQPyExpression]: if isinstance(parameter, FreeParameterExpression): return ( - OQDurationLiteral(parameter) + _FreeParameterExpressionIdentifier(parameter, duration) if is_duration_type else _FreeParameterExpressionIdentifier(parameter) ) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 21f972b89..c49b83bfc 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -740,7 +740,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[2] __qubits__;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -769,7 +769,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "bit[2] __bits__;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -800,7 +800,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "OPENQASM 3.0;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -835,7 +835,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[5] __qubits__;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -866,7 +866,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[2] __qubits__;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -899,7 +899,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[5] __qubits__;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -933,7 +933,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[7] __qubits__;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -965,7 +965,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[2] __qubits__;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -1033,8 +1033,7 @@ def test_parametric_circuit_with_fixed_argument_defcal(pulse_sequence): "bit[1] __bits__;", "qubit[1] __qubits__;", "cal {", - " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + " waveform drag_gauss_wf = drag_gaussian" + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -1131,8 +1130,7 @@ def foo( "bit[1] __bits__;", "qubit[1] __qubits__;", "cal {", - " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + " waveform drag_gauss_wf = drag_gaussian" + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal foo(-0.2) $0 {", " shift_phase(predefined_frame_1, -0.1);", @@ -3360,11 +3358,9 @@ def test_pulse_circuit_to_openqasm(predefined_frame_1, user_defined_frame): "bit[2] __bits__;", "cal {", " frame user_defined_frame_0 = newframe(device_port_x0, 10000000.0, 3.14);", - " waveform gauss_wf = gaussian(1000000.0ns, 700000000.0ns, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(3000000.0ns, 400000000.0ns, 0.2, 1," - " false);", - " waveform drag_gauss_wf_2 = drag_gaussian(3000000.0ns, 400000000.0ns, " - "0.2, 1, false);", + " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1," " false);", + " waveform drag_gauss_wf_2 = drag_gaussian(3.0ms, 400.0ms, " "0.2, 1, false);", "}", "h $0;", "cal {", @@ -3477,7 +3473,7 @@ def test_parametrized_pulse_circuit(user_defined_frame): "bit[2] __bits__;", "cal {", " frame user_defined_frame_0 = newframe(device_port_x0, 10000000.0, 3.14);", - " waveform gauss_wf = gaussian(10000.0ns, 700000000.0ns, 1, false);", + " waveform gauss_wf = gaussian(10.0us, 700.0ms, 1, false);", "}", "rx(0.5) $0;", "cal {", @@ -3502,7 +3498,7 @@ def test_parametrized_pulse_circuit(user_defined_frame): "bit[2] __bits__;", "cal {", " frame user_defined_frame_0 = newframe(device_port_x0, 10000000.0, 3.14);", - " waveform gauss_wf = gaussian(10000.0ns, 700000000.0ns, 1, false);", + " waveform gauss_wf = gaussian(10.0us, 700.0ms, 1, false);", "}", "rx(0.5) $0;", "cal {", diff --git a/test/unit_tests/braket/circuits/test_gate_calibration.py b/test/unit_tests/braket/circuits/test_gate_calibration.py index 31c2384db..4037f3faf 100644 --- a/test/unit_tests/braket/circuits/test_gate_calibration.py +++ b/test/unit_tests/braket/circuits/test_gate_calibration.py @@ -80,7 +80,7 @@ def test_to_ir(pulse_sequence): "OPENQASM 3.0;", "defcal rx(1.0) $0, $1 {", " barrier test_frame_rf;", - " delay[1000000000000.0ns] test_frame_rf;", + " delay[1000s] test_frame_rf;", "}", ] ) @@ -100,7 +100,7 @@ def test_to_ir_with_bad_key(pulse_sequence): "OPENQASM 3.0;", "defcal z $0, $1 {", " barrier test_frame_rf;", - " delay[1000000000000.0ns] test_frame_rf;", + " delay[1000s] test_frame_rf;", "}", ] ) @@ -118,7 +118,7 @@ def test_to_ir_with_key(pulse_sequence): "OPENQASM 3.0;", "defcal z $0, $1 {", " barrier test_frame_rf;", - " delay[1000000000000.0ns] test_frame_rf;", + " delay[1000s] test_frame_rf;", "}", ] ) diff --git a/test/unit_tests/braket/circuits/test_gates.py b/test/unit_tests/braket/circuits/test_gates.py index 9d5dd5580..51e73b83e 100644 --- a/test/unit_tests/braket/circuits/test_gates.py +++ b/test/unit_tests/braket/circuits/test_gates.py @@ -962,7 +962,7 @@ def to_ir(pulse_gate): [ "cal {", " set_frequency(user_frame, b + 3);", - " delay[(1000000000.0*c)ns] user_frame;", + " delay[c] user_frame;", "}", ] ) diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 57ba20fbd..411074727 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -125,11 +125,9 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian((1000000000.0*length_g)ns, (1000000000.0*sigma_g)ns, " - "1, false);", - " waveform drag_gauss_wf = " - "drag_gaussian((1000000000.0*length_dg)ns, (1000000000.0*sigma_dg)ns, 0.2, 1, false);", - " waveform constant_wf = constant((1000000000.0*length_c)ns, 2.0 + 0.3im);", + " waveform gauss_wf = gaussian(length_g, sigma_g, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(length_dg, sigma_dg, 0.2, 1, false);", + " waveform constant_wf = constant(length_c, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", " set_frequency(predefined_frame_1, a + 2*b);", @@ -138,12 +136,9 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined " shift_phase(predefined_frame_1, a + 2*b);", " set_scale(predefined_frame_1, a + 2*b);", " psb[0] = capture_v0(predefined_frame_1);", - ( - " delay[(1000000000.0*a + 2000000000.0*b)ns]" - " predefined_frame_1, predefined_frame_2;" - ), - " delay[(1000000000.0*a + 2000000000.0*b)ns] predefined_frame_1;", - " delay[1000000.0ns] predefined_frame_1;", + " delay[a + 2*b] predefined_frame_1, predefined_frame_2;", + " delay[a + 2*b] predefined_frame_1;", + " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", " play(predefined_frame_2, drag_gauss_wf);", @@ -173,10 +168,9 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian(1000000.0ns, (1000000000.0*sigma_g)ns, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(3000000.0ns, 400000000.0ns, 0.2, 1," - " false);", - " waveform constant_wf = constant(4000000.0ns, 2.0 + 0.3im);", + " waveform gauss_wf = gaussian(1.0ms, sigma_g, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", " set_frequency(predefined_frame_1, a + 4);", @@ -185,9 +179,9 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined " shift_phase(predefined_frame_1, a + 4);", " set_scale(predefined_frame_1, a + 4);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[(1000000000.0*a + 4000000000.0)ns] predefined_frame_1, predefined_frame_2;", - " delay[(1000000000.0*a + 4000000000.0)ns] predefined_frame_1;", - " delay[1000000.0ns] predefined_frame_1;", + " delay[a + 4] predefined_frame_1, predefined_frame_2;", + " delay[a + 4] predefined_frame_1;", + " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", " play(predefined_frame_2, drag_gauss_wf);", @@ -206,10 +200,9 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian(1000000.0ns, 700000000.0ns, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(3000000.0ns, 400000000.0ns, 0.2, 1," - " false);", - " waveform constant_wf = constant(4000000.0ns, 2.0 + 0.3im);", + " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", " set_frequency(predefined_frame_1, 5);", @@ -218,9 +211,9 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined " shift_phase(predefined_frame_1, 5);", " set_scale(predefined_frame_1, 5);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[5000000000.00000ns] predefined_frame_1, predefined_frame_2;", - " delay[5000000000.00000ns] predefined_frame_1;", - " delay[1000000.0ns] predefined_frame_1;", + " delay[5s] predefined_frame_1, predefined_frame_2;", + " delay[5s] predefined_frame_1;", + " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", " play(predefined_frame_2, drag_gauss_wf);", @@ -311,10 +304,9 @@ def test_pulse_sequence_to_ir(predefined_frame_1, predefined_frame_2): [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian(1000000.0ns, 700000000.0ns, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(3000000.0ns, 400000000.0ns, 0.2, 1," - " false);", - " waveform constant_wf = constant(4000000.0ns, 2.0 + 0.3im);", + " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", " set_frequency(predefined_frame_1, 3000000000.0);", @@ -324,8 +316,8 @@ def test_pulse_sequence_to_ir(predefined_frame_1, predefined_frame_2): " set_scale(predefined_frame_1, 0.25);", " psb[0] = capture_v0(predefined_frame_1);", " delay[2.0ns] predefined_frame_1, predefined_frame_2;", - " delay[1000.0ns] predefined_frame_1;", - " delay[1000000.0ns] $0;", + " delay[1.0us] predefined_frame_1;", + " delay[1.0ms] $0;", " barrier $0, $1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", diff --git a/test/unit_tests/braket/pulse/test_waveforms.py b/test/unit_tests/braket/pulse/test_waveforms.py index b42eacc0b..86f8253c1 100644 --- a/test/unit_tests/braket/pulse/test_waveforms.py +++ b/test/unit_tests/braket/pulse/test_waveforms.py @@ -74,7 +74,7 @@ def test_constant_waveform(): assert wf.iq == iq assert wf.id == id - _assert_wf_qasm(wf, "waveform const_wf_x = constant(4000000.0ns, 4);") + _assert_wf_qasm(wf, "waveform const_wf_x = constant(4.0ms, 4);") def test_constant_waveform_default_params(): @@ -101,14 +101,13 @@ def test_constant_wf_free_params(): assert wf.parameters == [FreeParameter("length_v") + FreeParameter("length_w")] _assert_wf_qasm( wf, - "waveform const_wf = " - "constant((1000000000.0*length_v + 1000000000.0*length_w)ns, 2.0 - 3.0im);", + "waveform const_wf = " "constant(length_v + length_w, 2.0 - 3.0im);", ) wf_2 = wf.bind_values(length_v=2e-6, length_w=4e-6) assert len(wf_2.parameters) == 1 assert math.isclose(wf_2.parameters[0], 6e-6) - _assert_wf_qasm(wf_2, "waveform const_wf = constant(6000.0ns, 2.0 - 3.0im);") + _assert_wf_qasm(wf_2, "waveform const_wf = constant(6.0us, 2.0 - 3.0im);") def test_drag_gaussian_waveform(): @@ -126,9 +125,7 @@ def test_drag_gaussian_waveform(): assert wf.sigma == sigma assert wf.length == length - _assert_wf_qasm( - wf, "waveform drag_gauss_wf = drag_gaussian(4.0ns, 300000000.0ns, 0.6, 0.4, false);" - ) + _assert_wf_qasm(wf, "waveform drag_gauss_wf = drag_gaussian(4.0ns, 300.0ms, 0.6, 0.4, false);") def test_drag_gaussian_waveform_default_params(): @@ -167,7 +164,7 @@ def test_gaussian_waveform(): assert wf.sigma == sigma assert wf.length == length - _assert_wf_qasm(wf, "waveform gauss_wf = gaussian(4.0ns, 300000000.0ns, 0.4, false);") + _assert_wf_qasm(wf, "waveform gauss_wf = gaussian(4.0ns, 300.0ms, 0.4, false);") def test_drag_gaussian_wf_free_params(): @@ -187,8 +184,8 @@ def test_drag_gaussian_wf_free_params(): _assert_wf_qasm( wf, "waveform d_gauss_wf = " - "drag_gaussian((1000000000.0*length_v)ns, (1000000000.0*sigma_a + " - "1000000000.0*sigma_b)ns, beta_y, amp_z, false);", + "drag_gaussian(length_v, sigma_a + " + "sigma_b, beta_y, amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_a=0.4) @@ -200,15 +197,12 @@ def test_drag_gaussian_wf_free_params(): ] _assert_wf_qasm( wf_2, - "waveform d_gauss_wf = drag_gaussian(600000000.0ns, (1000000000.0*sigma_b " - "+ 400000000.0)ns, beta_y, amp_z, false);", + "waveform d_gauss_wf = drag_gaussian(600.0ms, sigma_b + 0.4, beta_y, amp_z, false);", ) wf_3 = wf.bind_values(length_v=0.6, sigma_a=0.3, sigma_b=0.1, beta_y=0.2, amp_z=0.1) assert wf_3.parameters == [0.6, 0.4, 0.2, 0.1] - _assert_wf_qasm( - wf_3, "waveform d_gauss_wf = drag_gaussian(600000000.0ns, 400000000.0ns, 0.2, 0.1, false);" - ) + _assert_wf_qasm(wf_3, "waveform d_gauss_wf = drag_gaussian(600.0ms, 400.0ms, 0.2, 0.1, false);") def test_gaussian_waveform_default_params(): @@ -243,19 +237,16 @@ def test_gaussian_wf_free_params(): ] _assert_wf_qasm( wf, - "waveform gauss_wf = gaussian((1000000000.0*length_v)ns, (1000000000.0*sigma_x)ns, " - "amp_z, false);", + "waveform gauss_wf = gaussian(length_v, sigma_x, " "amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_x=0.4) assert wf_2.parameters == [0.6, 0.4, FreeParameter("amp_z")] - _assert_wf_qasm( - wf_2, "waveform gauss_wf = gaussian(600000000.0ns, 400000000.0ns, amp_z, false);" - ) + _assert_wf_qasm(wf_2, "waveform gauss_wf = gaussian(600.0ms, 400.0ms, amp_z, false);") wf_3 = wf.bind_values(length_v=0.6, sigma_x=0.3, amp_z=0.1) assert wf_3.parameters == [0.6, 0.3, 0.1] - _assert_wf_qasm(wf_3, "waveform gauss_wf = gaussian(600000000.0ns, 300000000.0ns, 0.1, false);") + _assert_wf_qasm(wf_3, "waveform gauss_wf = gaussian(600.0ms, 300.0ms, 0.1, false);") def _assert_wf_qasm(waveform, expected_qasm): From ed3adcec2991eae8a49a9d1a633cbc0ece91180b Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Fri, 8 Sep 2023 18:45:17 -0400 Subject: [PATCH 02/43] consider FreeParameter as float --- src/braket/pulse/ast/free_parameters.py | 44 +++++++++++++------ src/braket/pulse/ast/qasm_parser.py | 5 ++- src/braket/pulse/pulse_sequence.py | 5 +-- test/unit_tests/braket/circuits/test_gates.py | 2 +- .../braket/pulse/test_pulse_sequence.py | 16 +++---- .../unit_tests/braket/pulse/test_waveforms.py | 10 ++--- 6 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/braket/pulse/ast/free_parameters.py b/src/braket/pulse/ast/free_parameters.py index 6bcd719f6..16acb5134 100644 --- a/src/braket/pulse/ast/free_parameters.py +++ b/src/braket/pulse/ast/free_parameters.py @@ -20,21 +20,26 @@ from braket.parametric.free_parameter_expression import FreeParameterExpression -class _FreeParameterExpressionIdentifier(ast.Identifier): +class _FreeParameterExpressionIdentifier(ast.QASMNode): """Dummy AST node with FreeParameterExpression instance attached""" def __init__( - self, expression: FreeParameterExpression, type: ast.ClassicalType = ast.FloatType() + self, expression: FreeParameterExpression, type_: ast.ClassicalType = ast.FloatType() ): - super().__init__(name=f"FreeParameterExpression({expression})") + self.name = f"FreeParameterExpression({expression})" self._expression = expression - self.type = type + self.type_ = type_ @property def expression(self) -> FreeParameterExpression: return self._expression - def to_ast(self) -> ast.Identifier: + def __repr__(self) -> str: + return f"_FreeParameterExpressionIdentifier(name={self.name})" + + def to_ast(self, program: Program) -> ast.Expression: + if isinstance(self.type_, ast.DurationType): + return ast.DurationLiteral(self, ast.TimeUnit.s) return self @@ -58,11 +63,24 @@ def visit__FreeParameterExpressionIdentifier( """ new_value = identifier.expression.subs(self.param_values) if isinstance(new_value, FreeParameterExpression): - return _FreeParameterExpressionIdentifier(new_value, identifier.type) - else: - if isinstance(identifier.type, ast.FloatType): - return ast.FloatLiteral(new_value) - elif isinstance(identifier.type, ast.DurationType): - return OQDurationLiteral(new_value).to_ast(self.program) - else: - raise NotImplementedError(f"{identifier.type} is not a supported type.") + return _FreeParameterExpressionIdentifier(new_value, identifier.type_) + return ast.FloatLiteral(new_value) + + def visit_DurationLiteral(self, duration_literal: ast.DurationLiteral) -> ast.DurationLiteral: + """Visit Duration Literal. + node.value, node.unit (node.unit.name, node.unit.value) + 1 + Args: + duration_literal (DurationLiteral): The duration literal. + Returns: + DurationLiteral: The transformed duration literal. + """ + duration = duration_literal.value + if not isinstance(duration, _FreeParameterExpressionIdentifier): + return duration_literal + new_duration = duration.expression.subs(self.param_values) + if isinstance(new_duration, FreeParameterExpression): + return _FreeParameterExpressionIdentifier(new_duration, duration.type_).to_ast( + self.program + ) + return OQDurationLiteral(new_duration).to_ast(self.program) diff --git a/src/braket/pulse/ast/qasm_parser.py b/src/braket/pulse/ast/qasm_parser.py index 8e6f94ded..9f3d82f36 100644 --- a/src/braket/pulse/ast/qasm_parser.py +++ b/src/braket/pulse/ast/qasm_parser.py @@ -18,6 +18,7 @@ from openqasm3.printer import PrinterState from braket.parametric.free_parameter_expression import FreeParameterExpression +from braket.pulse.ast.free_parameters import _FreeParameterExpressionIdentifier class _PulsePrinter(Printer): @@ -45,8 +46,8 @@ def visit_DurationLiteral(self, node: DurationLiteral, context: PrinterState) -> context (PrinterState): The printer state context. """ duration = node.value - if isinstance(duration, FreeParameterExpression): - self.stream.write(f"({duration.expression}){node.unit.name}") + if isinstance(duration, _FreeParameterExpressionIdentifier): + self.stream.write(f"({duration.expression}) * 1{node.unit.name}") else: super().visit_DurationLiteral(node, context) diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 08d792681..b797756d4 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -331,10 +331,9 @@ def _format_parameter_ast( self._free_parameters.add(FreeParameter(p.name)) return _FreeParameterExpressionIdentifier(parameter, type_) else: - if isinstance(type_, ast.FloatType): - return parameter - elif isinstance(type_, ast.DurationType): + if isinstance(type_, ast.DurationType): return OQDurationLiteral(parameter) + return parameter def _parse_arg_from_calibration_schema( self, argument: Dict, waveforms: Dict[Waveform], frames: Dict[Frame] diff --git a/test/unit_tests/braket/circuits/test_gates.py b/test/unit_tests/braket/circuits/test_gates.py index 51e73b83e..61f890f84 100644 --- a/test/unit_tests/braket/circuits/test_gates.py +++ b/test/unit_tests/braket/circuits/test_gates.py @@ -962,7 +962,7 @@ def to_ir(pulse_gate): [ "cal {", " set_frequency(user_frame, b + 3);", - " delay[c] user_frame;", + " delay[(c) * 1s] user_frame;", "}", ] ) diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 411074727..0f6128174 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -125,9 +125,9 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian(length_g, sigma_g, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(length_dg, sigma_dg, 0.2, 1, false);", - " waveform constant_wf = constant(length_c, 2.0 + 0.3im);", + " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", + " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s, (sigma_dg) * 1s, 0.2, 1, false);", + " waveform constant_wf = constant((length_c) * 1s, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", " set_frequency(predefined_frame_1, a + 2*b);", @@ -136,8 +136,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined " shift_phase(predefined_frame_1, a + 2*b);", " set_scale(predefined_frame_1, a + 2*b);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[a + 2*b] predefined_frame_1, predefined_frame_2;", - " delay[a + 2*b] predefined_frame_1;", + " delay[(a + 2*b) * 1s] predefined_frame_1, predefined_frame_2;", + " delay[(a + 2*b) * 1s] predefined_frame_1;", " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", @@ -168,7 +168,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian(1.0ms, sigma_g, 1, false);", + " waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", @@ -179,8 +179,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined " shift_phase(predefined_frame_1, a + 4);", " set_scale(predefined_frame_1, a + 4);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[a + 4] predefined_frame_1, predefined_frame_2;", - " delay[a + 4] predefined_frame_1;", + " delay[(a + 4) * 1s] predefined_frame_1, predefined_frame_2;", + " delay[(a + 4) * 1s] predefined_frame_1;", " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", diff --git a/test/unit_tests/braket/pulse/test_waveforms.py b/test/unit_tests/braket/pulse/test_waveforms.py index 86f8253c1..5e8e78111 100644 --- a/test/unit_tests/braket/pulse/test_waveforms.py +++ b/test/unit_tests/braket/pulse/test_waveforms.py @@ -101,7 +101,7 @@ def test_constant_wf_free_params(): assert wf.parameters == [FreeParameter("length_v") + FreeParameter("length_w")] _assert_wf_qasm( wf, - "waveform const_wf = " "constant(length_v + length_w, 2.0 - 3.0im);", + "waveform const_wf = " "constant((length_v + length_w) * 1s, 2.0 - 3.0im);", ) wf_2 = wf.bind_values(length_v=2e-6, length_w=4e-6) @@ -184,8 +184,8 @@ def test_drag_gaussian_wf_free_params(): _assert_wf_qasm( wf, "waveform d_gauss_wf = " - "drag_gaussian(length_v, sigma_a + " - "sigma_b, beta_y, amp_z, false);", + "drag_gaussian((length_v) * 1s, (sigma_a + " + "sigma_b) * 1s, beta_y, amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_a=0.4) @@ -197,7 +197,7 @@ def test_drag_gaussian_wf_free_params(): ] _assert_wf_qasm( wf_2, - "waveform d_gauss_wf = drag_gaussian(600.0ms, sigma_b + 0.4, beta_y, amp_z, false);", + "waveform d_gauss_wf = drag_gaussian(600.0ms, (sigma_b + 0.4) * 1s, beta_y, amp_z, false);", ) wf_3 = wf.bind_values(length_v=0.6, sigma_a=0.3, sigma_b=0.1, beta_y=0.2, amp_z=0.1) @@ -237,7 +237,7 @@ def test_gaussian_wf_free_params(): ] _assert_wf_qasm( wf, - "waveform gauss_wf = gaussian(length_v, sigma_x, " "amp_z, false);", + "waveform gauss_wf = gaussian((length_v) * 1s, (sigma_x) * 1s, " "amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_x=0.4) From 9c4293cdb3e779414c18ea67ae89648faab9a7f6 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Fri, 8 Sep 2023 20:15:06 -0400 Subject: [PATCH 03/43] move to_ast to FreeParameterExprsesion --- .../parametric/free_parameter_expression.py | 51 +++++++++++++++++-- src/braket/pulse/ast/free_parameters.py | 38 ++++---------- src/braket/pulse/ast/qasm_parser.py | 1 - src/braket/pulse/pulse_sequence.py | 4 +- src/braket/pulse/waveforms.py | 11 ++-- .../braket/pulse/test_pulse_sequence.py | 3 +- 6 files changed, 66 insertions(+), 42 deletions(-) diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index 1b64e2ddb..1da2469e4 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -15,8 +15,17 @@ import ast from numbers import Number -from typing import Any, Dict, Union - +from typing import Any, Dict, Optional, Union + +from openpulse.ast import ( + ClassicalType, + DurationLiteral, + DurationType, + FloatType, + QASMNode, + TimeUnit, +) +from oqpy import Program from sympy import Expr, Float, Symbol, sympify @@ -30,7 +39,11 @@ class FreeParameterExpression: present will NOT run. Values must be substituted prior to execution. """ - def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str]): + def __init__( + self, + expression: Union[FreeParameterExpression, Number, Expr, str], + _type: Optional[ClassicalType] = None, + ): """ Initializes a FreeParameterExpression. Best practice is to initialize using FreeParameters and Numbers. Not meant to be initialized directly. @@ -39,6 +52,7 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str] Args: expression (Union[FreeParameterExpression, Number, Expr, str]): The expression to use. + _type (Optional[ClassicalType]): type of the expression Examples: >>> expression_1 = FreeParameter("theta") * FreeParameter("alpha") @@ -51,8 +65,11 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str] ast.Pow: self.__pow__, ast.USub: self.__neg__, } + self._type = _type if _type is not None else FloatType() if isinstance(expression, FreeParameterExpression): self._expression = expression.expression + if _type is None: + self._type = expression._type elif isinstance(expression, (Number, Expr)): self._expression = expression elif isinstance(expression, str): @@ -170,6 +187,34 @@ def __repr__(self) -> str: """ return repr(self.expression) + def to_ast(self, program: Program) -> QASMNode: + """Creates an AST node for the :class:'FreeParameterExpression'. + + Args: + program (Program): Unused. + + Returns: + QASMNode: The AST node. + """ + if isinstance(self._type, DurationType): + return DurationLiteral(_FreeParameterExpressionIdentifier(self), TimeUnit.s) + return _FreeParameterExpressionIdentifier(self) + + +class _FreeParameterExpressionIdentifier(QASMNode): + """Dummy AST node with FreeParameterExpression instance attached""" + + def __init__(self, expression: FreeParameterExpression): + self.name = f"FreeParameterExpression({expression})" + self._expression = expression + + @property + def expression(self) -> FreeParameterExpression: + return self._expression + + def __repr__(self) -> str: + return f"_FreeParameterExpressionIdentifier(name={self.name})" + def subs_if_free_parameter(parameter: Any, **kwargs) -> Any: """Substitute a free parameter with the given kwargs, if any. diff --git a/src/braket/pulse/ast/free_parameters.py b/src/braket/pulse/ast/free_parameters.py index 16acb5134..56bedc755 100644 --- a/src/braket/pulse/ast/free_parameters.py +++ b/src/braket/pulse/ast/free_parameters.py @@ -17,30 +17,10 @@ from oqpy.program import Program from oqpy.timing import OQDurationLiteral -from braket.parametric.free_parameter_expression import FreeParameterExpression - - -class _FreeParameterExpressionIdentifier(ast.QASMNode): - """Dummy AST node with FreeParameterExpression instance attached""" - - def __init__( - self, expression: FreeParameterExpression, type_: ast.ClassicalType = ast.FloatType() - ): - self.name = f"FreeParameterExpression({expression})" - self._expression = expression - self.type_ = type_ - - @property - def expression(self) -> FreeParameterExpression: - return self._expression - - def __repr__(self) -> str: - return f"_FreeParameterExpressionIdentifier(name={self.name})" - - def to_ast(self, program: Program) -> ast.Expression: - if isinstance(self.type_, ast.DurationType): - return ast.DurationLiteral(self, ast.TimeUnit.s) - return self +from braket.parametric.free_parameter_expression import ( + FreeParameterExpression, + _FreeParameterExpressionIdentifier, +) class _FreeParameterTransformer(QASMTransformer): @@ -63,8 +43,9 @@ def visit__FreeParameterExpressionIdentifier( """ new_value = identifier.expression.subs(self.param_values) if isinstance(new_value, FreeParameterExpression): - return _FreeParameterExpressionIdentifier(new_value, identifier.type_) - return ast.FloatLiteral(new_value) + return _FreeParameterExpressionIdentifier(new_value) + else: + return ast.FloatLiteral(new_value) def visit_DurationLiteral(self, duration_literal: ast.DurationLiteral) -> ast.DurationLiteral: """Visit Duration Literal. @@ -80,7 +61,8 @@ def visit_DurationLiteral(self, duration_literal: ast.DurationLiteral) -> ast.Du return duration_literal new_duration = duration.expression.subs(self.param_values) if isinstance(new_duration, FreeParameterExpression): - return _FreeParameterExpressionIdentifier(new_duration, duration.type_).to_ast( - self.program + return ast.DurationLiteral( + _FreeParameterExpressionIdentifier(new_duration), duration_literal.unit ) + # return super().visit(duration_literal) return OQDurationLiteral(new_duration).to_ast(self.program) diff --git a/src/braket/pulse/ast/qasm_parser.py b/src/braket/pulse/ast/qasm_parser.py index 9f3d82f36..f43fb7e7c 100644 --- a/src/braket/pulse/ast/qasm_parser.py +++ b/src/braket/pulse/ast/qasm_parser.py @@ -17,7 +17,6 @@ from openqasm3.ast import DurationLiteral from openqasm3.printer import PrinterState -from braket.parametric.free_parameter_expression import FreeParameterExpression from braket.pulse.ast.free_parameters import _FreeParameterExpressionIdentifier diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index b797756d4..b55e94f5b 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -329,7 +329,9 @@ def _format_parameter_ast( if isinstance(parameter, FreeParameterExpression): for p in parameter.expression.free_symbols: self._free_parameters.add(FreeParameter(p.name)) - return _FreeParameterExpressionIdentifier(parameter, type_) + if isinstance(type_, ast.DurationType): + return FreeParameterExpression(parameter, type_) + return parameter else: if isinstance(type_, ast.DurationType): return OQDurationLiteral(parameter) diff --git a/src/braket/pulse/waveforms.py b/src/braket/pulse/waveforms.py index 48da8904b..16f454b03 100644 --- a/src/braket/pulse/waveforms.py +++ b/src/braket/pulse/waveforms.py @@ -28,7 +28,6 @@ subs_if_free_parameter, ) from braket.parametric.parameterizable import Parameterizable -from braket.pulse.ast.free_parameters import _FreeParameterExpressionIdentifier class Waveform(ABC): @@ -453,13 +452,9 @@ def _make_identifier_name() -> str: def _map_to_oqpy_type( parameter: Union[FreeParameterExpression, float], is_duration_type: bool = False -) -> Union[_FreeParameterExpressionIdentifier, OQPyExpression]: - if isinstance(parameter, FreeParameterExpression): - return ( - _FreeParameterExpressionIdentifier(parameter, duration) - if is_duration_type - else _FreeParameterExpressionIdentifier(parameter) - ) +) -> Union[FreeParameterExpression, OQPyExpression]: + if isinstance(parameter, FreeParameterExpression) and is_duration_type: + return FreeParameterExpression(parameter, duration) return parameter diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 0f6128174..4890a4b36 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -126,7 +126,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined "OPENQASM 3.0;", "cal {", " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", - " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s, (sigma_dg) * 1s, 0.2, 1, false);", + " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s," + " (sigma_dg) * 1s, 0.2, 1, false);", " waveform constant_wf = constant((length_c) * 1s, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", From fc6aa3dddecca46ebe412d0c7a5af1b3661b764d Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 11 Sep 2023 13:02:53 -0400 Subject: [PATCH 04/43] change back FPEIdentifier's parent to Identifier --- src/braket/parametric/free_parameter_expression.py | 14 ++++++-------- src/braket/pulse/ast/free_parameters.py | 1 - 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index 1da2469e4..9eada787f 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -21,8 +21,9 @@ ClassicalType, DurationLiteral, DurationType, + Expression, FloatType, - QASMNode, + Identifier, TimeUnit, ) from oqpy import Program @@ -187,34 +188,31 @@ def __repr__(self) -> str: """ return repr(self.expression) - def to_ast(self, program: Program) -> QASMNode: + def to_ast(self, program: Program) -> Expression: """Creates an AST node for the :class:'FreeParameterExpression'. Args: program (Program): Unused. Returns: - QASMNode: The AST node. + Expression: The AST node. """ if isinstance(self._type, DurationType): return DurationLiteral(_FreeParameterExpressionIdentifier(self), TimeUnit.s) return _FreeParameterExpressionIdentifier(self) -class _FreeParameterExpressionIdentifier(QASMNode): +class _FreeParameterExpressionIdentifier(Identifier): """Dummy AST node with FreeParameterExpression instance attached""" def __init__(self, expression: FreeParameterExpression): - self.name = f"FreeParameterExpression({expression})" + super().__init__(name=f"FreeParameterExpression({expression})") self._expression = expression @property def expression(self) -> FreeParameterExpression: return self._expression - def __repr__(self) -> str: - return f"_FreeParameterExpressionIdentifier(name={self.name})" - def subs_if_free_parameter(parameter: Any, **kwargs) -> Any: """Substitute a free parameter with the given kwargs, if any. diff --git a/src/braket/pulse/ast/free_parameters.py b/src/braket/pulse/ast/free_parameters.py index 56bedc755..1d8545933 100644 --- a/src/braket/pulse/ast/free_parameters.py +++ b/src/braket/pulse/ast/free_parameters.py @@ -64,5 +64,4 @@ def visit_DurationLiteral(self, duration_literal: ast.DurationLiteral) -> ast.Du return ast.DurationLiteral( _FreeParameterExpressionIdentifier(new_duration), duration_literal.unit ) - # return super().visit(duration_literal) return OQDurationLiteral(new_duration).to_ast(self.program) From f7da1155d463747fd6e087707a3e4d522260d2ca Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 11 Sep 2023 13:53:05 -0400 Subject: [PATCH 05/43] clean up syntax --- src/braket/pulse/pulse_sequence.py | 17 ++++++++--------- src/braket/pulse/waveforms.py | 8 +++++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index b55e94f5b..c7583b82d 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -183,7 +183,7 @@ def delay( Returns: PulseSequence: self, with the instruction added. """ - duration = self._format_parameter_ast(duration, type_=ast.DurationType()) + duration = self._format_parameter_ast(duration, _type=ast.DurationType()) if not isinstance(qubits_or_frames, QubitSet): if not isinstance(qubits_or_frames, list): qubits_or_frames = [qubits_or_frames] @@ -324,18 +324,17 @@ def to_ir(self) -> str: def _format_parameter_ast( self, parameter: Union[float, FreeParameterExpression], - type_: ast.ClassicalType = ast.FloatType(), + _type: ast.ClassicalType = ast.FloatType(), ) -> Union[float, _FreeParameterExpressionIdentifier]: if isinstance(parameter, FreeParameterExpression): for p in parameter.expression.free_symbols: self._free_parameters.add(FreeParameter(p.name)) - if isinstance(type_, ast.DurationType): - return FreeParameterExpression(parameter, type_) - return parameter - else: - if isinstance(type_, ast.DurationType): - return OQDurationLiteral(parameter) - return parameter + return ( + FreeParameterExpression(parameter, _type) + if isinstance(_type, ast.DurationType) + else parameter + ) + return OQDurationLiteral(parameter) if isinstance(_type, ast.DurationType) else parameter def _parse_arg_from_calibration_schema( self, argument: Dict, waveforms: Dict[Waveform], frames: Dict[Frame] diff --git a/src/braket/pulse/waveforms.py b/src/braket/pulse/waveforms.py index 16f454b03..72d2120bc 100644 --- a/src/braket/pulse/waveforms.py +++ b/src/braket/pulse/waveforms.py @@ -453,9 +453,11 @@ def _make_identifier_name() -> str: def _map_to_oqpy_type( parameter: Union[FreeParameterExpression, float], is_duration_type: bool = False ) -> Union[FreeParameterExpression, OQPyExpression]: - if isinstance(parameter, FreeParameterExpression) and is_duration_type: - return FreeParameterExpression(parameter, duration) - return parameter + return ( + FreeParameterExpression(parameter, duration) + if isinstance(parameter, FreeParameterExpression) and is_duration_type + else parameter + ) def _parse_waveform_from_calibration_schema(waveform: Dict) -> Waveform: From ede6a679cc6223c923ddf563b89df93a75eaffac Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Thu, 14 Sep 2023 20:23:10 -0400 Subject: [PATCH 06/43] add precision about the expression type --- src/braket/parametric/free_parameter_expression.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index 9eada787f..eeba3c207 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -53,7 +53,10 @@ def __init__( Args: expression (Union[FreeParameterExpression, Number, Expr, str]): The expression to use. - _type (Optional[ClassicalType]): type of the expression + _type (Optional[ClassicalType]): The OpenQASM3 type associated with the expression. + Subtypes of openqasm3.ast.ClassicalType are used to specify how to express the + expression in the OpenQASM3 IR. Any type other than DurationType is considered + as FloatType. Examples: >>> expression_1 = FreeParameter("theta") * FreeParameter("alpha") From cebf58e3fbcb70304f20e2325e86c6521aabb79d Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Thu, 28 Sep 2023 13:06:35 -0400 Subject: [PATCH 07/43] add __repr__ to waveforms --- src/braket/circuits/circuit.py | 7 +- src/braket/pulse/waveforms.py | 19 +++++ .../unit_tests/braket/pulse/test_waveforms.py | 78 +++++++++++++++---- 3 files changed, 85 insertions(+), 19 deletions(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index 7e3bfcd4d..5eedd5573 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -57,9 +57,10 @@ from braket.ir.jaqcd import Program as JaqcdProgram from braket.ir.openqasm import Program as OpenQasmProgram from braket.ir.openqasm.program_v1 import io_type -from braket.pulse import ArbitraryWaveform, Frame from braket.pulse.ast.qasm_parser import ast_to_qasm +from braket.pulse.frame import Frame from braket.pulse.pulse_sequence import PulseSequence, _validate_uniqueness +from braket.pulse.waveforms import Waveform SubroutineReturn = TypeVar( "SubroutineReturn", Iterable[Instruction], Instruction, ResultType, Iterable[ResultType] @@ -1245,7 +1246,7 @@ def _validate_gate_calbrations_uniqueness( self, gate_definitions: Dict[Tuple[Gate, QubitSet], PulseSequence], frames: Dict[Frame], - waveforms: Dict[ArbitraryWaveform], + waveforms: Dict[Waveform], ) -> None: for key, calibration in gate_definitions.items(): for frame in calibration._frames.values(): @@ -1303,7 +1304,7 @@ def _generate_frame_wf_defcal_declarations( def _get_frames_waveforms_from_instrs( self, gate_definitions: Optional[Dict[Tuple[Gate, QubitSet], PulseSequence]] - ) -> Tuple[Dict[Frame], Dict[ArbitraryWaveform]]: + ) -> Tuple[Dict[Frame], Dict[Waveform]]: from braket.circuits.gates import PulseGate frames = {} diff --git a/src/braket/pulse/waveforms.py b/src/braket/pulse/waveforms.py index 72d2120bc..e05fec207 100644 --- a/src/braket/pulse/waveforms.py +++ b/src/braket/pulse/waveforms.py @@ -83,6 +83,9 @@ def __init__(self, amplitudes: List[complex], id: Optional[str] = None): self.amplitudes = list(amplitudes) self.id = id or _make_identifier_name() + def __repr__(self) -> str: + return f"ArbitraryWaveform('id': {self.id}, 'amplitudes': {self.amplitudes})" + def __eq__(self, other): return isinstance(other, ArbitraryWaveform) and (self.amplitudes, self.id) == ( other.amplitudes, @@ -131,6 +134,9 @@ def __init__( self.iq = iq self.id = id or _make_identifier_name() + def __repr__(self) -> str: + return f"ConstantWaveform('id': {self.id}, 'length': {self.length}, 'iq': {self.iq})" + @property def parameters(self) -> List[Union[FreeParameterExpression, FreeParameter, float]]: """Returns the parameters associated with the object, either unbound free parameter @@ -236,6 +242,13 @@ def __init__( self.zero_at_edges = zero_at_edges self.id = id or _make_identifier_name() + def __repr__(self) -> str: + return ( + f"DragGaussianWaveform('id': {self.id}, 'length': {self.length}, " + f"'sigma': {self.sigma}, 'beta': {self.beta}, 'amplitude': {self.amplitude}, " + f"'zero_at_edges': {self.zero_at_edges})" + ) + @property def parameters(self) -> List[Union[FreeParameterExpression, FreeParameter, float]]: """Returns the parameters associated with the object, either unbound free parameter @@ -360,6 +373,12 @@ def __init__( self.zero_at_edges = zero_at_edges self.id = id or _make_identifier_name() + def __repr__(self) -> str: + return ( + f"GaussianWaveform('id': {self.id}, 'length': {self.length}, 'sigma': {self.sigma}, " + f"'amplitude': {self.amplitude}, 'zero_at_edges': {self.zero_at_edges})" + ) + @property def parameters(self) -> List[Union[FreeParameterExpression, FreeParameter, float]]: """Returns the parameters associated with the object, either unbound free parameter diff --git a/test/unit_tests/braket/pulse/test_waveforms.py b/test/unit_tests/braket/pulse/test_waveforms.py index 5e8e78111..09ff3290e 100644 --- a/test/unit_tests/braket/pulse/test_waveforms.py +++ b/test/unit_tests/braket/pulse/test_waveforms.py @@ -42,6 +42,14 @@ def test_arbitrary_waveform(amps): assert oq_exp.name == wf.id +def test_arbitrary_waveform_repr(): + amps = [1, 4, 5] + id = "arb_wf_x" + wf = ArbitraryWaveform(amps, id) + expected = f"ArbitraryWaveform('id': {wf.id}, 'amplitudes': {wf.amplitudes})" + assert repr(wf) == expected + + def test_arbitrary_waveform_default_params(): amps = [1, 4, 5] wf = ArbitraryWaveform(amps) @@ -77,6 +85,15 @@ def test_constant_waveform(): _assert_wf_qasm(wf, "waveform const_wf_x = constant(4.0ms, 4);") +def test_constant_waveform_repr(): + length = 4e-3 + iq = 4 + id = "const_wf_x" + wf = ConstantWaveform(length, iq, id) + expected = f"ConstantWaveform('id': {wf.id}, 'length': {wf.length}, 'iq': {wf.iq})" + assert repr(wf) == expected + + def test_constant_waveform_default_params(): amps = [1, 4, 5] wf = ArbitraryWaveform(amps) @@ -128,6 +145,21 @@ def test_drag_gaussian_waveform(): _assert_wf_qasm(wf, "waveform drag_gauss_wf = drag_gaussian(4.0ns, 300.0ms, 0.6, 0.4, false);") +def test_drag_gaussian_waveform_repr(): + length = 4e-9 + sigma = 0.3 + beta = 0.6 + amplitude = 0.4 + zero_at_edges = False + id = "drag_gauss_wf" + wf = DragGaussianWaveform(length, sigma, beta, amplitude, zero_at_edges, id) + expected = ( + f"DragGaussianWaveform('id': {wf.id}, 'length': {wf.length}, 'sigma': {wf.sigma}, " + f"'beta': {wf.beta}, 'amplitude': {wf.amplitude}, 'zero_at_edges': {wf.zero_at_edges})" + ) + assert repr(wf) == expected + + def test_drag_gaussian_waveform_default_params(): length = 4e-9 sigma = 0.3 @@ -151,22 +183,6 @@ def test_drag_gaussian_wf_eq(): assert wf != wfc -def test_gaussian_waveform(): - length = 4e-9 - sigma = 0.3 - amplitude = 0.4 - zero_at_edges = False - id = "gauss_wf" - wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id) - assert wf.id == id - assert wf.zero_at_edges == zero_at_edges - assert wf.amplitude == amplitude - assert wf.sigma == sigma - assert wf.length == length - - _assert_wf_qasm(wf, "waveform gauss_wf = gaussian(4.0ns, 300.0ms, 0.4, false);") - - def test_drag_gaussian_wf_free_params(): wf = DragGaussianWaveform( FreeParameter("length_v"), @@ -205,6 +221,36 @@ def test_drag_gaussian_wf_free_params(): _assert_wf_qasm(wf_3, "waveform d_gauss_wf = drag_gaussian(600.0ms, 400.0ms, 0.2, 0.1, false);") +def test_gaussian_waveform(): + length = 4e-9 + sigma = 0.3 + amplitude = 0.4 + zero_at_edges = False + id = "gauss_wf" + wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id) + assert wf.id == id + assert wf.zero_at_edges == zero_at_edges + assert wf.amplitude == amplitude + assert wf.sigma == sigma + assert wf.length == length + + _assert_wf_qasm(wf, "waveform gauss_wf = gaussian(4.0ns, 300.0ms, 0.4, false);") + + +def test_gaussian_waveform_repr(): + length = 4e-9 + sigma = 0.3 + amplitude = 0.4 + zero_at_edges = False + id = "gauss_wf" + wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id) + expected = ( + f"GaussianWaveform('id': {wf.id}, 'length': {wf.length}, 'sigma': {wf.sigma}, " + f"'amplitude': {wf.amplitude}, 'zero_at_edges': {wf.zero_at_edges})" + ) + assert repr(wf) == expected + + def test_gaussian_waveform_default_params(): length = 4e-9 sigma = 0.3 From c75ddd6435a56471cee296a9e7ebba782ccaff5a Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Thu, 28 Sep 2023 23:06:20 -0400 Subject: [PATCH 08/43] do not simplify constants with defcals --- src/braket/circuits/circuit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index 5eedd5573..7e0f5c919 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1259,7 +1259,7 @@ def _validate_gate_calbrations_uniqueness( def _generate_frame_wf_defcal_declarations( self, gate_definitions: Optional[Dict[Tuple[Gate, QubitSet], PulseSequence]] ) -> Optional[str]: - program = oqpy.Program(None) + program = oqpy.Program(None, simplify_constants=False) frames, waveforms = self._get_frames_waveforms_from_instrs(gate_definitions) From 68afef4392ec72d369621623bda214c48e30e49f Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Wed, 4 Oct 2023 18:08:35 -0400 Subject: [PATCH 09/43] add type validation --- src/braket/parametric/free_parameter_expression.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index eeba3c207..8ee9fd21c 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -80,6 +80,7 @@ def __init__( self._expression = self._parse_string_expression(expression).expression else: raise NotImplementedError + self._validate_type() @property def expression(self) -> Union[Number, Expr]: @@ -117,6 +118,13 @@ def subs( else: return FreeParameterExpression(subbed_expr) + def _validate_type(self) -> None: + if not isinstance(self._type, (FloatType, DurationType)): + raise TypeError( + "FreeParameterExpression must be of type openqasm3.ast.FloatType " + "or openqasm3.ast.DurationType" + ) + def _parse_string_expression(self, expression: str) -> FreeParameterExpression: return self._eval_operation(ast.parse(expression, mode="eval").body) From b7313c0c95d403a55b0644c0fe7f017f3259fc60 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Thu, 5 Oct 2023 09:53:58 -0400 Subject: [PATCH 10/43] update oqpy to 0.3.2 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0108656b1..47ff1acdf 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ install_requires=[ "amazon-braket-schemas>=1.19.1", "amazon-braket-default-simulator>=1.19.1", - "oqpy~=0.3.1", + "oqpy~=0.3.2", "setuptools", "backoff", "boltons", From ce30feea8b70997a678eae49e65b1bd8c84b4806 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Thu, 5 Oct 2023 09:56:04 -0400 Subject: [PATCH 11/43] fix linters --- examples/job.py | 2 -- src/braket/jobs/data_persistence.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/job.py b/examples/job.py index f5e330d1a..87b06bf49 100644 --- a/examples/job.py +++ b/examples/job.py @@ -11,8 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import os - from braket.aws import AwsDevice, AwsQuantumJob from braket.circuits import Circuit from braket.devices import Devices diff --git a/src/braket/jobs/data_persistence.py b/src/braket/jobs/data_persistence.py index aa574d0d9..10806812f 100644 --- a/src/braket/jobs/data_persistence.py +++ b/src/braket/jobs/data_persistence.py @@ -126,7 +126,7 @@ def load_job_result(filename: Union[str, Path] = None) -> Dict[str, Any]: must be in the format used by `save_job_result`. Returns: - Dict[str, Any]: Job result data of current job + Dict[str, Any]: Job result data of current job """ persisted_data = _load_persisted_data(filename) deserialized_data = deserialize_values(persisted_data.dataDictionary, persisted_data.dataFormat) From a64971a92fa72a12fb796b1df0f07867bda1774f Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Fri, 6 Oct 2023 18:42:21 -0400 Subject: [PATCH 12/43] increase test coverage --- .../braket/parametric/test_free_parameter_expression.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/unit_tests/braket/parametric/test_free_parameter_expression.py b/test/unit_tests/braket/parametric/test_free_parameter_expression.py index 879706fe0..d991ec236 100644 --- a/test/unit_tests/braket/parametric/test_free_parameter_expression.py +++ b/test/unit_tests/braket/parametric/test_free_parameter_expression.py @@ -67,6 +67,11 @@ def test_unsupported_node_str(): FreeParameterExpression("theta , 1") +@pytest.mark.xfail(raises=TypeError) +def test_unsupported_type(): + FreeParameterExpression("theta", _type=float) + + def test_commutativity(): add_1 = 1 + FreeParameterExpression(FreeParameter("theta")) add_2 = FreeParameterExpression(FreeParameter("theta")) + 1 From 4c1103aabdd90ea179595ee5375d5ea28b49eb72 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 16 Oct 2023 15:39:06 -0400 Subject: [PATCH 13/43] update to oqpy 0.3.3 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 47ff1acdf..212009286 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ install_requires=[ "amazon-braket-schemas>=1.19.1", "amazon-braket-default-simulator>=1.19.1", - "oqpy~=0.3.2", + "oqpy~=0.3.3", "setuptools", "backoff", "boltons", From bba404e737ed11e776afc0634d28aeffa827a796 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 16 Oct 2023 16:07:49 -0400 Subject: [PATCH 14/43] fix last merge commit --- src/braket/circuits/circuit.py | 4 ++-- src/braket/parametric/free_parameter_expression.py | 4 ++-- src/braket/pulse/ast/free_parameters.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index f42ef3618..f6f6be4b0 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1247,7 +1247,7 @@ def _validate_gate_calbrations_uniqueness( self, gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], frames: dict[Frame], - waveforms: dict[ArbitraryWaveform], + waveforms: dict[Waveform], ) -> None: for key, calibration in gate_definitions.items(): for frame in calibration._frames.values(): @@ -1305,7 +1305,7 @@ def _generate_frame_wf_defcal_declarations( def _get_frames_waveforms_from_instrs( self, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] - ) -> tuple[dict[Frame], dict[ArbitraryWaveform]]: + ) -> tuple[dict[Frame], dict[Waveform]]: from braket.circuits.gates import PulseGate frames = {} diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index 73d9c643c..8a864120b 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -43,7 +43,7 @@ class FreeParameterExpression: def __init__( self, expression: Union[FreeParameterExpression, Number, Expr, str], - _type: Optional[ClassicalType] = None, + _type: ClassicalType | None = None, ): """ Initializes a FreeParameterExpression. Best practice is to initialize using @@ -53,7 +53,7 @@ def __init__( Args: expression (Union[FreeParameterExpression, Number, Expr, str]): The expression to use. - _type (Optional[ClassicalType]): The OpenQASM3 type associated with the expression. + _type (ClassicalType | None): The OpenQASM3 type associated with the expression. Subtypes of openqasm3.ast.ClassicalType are used to specify how to express the expression in the OpenQASM3 IR. Any type other than DurationType is considered as FloatType. diff --git a/src/braket/pulse/ast/free_parameters.py b/src/braket/pulse/ast/free_parameters.py index f51d80127..639645b27 100644 --- a/src/braket/pulse/ast/free_parameters.py +++ b/src/braket/pulse/ast/free_parameters.py @@ -27,7 +27,7 @@ class _FreeParameterTransformer(QASMTransformer): """Walk the AST and evaluate FreeParameterExpressions.""" - def __init__(self, param_values: dict[str, float]): + def __init__(self, param_values: dict[str, float], program: Program): self.param_values = param_values self.program = program super().__init__() From 5093176b2998340f30f955f94245a274e0bc2527 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 16 Oct 2023 16:19:34 -0400 Subject: [PATCH 15/43] fix type hints --- src/braket/circuits/circuit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index f6f6be4b0..a189807e5 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1246,8 +1246,8 @@ def _create_openqasm_header( def _validate_gate_calbrations_uniqueness( self, gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], - frames: dict[Frame], - waveforms: dict[Waveform], + frames: dict[str, Frame], + waveforms: dict[str, Waveform], ) -> None: for key, calibration in gate_definitions.items(): for frame in calibration._frames.values(): @@ -1305,7 +1305,7 @@ def _generate_frame_wf_defcal_declarations( def _get_frames_waveforms_from_instrs( self, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] - ) -> tuple[dict[Frame], dict[Waveform]]: + ) -> tuple[dict[str, Frame], dict[str, Waveform]]: from braket.circuits.gates import PulseGate frames = {} From 423a97c9508ae0f4aba792c6b74494029d0a0a57 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 6 Nov 2023 15:04:11 -0500 Subject: [PATCH 16/43] update to oqpy 0.3.4 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cfb1365dc..e08d37ba7 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ install_requires=[ "amazon-braket-schemas>=1.19.1", "amazon-braket-default-simulator>=1.19.1", - "oqpy~=0.3.3", + "oqpy~=0.3.4", "setuptools", "backoff", "boltons", From 9da40a662ada84dae7af130945f91509ffa43d60 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 11 Dec 2023 16:48:54 +0100 Subject: [PATCH 17/43] fix oqpy to 0.3.3 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6c9f31f9b..328f0a1bf 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ install_requires=[ "amazon-braket-schemas>=1.19.1", "amazon-braket-default-simulator>=1.19.1", - "oqpy~=0.3.4", + "oqpy~=0.3.3", "setuptools", "backoff", "boltons", From 004de4a201540e17f0fd64efd85b3351c9d87610 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Wed, 6 Dec 2023 20:28:11 +0100 Subject: [PATCH 18/43] declare FreeParameter as oqpy var --- src/braket/pulse/pulse_sequence.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 4e3bb06d4..cafb96a42 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -19,7 +19,7 @@ from typing import Any, Union from openpulse import ast -from oqpy import BitVar, PhysicalQubits, Program +from oqpy import BitVar, FloatVar, PhysicalQubits, Program from oqpy.timing import OQDurationLiteral from braket.parametric.free_parameter import FreeParameter @@ -310,6 +310,8 @@ def to_ir(self) -> str: str: a str representing the OpenPulse program encoding the PulseSequence. """ program = deepcopy(self._program) + for param in self.parameters: + program.declare(FloatVar(name=param.name, size=None, init_expression="input"), to_beginning=True) if self._capture_v0_count: register_identifier = "psb" program.declare( From 3ae122e34c7c55155f8ea5467f376b30d0fc8d29 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 11 Dec 2023 16:32:13 +0100 Subject: [PATCH 19/43] declare free params before playing the waveform --- src/braket/pulse/pulse_sequence.py | 19 ++++++++++++++++--- .../braket/pulse/test_pulse_sequence.py | 9 +++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index cafb96a42..5a4b2535a 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -231,12 +231,20 @@ def play(self, frame: Frame, waveform: Waveform) -> PulseSequence: """ _validate_uniqueness(self._frames, frame) _validate_uniqueness(self._waveforms, waveform) - self._program.play(frame=frame, waveform=waveform) if isinstance(waveform, Parameterizable): for param in waveform.parameters: if isinstance(param, FreeParameterExpression): for p in param.expression.free_symbols: + self._program._add_var( + FloatVar( + name=p.name, + size=None, + init_expression="input", + needs_declaration=True, + ) + ) self._free_parameters.add(FreeParameter(p.name)) + self._program.play(frame=frame, waveform=waveform) self._frames[frame.id] = frame self._waveforms[waveform.id] = waveform return self @@ -283,6 +291,8 @@ def make_bound_pulse_sequence(self, param_values: dict[str, float]) -> PulseSequ new_pulse_sequence = PulseSequence() new_pulse_sequence._program = new_program + for param_name in param_values: + new_pulse_sequence._program.undeclared_vars.pop(param_name, None) new_pulse_sequence._frames = deepcopy(self._frames) new_pulse_sequence._waveforms = { wf.id: wf.bind_values(**param_values) if isinstance(wf, Parameterizable) else wf @@ -310,8 +320,6 @@ def to_ir(self) -> str: str: a str representing the OpenPulse program encoding the PulseSequence. """ program = deepcopy(self._program) - for param in self.parameters: - program.declare(FloatVar(name=param.name, size=None, init_expression="input"), to_beginning=True) if self._capture_v0_count: register_identifier = "psb" program.declare( @@ -330,6 +338,11 @@ def _format_parameter_ast( ) -> Union[float, _FreeParameterExpressionIdentifier]: if isinstance(parameter, FreeParameterExpression): for p in parameter.expression.free_symbols: + self._program._add_var( + FloatVar( + name=p.name, size=None, init_expression="input", needs_declaration=True + ) + ) self._free_parameters.add(FreeParameter(p.name)) return ( FreeParameterExpression(parameter, _type) diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 4890a4b36..dc06ffdcc 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -125,9 +125,16 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", + " input float[64] b;", + " input float[64] a;", + " input float[64] length_g;", + " input float[64] sigma_g;", " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", + " input float[64] length_dg;", + " input float[64] sigma_dg;", " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s," " (sigma_dg) * 1s, 0.2, 1, false);", + " input float[64] length_c;", " waveform constant_wf = constant((length_c) * 1s, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", @@ -169,6 +176,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", + " input float[64] a;", + " input float[64] sigma_g;", " waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", From 5c557a885929227e09a8b85f63f41810d23754cd Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Thu, 14 Dec 2023 20:04:18 +0100 Subject: [PATCH 20/43] remove FreeParameterExpressionIdentitifer --- .../parametric/free_parameter_expression.py | 28 ++++--------- src/braket/pulse/ast/free_parameters.py | 41 ++++++++++--------- src/braket/pulse/ast/qasm_parser.py | 14 +++---- src/braket/pulse/pulse_sequence.py | 7 +--- .../braket/pulse/test_pulse_sequence.py | 12 +++--- 5 files changed, 43 insertions(+), 59 deletions(-) diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index 8a864120b..71292d306 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -107,7 +107,7 @@ def subs( """ new_parameter_values = dict() for key, val in parameter_values.items(): - if issubclass(type(key), FreeParameterExpression): + if isinstance(key, FreeParameterExpression): new_parameter_values[key.expression] = val else: new_parameter_values[key] = val @@ -147,7 +147,7 @@ def _eval_operation(self, node: Any) -> FreeParameterExpression: raise ValueError(f"Unsupported string detected: {node}") def __add__(self, other): - if issubclass(type(other), FreeParameterExpression): + if isinstance(other, FreeParameterExpression): return FreeParameterExpression(self.expression + other.expression) else: return FreeParameterExpression(self.expression + other) @@ -156,7 +156,7 @@ def __radd__(self, other): return FreeParameterExpression(other + self.expression) def __sub__(self, other): - if issubclass(type(other), FreeParameterExpression): + if isinstance(other, FreeParameterExpression): return FreeParameterExpression(self.expression - other.expression) else: return FreeParameterExpression(self.expression - other) @@ -165,7 +165,7 @@ def __rsub__(self, other): return FreeParameterExpression(other - self.expression) def __mul__(self, other): - if issubclass(type(other), FreeParameterExpression): + if isinstance(other, FreeParameterExpression): return FreeParameterExpression(self.expression * other.expression) else: return FreeParameterExpression(self.expression * other) @@ -174,7 +174,7 @@ def __rmul__(self, other): return FreeParameterExpression(other * self.expression) def __pow__(self, other, modulo=None): - if issubclass(type(other), FreeParameterExpression): + if isinstance(other, FreeParameterExpression): return FreeParameterExpression(self.expression**other.expression) else: return FreeParameterExpression(self.expression**other) @@ -208,21 +208,11 @@ def to_ast(self, program: Program) -> Expression: Returns: Expression: The AST node. """ + # TODO (#822): capture expressions into expression ASTs rather than just an Identifier + identifier = Identifier(name=self) if isinstance(self._type, DurationType): - return DurationLiteral(_FreeParameterExpressionIdentifier(self), TimeUnit.s) - return _FreeParameterExpressionIdentifier(self) - - -class _FreeParameterExpressionIdentifier(Identifier): - """Dummy AST node with FreeParameterExpression instance attached""" - - def __init__(self, expression: FreeParameterExpression): - super().__init__(name=f"FreeParameterExpression({expression})") - self._expression = expression - - @property - def expression(self) -> FreeParameterExpression: - return self._expression + return DurationLiteral(identifier, TimeUnit.s) + return identifier def subs_if_free_parameter(parameter: Any, **kwargs) -> Any: diff --git a/src/braket/pulse/ast/free_parameters.py b/src/braket/pulse/ast/free_parameters.py index 639645b27..96e319eb6 100644 --- a/src/braket/pulse/ast/free_parameters.py +++ b/src/braket/pulse/ast/free_parameters.py @@ -18,10 +18,7 @@ from oqpy.program import Program from oqpy.timing import OQDurationLiteral -from braket.parametric.free_parameter_expression import ( - FreeParameterExpression, - _FreeParameterExpressionIdentifier, -) +from braket.parametric.free_parameter_expression import FreeParameterExpression class _FreeParameterTransformer(QASMTransformer): @@ -32,21 +29,27 @@ def __init__(self, param_values: dict[str, float], program: Program): self.program = program super().__init__() - def visit__FreeParameterExpressionIdentifier( - self, identifier: _FreeParameterExpressionIdentifier - ) -> Union[_FreeParameterExpressionIdentifier, ast.FloatLiteral]: - """Visit a FreeParameterExpressionIdentifier. + def visit_Identifier( + self, identifier: ast.Identifier + ) -> Union[ast.Identifier, ast.FloatLiteral]: + """Visit an Identifier. + + If the Identifier is used to hold a `FreeParameterExpression`, it will be simplified + using the given parameter values. + Args: - identifier (_FreeParameterExpressionIdentifier): The identifier. + identifier (Identifier): The identifier. Returns: - Union[_FreeParameterExpressionIdentifier, FloatLiteral]: The transformed expression. + Union[Identifier, FloatLiteral]: The transformed identifier. """ - new_value = identifier.expression.subs(self.param_values) - if isinstance(new_value, FreeParameterExpression): - return _FreeParameterExpressionIdentifier(new_value) - else: - return ast.FloatLiteral(new_value) + if isinstance(identifier.name, FreeParameterExpression): + new_value = FreeParameterExpression(identifier.name).subs(self.param_values) + if isinstance(new_value, FreeParameterExpression): + return ast.Identifier(new_value) + else: + return ast.FloatLiteral(float(new_value)) + return identifier def visit_DurationLiteral(self, duration_literal: ast.DurationLiteral) -> ast.DurationLiteral: """Visit Duration Literal. @@ -58,11 +61,9 @@ def visit_DurationLiteral(self, duration_literal: ast.DurationLiteral) -> ast.Du DurationLiteral: The transformed duration literal. """ duration = duration_literal.value - if not isinstance(duration, _FreeParameterExpressionIdentifier): + if not isinstance(duration, ast.Identifier): return duration_literal - new_duration = duration.expression.subs(self.param_values) + new_duration = FreeParameterExpression(duration.name).subs(self.param_values) if isinstance(new_duration, FreeParameterExpression): - return ast.DurationLiteral( - _FreeParameterExpressionIdentifier(new_duration), duration_literal.unit - ) + return ast.DurationLiteral(ast.Identifier(str(new_duration)), duration_literal.unit) return OQDurationLiteral(new_duration).to_ast(self.program) diff --git a/src/braket/pulse/ast/qasm_parser.py b/src/braket/pulse/ast/qasm_parser.py index 1f5c48cce..e892ba86b 100644 --- a/src/braket/pulse/ast/qasm_parser.py +++ b/src/braket/pulse/ast/qasm_parser.py @@ -18,8 +18,6 @@ from openqasm3.ast import DurationLiteral from openqasm3.printer import PrinterState -from braket.pulse.ast.free_parameters import _FreeParameterExpressionIdentifier - class _PulsePrinter(Printer): """Walks the AST and prints it to an OpenQASM3 string.""" @@ -27,15 +25,13 @@ class _PulsePrinter(Printer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def visit__FreeParameterExpressionIdentifier( - self, node: ast.Identifier, context: PrinterState - ) -> None: - """Visit a FreeParameterExpressionIdentifier. + def visit_Identifier(self, node: ast.Identifier, context: PrinterState) -> None: + """Visit an Identifier. Args: node (ast.Identifier): The identifier. context (PrinterState): The printer state context. """ - self.stream.write(str(node.expression.expression)) + self.stream.write(str(node.name)) def visit_DurationLiteral(self, node: DurationLiteral, context: PrinterState) -> None: """Visit Duration Literal. @@ -46,8 +42,8 @@ def visit_DurationLiteral(self, node: DurationLiteral, context: PrinterState) -> context (PrinterState): The printer state context. """ duration = node.value - if isinstance(duration, _FreeParameterExpressionIdentifier): - self.stream.write(f"({duration.expression}) * 1{node.unit.name}") + if isinstance(duration, ast.Identifier): + self.stream.write(f"({duration.name}) * 1{node.unit.name}") else: super().visit_DurationLiteral(node, context) diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 4e3bb06d4..4fac2bde0 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -26,10 +26,7 @@ from braket.parametric.free_parameter_expression import FreeParameterExpression from braket.parametric.parameterizable import Parameterizable from braket.pulse.ast.approximation_parser import _ApproximationParser -from braket.pulse.ast.free_parameters import ( - _FreeParameterExpressionIdentifier, - _FreeParameterTransformer, -) +from braket.pulse.ast.free_parameters import _FreeParameterTransformer from braket.pulse.ast.qasm_parser import ast_to_qasm from braket.pulse.ast.qasm_transformer import _IRQASMTransformer from braket.pulse.frame import Frame @@ -325,7 +322,7 @@ def _format_parameter_ast( self, parameter: Union[float, FreeParameterExpression], _type: ast.ClassicalType = ast.FloatType(), - ) -> Union[float, _FreeParameterExpressionIdentifier]: + ) -> Union[float, FreeParameterExpression]: if isinstance(parameter, FreeParameterExpression): for p in parameter.expression.free_symbols: self._free_parameters.add(FreeParameter(p.name)) diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 4890a4b36..cbf57a985 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -194,7 +194,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined ) assert b_bound.to_ir() == b_bound_call.to_ir() == expected_str_b_bound assert pulse_sequence.to_ir() == expected_str_unbound - assert b_bound.parameters == set([FreeParameter("sigma_g"), FreeParameter("a")]) + assert b_bound.parameters == {FreeParameter("sigma_g"), FreeParameter("a")} both_bound = b_bound.make_bound_pulse_sequence({"a": 1, "sigma_g": 0.7}) both_bound_call = b_bound_call(1, sigma_g=0.7) # use arg 1 for a expected_str_both_bound = "\n".join( @@ -206,11 +206,11 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", - " set_frequency(predefined_frame_1, 5);", - " shift_frequency(predefined_frame_1, 5);", - " set_phase(predefined_frame_1, 5);", - " shift_phase(predefined_frame_1, 5);", - " set_scale(predefined_frame_1, 5);", + " set_frequency(predefined_frame_1, 5.0);", + " shift_frequency(predefined_frame_1, 5.0);", + " set_phase(predefined_frame_1, 5.0);", + " shift_phase(predefined_frame_1, 5.0);", + " set_scale(predefined_frame_1, 5.0);", " psb[0] = capture_v0(predefined_frame_1);", " delay[5s] predefined_frame_1, predefined_frame_2;", " delay[5s] predefined_frame_1;", From ec2a559441a64ef84e626ef4f53ca92862ee49ea Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 11:34:31 -0500 Subject: [PATCH 21/43] declare input parameters with pulse sequences --- src/braket/pulse/ast/qasm_transformer.py | 40 +++++++++++++++++++ src/braket/pulse/pulse_sequence.py | 19 ++++++++- .../braket/pulse/test_pulse_sequence.py | 9 +++++ 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/braket/pulse/ast/qasm_transformer.py b/src/braket/pulse/ast/qasm_transformer.py index f5e350883..6626bf6b5 100644 --- a/src/braket/pulse/ast/qasm_transformer.py +++ b/src/braket/pulse/ast/qasm_transformer.py @@ -57,3 +57,43 @@ def visit_ExpressionStatement(self, expression_statement: ast.ExpressionStatemen return new_val else: return expression_statement + + def visit_Program(self, program: ast.Program) -> ast.Program: + """Visit a Program. + Args: + program (Program): The program. + Returns: + Program: the modified program. + """ + new_statement_list = [] + for statement in program.statements: + if isinstance(statement, ast.CalibrationStatement): + input_vars, body = self.split_input_vars(statement.body) + new_statement_list.extend(input_vars) + new_statement_list.append(ast.CalibrationStatement(body)) + else: + new_statement_list.append(statement) + + program.statements = new_statement_list + return self.generic_visit(program) + + def split_input_vars( + self, + body: list[ast.Statement], + ) -> tuple[list[ast.IODeclaration], list[ast.Statement]]: + """Split input vars out of the calibrationStatement block + + Args: + body (list[Statement]): The list of statement in the CalibrationStatement block + Returns: + tuple[list[IODeclaration], list[Statement]]: A tuple of input vars and the list + of remaining statements. + """ + input_vars = [] + new_body = [] + for child in body: + if isinstance(child, ast.IODeclaration) and child.io_identifier is ast.IOKeyword.input: + input_vars.append(child) + else: + new_body.append(child) + return input_vars, new_body diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 4fac2bde0..059af35b5 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -19,7 +19,7 @@ from typing import Any, Union from openpulse import ast -from oqpy import BitVar, PhysicalQubits, Program +from oqpy import BitVar, FloatVar, PhysicalQubits, Program from oqpy.timing import OQDurationLiteral from braket.parametric.free_parameter import FreeParameter @@ -228,12 +228,20 @@ def play(self, frame: Frame, waveform: Waveform) -> PulseSequence: """ _validate_uniqueness(self._frames, frame) _validate_uniqueness(self._waveforms, waveform) - self._program.play(frame=frame, waveform=waveform) if isinstance(waveform, Parameterizable): for param in waveform.parameters: if isinstance(param, FreeParameterExpression): for p in param.expression.free_symbols: + self._program._add_var( + FloatVar( + name=p.name, + size=None, + init_expression="input", + needs_declaration=True, + ) + ) self._free_parameters.add(FreeParameter(p.name)) + self._program.play(frame=frame, waveform=waveform) self._frames[frame.id] = frame self._waveforms[waveform.id] = waveform return self @@ -280,6 +288,8 @@ def make_bound_pulse_sequence(self, param_values: dict[str, float]) -> PulseSequ new_pulse_sequence = PulseSequence() new_pulse_sequence._program = new_program + for param_name in param_values: + new_pulse_sequence._program.undeclared_vars.pop(param_name, None) new_pulse_sequence._frames = deepcopy(self._frames) new_pulse_sequence._waveforms = { wf.id: wf.bind_values(**param_values) if isinstance(wf, Parameterizable) else wf @@ -325,6 +335,11 @@ def _format_parameter_ast( ) -> Union[float, FreeParameterExpression]: if isinstance(parameter, FreeParameterExpression): for p in parameter.expression.free_symbols: + self._program._add_var( + FloatVar( + name=p.name, size=None, init_expression="input", needs_declaration=True + ) + ) self._free_parameters.add(FreeParameter(p.name)) return ( FreeParameterExpression(parameter, _type) diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index cbf57a985..555cac1f9 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -124,6 +124,13 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_unbound = "\n".join( [ "OPENQASM 3.0;", + "input float[64] b;", + "input float[64] a;", + "input float[64] length_g;", + "input float[64] sigma_g;", + "input float[64] length_dg;", + "input float[64] sigma_dg;", + "input float[64] length_c;", "cal {", " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s," @@ -168,6 +175,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_b_bound = "\n".join( [ "OPENQASM 3.0;", + "input float[64] a;", + "input float[64] sigma_g;", "cal {", " waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", From 7d7b24f93435f7efafb3fca8edbc4380cb342af8 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 15:38:11 -0500 Subject: [PATCH 22/43] use machine-size types --- src/braket/pulse/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/braket/pulse/__init__.py b/src/braket/pulse/__init__.py index 01ef66892..51414682d 100644 --- a/src/braket/pulse/__init__.py +++ b/src/braket/pulse/__init__.py @@ -11,6 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import oqpy + from braket.pulse.frame import Frame # noqa: F401 from braket.pulse.port import Port # noqa: F401 from braket.pulse.pulse_sequence import PulseSequence # noqa: F401 @@ -20,3 +22,8 @@ DragGaussianWaveform, GaussianWaveform, ) + +oqpy.AngleVar.default_size = None +oqpy.FloatVar.default_size = None +oqpy.IntVar.default_size = None +oqpy.UintVar.default_size = None From c21768ed4fe6fe626dd9feabcffe3c1fba615054 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 15:50:57 -0500 Subject: [PATCH 23/43] create _InputVarSplitter --- src/braket/pulse/ast/qasm_transformer.py | 28 +++++++++++-------- src/braket/pulse/pulse_sequence.py | 3 +- .../braket/pulse/test_pulse_sequence.py | 18 ++++++------ 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/braket/pulse/ast/qasm_transformer.py b/src/braket/pulse/ast/qasm_transformer.py index 6626bf6b5..904bebc67 100644 --- a/src/braket/pulse/ast/qasm_transformer.py +++ b/src/braket/pulse/ast/qasm_transformer.py @@ -58,6 +58,14 @@ def visit_ExpressionStatement(self, expression_statement: ast.ExpressionStatemen else: return expression_statement + +class _InputVarSplitter(QASMTransformer): + """ + QASMTransformer which walks the AST and makes the necessary modifications needed + for IR generation. Currently, it performs the following operations: + * Bubbles up input variables to the top of the CalibrationStatement block. + """ + def visit_Program(self, program: ast.Program) -> ast.Program: """Visit a Program. Args: @@ -68,9 +76,8 @@ def visit_Program(self, program: ast.Program) -> ast.Program: new_statement_list = [] for statement in program.statements: if isinstance(statement, ast.CalibrationStatement): - input_vars, body = self.split_input_vars(statement.body) - new_statement_list.extend(input_vars) - new_statement_list.append(ast.CalibrationStatement(body)) + reordered_cal_block_statements = self.split_input_vars(statement) + new_statement_list.extend(reordered_cal_block_statements) else: new_statement_list.append(statement) @@ -79,21 +86,20 @@ def visit_Program(self, program: ast.Program) -> ast.Program: def split_input_vars( self, - body: list[ast.Statement], - ) -> tuple[list[ast.IODeclaration], list[ast.Statement]]: - """Split input vars out of the calibrationStatement block + node: ast.CalibrationStatement, + ) -> list[ast.Statement]: + """Split input variables out of the calibrationStatement block. Args: - body (list[Statement]): The list of statement in the CalibrationStatement block + node (CalibrationStatement): The CalibrationStatement block. Returns: - tuple[list[IODeclaration], list[Statement]]: A tuple of input vars and the list - of remaining statements. + list[Statement]: The list of statements with input variables outside and in front. """ input_vars = [] new_body = [] - for child in body: + for child in node.body: if isinstance(child, ast.IODeclaration) and child.io_identifier is ast.IOKeyword.input: input_vars.append(child) else: new_body.append(child) - return input_vars, new_body + return input_vars + [ast.CalibrationStatement(new_body)] diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 059af35b5..7d820a014 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -28,7 +28,7 @@ from braket.pulse.ast.approximation_parser import _ApproximationParser from braket.pulse.ast.free_parameters import _FreeParameterTransformer from braket.pulse.ast.qasm_parser import ast_to_qasm -from braket.pulse.ast.qasm_transformer import _IRQASMTransformer +from braket.pulse.ast.qasm_transformer import _InputVarSplitter, _IRQASMTransformer from braket.pulse.frame import Frame from braket.pulse.pulse_sequence_trace import PulseSequenceTrace from braket.pulse.waveforms import Waveform @@ -326,6 +326,7 @@ def to_ir(self) -> str: tree = _IRQASMTransformer(register_identifier).visit(tree) else: tree = program.to_ast(encal=True, include_externs=False) + tree = _InputVarSplitter().visit(tree) return ast_to_qasm(tree) def _format_parameter_ast( diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 555cac1f9..13d835965 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -124,13 +124,13 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_unbound = "\n".join( [ "OPENQASM 3.0;", - "input float[64] b;", - "input float[64] a;", - "input float[64] length_g;", - "input float[64] sigma_g;", - "input float[64] length_dg;", - "input float[64] sigma_dg;", - "input float[64] length_c;", + "input float b;", + "input float a;", + "input float length_g;", + "input float sigma_g;", + "input float length_dg;", + "input float sigma_dg;", + "input float length_c;", "cal {", " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s," @@ -175,8 +175,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_b_bound = "\n".join( [ "OPENQASM 3.0;", - "input float[64] a;", - "input float[64] sigma_g;", + "input float a;", + "input float sigma_g;", "cal {", " waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", From e536d9fac495c195cee2d26cef82000257dda9e0 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 15:56:00 -0500 Subject: [PATCH 24/43] remove never visited branch --- src/braket/pulse/ast/qasm_transformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/braket/pulse/ast/qasm_transformer.py b/src/braket/pulse/ast/qasm_transformer.py index 904bebc67..74586fafe 100644 --- a/src/braket/pulse/ast/qasm_transformer.py +++ b/src/braket/pulse/ast/qasm_transformer.py @@ -78,8 +78,6 @@ def visit_Program(self, program: ast.Program) -> ast.Program: if isinstance(statement, ast.CalibrationStatement): reordered_cal_block_statements = self.split_input_vars(statement) new_statement_list.extend(reordered_cal_block_statements) - else: - new_statement_list.append(statement) program.statements = new_statement_list return self.generic_visit(program) From 4c9710a79cdf32e0e9c918fbc610354aafc73b05 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 16:32:46 -0500 Subject: [PATCH 25/43] fix partial coverage --- src/braket/pulse/ast/qasm_transformer.py | 11 ++++------- test/unit_tests/braket/pulse/test_pulse_sequence.py | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/braket/pulse/ast/qasm_transformer.py b/src/braket/pulse/ast/qasm_transformer.py index 74586fafe..a8cc9b9b3 100644 --- a/src/braket/pulse/ast/qasm_transformer.py +++ b/src/braket/pulse/ast/qasm_transformer.py @@ -73,13 +73,10 @@ def visit_Program(self, program: ast.Program) -> ast.Program: Returns: Program: the modified program. """ - new_statement_list = [] - for statement in program.statements: - if isinstance(statement, ast.CalibrationStatement): - reordered_cal_block_statements = self.split_input_vars(statement) - new_statement_list.extend(reordered_cal_block_statements) - - program.statements = new_statement_list + assert len(program.statements) == 1 and isinstance( + program.statements[0], ast.CalibrationStatement + ) + program.statements = self.split_input_vars(program.statements[0]) return self.generic_visit(program) def split_input_vars( diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 13d835965..37c4c61f2 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -124,8 +124,8 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_unbound = "\n".join( [ "OPENQASM 3.0;", - "input float b;", "input float a;", + "input float b;", "input float length_g;", "input float sigma_g;", "input float length_dg;", From 1bc972bc13b61ad7e63c20b4c97a25f058bf08de Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 17:02:30 -0500 Subject: [PATCH 26/43] hacking test because the set order changes with python version --- test/unit_tests/braket/pulse/test_pulse_sequence.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 37c4c61f2..88945aa9c 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. import pytest +from oqpy import FloatVar from braket.circuits import FreeParameter, QubitSet from braket.pulse import ( @@ -124,13 +125,11 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_unbound = "\n".join( [ "OPENQASM 3.0;", - "input float a;", - "input float b;", - "input float length_g;", - "input float sigma_g;", - "input float length_dg;", - "input float sigma_dg;", - "input float length_c;", + *[ + f"input float {var.name};" + for var in pulse_sequence._program.undeclared_vars.values() + if isinstance(var, FloatVar) + ], "cal {", " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s," From ab492d905a60cf615cea06e23cea5c6c1399cd99 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 17:47:19 -0500 Subject: [PATCH 27/43] pass inputs with PulseSequence --- src/braket/aws/aws_quantum_task.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index c490a4190..91815d0c8 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -569,7 +569,12 @@ def _( *args, **kwargs, ) -> AwsQuantumTask: - create_task_kwargs.update({"action": OpenQASMProgram(source=pulse_sequence.to_ir()).json()}) + openqasm_program = OpenQASMProgram( + source=pulse_sequence.to_ir(), + inputs=inputs if inputs else None, + ) + + create_task_kwargs.update({"action": openqasm_program.json()}) task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) From e7bdac805dbc6b755789b9a207fc2d0890e6d756 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 19 Dec 2023 18:53:30 -0500 Subject: [PATCH 28/43] pass empty dict to OpenQasmProgram --- src/braket/aws/aws_quantum_task.py | 2 +- test/unit_tests/braket/aws/test_aws_quantum_task.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index 91815d0c8..1b43694f7 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -571,7 +571,7 @@ def _( ) -> AwsQuantumTask: openqasm_program = OpenQASMProgram( source=pulse_sequence.to_ir(), - inputs=inputs if inputs else None, + inputs=inputs if inputs else {}, ) create_task_kwargs.update({"action": openqasm_program.json()}) diff --git a/test/unit_tests/braket/aws/test_aws_quantum_task.py b/test/unit_tests/braket/aws/test_aws_quantum_task.py index 656c37dcf..4d8daa36f 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_task.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_task.py @@ -617,7 +617,7 @@ def test_create_pulse_sequence(aws_session, arn, pulse_sequence): "}", ] ) - expected_program = OpenQASMProgram(source=expected_openqasm) + expected_program = OpenQASMProgram(source=expected_openqasm, inputs={}) aws_session.create_quantum_task.return_value = arn AwsQuantumTask.create(aws_session, SIMULATOR_ARN, pulse_sequence, S3_TARGET, 10) From 3b47b4311405b6fdf715315047179e5ed8253ed5 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 21 Dec 2023 16:15:31 +0000 Subject: [PATCH 29/43] prepare release v1.65.0 --- CHANGELOG.md | 6 ++++++ src/braket/_sdk/_version.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fc496e21..37503c397 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## v1.65.0 (2023-12-21) + +### Features + + * add U and GPhase gates + ## v1.64.2 (2023-12-19) ### Bug Fixes and Other Changes diff --git a/src/braket/_sdk/_version.py b/src/braket/_sdk/_version.py index a30db2260..d21a3b476 100644 --- a/src/braket/_sdk/_version.py +++ b/src/braket/_sdk/_version.py @@ -15,4 +15,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "1.64.3.dev0" +__version__ = "1.65.0" From 3cf9a080dc18f448af1df0cd54a8e5b33b612b25 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 21 Dec 2023 16:15:31 +0000 Subject: [PATCH 30/43] update development version to v1.65.1.dev0 --- src/braket/_sdk/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/braket/_sdk/_version.py b/src/braket/_sdk/_version.py index d21a3b476..ec7335e82 100644 --- a/src/braket/_sdk/_version.py +++ b/src/braket/_sdk/_version.py @@ -15,4 +15,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "1.65.0" +__version__ = "1.65.1.dev0" From d85699672cea2c8b88e4efc76e4bd43c0d554e8d Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula <99367153+jcjaskula-aws@users.noreply.github.com> Date: Fri, 22 Dec 2023 16:10:14 -0500 Subject: [PATCH 31/43] fix: validate out circuits that contain only non-zero-qubit gates (#842) * validate out all gphase circuits * update docstring * consider ctrl @ gphase * group error cases --- src/braket/circuits/circuit_helpers.py | 10 +++--- .../braket/circuits/test_circuit_helpers.py | 31 +++++++++++++++---- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/braket/circuits/circuit_helpers.py b/src/braket/circuits/circuit_helpers.py index 44e9bd571..f0e3f3144 100644 --- a/src/braket/circuits/circuit_helpers.py +++ b/src/braket/circuits/circuit_helpers.py @@ -23,13 +23,15 @@ def validate_circuit_and_shots(circuit: Circuit, shots: int) -> None: shots (int): shots to validate Raises: - ValueError: If circuit has no instructions; if no result types - specified for circuit and `shots=0`. See `braket.circuit.result_types`; + ValueError: If circuit has no instructions; if circuit has a non-gphase instruction; if no + result types specified for circuit and `shots=0`. See `braket.circuit.result_types`; if circuit has observables that cannot be simultaneously measured and `shots>0`; or, if `StateVector` or `Amplitude` are specified as result types when `shots>0`. """ - if not circuit.instructions: - raise ValueError("Circuit must have instructions to run on a device") + if not circuit.instructions or all( + not (inst.target or inst.control) for inst in circuit.instructions + ): + raise ValueError("Circuit must have at least one non-zero-qubit gate to run on a device") if not shots and not circuit.result_types: raise ValueError( "No result types specified for circuit and shots=0. See `braket.circuits.result_types`" diff --git a/test/unit_tests/braket/circuits/test_circuit_helpers.py b/test/unit_tests/braket/circuits/test_circuit_helpers.py index 56f7fd2ec..8960325cb 100644 --- a/test/unit_tests/braket/circuits/test_circuit_helpers.py +++ b/test/unit_tests/braket/circuits/test_circuit_helpers.py @@ -18,17 +18,32 @@ def test_validate_circuit_and_shots_no_instructions(): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Circuit must have at least one non-zero-qubit gate to run on a device" + ): validate_circuit_and_shots(Circuit(), 100) +def test_validate_circuit_and_shots_only_gphase(): + with pytest.raises( + ValueError, match="Circuit must have at least one non-zero-qubit gate to run on a device" + ): + validate_circuit_and_shots(Circuit().gphase(0.15), 100) + + +def test_validate_circuit_and_shots_ctrl_gphase(): + assert validate_circuit_and_shots(Circuit().gphase(0.15, control=[0]), 100) is None + + def test_validate_circuit_and_shots_0_no_instructions(): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Circuit must have at least one non-zero-qubit gate to run on a device" + ): validate_circuit_and_shots(Circuit(), 0) def test_validate_circuit_and_shots_0_no_results(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No result types specified for circuit and shots=0."): validate_circuit_and_shots(Circuit().h(0), 0) @@ -54,12 +69,16 @@ def test_validate_circuit_and_shots_100_results_mixed_result(): def test_validate_circuit_and_shots_100_result_state_vector(): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="StateVector or Amplitude cannot be specified when shots>0" + ): validate_circuit_and_shots(Circuit().h(0).state_vector(), 100) def test_validate_circuit_and_shots_100_result_amplitude(): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="StateVector or Amplitude cannot be specified when shots>0" + ): validate_circuit_and_shots(Circuit().h(0).amplitude(state=["0"]), 100) @@ -74,7 +93,7 @@ def test_validate_circuit_and_shots_0_noncommuting(): def test_validate_circuit_and_shots_100_noncommuting(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Observables cannot be sampled simultaneously"): validate_circuit_and_shots( Circuit() .h(0) From e96fad8e6f74553f5e39743909966fa7b213be25 Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 25 Dec 2023 16:16:49 +0000 Subject: [PATCH 32/43] prepare release v1.65.1 --- CHANGELOG.md | 6 ++++++ src/braket/_sdk/_version.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37503c397..47d87d33a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## v1.65.1 (2023-12-25) + +### Bug Fixes and Other Changes + + * validate out circuits that contain only non-zero-qubit gates + ## v1.65.0 (2023-12-21) ### Features diff --git a/src/braket/_sdk/_version.py b/src/braket/_sdk/_version.py index ec7335e82..ad26cf05f 100644 --- a/src/braket/_sdk/_version.py +++ b/src/braket/_sdk/_version.py @@ -15,4 +15,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "1.65.1.dev0" +__version__ = "1.65.1" From eb4d1be6813fec41750bf16fab5db941cbf07ac7 Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 25 Dec 2023 16:16:49 +0000 Subject: [PATCH 33/43] update development version to v1.65.2.dev0 --- src/braket/_sdk/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/braket/_sdk/_version.py b/src/braket/_sdk/_version.py index ad26cf05f..278494287 100644 --- a/src/braket/_sdk/_version.py +++ b/src/braket/_sdk/_version.py @@ -15,4 +15,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "1.65.1" +__version__ = "1.65.2.dev0" From 6874b4090d5cfdbf49aa850ef84fd1ce2e95a5a6 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Fri, 5 Jan 2024 23:34:06 -0500 Subject: [PATCH 34/43] remove workaround --- src/braket/circuits/circuit.py | 82 ---------------------------------- 1 file changed, 82 deletions(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index e5a9d0c71..798619a0c 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1288,16 +1288,6 @@ def _generate_frame_wf_defcal_declarations( if gate_definitions is not None: for key, calibration in gate_definitions.items(): gate, qubits = key - - # Ignoring parametric gates - # Corresponding defcals with fixed arguments have been added - # in _get_frames_waveforms_from_instrs - if isinstance(gate, Parameterizable) and any( - not isinstance(parameter, (float, int, complex)) - for parameter in gate.parameters - ): - continue - gate_name = gate._qasm_name arguments = ( [calibration._format_parameter_ast(value) for value in gate.parameters] @@ -1329,80 +1319,8 @@ def _get_frames_waveforms_from_instrs( for waveform in instruction.operator.pulse_sequence._waveforms.values(): _validate_uniqueness(waveforms, waveform) waveforms[waveform.id] = waveform - # this will change with full parametric calibration support - elif isinstance(instruction.operator, Parameterizable) and gate_definitions is not None: - fixed_argument_calibrations = self._add_fixed_argument_calibrations( - gate_definitions, instruction - ) - gate_definitions.update(fixed_argument_calibrations) return frames, waveforms - def _add_fixed_argument_calibrations( - self, - gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], - instruction: Instruction, - ) -> dict[tuple[Gate, QubitSet], PulseSequence]: - """Adds calibrations with arguments set to the instruction parameter values - - Given the collection of parameters in instruction.operator, this function looks for matching - parametric calibrations that have free parameters. If such a calibration is found and the - number N of its free parameters equals the number of instruction parameters, we can bind - the arguments of the calibration and add it to the calibration dictionary. - - If N is smaller, it is probably impossible to assign the instruction parameter values to the - corresponding calibration parameters so we raise an error. - If N=0, we ignore it as it will not be removed by _generate_frame_wf_defcal_declarations. - - Args: - gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence]): a dictionary of - calibrations - instruction (Instruction): a Circuit instruction - - Returns: - dict[tuple[Gate, QubitSet], PulseSequence]: additional calibrations - - Raises: - NotImplementedError: in two cases: (i) if the instruction contains unbound parameters - and the calibration dictionary contains a parametric calibration applicable to this - instructions; (ii) if the calibration is defined with a partial number of unbound - parameters. - """ - additional_calibrations = {} - for key, calibration in gate_definitions.items(): - gate = key[0] - target = key[1] - if target != instruction.target: - continue - if isinstance(gate, type(instruction.operator)) and len( - instruction.operator.parameters - ) == len(gate.parameters): - free_parameter_number = sum( - [isinstance(p, FreeParameterExpression) for p in gate.parameters] - ) - if free_parameter_number == 0: - continue - elif free_parameter_number < len(gate.parameters): - raise NotImplementedError( - "Calibrations with a partial number of fixed parameters are not supported." - ) - elif any( - isinstance(p, FreeParameterExpression) for p in instruction.operator.parameters - ): - raise NotImplementedError( - "Parametric calibrations cannot be attached with parametric circuits." - ) - bound_key = ( - type(instruction.operator)(*instruction.operator.parameters), - instruction.target, - ) - additional_calibrations[bound_key] = calibration( - **{ - p.name if isinstance(p, FreeParameterExpression) else p: v - for p, v in zip(gate.parameters, instruction.operator.parameters) - } - ) - return additional_calibrations - def to_unitary(self) -> np.ndarray: """ Returns the unitary matrix representation of the entire circuit. From ea106d9d553280a57a99614c42648317cb924fec Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 9 Jan 2024 17:38:16 -0500 Subject: [PATCH 35/43] fix tests --- src/braket/circuits/circuit.py | 7 +- src/braket/pulse/pulse_sequence.py | 11 +-- .../braket/circuits/test_circuit.py | 74 ++++++++++++++++--- .../braket/pulse/test_pulse_sequence.py | 24 ++---- 4 files changed, 82 insertions(+), 34 deletions(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index 798619a0c..ed4a19e83 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1290,7 +1290,12 @@ def _generate_frame_wf_defcal_declarations( gate, qubits = key gate_name = gate._qasm_name arguments = ( - [calibration._format_parameter_ast(value) for value in gate.parameters] + [ + oqpy.FloatVar(name=value.name) + if isinstance(value, FreeParameter) + else value + for value in gate.parameters + ] if isinstance(gate, Parameterizable) else None ) diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 7d820a014..a243daa89 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -317,6 +317,12 @@ def to_ir(self) -> str: str: a str representing the OpenPulse program encoding the PulseSequence. """ program = deepcopy(self._program) + for param in self._free_parameters: + program._add_var( + FloatVar( + name=param.name, size=None, init_expression="input", needs_declaration=True + ) + ) if self._capture_v0_count: register_identifier = "psb" program.declare( @@ -336,11 +342,6 @@ def _format_parameter_ast( ) -> Union[float, FreeParameterExpression]: if isinstance(parameter, FreeParameterExpression): for p in parameter.expression.free_symbols: - self._program._add_var( - FloatVar( - name=p.name, size=None, init_expression="input", needs_declaration=True - ) - ) self._free_parameters.add(FreeParameter(p.name)) return ( FreeParameterExpression(parameter, _type) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 626b4d4cf..73074e3f4 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -746,10 +746,16 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) q[0];", "rx(0.3) q[1];", "b[0] = measure q[0];", @@ -775,10 +781,16 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) $0;", "rx(0.3) $4;", "b[0] = measure $0;", @@ -806,10 +818,16 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) $0;", "#pragma braket verbatim", "box{", @@ -841,10 +859,16 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) q[0];", "rx(0.3) q[4];", "#pragma braket noise bit_flip(0.2) q[3]", @@ -872,10 +896,16 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) q[0];", "rx(theta) q[1];", "b[0] = measure q[0];", @@ -905,10 +935,16 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "negctrl @ rx(0.15) q[2], q[0];", "ctrl(2) @ rx(0.3) q[2], q[3], q[1];", "ctrl(2) @ cnot q[2], q[3], q[4], q[0];", @@ -939,6 +975,16 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal rx(float theta) $0 {", + " set_frequency(predefined_frame_1, 6000000.0);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "cnot q[0], q[1];", "cnot q[3], q[2];", "ctrl @ cnot q[5], q[6], q[4];", @@ -971,10 +1017,14 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal ms(-0.1, -0.2, -0.3) $0, $1 {", - " shift_phase(predefined_frame_1, -0.1);", - " set_phase(predefined_frame_1, -0.3);", - " shift_phase(predefined_frame_1, -0.2);", + "defcal rx(float theta) $0 {", + " set_frequency(predefined_frame_1, 6000000.0);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", " play(predefined_frame_1, drag_gauss_wf);", "}", "inv @ pow(2.5) @ h q[0];", @@ -1132,10 +1182,10 @@ def foo( "cal {", " waveform drag_gauss_wf = drag_gaussian" + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", - "defcal foo(-0.2) $0 {", + "defcal foo(float beta) $0 {", " shift_phase(predefined_frame_1, -0.1);", " set_phase(predefined_frame_1, -0.3);", - " shift_phase(predefined_frame_1, -0.2);", + " shift_phase(predefined_frame_1, beta);", " play(predefined_frame_1, drag_gauss_wf);", "}", "foo(-0.2) q[0];", diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 02f28e93f..8de93ce1a 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. import pytest -from oqpy import FloatVar from braket.circuits import FreeParameter, QubitSet from braket.pulse import ( @@ -125,22 +124,17 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_unbound = "\n".join( [ "OPENQASM 3.0;", - *[ - f"input float {var.name};" - for var in pulse_sequence._program.undeclared_vars.values() - if isinstance(var, FloatVar) - ], + "input float length_g;", + "input float sigma_g;", + "input float length_dg;", + "input float sigma_dg;", + "input float length_c;", + "input float b;", + "input float a;", "cal {", - " input float[64] b;", - " input float[64] a;", - " input float[64] length_g;", - " input float[64] sigma_g;", " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", - " input float[64] length_dg;", - " input float[64] sigma_dg;", " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s," " (sigma_dg) * 1s, 0.2, 1, false);", - " input float[64] length_c;", " waveform constant_wf = constant((length_c) * 1s, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", @@ -181,11 +175,9 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_b_bound = "\n".join( [ "OPENQASM 3.0;", - "input float a;", "input float sigma_g;", + "input float a;", "cal {", - " input float[64] a;", - " input float[64] sigma_g;", " waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", From 612f0db1b57b00e2c524c667eecdd0330d7e875c Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Fri, 26 Jan 2024 22:28:01 -0500 Subject: [PATCH 36/43] Merge branch 'jcjaskula-aws/fix_oqpy_upgrade' into jcjaskula-aws/clean_parametric_defcal_code --- CHANGELOG.md | 18 ++++ setup.py | 3 +- src/braket/_sdk/_version.py | 2 +- src/braket/aws/aws_quantum_job.py | 21 +++-- src/braket/aws/aws_quantum_task.py | 14 +++ src/braket/circuits/circuit.py | 11 +-- src/braket/jobs/hybrid_job.py | 7 +- src/braket/jobs/logs.py | 13 ++- src/braket/jobs/quantum_job_creation.py | 42 +++++++-- .../parametric/free_parameter_expression.py | 88 ++++++++----------- src/braket/pulse/__init__.py | 7 -- src/braket/pulse/ast/free_parameters.py | 68 +++++++++----- src/braket/pulse/ast/qasm_parser.py | 15 ---- src/braket/pulse/pulse_sequence.py | 53 ++++------- src/braket/pulse/waveforms.py | 26 ++---- test/integ_tests/test_create_quantum_job.py | 52 +++++++---- .../braket/aws/test_aws_quantum_job.py | 64 +++++++++++++- .../braket/aws/test_aws_quantum_task.py | 59 +++++++++++++ .../braket/circuits/test_circuit.py | 4 +- test/unit_tests/braket/circuits/test_gates.py | 4 +- .../braket/jobs/test_quantum_job_creation.py | 18 ++-- .../test_free_parameter_expression.py | 1 + .../braket/pulse/test_pulse_sequence.py | 54 ++++++------ .../unit_tests/braket/pulse/test_waveforms.py | 25 ++++-- tox.ini | 1 - 25 files changed, 427 insertions(+), 243 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47d87d33a..7bafd45c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## v1.68.0 (2024-01-25) + +### Features + + * update S3 locations for jobs + +## v1.67.0 (2024-01-23) + +### Features + + * add queue position to the logs for tasks and jobs + +## v1.66.0 (2024-01-11) + +### Features + + * update job name to use metadata + ## v1.65.1 (2023-12-25) ### Bug Fixes and Other Changes diff --git a/setup.py b/setup.py index 328f0a1bf..1e85974d5 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ install_requires=[ "amazon-braket-schemas>=1.19.1", "amazon-braket-default-simulator>=1.19.1", - "oqpy~=0.3.3", + "oqpy~=0.3.5", "setuptools", "backoff", "boltons", @@ -47,7 +47,6 @@ "black", "botocore", "flake8<=5.0.4", - "flake8-rst-docstrings", "isort", "jsonschema==3.2.0", "pre-commit", diff --git a/src/braket/_sdk/_version.py b/src/braket/_sdk/_version.py index 278494287..051177529 100644 --- a/src/braket/_sdk/_version.py +++ b/src/braket/_sdk/_version.py @@ -15,4 +15,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "1.65.2.dev0" +__version__ = "1.68.1.dev0" diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py index 1e929e857..562c61ba8 100644 --- a/src/braket/aws/aws_quantum_job.py +++ b/src/braket/aws/aws_quantum_job.py @@ -81,6 +81,7 @@ def create( aws_session: AwsSession | None = None, tags: dict[str, str] | None = None, logger: Logger = getLogger(__name__), + quiet: bool = False, reservation_arn: str | None = None, ) -> AwsQuantumJob: """Creates a hybrid job by invoking the Braket CreateJob API. @@ -176,6 +177,9 @@ def create( while waiting for quantum task to be in a terminal state. Default is `getLogger(__name__)` + quiet (bool): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. + reservation_arn (str | None): the reservation window arn provided by Braket Direct to reserve exclusive usage for the device to run the hybrid job on. Default: None. @@ -210,7 +214,7 @@ def create( ) job_arn = aws_session.create_job(**create_job_kwargs) - job = AwsQuantumJob(job_arn, aws_session) + job = AwsQuantumJob(job_arn, aws_session, quiet) if wait_until_complete: print(f"Initializing Braket Job: {job_arn}") @@ -218,15 +222,18 @@ def create( return job - def __init__(self, arn: str, aws_session: AwsSession | None = None): + def __init__(self, arn: str, aws_session: AwsSession | None = None, quiet: bool = False): """ Args: arn (str): The ARN of the hybrid job. aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services. Default is `None`, in which case an `AwsSession` object will be created with the region of the hybrid job. + quiet (bool): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. """ self._arn: str = arn + self._quiet = quiet if aws_session: if not self._is_valid_aws_session_region_for_job_arn(aws_session, arn): raise ValueError( @@ -268,7 +275,7 @@ def arn(self) -> str: @property def name(self) -> str: """str: The name of the quantum job.""" - return self._arn.partition("job/")[-1] + return self.metadata(use_cached_value=True).get("jobName") def state(self, use_cached_value: bool = False) -> str: """The state of the quantum hybrid job. @@ -371,10 +378,11 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"] has_streams = False color_wrap = logs.ColorWrap() + previous_state = self.state() while True: time.sleep(poll_interval_seconds) - + current_state = self.state() has_streams = logs.flush_log_streams( self._aws_session, log_group, @@ -384,14 +392,17 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: instance_count, has_streams, color_wrap, + [previous_state, current_state], + self.queue_position().queue_position if not self._quiet else None, ) + previous_state = current_state if log_state == AwsQuantumJob.LogState.COMPLETE: break if log_state == AwsQuantumJob.LogState.JOB_COMPLETE: log_state = AwsQuantumJob.LogState.COMPLETE - elif self.state() in AwsQuantumJob.TERMINAL_STATES: + elif current_state in AwsQuantumJob.TERMINAL_STATES: log_state = AwsQuantumJob.LogState.JOB_COMPLETE def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index 1b43694f7..c53f5cda5 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -105,6 +105,7 @@ def create( tags: dict[str, str] | None = None, inputs: dict[str, float] | None = None, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None, + quiet: bool = False, reservation_arn: str | None = None, *args, **kwargs, @@ -152,6 +153,9 @@ def create( a `PulseSequence`. Default: None. + quiet (bool): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. + reservation_arn (str | None): The reservation ARN provided by Braket Direct to reserve exclusive usage for the device to run the quantum task on. Note: If you are creating tasks in a job that itself was created reservation ARN, @@ -215,6 +219,7 @@ def create( disable_qubit_rewiring, inputs, gate_definitions=gate_definitions, + quiet=quiet, *args, **kwargs, ) @@ -226,6 +231,7 @@ def __init__( poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL, logger: Logger = getLogger(__name__), + quiet: bool = False, ): """ Args: @@ -238,6 +244,8 @@ def __init__( logger (Logger): Logger object with which to write logs, such as quantum task statuses while waiting for quantum task to be in a terminal state. Default is `getLogger(__name__)` + quiet (bool): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. Examples: >>> task = AwsQuantumTask(arn='task_arn') @@ -259,6 +267,7 @@ def __init__( self._poll_interval_seconds = poll_interval_seconds self._logger = logger + self._quiet = quiet self._metadata: dict[str, Any] = {} self._result: Union[ @@ -477,6 +486,11 @@ async def _wait_for_completion( while (time.time() - start_time) < self._poll_timeout_seconds: # Used cached metadata if cached status is terminal task_status = self._update_status_if_nonterminal() + if not self._quiet and task_status == "QUEUED": + queue = self.queue_position() + self._logger.debug( + f"Task is in {queue.queue_type} queue position: {queue.queue_position}" + ) self._logger.debug(f"Task {self._arn}: task status {task_status}") if task_status in AwsQuantumTask.RESULTS_READY_STATES: return self._download_result() diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index ed4a19e83..a1015f68d 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1289,16 +1289,7 @@ def _generate_frame_wf_defcal_declarations( for key, calibration in gate_definitions.items(): gate, qubits = key gate_name = gate._qasm_name - arguments = ( - [ - oqpy.FloatVar(name=value.name) - if isinstance(value, FreeParameter) - else value - for value in gate.parameters - ] - if isinstance(gate, Parameterizable) - else None - ) + arguments = gate.parameters if isinstance(gate, Parameterizable) else None with oqpy.defcal( program, [oqpy.PhysicalQubits[int(k)] for k in qubits], gate_name, arguments ): diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index 707f18fd5..b8e1e58bf 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -63,6 +63,7 @@ def hybrid_job( aws_session: AwsSession | None = None, tags: dict[str, str] | None = None, logger: Logger = getLogger(__name__), + quiet: bool | None = None, reservation_arn: str | None = None, ) -> Callable: """Defines a hybrid job by decorating the entry point function. The job will be created @@ -71,7 +72,7 @@ def hybrid_job( The job created will be a `LocalQuantumJob` when `local` is set to `True`, otherwise an `AwsQuantumJob`. The following parameters will be ignored when running a job with `local` set to `True`: `wait_until_complete`, `instance_config`, `distribution`, - `copy_checkpoints_from_job`, `stopping_condition`, `tags`, and `logger`. + `copy_checkpoints_from_job`, `stopping_condition`, `tags`, `logger`, and `quiet`. Args: device (str | None): Device ARN of the QPU device that receives priority quantum @@ -153,6 +154,9 @@ def hybrid_job( logger (Logger): Logger object with which to write logs, such as task statuses while waiting for task to be in a terminal state. Default: `getLogger(__name__)` + quiet (bool | None): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. + reservation_arn (str | None): the reservation window arn provided by Braket Direct to reserve exclusive usage for the device to run the hybrid job on. Default: None. @@ -210,6 +214,7 @@ def job_wrapper(*args, **kwargs) -> Callable: "output_data_config": output_data_config, "aws_session": aws_session, "tags": tags, + "quiet": quiet, "reservation_arn": reservation_arn, } for key, value in optional_args.items(): diff --git a/src/braket/jobs/logs.py b/src/braket/jobs/logs.py index e0f54458d..734d51123 100644 --- a/src/braket/jobs/logs.py +++ b/src/braket/jobs/logs.py @@ -20,7 +20,7 @@ # Support for reading logs # ############################################################################## -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from botocore.exceptions import ClientError @@ -155,7 +155,7 @@ def log_stream( yield ev -def flush_log_streams( +def flush_log_streams( # noqa C901 aws_session: AwsSession, log_group: str, stream_prefix: str, @@ -164,6 +164,8 @@ def flush_log_streams( stream_count: int, has_streams: bool, color_wrap: ColorWrap, + state: list[str], + queue_position: Optional[str] = None, ) -> bool: """Flushes log streams to stdout. @@ -183,6 +185,9 @@ def flush_log_streams( been found. This value is possibly updated and returned at the end of execution. color_wrap (ColorWrap): An instance of ColorWrap to potentially color-wrap print statements from different streams. + state (list[str]): The previous and current state of the job. + queue_position (Optional[str]): The current queue position. This is not passed in if the job + is ran with `quiet=True` Returns: bool: Returns 'True' if any streams have been flushed. @@ -225,6 +230,10 @@ def flush_log_streams( positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1) else: positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1) + elif queue_position is not None and state[1] == "QUEUED": + print(f"Job queue position: {queue_position}", end="\n", flush=True) + elif state[0] != state[1] and state[1] == "RUNNING" and queue_position is not None: + print("Running:", end="\n", flush=True) else: print(".", end="", flush=True) return has_streams diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py index 657ed0829..9e18faeab 100644 --- a/src/braket/jobs/quantum_job_creation.py +++ b/src/braket/jobs/quantum_job_creation.py @@ -161,14 +161,15 @@ def prepare_quantum_job( _validate_params(param_datatype_map) aws_session = aws_session or AwsSession() device_config = DeviceConfig(device) - job_name = job_name or _generate_default_job_name(image_uri=image_uri) + timestamp = str(int(time.time() * 1000)) + job_name = job_name or _generate_default_job_name(image_uri=image_uri, timestamp=timestamp) role_arn = role_arn or os.getenv("BRAKET_JOBS_ROLE_ARN", aws_session.get_default_jobs_role()) hyperparameters = hyperparameters or {} hyperparameters = {str(key): str(value) for key, value in hyperparameters.items()} input_data = input_data or {} tags = tags or {} default_bucket = aws_session.default_bucket() - input_data_list = _process_input_data(input_data, job_name, aws_session) + input_data_list = _process_input_data(input_data, job_name, aws_session, timestamp) instance_config = instance_config or InstanceConfig() stopping_condition = stopping_condition or StoppingCondition() output_data_config = output_data_config or OutputDataConfig() @@ -177,6 +178,7 @@ def prepare_quantum_job( default_bucket, "jobs", job_name, + timestamp, "script", ) @@ -201,6 +203,7 @@ def prepare_quantum_job( default_bucket, "jobs", job_name, + timestamp, "data", ) if not checkpoint_config.s3Uri: @@ -208,6 +211,7 @@ def prepare_quantum_job( default_bucket, "jobs", job_name, + timestamp, "checkpoints", ) if copy_checkpoints_from_job: @@ -251,19 +255,22 @@ def prepare_quantum_job( return create_job_kwargs -def _generate_default_job_name(image_uri: str | None = None, func: Callable | None = None) -> str: +def _generate_default_job_name( + image_uri: str | None = None, func: Callable | None = None, timestamp: int | str | None = None +) -> str: """ Generate default job name using the image uri and entrypoint function. Args: image_uri (str | None): URI for the image container. func (Callable | None): The entry point function. + timestamp (int | str | None): Optional timestamp to use instead of generating one. Returns: str: Hybrid job name. """ max_length = 50 - timestamp = str(int(time.time() * 1000)) + timestamp = timestamp if timestamp is not None else str(int(time.time() * 1000)) if func: name = func.__name__.replace("_", "-") @@ -395,7 +402,10 @@ def _validate_params(dict_arr: dict[str, tuple[any, any]]) -> None: def _process_input_data( - input_data: str | dict | S3DataSourceConfig, job_name: str, aws_session: AwsSession + input_data: str | dict | S3DataSourceConfig, + job_name: str, + aws_session: AwsSession, + subdirectory: str, ) -> list[dict[str, Any]]: """ Convert input data into a list of dicts compatible with the Braket API. @@ -405,6 +415,7 @@ def _process_input_data( can be an S3DataSourceConfig or a str corresponding to a local prefix or S3 prefix. job_name (str): Hybrid job name. aws_session (AwsSession): AwsSession for possibly uploading local data. + subdirectory (str): Subdirectory within job name for S3 locations. Returns: list[dict[str, Any]]: A list of channel configs. @@ -413,12 +424,18 @@ def _process_input_data( input_data = {"input": input_data} for channel_name, data in input_data.items(): if not isinstance(data, S3DataSourceConfig): - input_data[channel_name] = _process_channel(data, job_name, aws_session, channel_name) + input_data[channel_name] = _process_channel( + data, job_name, aws_session, channel_name, subdirectory + ) return _convert_input_to_config(input_data) def _process_channel( - location: str, job_name: str, aws_session: AwsSession, channel_name: str + location: str, + job_name: str, + aws_session: AwsSession, + channel_name: str, + subdirectory: str, ) -> S3DataSourceConfig: """ Convert a location to an S3DataSourceConfig, uploading local data to S3, if necessary. @@ -427,6 +444,7 @@ def _process_channel( job_name (str): Hybrid job name. aws_session (AwsSession): AwsSession to be used for uploading local data. channel_name (str): Name of the channel. + subdirectory (str): Subdirectory within job name for S3 locations. Returns: S3DataSourceConfig: S3DataSourceConfig for the channel. @@ -435,10 +453,16 @@ def _process_channel( return S3DataSourceConfig(location) else: # local prefix "path/to/prefix" will be mapped to - # s3://bucket/jobs/job-name/data/input/prefix + # s3://bucket/jobs/job-name/subdirectory/data/input/prefix location_name = Path(location).name s3_prefix = AwsSession.construct_s3_uri( - aws_session.default_bucket(), "jobs", job_name, "data", channel_name, location_name + aws_session.default_bucket(), + "jobs", + job_name, + subdirectory, + "data", + channel_name, + location_name, ) aws_session.upload_local_data(location, s3_prefix) return S3DataSourceConfig(s3_prefix) diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index 71292d306..98916fbf5 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -14,20 +14,14 @@ from __future__ import annotations import ast +import operator +from functools import reduce from numbers import Number from typing import Any, Union -from openpulse.ast import ( - ClassicalType, - DurationLiteral, - DurationType, - Expression, - FloatType, - Identifier, - TimeUnit, -) -from oqpy import Program -from sympy import Expr, Float, Symbol, sympify +import sympy +from oqpy.base import OQPyExpression +from oqpy.classical_types import FloatVar class FreeParameterExpression: @@ -40,11 +34,7 @@ class FreeParameterExpression: present will NOT run. Values must be substituted prior to execution. """ - def __init__( - self, - expression: Union[FreeParameterExpression, Number, Expr, str], - _type: ClassicalType | None = None, - ): + def __init__(self, expression: Union[FreeParameterExpression, Number, sympy.Expr, str]): """ Initializes a FreeParameterExpression. Best practice is to initialize using FreeParameters and Numbers. Not meant to be initialized directly. @@ -53,10 +43,6 @@ def __init__( Args: expression (Union[FreeParameterExpression, Number, Expr, str]): The expression to use. - _type (ClassicalType | None): The OpenQASM3 type associated with the expression. - Subtypes of openqasm3.ast.ClassicalType are used to specify how to express the - expression in the OpenQASM3 IR. Any type other than DurationType is considered - as FloatType. Examples: >>> expression_1 = FreeParameter("theta") * FreeParameter("alpha") @@ -69,21 +55,17 @@ def __init__( ast.Pow: self.__pow__, ast.USub: self.__neg__, } - self._type = _type if _type is not None else FloatType() if isinstance(expression, FreeParameterExpression): self._expression = expression.expression - if _type is None: - self._type = expression._type - elif isinstance(expression, (Number, Expr)): + elif isinstance(expression, (Number, sympy.Expr)): self._expression = expression elif isinstance(expression, str): self._expression = self._parse_string_expression(expression).expression else: raise NotImplementedError - self._validate_type() @property - def expression(self) -> Union[Number, Expr]: + def expression(self) -> Union[Number, sympy.Expr]: """Gets the expression. Returns: Union[Number, Expr]: The expression for the FreeParameterExpression. @@ -92,7 +74,7 @@ def expression(self) -> Union[Number, Expr]: def subs( self, parameter_values: dict[str, Number] - ) -> Union[FreeParameterExpression, Number, Expr]: + ) -> Union[FreeParameterExpression, Number, sympy.Expr]: """ Similar to a substitution in Sympy. Parameters are swapped for corresponding values or expressions from the dictionary. @@ -107,7 +89,7 @@ def subs( """ new_parameter_values = dict() for key, val in parameter_values.items(): - if isinstance(key, FreeParameterExpression): + if issubclass(type(key), FreeParameterExpression): new_parameter_values[key.expression] = val else: new_parameter_values[key] = val @@ -118,13 +100,6 @@ def subs( else: return FreeParameterExpression(subbed_expr) - def _validate_type(self) -> None: - if not isinstance(self._type, (FloatType, DurationType)): - raise TypeError( - "FreeParameterExpression must be of type openqasm3.ast.FloatType " - "or openqasm3.ast.DurationType" - ) - def _parse_string_expression(self, expression: str) -> FreeParameterExpression: return self._eval_operation(ast.parse(expression, mode="eval").body) @@ -132,7 +107,7 @@ def _eval_operation(self, node: Any) -> FreeParameterExpression: if isinstance(node, ast.Num): return FreeParameterExpression(node.n) elif isinstance(node, ast.Name): - return FreeParameterExpression(Symbol(node.id)) + return FreeParameterExpression(sympy.Symbol(node.id)) elif isinstance(node, ast.BinOp): if type(node.op) not in self._operations.keys(): raise ValueError(f"Unsupported binary operation: {type(node.op)}") @@ -147,7 +122,7 @@ def _eval_operation(self, node: Any) -> FreeParameterExpression: raise ValueError(f"Unsupported string detected: {node}") def __add__(self, other): - if isinstance(other, FreeParameterExpression): + if issubclass(type(other), FreeParameterExpression): return FreeParameterExpression(self.expression + other.expression) else: return FreeParameterExpression(self.expression + other) @@ -156,7 +131,7 @@ def __radd__(self, other): return FreeParameterExpression(other + self.expression) def __sub__(self, other): - if isinstance(other, FreeParameterExpression): + if issubclass(type(other), FreeParameterExpression): return FreeParameterExpression(self.expression - other.expression) else: return FreeParameterExpression(self.expression - other) @@ -165,7 +140,7 @@ def __rsub__(self, other): return FreeParameterExpression(other - self.expression) def __mul__(self, other): - if isinstance(other, FreeParameterExpression): + if issubclass(type(other), FreeParameterExpression): return FreeParameterExpression(self.expression * other.expression) else: return FreeParameterExpression(self.expression * other) @@ -174,7 +149,7 @@ def __rmul__(self, other): return FreeParameterExpression(other * self.expression) def __pow__(self, other, modulo=None): - if isinstance(other, FreeParameterExpression): + if issubclass(type(other), FreeParameterExpression): return FreeParameterExpression(self.expression**other.expression) else: return FreeParameterExpression(self.expression**other) @@ -187,7 +162,7 @@ def __neg__(self): def __eq__(self, other): if isinstance(other, FreeParameterExpression): - return sympify(self.expression).equals(sympify(other.expression)) + return sympy.sympify(self.expression).equals(sympy.sympify(other.expression)) return False def __repr__(self) -> str: @@ -199,20 +174,27 @@ def __repr__(self) -> str: """ return repr(self.expression) - def to_ast(self, program: Program) -> Expression: - """Creates an AST node for the :class:'FreeParameterExpression'. - - Args: - program (Program): Unused. + def _to_oqpy_expression(self) -> OQPyExpression: + """Transforms into an OQPyExpression. Returns: - Expression: The AST node. + OQPyExpression: The AST node. """ - # TODO (#822): capture expressions into expression ASTs rather than just an Identifier - identifier = Identifier(name=self) - if isinstance(self._type, DurationType): - return DurationLiteral(identifier, TimeUnit.s) - return identifier + ops = {sympy.Add: operator.add, sympy.Mul: operator.mul, sympy.Pow: operator.pow} + if isinstance(self.expression, tuple(ops)): + return reduce( + ops[type(self.expression)], + map( + lambda x: FreeParameterExpression(x)._to_oqpy_expression(), self.expression.args + ), + ) + elif isinstance(self.expression, sympy.Number): + return float(self.expression) + else: + fvar = FloatVar(name=self.expression.name, init_expression="input") + fvar.size = None + fvar.type.size = None + return fvar def subs_if_free_parameter(parameter: Any, **kwargs) -> Any: @@ -226,7 +208,7 @@ def subs_if_free_parameter(parameter: Any, **kwargs) -> Any: """ if isinstance(parameter, FreeParameterExpression): substituted = parameter.subs(kwargs) - if isinstance(substituted, Float): + if isinstance(substituted, sympy.Number): substituted = float(substituted) return substituted return parameter diff --git a/src/braket/pulse/__init__.py b/src/braket/pulse/__init__.py index 51414682d..01ef66892 100644 --- a/src/braket/pulse/__init__.py +++ b/src/braket/pulse/__init__.py @@ -11,8 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import oqpy - from braket.pulse.frame import Frame # noqa: F401 from braket.pulse.port import Port # noqa: F401 from braket.pulse.pulse_sequence import PulseSequence # noqa: F401 @@ -22,8 +20,3 @@ DragGaussianWaveform, GaussianWaveform, ) - -oqpy.AngleVar.default_size = None -oqpy.FloatVar.default_size = None -oqpy.IntVar.default_size = None -oqpy.UintVar.default_size = None diff --git a/src/braket/pulse/ast/free_parameters.py b/src/braket/pulse/ast/free_parameters.py index 96e319eb6..41c541da8 100644 --- a/src/braket/pulse/ast/free_parameters.py +++ b/src/braket/pulse/ast/free_parameters.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import operator from typing import Union from openpulse import ast @@ -18,8 +19,6 @@ from oqpy.program import Program from oqpy.timing import OQDurationLiteral -from braket.parametric.free_parameter_expression import FreeParameterExpression - class _FreeParameterTransformer(QASMTransformer): """Walk the AST and evaluate FreeParameterExpressions.""" @@ -43,27 +42,54 @@ def visit_Identifier( Returns: Union[Identifier, FloatLiteral]: The transformed identifier. """ - if isinstance(identifier.name, FreeParameterExpression): - new_value = FreeParameterExpression(identifier.name).subs(self.param_values) - if isinstance(new_value, FreeParameterExpression): - return ast.Identifier(new_value) - else: - return ast.FloatLiteral(float(new_value)) + if identifier.name in self.param_values: + return ast.FloatLiteral(float(self.param_values[identifier.name])) return identifier - def visit_DurationLiteral(self, duration_literal: ast.DurationLiteral) -> ast.DurationLiteral: - """Visit Duration Literal. - node.value, node.unit (node.unit.name, node.unit.value) - 1 + def visit_BinaryExpression( + self, node: ast.BinaryExpression + ) -> Union[ast.BinaryExpression, ast.FloatLiteral]: + """Visit a BinaryExpression. + + Visit the operands and simplify if they are literals. + Args: - duration_literal (DurationLiteral): The duration literal. + node (BinaryExpression): The node. + + Returns: + Union[BinaryExpression, FloatLiteral]: The transformed identifier. + """ + lhs = self.visit(node.lhs) + rhs = self.visit(node.rhs) + ops = { + ast.BinaryOperator["+"]: operator.add, + ast.BinaryOperator["*"]: operator.mul, + ast.BinaryOperator["**"]: operator.pow, + } + if isinstance(lhs, ast.FloatLiteral): + if isinstance(rhs, ast.FloatLiteral): + return ast.FloatLiteral(ops[node.op](lhs.value, rhs.value)) + elif isinstance(rhs, ast.DurationLiteral) and node.op == ast.BinaryOperator["*"]: + return OQDurationLiteral(lhs.value * rhs.value).to_ast(self.program) + return ast.BinaryExpression(op=node.op, lhs=lhs, rhs=rhs) + + def visit_UnaryExpression( + self, node: ast.UnaryExpression + ) -> Union[ast.UnaryExpression, ast.FloatLiteral]: + """Visit an UnaryExpression. + + Visit the operand and simplify if it is a literal. + + Args: + node (UnaryExpression): The node. + Returns: - DurationLiteral: The transformed duration literal. + Union[UnaryExpression, FloatLiteral]: The transformed identifier. """ - duration = duration_literal.value - if not isinstance(duration, ast.Identifier): - return duration_literal - new_duration = FreeParameterExpression(duration.name).subs(self.param_values) - if isinstance(new_duration, FreeParameterExpression): - return ast.DurationLiteral(ast.Identifier(str(new_duration)), duration_literal.unit) - return OQDurationLiteral(new_duration).to_ast(self.program) + expression = self.visit(node.expression) + if ( + isinstance(expression, (ast.FloatLiteral, ast.DurationLiteral)) + and node.op == ast.UnaryOperator["-"] + ): + return type(expression)(-expression.value) + return ast.UnaryExpression(op=node.op, expression=node.expression) # pragma: no cover diff --git a/src/braket/pulse/ast/qasm_parser.py b/src/braket/pulse/ast/qasm_parser.py index e892ba86b..0146eca66 100644 --- a/src/braket/pulse/ast/qasm_parser.py +++ b/src/braket/pulse/ast/qasm_parser.py @@ -15,7 +15,6 @@ from openpulse import ast from openpulse.printer import Printer -from openqasm3.ast import DurationLiteral from openqasm3.printer import PrinterState @@ -33,20 +32,6 @@ def visit_Identifier(self, node: ast.Identifier, context: PrinterState) -> None: """ self.stream.write(str(node.name)) - def visit_DurationLiteral(self, node: DurationLiteral, context: PrinterState) -> None: - """Visit Duration Literal. - node.value, node.unit (node.unit.name, node.unit.value) - 1 - Args: - node (ast.DurationLiteral): The duration literal. - context (PrinterState): The printer state context. - """ - duration = node.value - if isinstance(duration, ast.Identifier): - self.stream.write(f"({duration.name}) * 1{node.unit.name}") - else: - super().visit_DurationLiteral(node, context) - def visit_ClassicalDeclaration( self, node: ast.ClassicalDeclaration, context: PrinterState ) -> None: diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index a243daa89..831030923 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -19,8 +19,7 @@ from typing import Any, Union from openpulse import ast -from oqpy import BitVar, FloatVar, PhysicalQubits, Program -from oqpy.timing import OQDurationLiteral +from oqpy import BitVar, PhysicalQubits, Program from braket.parametric.free_parameter import FreeParameter from braket.parametric.free_parameter_expression import FreeParameterExpression @@ -84,7 +83,8 @@ def set_frequency( """ _validate_uniqueness(self._frames, frame) - self._program.set_frequency(frame=frame, freq=self._format_parameter_ast(frequency)) + self._register_free_parameters(frequency) + self._program.set_frequency(frame=frame, freq=frequency) self._frames[frame.id] = frame return self @@ -103,7 +103,8 @@ def shift_frequency( PulseSequence: self, with the instruction added. """ _validate_uniqueness(self._frames, frame) - self._program.shift_frequency(frame=frame, freq=self._format_parameter_ast(frequency)) + self._register_free_parameters(frequency) + self._program.shift_frequency(frame=frame, freq=frequency) self._frames[frame.id] = frame return self @@ -122,7 +123,8 @@ def set_phase( PulseSequence: self, with the instruction added. """ _validate_uniqueness(self._frames, frame) - self._program.set_phase(frame=frame, phase=self._format_parameter_ast(phase)) + self._register_free_parameters(phase) + self._program.set_phase(frame=frame, phase=phase) self._frames[frame.id] = frame return self @@ -141,7 +143,8 @@ def shift_phase( PulseSequence: self, with the instruction added. """ _validate_uniqueness(self._frames, frame) - self._program.shift_phase(frame=frame, phase=self._format_parameter_ast(phase)) + self._register_free_parameters(phase) + self._program.shift_phase(frame=frame, phase=phase) self._frames[frame.id] = frame return self @@ -160,7 +163,8 @@ def set_scale( PulseSequence: self, with the instruction added. """ _validate_uniqueness(self._frames, frame) - self._program.set_scale(frame=frame, scale=self._format_parameter_ast(scale)) + self._register_free_parameters(scale) + self._program.set_scale(frame=frame, scale=scale) self._frames[frame.id] = frame return self @@ -180,7 +184,7 @@ def delay( Returns: PulseSequence: self, with the instruction added. """ - duration = self._format_parameter_ast(duration, _type=ast.DurationType()) + self._register_free_parameters(duration) if not isinstance(qubits_or_frames, QubitSet): if not isinstance(qubits_or_frames, list): qubits_or_frames = [qubits_or_frames] @@ -230,17 +234,7 @@ def play(self, frame: Frame, waveform: Waveform) -> PulseSequence: _validate_uniqueness(self._waveforms, waveform) if isinstance(waveform, Parameterizable): for param in waveform.parameters: - if isinstance(param, FreeParameterExpression): - for p in param.expression.free_symbols: - self._program._add_var( - FloatVar( - name=p.name, - size=None, - init_expression="input", - needs_declaration=True, - ) - ) - self._free_parameters.add(FreeParameter(p.name)) + self._register_free_parameters(param) self._program.play(frame=frame, waveform=waveform) self._frames[frame.id] = frame self._waveforms[waveform.id] = waveform @@ -283,13 +277,13 @@ def make_bound_pulse_sequence(self, param_values: dict[str, float]) -> PulseSequ new_program = Program(simplify_constants=False) new_program.declared_vars = program.declared_vars new_program.undeclared_vars = program.undeclared_vars + for param_name in param_values: + new_program.undeclared_vars.pop(param_name, None) for x in new_tree.statements: new_program._add_statement(x) new_pulse_sequence = PulseSequence() new_pulse_sequence._program = new_program - for param_name in param_values: - new_pulse_sequence._program.undeclared_vars.pop(param_name, None) new_pulse_sequence._frames = deepcopy(self._frames) new_pulse_sequence._waveforms = { wf.id: wf.bind_values(**param_values) if isinstance(wf, Parameterizable) else wf @@ -317,12 +311,6 @@ def to_ir(self) -> str: str: a str representing the OpenPulse program encoding the PulseSequence. """ program = deepcopy(self._program) - for param in self._free_parameters: - program._add_var( - FloatVar( - name=param.name, size=None, init_expression="input", needs_declaration=True - ) - ) if self._capture_v0_count: register_identifier = "psb" program.declare( @@ -335,20 +323,13 @@ def to_ir(self) -> str: tree = _InputVarSplitter().visit(tree) return ast_to_qasm(tree) - def _format_parameter_ast( + def _register_free_parameters( self, parameter: Union[float, FreeParameterExpression], - _type: ast.ClassicalType = ast.FloatType(), - ) -> Union[float, FreeParameterExpression]: + ) -> None: if isinstance(parameter, FreeParameterExpression): for p in parameter.expression.free_symbols: self._free_parameters.add(FreeParameter(p.name)) - return ( - FreeParameterExpression(parameter, _type) - if isinstance(_type, ast.DurationType) - else parameter - ) - return OQDurationLiteral(parameter) if isinstance(_type, ast.DurationType) else parameter def _parse_arg_from_calibration_schema( self, argument: dict, waveforms: dict[Waveform], frames: dict[Frame] diff --git a/src/braket/pulse/waveforms.py b/src/braket/pulse/waveforms.py index 714340498..971321a71 100644 --- a/src/braket/pulse/waveforms.py +++ b/src/braket/pulse/waveforms.py @@ -173,7 +173,7 @@ def _to_oqpy_expression(self) -> OQPyExpression: "constant", [("length", duration), ("iq", complex128)] ) return WaveformVar( - init_expression=constant_generator(_map_to_oqpy_type(self.length, True), self.iq), + init_expression=constant_generator(self.length, self.iq), name=self.id, ) @@ -299,10 +299,10 @@ def _to_oqpy_expression(self) -> OQPyExpression: ) return WaveformVar( init_expression=drag_gaussian_generator( - _map_to_oqpy_type(self.length, True), - _map_to_oqpy_type(self.sigma, True), - _map_to_oqpy_type(self.beta), - _map_to_oqpy_type(self.amplitude), + self.length, + self.sigma, + self.beta, + self.amplitude, self.zero_at_edges, ), name=self.id, @@ -426,9 +426,9 @@ def _to_oqpy_expression(self) -> OQPyExpression: ) return WaveformVar( init_expression=gaussian_generator( - _map_to_oqpy_type(self.length, True), - _map_to_oqpy_type(self.sigma, True), - _map_to_oqpy_type(self.amplitude), + self.length, + self.sigma, + self.amplitude, self.zero_at_edges, ), name=self.id, @@ -469,16 +469,6 @@ def _make_identifier_name() -> str: return "".join([random.choice(string.ascii_letters) for _ in range(10)]) -def _map_to_oqpy_type( - parameter: Union[FreeParameterExpression, float], is_duration_type: bool = False -) -> Union[FreeParameterExpression, OQPyExpression]: - return ( - FreeParameterExpression(parameter, duration) - if isinstance(parameter, FreeParameterExpression) and is_duration_type - else parameter - ) - - def _parse_waveform_from_calibration_schema(waveform: dict) -> Waveform: waveform_names = { "arbitrary": ArbitraryWaveform._from_calibration_schema, diff --git a/test/integ_tests/test_create_quantum_job.py b/test/integ_tests/test_create_quantum_job.py index 3b1b8ae95..02c16313b 100644 --- a/test/integ_tests/test_create_quantum_job.py +++ b/test/integ_tests/test_create_quantum_job.py @@ -52,20 +52,25 @@ def test_failed_quantum_job(aws_session, capsys): hyperparameters={"test_case": "failed"}, ) - job_name = job.name - pattern = f"^arn:aws:braket:{aws_session.region}:\\d12:job/{job_name}$" - re.match(pattern=pattern, string=job.arn) + pattern = f"^arn:aws:braket:{aws_session.region}:\\d{{12}}:job/[a-z0-9-]+$" + assert re.match(pattern=pattern, string=job.arn) # Check job is in failed state. assert job.state() == "FAILED" # Check whether the respective folder with files are created for script, # output, tasks and checkpoints. + job_name = job.name + s3_bucket = aws_session.default_bucket() + subdirectory = re.match( + rf"s3://{s3_bucket}/jobs/{job.name}/(\d+)/script/source.tar.gz", + job.metadata()["algorithmSpecification"]["scriptModeConfig"]["s3Uri"], + ).group(1) keys = aws_session.list_keys( - bucket=f"amazon-braket-{aws_session.region}-{aws_session.account_id}", - prefix=f"jobs/{job_name}", + bucket=s3_bucket, + prefix=f"jobs/{job_name}/{subdirectory}/", ) - assert keys == [f"jobs/{job_name}/script/source.tar.gz"] + assert keys == [f"jobs/{job_name}/{subdirectory}/script/source.tar.gz"] # no results saved assert job.result() == {} @@ -108,33 +113,44 @@ def test_completed_quantum_job(aws_session, capsys): hyperparameters={"test_case": "completed"}, ) - job_name = job.name - pattern = f"^arn:aws:braket:{aws_session.region}:\\d12:job/{job_name}$" - re.match(pattern=pattern, string=job.arn) + pattern = f"^arn:aws:braket:{aws_session.region}:\\d{{12}}:job/[a-z0-9-]+$" + assert re.match(pattern=pattern, string=job.arn) # check job is in completed state. assert job.state() == "COMPLETED" # Check whether the respective folder with files are created for script, # output, tasks and checkpoints. - s3_bucket = f"amazon-braket-{aws_session.region}-{aws_session.account_id}" + job_name = job.name + s3_bucket = aws_session.default_bucket() + subdirectory = re.match( + rf"s3://{s3_bucket}/jobs/{job.name}/(\d+)/script/source.tar.gz", + job.metadata()["algorithmSpecification"]["scriptModeConfig"]["s3Uri"], + ).group(1) keys = aws_session.list_keys( bucket=s3_bucket, - prefix=f"jobs/{job_name}", + prefix=f"jobs/{job_name}/{subdirectory}/", ) for expected_key in [ - f"jobs/{job_name}/script/source.tar.gz", - f"jobs/{job_name}/data/output/model.tar.gz", - f"jobs/{job_name}/tasks/[^/]*/results.json", - f"jobs/{job_name}/checkpoints/{job_name}_plain_data.json", - f"jobs/{job_name}/checkpoints/{job_name}.json", + f"jobs/{job_name}/{subdirectory}/script/source.tar.gz", + f"jobs/{job_name}/{subdirectory}/data/output/model.tar.gz", + f"jobs/{job_name}/{subdirectory}/checkpoints/{job_name}_plain_data.json", + f"jobs/{job_name}/{subdirectory}/checkpoints/{job_name}.json", ]: assert any(re.match(expected_key, key) for key in keys) + # Check that tasks exist in the correct location + tasks_keys = aws_session.list_keys( + bucket=s3_bucket, + prefix=f"jobs/{job_name}/tasks/", + ) + expected_task_location = f"jobs/{job_name}/tasks/[^/]*/results.json" + assert any(re.match(expected_task_location, key) for key in tasks_keys) + # Check if checkpoint is uploaded in requested format. for s3_key, expected_data in [ ( - f"jobs/{job_name}/checkpoints/{job_name}_plain_data.json", + f"jobs/{job_name}/{subdirectory}/checkpoints/{job_name}_plain_data.json", { "braketSchemaHeader": { "name": "braket.jobs_data.persisted_job_data", @@ -145,7 +161,7 @@ def test_completed_quantum_job(aws_session, capsys): }, ), ( - f"jobs/{job_name}/checkpoints/{job_name}.json", + f"jobs/{job_name}/{subdirectory}/checkpoints/{job_name}.json", { "braketSchemaHeader": { "name": "braket.jobs_data.persisted_job_data", diff --git a/test/unit_tests/braket/aws/test_aws_quantum_job.py b/test/unit_tests/braket/aws/test_aws_quantum_job.py index 3a36d8e75..7f9dc1a84 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_job.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_job.py @@ -93,6 +93,7 @@ def _get_job_response(**kwargs): "jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446", "jobName": "job-test-20210628140446", "outputDataConfig": {"s3Path": "s3://amazon-braket-jobs/job-path/data"}, + "queueInfo": {"position": "1", "queue": "JOBS_QUEUE"}, "roleArn": "arn:aws:iam::875981177017:role/AmazonBraketJobRole", "status": "RUNNING", "stoppingCondition": {"maxRuntimeInSeconds": 1200}, @@ -554,8 +555,9 @@ def test_arn(quantum_job_arn, aws_session): assert quantum_job.arn == quantum_job_arn -def test_name(quantum_job_arn, quantum_job_name, aws_session): +def test_name(quantum_job_arn, quantum_job_name, aws_session, generate_get_job_response): quantum_job = AwsQuantumJob(quantum_job_arn, aws_session) + aws_session.get_job.return_value = generate_get_job_response(jobName=quantum_job_name) assert quantum_job.name == quantum_job_name @@ -719,6 +721,14 @@ def test_logs( generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), generate_get_job_response(status="COMPLETED"), ) quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses @@ -739,6 +749,48 @@ def test_logs( ) +def test_logs_queue_progress( + quantum_job, + generate_get_job_response, + log_events_responses, + log_stream_responses, + capsys, +): + queue_info = {"queue": "JOBS_QUEUE", "position": "1"} + quantum_job._aws_session.get_job.side_effect = ( + generate_get_job_response(status="QUEUED", queue_info=queue_info), + generate_get_job_response(status="QUEUED", queue_info=queue_info), + generate_get_job_response(status="QUEUED", queue_info=queue_info), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + ) + quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses + quantum_job._aws_session.get_log_events.side_effect = log_events_responses + + quantum_job.logs(wait=True, poll_interval_seconds=0) + + captured = capsys.readouterr() + assert captured.out == "\n".join( + ( + f"Job queue position: {queue_info['position']}", + "Running:", + "", + "hi there #1", + "hi there #2", + "hi there #2a", + "hi there #3", + "", + ) + ) + + @patch.dict("os.environ", {"JPY_PARENT_PID": "True"}) def test_logs_multiple_instances( quantum_job, @@ -752,6 +804,15 @@ def test_logs_multiple_instances( generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), generate_get_job_response(status="COMPLETED"), ) log_stream_responses[-1]["logStreams"].append({"logStreamName": "stream-2"}) @@ -817,6 +878,7 @@ def get_log_events(log_group, log_stream, start_time, start_from_head, next_toke def test_logs_error(quantum_job, generate_get_job_response, capsys): quantum_job._aws_session.get_job.side_effect = ( + generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), generate_get_job_response(status="COMPLETED"), diff --git a/test/unit_tests/braket/aws/test_aws_quantum_task.py b/test/unit_tests/braket/aws/test_aws_quantum_task.py index 4d8daa36f..6e789ec92 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_task.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_task.py @@ -83,6 +83,11 @@ def quantum_task(aws_session): return AwsQuantumTask("foo:bar:arn", aws_session, poll_timeout_seconds=2) +@pytest.fixture +def quantum_task_quiet(aws_session): + return AwsQuantumTask("foo:bar:arn", aws_session, poll_timeout_seconds=2, quiet=True) + + @pytest.fixture def circuit_task(aws_session): return AwsQuantumTask("foo:bar:arn", aws_session, poll_timeout_seconds=2) @@ -243,6 +248,23 @@ def test_queue_position(quantum_task): ) +def test_queued_quiet(quantum_task_quiet): + state_1 = "QUEUED" + _mock_metadata(quantum_task_quiet._aws_session, state_1) + assert quantum_task_quiet.queue_position() == QuantumTaskQueueInfo( + queue_type=QueueType.NORMAL, queue_position="2", message=None + ) + + state_2 = "COMPLETED" + message = ( + f"'Task is in {state_2} status. AmazonBraket does not show queue position for this status.'" + ) + _mock_metadata(quantum_task_quiet._aws_session, state_2) + assert quantum_task_quiet.queue_position() == QuantumTaskQueueInfo( + queue_type=QueueType.NORMAL, queue_position=None, message=message + ) + + def test_state(quantum_task): state_1 = "RUNNING" _mock_metadata(quantum_task._aws_session, state_1) @@ -432,6 +454,43 @@ def set_result_from_callback(future): assert result_from_future == result +@pytest.mark.parametrize( + "status, result", + [ + ("COMPLETED", GateModelQuantumTaskResult.from_string(MockS3.MOCK_S3_RESULT_GATE_MODEL)), + ("FAILED", None), + ], +) +def test_async_result_queued(circuit_task, status, result): + def set_result_from_callback(future): + # Set the result_from_callback variable in the enclosing functions scope + nonlocal result_from_callback + result_from_callback = future.result() + + _mock_metadata(circuit_task._aws_session, "QUEUED") + _mock_s3(circuit_task._aws_session, MockS3.MOCK_S3_RESULT_GATE_MODEL) + + future = circuit_task.async_result() + + # test the different ways to get the result from async + + # via callback + result_from_callback = None + future.add_done_callback(set_result_from_callback) + + # via asyncio waiting for result + _mock_metadata(circuit_task._aws_session, status) + event_loop = asyncio.get_event_loop() + result_from_waiting = event_loop.run_until_complete(future) + + # via future.result(). Note that this would fail if the future is not complete. + result_from_future = future.result() + + assert result_from_callback == result + assert result_from_waiting == result + assert result_from_future == result + + def test_failed_task(quantum_task): _mock_metadata(quantum_task._aws_session, "FAILED") _mock_s3(quantum_task._aws_session, MockS3.MOCK_S3_RESULT_GATE_MODEL) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 73074e3f4..9a112f6e7 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -1083,7 +1083,7 @@ def test_parametric_circuit_with_fixed_argument_defcal(pulse_sequence): "bit[1] b;", "qubit[1] q;", "cal {", - " waveform drag_gauss_wf = drag_gaussian" + "(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -1180,7 +1180,7 @@ def foo( "bit[1] b;", "qubit[1] q;", "cal {", - " waveform drag_gauss_wf = drag_gaussian" + "(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal foo(float beta) $0 {", " shift_phase(predefined_frame_1, -0.1);", diff --git a/test/unit_tests/braket/circuits/test_gates.py b/test/unit_tests/braket/circuits/test_gates.py index 5639399ad..1b6ac56c7 100644 --- a/test/unit_tests/braket/circuits/test_gates.py +++ b/test/unit_tests/braket/circuits/test_gates.py @@ -1048,8 +1048,8 @@ def to_ir(pulse_gate): assert a_bound_ir == "\n".join( [ "cal {", - " set_frequency(user_frame, b + 3);", - " delay[(c) * 1s] user_frame;", + " set_frequency(user_frame, 3.0 + b);", + " delay[c * 1s] user_frame;", "}", ] ) diff --git a/test/unit_tests/braket/jobs/test_quantum_job_creation.py b/test/unit_tests/braket/jobs/test_quantum_job_creation.py index bef4fd643..8cd1fbca9 100644 --- a/test/unit_tests/braket/jobs/test_quantum_job_creation.py +++ b/test/unit_tests/braket/jobs/test_quantum_job_creation.py @@ -323,8 +323,9 @@ def _translate_creation_args(create_job_args): image_uri = create_job_args["image_uri"] job_name = create_job_args["job_name"] or _generate_default_job_name(image_uri) default_bucket = aws_session.default_bucket() + timestamp = str(int(time.time() * 1000)) code_location = create_job_args["code_location"] or AwsSession.construct_s3_uri( - default_bucket, "jobs", job_name, "script" + default_bucket, "jobs", job_name, timestamp, "script" ) role_arn = create_job_args["role_arn"] or aws_session.get_default_jobs_role() device = create_job_args["device"] @@ -340,11 +341,13 @@ def _translate_creation_args(create_job_args): } hyperparameters.update(distributed_hyperparams) output_data_config = create_job_args["output_data_config"] or OutputDataConfig( - s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "data") + s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, timestamp, "data") ) stopping_condition = create_job_args["stopping_condition"] or StoppingCondition() checkpoint_config = create_job_args["checkpoint_config"] or CheckpointConfig( - s3Uri=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "checkpoints") + s3Uri=AwsSession.construct_s3_uri( + default_bucket, "jobs", job_name, timestamp, "checkpoints" + ) ) entry_point = create_job_args["entry_point"] source_module = create_job_args["source_module"] @@ -365,7 +368,7 @@ def _translate_creation_args(create_job_args): "jobName": job_name, "roleArn": role_arn, "algorithmSpecification": algorithm_specification, - "inputDataConfig": _process_input_data(input_data, job_name, aws_session), + "inputDataConfig": _process_input_data(input_data, job_name, aws_session, timestamp), "instanceConfig": asdict(instance_config), "outputDataConfig": asdict(output_data_config, dict_factory=_exclude_nones_factory), "checkpointConfig": asdict(checkpoint_config), @@ -403,6 +406,7 @@ def test_generate_default_job_name(mock_time, image_uri): mock_time.return_value = datetime.datetime.now().timestamp() timestamp = str(int(time.time() * 1000)) assert _generate_default_job_name(image_uri) == f"braket-job{job_type}-{timestamp}" + assert _generate_default_job_name(image_uri, timestamp="ts") == f"braket-job{job_type}-ts" @pytest.mark.parametrize( @@ -602,7 +606,7 @@ def test_invalid_input_parameters(entry_point, aws_session): "channelName": "input", "dataSource": { "s3DataSource": { - "s3Uri": "s3://default-bucket-name/jobs/job-name/data/input/prefix", + "s3Uri": "s3://default-bucket-name/jobs/job-name/ts/data/input/prefix", }, }, } @@ -651,7 +655,7 @@ def test_invalid_input_parameters(entry_point, aws_session): "channelName": "local-input", "dataSource": { "s3DataSource": { - "s3Uri": "s3://default-bucket-name/jobs/job-name/" + "s3Uri": "s3://default-bucket-name/jobs/job-name/ts/" "data/local-input/prefix", }, }, @@ -678,4 +682,4 @@ def test_invalid_input_parameters(entry_point, aws_session): ) def test_process_input_data(aws_session, input_data, input_data_configs): job_name = "job-name" - assert _process_input_data(input_data, job_name, aws_session) == input_data_configs + assert _process_input_data(input_data, job_name, aws_session, "ts") == input_data_configs diff --git a/test/unit_tests/braket/parametric/test_free_parameter_expression.py b/test/unit_tests/braket/parametric/test_free_parameter_expression.py index d991ec236..7ba8fd1ac 100644 --- a/test/unit_tests/braket/parametric/test_free_parameter_expression.py +++ b/test/unit_tests/braket/parametric/test_free_parameter_expression.py @@ -162,6 +162,7 @@ def test_sub_return_expression(): (FreeParameter("a") + 2 * FreeParameter("b"), {"a": 0.1, "b": 0.3}, 0.7, float), (FreeParameter("x"), {"y": 1}, FreeParameter("x"), FreeParameter), (FreeParameter("y"), {"y": -0.1}, -0.1, float), + (2 * FreeParameter("i"), {"i": 1}, 2.0, float), ( FreeParameter("a") + 2 * FreeParameter("x"), {"a": 0.4, "b": 0.4}, diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 8de93ce1a..ca3edfa2f 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -87,7 +87,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined .set_frequency(predefined_frame_1, param) .shift_frequency(predefined_frame_1, param) .set_phase(predefined_frame_1, param) - .shift_phase(predefined_frame_1, param) + .shift_phase(predefined_frame_1, -param) .set_scale(predefined_frame_1, param) .capture_v0(predefined_frame_1) .delay([predefined_frame_1, predefined_frame_2], param) @@ -124,28 +124,28 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_unbound = "\n".join( [ "OPENQASM 3.0;", - "input float length_g;", - "input float sigma_g;", + "input float length_c;", "input float length_dg;", "input float sigma_dg;", - "input float length_c;", - "input float b;", + "input float length_g;", + "input float sigma_g;", "input float a;", + "input float b;", "cal {", - " waveform gauss_wf = gaussian((length_g) * 1s, (sigma_g) * 1s, 1, false);", - " waveform drag_gauss_wf = drag_gaussian((length_dg) * 1s," - " (sigma_dg) * 1s, 0.2, 1, false);", - " waveform constant_wf = constant((length_c) * 1s, 2.0 + 0.3im);", + " waveform gauss_wf = gaussian(length_g * 1s, sigma_g * 1s, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(length_dg * 1s," + " sigma_dg * 1s, 0.2, 1, false);", + " waveform constant_wf = constant(length_c * 1s, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", - " set_frequency(predefined_frame_1, a + 2*b);", - " shift_frequency(predefined_frame_1, a + 2*b);", - " set_phase(predefined_frame_1, a + 2*b);", - " shift_phase(predefined_frame_1, a + 2*b);", - " set_scale(predefined_frame_1, a + 2*b);", + " set_frequency(predefined_frame_1, a + 2.0 * b);", + " shift_frequency(predefined_frame_1, a + 2.0 * b);", + " set_phase(predefined_frame_1, a + 2.0 * b);", + " shift_phase(predefined_frame_1, -1.0 * a + -2.0 * b);", + " set_scale(predefined_frame_1, a + 2.0 * b);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[(a + 2*b) * 1s] predefined_frame_1, predefined_frame_2;", - " delay[(a + 2*b) * 1s] predefined_frame_1;", + " delay[(a + 2.0 * b) * 1s] predefined_frame_1, predefined_frame_2;", + " delay[(a + 2.0 * b) * 1s] predefined_frame_1;", " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", @@ -178,19 +178,19 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined "input float sigma_g;", "input float a;", "cal {", - " waveform gauss_wf = gaussian(1.0ms, (sigma_g) * 1s, 1, false);", + " waveform gauss_wf = gaussian(1.0ms, sigma_g * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", - " set_frequency(predefined_frame_1, a + 4);", - " shift_frequency(predefined_frame_1, a + 4);", - " set_phase(predefined_frame_1, a + 4);", - " shift_phase(predefined_frame_1, a + 4);", - " set_scale(predefined_frame_1, a + 4);", + " set_frequency(predefined_frame_1, a + 4.0);", + " shift_frequency(predefined_frame_1, a + 4.0);", + " set_phase(predefined_frame_1, a + 4.0);", + " shift_phase(predefined_frame_1, -1.0 * a + -4.0);", + " set_scale(predefined_frame_1, a + 4.0);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[(a + 4) * 1s] predefined_frame_1, predefined_frame_2;", - " delay[(a + 4) * 1s] predefined_frame_1;", + " delay[(a + 4.0) * 1s] predefined_frame_1, predefined_frame_2;", + " delay[(a + 4.0) * 1s] predefined_frame_1;", " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", @@ -218,11 +218,11 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined " set_frequency(predefined_frame_1, 5.0);", " shift_frequency(predefined_frame_1, 5.0);", " set_phase(predefined_frame_1, 5.0);", - " shift_phase(predefined_frame_1, 5.0);", + " shift_phase(predefined_frame_1, -5.0);", " set_scale(predefined_frame_1, 5.0);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[5s] predefined_frame_1, predefined_frame_2;", - " delay[5s] predefined_frame_1;", + " delay[5.0s] predefined_frame_1, predefined_frame_2;", + " delay[5.0s] predefined_frame_1;", " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", diff --git a/test/unit_tests/braket/pulse/test_waveforms.py b/test/unit_tests/braket/pulse/test_waveforms.py index 09ff3290e..0dd8f4ea5 100644 --- a/test/unit_tests/braket/pulse/test_waveforms.py +++ b/test/unit_tests/braket/pulse/test_waveforms.py @@ -118,7 +118,9 @@ def test_constant_wf_free_params(): assert wf.parameters == [FreeParameter("length_v") + FreeParameter("length_w")] _assert_wf_qasm( wf, - "waveform const_wf = " "constant((length_v + length_w) * 1s, 2.0 - 3.0im);", + "input float length_v;\n" + "input float length_w;\n" + "waveform const_wf = constant((length_v + length_w) * 1s, 2.0 - 3.0im);", ) wf_2 = wf.bind_values(length_v=2e-6, length_w=4e-6) @@ -199,8 +201,13 @@ def test_drag_gaussian_wf_free_params(): ] _assert_wf_qasm( wf, + "input float length_v;\n" + "input float sigma_a;\n" + "input float sigma_b;\n" + "input float beta_y;\n" + "input float amp_z;\n" "waveform d_gauss_wf = " - "drag_gaussian((length_v) * 1s, (sigma_a + " + "drag_gaussian(length_v * 1s, (sigma_a + " "sigma_b) * 1s, beta_y, amp_z, false);", ) @@ -213,7 +220,10 @@ def test_drag_gaussian_wf_free_params(): ] _assert_wf_qasm( wf_2, - "waveform d_gauss_wf = drag_gaussian(600.0ms, (sigma_b + 0.4) * 1s, beta_y, amp_z, false);", + "input float sigma_b;\n" + "input float beta_y;\n" + "input float amp_z;\n" + "waveform d_gauss_wf = drag_gaussian(600.0ms, (0.4 + sigma_b) * 1s, beta_y, amp_z, false);", ) wf_3 = wf.bind_values(length_v=0.6, sigma_a=0.3, sigma_b=0.1, beta_y=0.2, amp_z=0.1) @@ -283,12 +293,17 @@ def test_gaussian_wf_free_params(): ] _assert_wf_qasm( wf, - "waveform gauss_wf = gaussian((length_v) * 1s, (sigma_x) * 1s, " "amp_z, false);", + "input float length_v;\n" + "input float sigma_x;\n" + "input float amp_z;\n" + "waveform gauss_wf = gaussian(length_v * 1s, sigma_x * 1s, amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_x=0.4) assert wf_2.parameters == [0.6, 0.4, FreeParameter("amp_z")] - _assert_wf_qasm(wf_2, "waveform gauss_wf = gaussian(600.0ms, 400.0ms, amp_z, false);") + _assert_wf_qasm( + wf_2, "input float amp_z;\nwaveform gauss_wf = gaussian(600.0ms, 400.0ms, amp_z, false);" + ) wf_3 = wf.bind_values(length_v=0.6, sigma_x=0.3, amp_z=0.1) assert wf_3.parameters == [0.6, 0.3, 0.1] diff --git a/tox.ini b/tox.ini index b77ac4525..467b9ab84 100644 --- a/tox.ini +++ b/tox.ini @@ -59,7 +59,6 @@ basepython = python3 skip_install = true deps = flake8 - flake8-rst-docstrings git+https://github.com/amazon-braket/amazon-braket-build-tools.git commands = flake8 --extend-exclude src {posargs} From d0ba60dba4b6e0ebab1419563d8ca744964c61b6 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Tue, 6 Feb 2024 19:34:04 -0500 Subject: [PATCH 37/43] fix merge --- src/braket/pulse/ast/qasm_transformer.py | 41 ------------------- src/braket/pulse/pulse_sequence.py | 3 +- .../test_free_parameter_expression.py | 5 --- .../braket/pulse/test_pulse_sequence.py | 9 ---- 4 files changed, 1 insertion(+), 57 deletions(-) diff --git a/src/braket/pulse/ast/qasm_transformer.py b/src/braket/pulse/ast/qasm_transformer.py index a8cc9b9b3..f5e350883 100644 --- a/src/braket/pulse/ast/qasm_transformer.py +++ b/src/braket/pulse/ast/qasm_transformer.py @@ -57,44 +57,3 @@ def visit_ExpressionStatement(self, expression_statement: ast.ExpressionStatemen return new_val else: return expression_statement - - -class _InputVarSplitter(QASMTransformer): - """ - QASMTransformer which walks the AST and makes the necessary modifications needed - for IR generation. Currently, it performs the following operations: - * Bubbles up input variables to the top of the CalibrationStatement block. - """ - - def visit_Program(self, program: ast.Program) -> ast.Program: - """Visit a Program. - Args: - program (Program): The program. - Returns: - Program: the modified program. - """ - assert len(program.statements) == 1 and isinstance( - program.statements[0], ast.CalibrationStatement - ) - program.statements = self.split_input_vars(program.statements[0]) - return self.generic_visit(program) - - def split_input_vars( - self, - node: ast.CalibrationStatement, - ) -> list[ast.Statement]: - """Split input variables out of the calibrationStatement block. - - Args: - node (CalibrationStatement): The CalibrationStatement block. - Returns: - list[Statement]: The list of statements with input variables outside and in front. - """ - input_vars = [] - new_body = [] - for child in node.body: - if isinstance(child, ast.IODeclaration) and child.io_identifier is ast.IOKeyword.input: - input_vars.append(child) - else: - new_body.append(child) - return input_vars + [ast.CalibrationStatement(new_body)] diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 18b682c4c..25e5c3b55 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -27,7 +27,7 @@ from braket.pulse.ast.approximation_parser import _ApproximationParser from braket.pulse.ast.free_parameters import _FreeParameterTransformer from braket.pulse.ast.qasm_parser import ast_to_qasm -from braket.pulse.ast.qasm_transformer import _InputVarSplitter, _IRQASMTransformer +from braket.pulse.ast.qasm_transformer import _IRQASMTransformer from braket.pulse.frame import Frame from braket.pulse.pulse_sequence_trace import PulseSequenceTrace from braket.pulse.waveforms import Waveform @@ -320,7 +320,6 @@ def to_ir(self) -> str: tree = _IRQASMTransformer(register_identifier).visit(tree) else: tree = program.to_ast(encal=True, include_externs=False) - tree = _InputVarSplitter().visit(tree) return ast_to_qasm(tree) def _register_free_parameters( diff --git a/test/unit_tests/braket/parametric/test_free_parameter_expression.py b/test/unit_tests/braket/parametric/test_free_parameter_expression.py index 7ba8fd1ac..370d45083 100644 --- a/test/unit_tests/braket/parametric/test_free_parameter_expression.py +++ b/test/unit_tests/braket/parametric/test_free_parameter_expression.py @@ -67,11 +67,6 @@ def test_unsupported_node_str(): FreeParameterExpression("theta , 1") -@pytest.mark.xfail(raises=TypeError) -def test_unsupported_type(): - FreeParameterExpression("theta", _type=float) - - def test_commutativity(): add_1 = 1 + FreeParameterExpression(FreeParameter("theta")) add_2 = FreeParameterExpression(FreeParameter("theta")) + 1 diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 841b64dfe..006326047 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -124,13 +124,6 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_unbound = "\n".join( [ "OPENQASM 3.0;", - "input float length_c;", - "input float length_dg;", - "input float sigma_dg;", - "input float length_g;", - "input float sigma_g;", - "input float a;", - "input float b;", "cal {", " input float length_c;", " input float length_dg;", @@ -182,8 +175,6 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined expected_str_b_bound = "\n".join( [ "OPENQASM 3.0;", - "input float sigma_g;", - "input float a;", "cal {", " input float sigma_g;", " input float a;", From 3fc6296aea34c3a29af5f06333280c6c004ec838 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Wed, 7 Feb 2024 20:32:18 -0500 Subject: [PATCH 38/43] do not declare FreeParameterExpression by default --- src/braket/circuits/circuit.py | 25 +++++++++++++++++-- .../parametric/free_parameter_expression.py | 4 ++- src/braket/pulse/pulse_sequence.py | 5 ++++ .../braket/pulse/test_pulse_sequence.py | 16 ++++++------ .../unit_tests/braket/pulse/test_waveforms.py | 20 ++------------- 5 files changed, 41 insertions(+), 29 deletions(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index a1015f68d..ad8fbe33f 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1235,6 +1235,7 @@ def _create_openqasm_header( gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], ) -> list[str]: ir_instructions = ["OPENQASM 3.0;"] + frame_wf_declarations = self._generate_frame_wf_defcal_declarations(gate_definitions) for parameter in self.parameters: ir_instructions.append(f"input float {parameter};") if not self.result_types: @@ -1249,7 +1250,6 @@ def _create_openqasm_header( f"{serialization_properties.qubit_reference_type} supplied." ) - frame_wf_declarations = self._generate_frame_wf_defcal_declarations(gate_definitions) if frame_wf_declarations: ir_instructions.append(frame_wf_declarations) return ir_instructions @@ -1271,6 +1271,18 @@ def _validate_gate_calbrations_uniqueness( def _generate_frame_wf_defcal_declarations( self, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] ) -> Optional[str]: + """Generates the header where frame, waveform and defcals are declared. + + It also adds any FreeParameter that is not gate arguments to the circuit parameter set. + + Args: + gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]]): The + calibration data for the device. + + Returns: + Optional[str]: An OpenQASM string + """ + program = oqpy.Program(None, simplify_constants=False) frames, waveforms = self._get_frames_waveforms_from_instrs(gate_definitions) @@ -1289,7 +1301,16 @@ def _generate_frame_wf_defcal_declarations( for key, calibration in gate_definitions.items(): gate, qubits = key gate_name = gate._qasm_name - arguments = gate.parameters if isinstance(gate, Parameterizable) else None + arguments = gate.parameters if isinstance(gate, Parameterizable) else [] + + for param in calibration.parameters: + if param not in arguments: + self._parameters.add(param) + arguments = [ + param._to_oqpy_expression() if isinstance(param, FreeParameter) else param + for param in arguments + ] + with oqpy.defcal( program, [oqpy.PhysicalQubits[int(k)] for k in qubits], gate_name, arguments ): diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index 98916fbf5..fdfa1469a 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -191,7 +191,9 @@ def _to_oqpy_expression(self) -> OQPyExpression: elif isinstance(self.expression, sympy.Number): return float(self.expression) else: - fvar = FloatVar(name=self.expression.name, init_expression="input") + fvar = FloatVar( + name=self.expression.name, init_expression="input", needs_declaration=False + ) fvar.size = None fvar.type.size = None return fvar diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 25e5c3b55..cabd4943f 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -311,6 +311,11 @@ def to_ir(self) -> str: str: a str representing the OpenPulse program encoding the PulseSequence. """ program = deepcopy(self._program) + program.autodeclare(encal=False) + sorted_parameters = sorted(self.parameters, key=lambda p: p.name, reverse=True) + for param in sorted_parameters: + program.declare(param._to_oqpy_expression(), to_beginning=True) + if self._capture_v0_count: register_identifier = "psb" program.declare( diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 006326047..1008c79f1 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -125,19 +125,19 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", + " bit[2] psb;", + " input float a;", + " input float b;", " input float length_c;", " input float length_dg;", - " input float sigma_dg;", " input float length_g;", + " input float sigma_dg;", " input float sigma_g;", - " input float a;", - " input float b;", " waveform gauss_wf = gaussian(length_g * 1s, sigma_g * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(length_dg * 1s," " sigma_dg * 1s, 0.2, 1, false);", " waveform constant_wf = constant(length_c * 1s, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", - " bit[2] psb;", " set_frequency(predefined_frame_1, a + 2.0 * b);", " shift_frequency(predefined_frame_1, a + 2.0 * b);", " set_phase(predefined_frame_1, a + 2.0 * b);", @@ -176,13 +176,13 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", - " input float sigma_g;", + " bit[2] psb;", " input float a;", + " input float sigma_g;", " waveform gauss_wf = gaussian(1.0ms, sigma_g * 1s, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", - " bit[2] psb;", " set_frequency(predefined_frame_1, a + 4.0);", " shift_frequency(predefined_frame_1, a + 4.0);", " set_phase(predefined_frame_1, a + 4.0);", @@ -210,11 +210,11 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", + " bit[2] psb;", " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", - " bit[2] psb;", " set_frequency(predefined_frame_1, 5.0);", " shift_frequency(predefined_frame_1, 5.0);", " set_phase(predefined_frame_1, 5.0);", @@ -314,11 +314,11 @@ def test_pulse_sequence_to_ir(predefined_frame_1, predefined_frame_2): [ "OPENQASM 3.0;", "cal {", + " bit[2] psb;", " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", - " bit[2] psb;", " set_frequency(predefined_frame_1, 3000000000.0);", " shift_frequency(predefined_frame_1, 1000000000.0);", " set_phase(predefined_frame_1, -0.5);", diff --git a/test/unit_tests/braket/pulse/test_waveforms.py b/test/unit_tests/braket/pulse/test_waveforms.py index 0dd8f4ea5..0c56d3542 100644 --- a/test/unit_tests/braket/pulse/test_waveforms.py +++ b/test/unit_tests/braket/pulse/test_waveforms.py @@ -118,8 +118,6 @@ def test_constant_wf_free_params(): assert wf.parameters == [FreeParameter("length_v") + FreeParameter("length_w")] _assert_wf_qasm( wf, - "input float length_v;\n" - "input float length_w;\n" "waveform const_wf = constant((length_v + length_w) * 1s, 2.0 - 3.0im);", ) @@ -201,14 +199,8 @@ def test_drag_gaussian_wf_free_params(): ] _assert_wf_qasm( wf, - "input float length_v;\n" - "input float sigma_a;\n" - "input float sigma_b;\n" - "input float beta_y;\n" - "input float amp_z;\n" "waveform d_gauss_wf = " - "drag_gaussian(length_v * 1s, (sigma_a + " - "sigma_b) * 1s, beta_y, amp_z, false);", + "drag_gaussian(length_v * 1s, (sigma_a + sigma_b) * 1s, beta_y, amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_a=0.4) @@ -220,9 +212,6 @@ def test_drag_gaussian_wf_free_params(): ] _assert_wf_qasm( wf_2, - "input float sigma_b;\n" - "input float beta_y;\n" - "input float amp_z;\n" "waveform d_gauss_wf = drag_gaussian(600.0ms, (0.4 + sigma_b) * 1s, beta_y, amp_z, false);", ) @@ -293,17 +282,12 @@ def test_gaussian_wf_free_params(): ] _assert_wf_qasm( wf, - "input float length_v;\n" - "input float sigma_x;\n" - "input float amp_z;\n" "waveform gauss_wf = gaussian(length_v * 1s, sigma_x * 1s, amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_x=0.4) assert wf_2.parameters == [0.6, 0.4, FreeParameter("amp_z")] - _assert_wf_qasm( - wf_2, "input float amp_z;\nwaveform gauss_wf = gaussian(600.0ms, 400.0ms, amp_z, false);" - ) + _assert_wf_qasm(wf_2, "waveform gauss_wf = gaussian(600.0ms, 400.0ms, amp_z, false);") wf_3 = wf.bind_values(length_v=0.6, sigma_x=0.3, amp_z=0.1) assert wf_3.parameters == [0.6, 0.3, 0.1] From 6047c86e019406c5de76843bab8cad6545c36e0e Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Wed, 7 Feb 2024 21:51:10 -0500 Subject: [PATCH 39/43] add tests --- .../braket/circuits/test_circuit.py | 174 ++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 9a112f6e7..c601e6d38 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -1063,6 +1063,180 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir, assert copy_of_gate_calibrations.pulse_sequences == gate_calibrations.pulse_sequences +@pytest.mark.parametrize( + "circuit, calibration_key, expected_ir", + [ + ( + Circuit().rx(0, 0.2), + (Gate.Rx(FreeParameter("alpha")), QubitSet(0)), + OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "input float gamma;", + "input float beta;", + "bit[1] b;", + "qubit[1] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," + " 400.0ms, 0.2, 1, false);", + "}", + "defcal rx(float alpha) $0 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "rx(0.2) q[0];", + "b[0] = measure q[0];", + ] + ), + inputs={}, + ), + ), + ( + Circuit().rx(0, FreeParameter("gamma")), + (Gate.Rx(FreeParameter("alpha")), QubitSet(0)), + OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "input float gamma;", + "input float beta;", + "bit[1] b;", + "qubit[1] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," + " 400.0ms, 0.2, 1, false);", + "}", + "defcal rx(float alpha) $0 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "rx(gamma) q[0];", + "b[0] = measure q[0];", + ] + ), + inputs={}, + ), + ), + ( + Circuit().ms(0, 1, 0.1, 0.2, 0.3), + ( + Gate.MS(FreeParameter("alpha"), FreeParameter("beta"), FreeParameter("gamma")), + QubitSet([0, 1]), + ), + OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "bit[2] b;", + "qubit[2] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," + " 400.0ms, 0.2, 1, false);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "ms(0.1, 0.2, 0.3) q[0], q[1];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + ] + ), + inputs={}, + ), + ), + ( + Circuit().ms(0, 1, 0.1, 0.2, FreeParameter("gamma")), + ( + Gate.MS(FreeParameter("alpha"), FreeParameter("beta"), FreeParameter("gamma")), + QubitSet([0, 1]), + ), + OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "input float gamma;", + "bit[2] b;", + "qubit[2] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," + " 400.0ms, 0.2, 1, false);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "ms(0.1, 0.2, gamma) q[0], q[1];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + ] + ), + inputs={}, + ), + ), + ( + Circuit().ms(0, 1, 0.1, 0.2, FreeParameter("gamma")), + ( + Gate.MS(FreeParameter("alpha"), FreeParameter("theta"), FreeParameter("gamma")), + QubitSet([0, 1]), + ), + OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "input float gamma;", + "input float beta;", + "bit[2] b;", + "qubit[2] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," + " 400.0ms, 0.2, 1, false);", + "}", + "defcal ms(float alpha, float theta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "ms(0.1, 0.2, gamma) q[0], q[1];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + ] + ), + inputs={}, + ), + ), + ], +) +def test_parametric_circuit_with_parametric_defcal( + circuit, calibration_key, expected_ir, pulse_sequence_2 +): + serialization_properties = OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL) + gate_calibrations = GateCalibrations( + { + calibration_key: pulse_sequence_2, + } + ) + + assert ( + circuit.to_ir( + ir_type=IRType.OPENQASM, + serialization_properties=serialization_properties, + gate_definitions=gate_calibrations.pulse_sequences, + ) + == expected_ir + ) + + def test_parametric_circuit_with_fixed_argument_defcal(pulse_sequence): circ = Circuit().h(0, power=-2.5).h(0, power=0).rx(0, angle=FreeParameter("theta")) serialization_properties = OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL) From 1491e7e262fe4a94f2a364e41f8738769eb7f5fd Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Sun, 25 Feb 2024 22:39:10 -0500 Subject: [PATCH 40/43] fix merge --- src/braket/circuits/circuit.py | 15 +------- .../braket/circuits/test_circuit.py | 36 ++----------------- 2 files changed, 3 insertions(+), 48 deletions(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index 4f561c63d..0b89e466e 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1260,8 +1260,7 @@ def _generate_frame_wf_defcal_declarations( ) -> str | None: """Generates the header where frames, waveforms and defcals are declared. - It also adds any FreeParameter of the calibrations that is not gate arguments - to the circuit parameter set. + It also adds any FreeParameter that is not gate arguments to the circuit parameter set. Args: gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): The @@ -1271,18 +1270,6 @@ def _generate_frame_wf_defcal_declarations( str | None: An OpenQASM string """ - """Generates the header where frame, waveform and defcals are declared. - - It also adds any FreeParameter that is not gate arguments to the circuit parameter set. - - Args: - gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]]): The - calibration data for the device. - - Returns: - Optional[str]: An OpenQASM string - """ - program = oqpy.Program(None, simplify_constants=False) frames, waveforms = self._get_frames_waveforms_from_instrs(gate_definitions) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index c601e6d38..cc1249e76 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -1073,8 +1073,8 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir, source="\n".join( [ "OPENQASM 3.0;", - "input float gamma;", "input float beta;", + "input float gamma;", "bit[1] b;", "qubit[1] q;", "cal {", @@ -1101,8 +1101,8 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir, source="\n".join( [ "OPENQASM 3.0;", - "input float gamma;", "input float beta;", + "input float gamma;", "bit[1] b;", "qubit[1] q;", "cal {", @@ -1183,38 +1183,6 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir, inputs={}, ), ), - ( - Circuit().ms(0, 1, 0.1, 0.2, FreeParameter("gamma")), - ( - Gate.MS(FreeParameter("alpha"), FreeParameter("theta"), FreeParameter("gamma")), - QubitSet([0, 1]), - ), - OpenQasmProgram( - source="\n".join( - [ - "OPENQASM 3.0;", - "input float gamma;", - "input float beta;", - "bit[2] b;", - "qubit[2] q;", - "cal {", - " waveform drag_gauss_wf = drag_gaussian(3.0ms," - " 400.0ms, 0.2, 1, false);", - "}", - "defcal ms(float alpha, float theta, float gamma) $0, $1 {", - " shift_phase(predefined_frame_1, alpha);", - " set_phase(predefined_frame_1, gamma);", - " shift_phase(predefined_frame_1, beta);", - " play(predefined_frame_1, drag_gauss_wf);", - "}", - "ms(0.1, 0.2, gamma) q[0], q[1];", - "b[0] = measure q[0];", - "b[1] = measure q[1];", - ] - ), - inputs={}, - ), - ), ], ) def test_parametric_circuit_with_parametric_defcal( From 353795523f5e3f131d049bcfb690738a843f5acf Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 18 Mar 2024 12:09:12 -0400 Subject: [PATCH 41/43] fix tests --- .../braket/circuits/test_circuit.py | 189 ++++++++---------- 1 file changed, 85 insertions(+), 104 deletions(-) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 07c7fc2e2..558d5bbb7 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -1105,63 +1105,49 @@ def test_circuit_to_ir_openqasm_with_gate_calibrations( @pytest.mark.parametrize( - "circuit, calibration_key, expected_ir", + "circuit, calibration_key, input_variables, expected_ir, input_values", [ ( Circuit().rx(0, 0.2), (Gate.Rx(FreeParameter("alpha")), QubitSet(0)), - OpenQasmProgram( - source="\n".join( - [ - "OPENQASM 3.0;", - "input float beta;", - "input float gamma;", - "bit[1] b;", - "qubit[1] q;", - "cal {", - " waveform drag_gauss_wf = drag_gaussian(3.0ms," - " 400.0ms, 0.2, 1, false);", - "}", - "defcal rx(float alpha) $0 {", - " shift_phase(predefined_frame_1, alpha);", - " set_phase(predefined_frame_1, gamma);", - " shift_phase(predefined_frame_1, beta);", - " play(predefined_frame_1, drag_gauss_wf);", - "}", - "rx(0.2) q[0];", - "b[0] = measure q[0];", - ] - ), - inputs={}, - ), + {"beta", "gamma"}, + [ + "bit[1] b;", + "qubit[1] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," " 400.0ms, 0.2, 1, false);", + "}", + "defcal rx(float alpha) $0 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "rx(0.2) q[0];", + "b[0] = measure q[0];", + ], + {}, ), ( Circuit().rx(0, FreeParameter("gamma")), (Gate.Rx(FreeParameter("alpha")), QubitSet(0)), - OpenQasmProgram( - source="\n".join( - [ - "OPENQASM 3.0;", - "input float beta;", - "input float gamma;", - "bit[1] b;", - "qubit[1] q;", - "cal {", - " waveform drag_gauss_wf = drag_gaussian(3.0ms," - " 400.0ms, 0.2, 1, false);", - "}", - "defcal rx(float alpha) $0 {", - " shift_phase(predefined_frame_1, alpha);", - " set_phase(predefined_frame_1, gamma);", - " shift_phase(predefined_frame_1, beta);", - " play(predefined_frame_1, drag_gauss_wf);", - "}", - "rx(gamma) q[0];", - "b[0] = measure q[0];", - ] - ), - inputs={}, - ), + {"beta", "gamma"}, + [ + "bit[1] b;", + "qubit[1] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," " 400.0ms, 0.2, 1, false);", + "}", + "defcal rx(float alpha) $0 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "rx(gamma) q[0];", + "b[0] = measure q[0];", + ], + {}, ), ( Circuit().ms(0, 1, 0.1, 0.2, 0.3), @@ -1169,29 +1155,24 @@ def test_circuit_to_ir_openqasm_with_gate_calibrations( Gate.MS(FreeParameter("alpha"), FreeParameter("beta"), FreeParameter("gamma")), QubitSet([0, 1]), ), - OpenQasmProgram( - source="\n".join( - [ - "OPENQASM 3.0;", - "bit[2] b;", - "qubit[2] q;", - "cal {", - " waveform drag_gauss_wf = drag_gaussian(3.0ms," - " 400.0ms, 0.2, 1, false);", - "}", - "defcal ms(float alpha, float beta, float gamma) $0, $1 {", - " shift_phase(predefined_frame_1, alpha);", - " set_phase(predefined_frame_1, gamma);", - " shift_phase(predefined_frame_1, beta);", - " play(predefined_frame_1, drag_gauss_wf);", - "}", - "ms(0.1, 0.2, 0.3) q[0], q[1];", - "b[0] = measure q[0];", - "b[1] = measure q[1];", - ] - ), - inputs={}, - ), + {}, + [ + "bit[2] b;", + "qubit[2] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," " 400.0ms, 0.2, 1, false);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "ms(0.1, 0.2, 0.3) q[0], q[1];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + ], + {}, ), ( Circuit().ms(0, 1, 0.1, 0.2, FreeParameter("gamma")), @@ -1199,35 +1180,29 @@ def test_circuit_to_ir_openqasm_with_gate_calibrations( Gate.MS(FreeParameter("alpha"), FreeParameter("beta"), FreeParameter("gamma")), QubitSet([0, 1]), ), - OpenQasmProgram( - source="\n".join( - [ - "OPENQASM 3.0;", - "input float gamma;", - "bit[2] b;", - "qubit[2] q;", - "cal {", - " waveform drag_gauss_wf = drag_gaussian(3.0ms," - " 400.0ms, 0.2, 1, false);", - "}", - "defcal ms(float alpha, float beta, float gamma) $0, $1 {", - " shift_phase(predefined_frame_1, alpha);", - " set_phase(predefined_frame_1, gamma);", - " shift_phase(predefined_frame_1, beta);", - " play(predefined_frame_1, drag_gauss_wf);", - "}", - "ms(0.1, 0.2, gamma) q[0], q[1];", - "b[0] = measure q[0];", - "b[1] = measure q[1];", - ] - ), - inputs={}, - ), + {"gamma"}, + [ + "bit[2] b;", + "qubit[2] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," " 400.0ms, 0.2, 1, false);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "ms(0.1, 0.2, gamma) q[0], q[1];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + ], + {}, ), ], ) def test_parametric_circuit_with_parametric_defcal( - circuit, calibration_key, expected_ir, pulse_sequence_2 + circuit, calibration_key, input_variables, expected_ir, input_values, pulse_sequence_2 ): serialization_properties = OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL) gate_calibrations = GateCalibrations( @@ -1236,13 +1211,19 @@ def test_parametric_circuit_with_parametric_defcal( } ) - assert ( - circuit.to_ir( - ir_type=IRType.OPENQASM, - serialization_properties=serialization_properties, - gate_definitions=gate_calibrations.pulse_sequences, - ) - == expected_ir + assert circuit.to_ir( + ir_type=IRType.OPENQASM, + serialization_properties=serialization_properties, + gate_definitions=gate_calibrations.pulse_sequences, + ) == OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + *[f"input float {parameter};" for parameter in input_variables], + *expected_ir, + ] + ), + inputs=input_values, ) From eeb87c344be1d2a758d66ae5f6e9b6352c02c0a9 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 18 Mar 2024 12:40:49 -0400 Subject: [PATCH 42/43] create more robust test --- test/unit_tests/braket/circuits/test_circuit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 558d5bbb7..8179a5719 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -1219,12 +1219,13 @@ def test_parametric_circuit_with_parametric_defcal( source="\n".join( [ "OPENQASM 3.0;", - *[f"input float {parameter};" for parameter in input_variables], + *[f"input float {parameter};" for parameter in circuit.parameters], *expected_ir, ] ), inputs=input_values, ) + assert circuit.parameters == {FreeParameter(name) for name in input_variables} def test_parametric_circuit_with_fixed_argument_defcal(pulse_sequence): From 0bf9b2234803133815263b84d06d00d9ba5b5cf4 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Mon, 18 Mar 2024 13:38:08 -0400 Subject: [PATCH 43/43] fix docstring --- src/braket/circuits/circuit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index b742ae955..d8397af2e 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1263,7 +1263,7 @@ def _generate_frame_wf_defcal_declarations( ) -> str | None: """Generates the header where frames, waveforms and defcals are declared. - It also adds any FreeParameter that is not gate arguments to the circuit parameter set. + It also adds any FreeParameter that is not a gate argument to the circuit parameter set. Args: gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): The