diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index 3f4918a1f..56f0f2a8d 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1251,7 +1251,7 @@ def _validate_gate_calibrations_uniqueness( frames: dict[str, Frame], waveforms: dict[str, Waveform], ) -> None: - for _key, calibration in gate_definitions.items(): + for calibration in gate_definitions.values(): for frame in calibration._frames.values(): _validate_uniqueness(frames, frame) frames[frame.id] = frame @@ -1264,7 +1264,7 @@ def _generate_frame_wf_defcal_declarations( ) -> str | None: """Generates the header where frames, waveforms and defcals are declared. - It also adds any FreeParameter of the calibrations to the circuit parameter set. + It also adds any FreeParameter that is not a gate argument to the circuit parameter set. Args: gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): The @@ -1289,21 +1289,12 @@ def _generate_frame_wf_defcal_declarations( for key, calibration in gate_definitions.items(): gate, qubits = key - - # Ignoring parametric gates - # Corresponding defcals with fixed arguments have been added - # in _get_frames_waveforms_from_instrs - if isinstance(gate, Parameterizable) and any( - not isinstance(parameter, (float, int, complex)) - for parameter in gate.parameters - ): - continue - gate_name = gate._qasm_name arguments = gate.parameters if isinstance(gate, Parameterizable) else [] for param in calibration.parameters: - self._parameters.add(param) + if param not in arguments: + self._parameters.add(param) arguments = [ param._to_oqpy_expression() if isinstance(param, FreeParameter) else param for param in arguments @@ -1334,80 +1325,8 @@ def _get_frames_waveforms_from_instrs( for waveform in instruction.operator.pulse_sequence._waveforms.values(): _validate_uniqueness(waveforms, waveform) waveforms[waveform.id] = waveform - # this will change with full parametric calibration support - elif isinstance(instruction.operator, Parameterizable): - fixed_argument_calibrations = self._add_fixed_argument_calibrations( - gate_definitions, instruction - ) - gate_definitions.update(fixed_argument_calibrations) return frames, waveforms - def _add_fixed_argument_calibrations( - self, - gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], - instruction: Instruction, - ) -> dict[tuple[Gate, QubitSet], PulseSequence]: - """Adds calibrations with arguments set to the instruction parameter values - - Given the collection of parameters in instruction.operator, this function looks for matching - parametric calibrations that have free parameters. If such a calibration is found and the - number N of its free parameters equals the number of instruction parameters, we can bind - the arguments of the calibration and add it to the calibration dictionary. - - If N is smaller, it is probably impossible to assign the instruction parameter values to the - corresponding calibration parameters so we raise an error. - If N=0, we ignore it as it will not be removed by _generate_frame_wf_defcal_declarations. - - Args: - gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence]): a dictionary of - calibrations - instruction (Instruction): a Circuit instruction - - Returns: - dict[tuple[Gate, QubitSet], PulseSequence]: additional calibrations - - Raises: - NotImplementedError: in two cases: (i) if the instruction contains unbound parameters - and the calibration dictionary contains a parametric calibration applicable to this - instructions; (ii) if the calibration is defined with a partial number of unbound - parameters. - """ - additional_calibrations = {} - for key, calibration in gate_definitions.items(): - gate = key[0] - target = key[1] - if target != instruction.target: - continue - if isinstance(gate, type(instruction.operator)) and len( - instruction.operator.parameters - ) == len(gate.parameters): - free_parameter_number = sum( - [isinstance(p, FreeParameterExpression) for p in gate.parameters] - ) - if free_parameter_number == 0: - continue - elif free_parameter_number < len(gate.parameters): - raise NotImplementedError( - "Calibrations with a partial number of fixed parameters are not supported." - ) - elif any( - isinstance(p, FreeParameterExpression) for p in instruction.operator.parameters - ): - raise NotImplementedError( - "Parametric calibrations cannot be attached with parametric circuits." - ) - bound_key = ( - type(instruction.operator)(*instruction.operator.parameters), - instruction.target, - ) - additional_calibrations[bound_key] = calibration( - **{ - p.name if isinstance(p, FreeParameterExpression) else p: v - for p, v in zip(gate.parameters, instruction.operator.parameters) - } - ) - return additional_calibrations - def to_unitary(self) -> np.ndarray: """Returns the unitary matrix representation of the entire circuit. diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index ecaad12bf..8179a5719 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -148,25 +148,6 @@ def pulse_sequence_2(predefined_frame_1): ) -@pytest.fixture -def pulse_sequence_3(predefined_frame_1): - return ( - PulseSequence() - .shift_phase( - predefined_frame_1, - FreeParameter("alpha"), - ) - .shift_phase( - predefined_frame_1, - FreeParameter("beta"), - ) - .play( - predefined_frame_1, - DragGaussianWaveform(length=3e-3, sigma=0.4, beta=0.2, id="drag_gauss_wf"), - ) - ) - - @pytest.fixture def gate_calibrations(pulse_sequence, pulse_sequence_2): calibration_key = (Gate.Z(), QubitSet([0, 1])) @@ -804,10 +785,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) q[0];", "rx(0.3) q[1];", "b[0] = measure q[0];", @@ -833,10 +820,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) $0;", "rx(0.3) $4;", "b[0] = measure $0;", @@ -864,10 +857,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) $0;", "#pragma braket verbatim", "box{", @@ -899,10 +898,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) q[0];", "rx(0.3) q[4];", "#pragma braket noise bit_flip(0.2) q[3]", @@ -930,10 +935,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) q[0];", "rx(theta) q[1];", "b[0] = measure q[0];", @@ -963,10 +974,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "negctrl @ rx(0.15) q[2], q[0];", "ctrl(2) @ rx(0.3) q[2], q[3], q[1];", "ctrl(2) @ cnot q[2], q[3], q[4], q[0];", @@ -997,6 +1014,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal rx(float theta) $0 {", + " set_frequency(predefined_frame_1, 6000000.0);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "cnot q[0], q[1];", "cnot q[3], q[2];", "ctrl @ cnot q[5], q[6], q[4];", @@ -1029,10 +1056,14 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal ms(-0.1, -0.2, -0.3) $0, $1 {", - " shift_phase(predefined_frame_1, -0.1);", - " set_phase(predefined_frame_1, -0.3);", - " shift_phase(predefined_frame_1, -0.2);", + "defcal rx(float theta) $0 {", + " set_frequency(predefined_frame_1, 6000000.0);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", " play(predefined_frame_1, drag_gauss_wf);", "}", "inv @ pow(2.5) @ h q[0];", @@ -1074,52 +1105,127 @@ def test_circuit_to_ir_openqasm_with_gate_calibrations( @pytest.mark.parametrize( - "circuit, calibration_key, expected_ir", + "circuit, calibration_key, input_variables, expected_ir, input_values", [ ( Circuit().rx(0, 0.2), (Gate.Rx(FreeParameter("alpha")), QubitSet(0)), - OpenQasmProgram( - source="\n".join( - [ - "OPENQASM 3.0;", - "input float beta;", - "bit[1] b;", - "qubit[1] q;", - "cal {", - " waveform drag_gauss_wf = drag_gaussian(3.0ms," - " 400.0ms, 0.2, 1, false);", - "}", - "defcal rx(0.2) $0 {", - " shift_phase(predefined_frame_1, 0.2);", - " shift_phase(predefined_frame_1, beta);", - " play(predefined_frame_1, drag_gauss_wf);", - "}", - "rx(0.2) q[0];", - "b[0] = measure q[0];", - ] - ), - inputs={}, + {"beta", "gamma"}, + [ + "bit[1] b;", + "qubit[1] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," " 400.0ms, 0.2, 1, false);", + "}", + "defcal rx(float alpha) $0 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "rx(0.2) q[0];", + "b[0] = measure q[0];", + ], + {}, + ), + ( + Circuit().rx(0, FreeParameter("gamma")), + (Gate.Rx(FreeParameter("alpha")), QubitSet(0)), + {"beta", "gamma"}, + [ + "bit[1] b;", + "qubit[1] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," " 400.0ms, 0.2, 1, false);", + "}", + "defcal rx(float alpha) $0 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "rx(gamma) q[0];", + "b[0] = measure q[0];", + ], + {}, + ), + ( + Circuit().ms(0, 1, 0.1, 0.2, 0.3), + ( + Gate.MS(FreeParameter("alpha"), FreeParameter("beta"), FreeParameter("gamma")), + QubitSet([0, 1]), ), + {}, + [ + "bit[2] b;", + "qubit[2] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," " 400.0ms, 0.2, 1, false);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "ms(0.1, 0.2, 0.3) q[0], q[1];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + ], + {}, + ), + ( + Circuit().ms(0, 1, 0.1, 0.2, FreeParameter("gamma")), + ( + Gate.MS(FreeParameter("alpha"), FreeParameter("beta"), FreeParameter("gamma")), + QubitSet([0, 1]), + ), + {"gamma"}, + [ + "bit[2] b;", + "qubit[2] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," " 400.0ms, 0.2, 1, false);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "ms(0.1, 0.2, gamma) q[0], q[1];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + ], + {}, ), ], ) -def test_circuit_with_parametric_defcal(circuit, calibration_key, expected_ir, pulse_sequence_3): +def test_parametric_circuit_with_parametric_defcal( + circuit, calibration_key, input_variables, expected_ir, input_values, pulse_sequence_2 +): serialization_properties = OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL) gate_calibrations = GateCalibrations( { - calibration_key: pulse_sequence_3, + calibration_key: pulse_sequence_2, } ) - assert ( - circuit.to_ir( - ir_type=IRType.OPENQASM, - serialization_properties=serialization_properties, - gate_definitions=gate_calibrations.pulse_sequences, - ) - == expected_ir + assert circuit.to_ir( + ir_type=IRType.OPENQASM, + serialization_properties=serialization_properties, + gate_definitions=gate_calibrations.pulse_sequences, + ) == OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + *[f"input float {parameter};" for parameter in circuit.parameters], + *expected_ir, + ] + ), + inputs=input_values, ) + assert circuit.parameters == {FreeParameter(name) for name in input_variables} def test_parametric_circuit_with_fixed_argument_defcal(pulse_sequence): @@ -1241,10 +1347,10 @@ def foo( "cal {", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", "}", - "defcal foo(-0.2) $0 {", + "defcal foo(float beta) $0 {", " shift_phase(predefined_frame_1, -0.1);", " set_phase(predefined_frame_1, -0.3);", - " shift_phase(predefined_frame_1, -0.2);", + " shift_phase(predefined_frame_1, beta);", " play(predefined_frame_1, drag_gauss_wf);", "}", "foo(-0.2) q[0];",