Skip to content

Commit 0261648

Browse files
authoredDec 27, 2024··
fix: Improvements to precompilation and adjustments to new Quasar api (#61)
* change: Updates for latest Quasar parsing * fix: More cleanup, improved precompilation
1 parent 42a9214 commit 0261648

11 files changed

+63
-45
lines changed
 

‎Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BraketSimulator"
22
uuid = "76d27892-9a0b-406c-98e4-7c178e9b3dff"
33
authors = ["Katharine Hyatt <hyatkath@amazon.com> and contributors"]
4-
version = "0.0.6"
4+
version = "0.0.7"
55

66
[deps]
77
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -39,7 +39,7 @@ Logging = "1.6"
3939
OrderedCollections = "=1.6.3"
4040
PrecompileTools = "=1.2.1"
4141
PythonCall = "=0.9.23"
42-
Quasar = "0.0.1"
42+
Quasar = "0.0.2"
4343
Random = "1.6"
4444
StaticArrays = "1.9"
4545
StatsBase = "0.34"

‎src/BraketSimulator.jl

+19-3
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,11 @@ include("noise_kernels.jl")
7878

7979
function __init__()
8080
Quasar.builtin_gates[] = builtin_gates
81+
Quasar.parse_pragma[] = parse_pragma
82+
Quasar.visit_pragma[] = visit_pragma
8183
end
8284

83-
const LOG2_CHUNK_SIZE = 10
85+
const LOG2_CHUNK_SIZE = 16
8486
const CHUNK_SIZE = 2^LOG2_CHUNK_SIZE
8587

8688
function _index_to_endian_bits(ix::Int, qubit_count::Int)
@@ -203,7 +205,7 @@ basis rotation instructions if running with non-zero shots. Return the `Program`
203205
parsing and the qubit count of the circuit.
204206
"""
205207
function _prepare_program(circuit_ir::OpenQasmProgram, inputs::Dict{String, <:Any}, shots::Int)
206-
ir_inputs = isnothing(circuit_ir.inputs) ? Dict{String, Float64}() : circuit_ir.inputs
208+
ir_inputs = isnothing(circuit_ir.inputs) ? Dict{String, Float64}() : circuit_ir.inputs
207209
merged_inputs = merge(ir_inputs, inputs)
208210
src = circuit_ir.source::String
209211
circuit = to_circuit(src, merged_inputs)
@@ -704,11 +706,14 @@ include("dm_simulator.jl")
704706
bit[3] b;
705707
qubit[3] q;
706708
rx(0.1) q[0];
709+
rx(1) q[0];
707710
prx(0.1, 0.2) q[0];
708711
x q[0];
709712
ry(0.1) q[0];
713+
ry(1) q[0];
710714
y q[0];
711715
rz(0.1) q[0];
716+
rz(1) q[0];
712717
z q[0];
713718
h q[0];
714719
i q[0];
@@ -758,6 +763,11 @@ include("dm_simulator.jl")
758763
#pragma braket result density_matrix
759764
#pragma braket result probability
760765
#pragma braket result expectation x(q[0])
766+
#pragma braket result expectation x(q[0]) @ x(q[1])
767+
#pragma braket result expectation z(q[0]) @ z(q[1])
768+
#pragma braket result expectation y(q[0]) @ y(q[1])
769+
#pragma braket result expectation h(q[0]) @ h(q[1])
770+
#pragma braket result expectation i(q[0]) @ i(q[1])
761771
#pragma braket result variance x(q[0]) @ y(q[1])
762772
"""
763773
dm_exact_results_qasm = """
@@ -772,16 +782,22 @@ include("dm_simulator.jl")
772782
"""
773783
shots_results_qasm = """
774784
OPENQASM 3.0;
775-
qubit[2] q;
785+
qubit[10] q;
776786
h q;
777787
#pragma braket result probability
778788
#pragma braket result expectation x(q[0])
779789
#pragma braket result variance x(q[0]) @ y(q[1])
780790
#pragma braket result sample x(q[0]) @ y(q[1])
791+
#pragma braket result expectation z(q[2]) @ z(q[3])
792+
#pragma braket result expectation x(q[4]) @ x(q[5])
793+
#pragma braket result expectation y(q[6]) @ y(q[7])
794+
#pragma braket result expectation h(q[8]) @ h(q[9])
781795
"""
782796
@compile_workload begin
783797
using BraketSimulator, Quasar
784798
Quasar.builtin_gates[] = BraketSimulator.builtin_gates
799+
Quasar.parse_pragma[] = BraketSimulator.parse_pragma
800+
Quasar.visit_pragma[] = BraketSimulator.visit_pragma
785801
simulator = StateVectorSimulator(5, 0)
786802
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), custom_qasm, nothing)
787803
simulate(simulator, oq3_program, 100)

‎src/circuit.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,12 @@ function basis_rotation_instructions!(c::Circuit)
169169
c.basis_rotation_instructions = reduce(vcat, _observable_to_instruction(all_qubit_observable, target) for target in qubits(c))
170170
return c
171171
end
172-
unsorted = collect(Set(values(c.qubit_observable_target_mapping)))
173-
target_lists = sort(unsorted)
172+
mapping_vals = collect(values(c.qubit_observable_target_mapping))
173+
target_lists = unique(mapping_vals)
174174
for target_list in target_lists
175-
observable = c.qubit_observable_mapping[first(target_list)]
176-
append!(basis_rotation_instructions, _observable_to_instruction(observable, target_list))
175+
observable = c.qubit_observable_mapping[first(target_list)]
176+
observable_ix = _observable_to_instruction(observable, target_list)
177+
append!(basis_rotation_instructions, observable_ix)
177178
end
178179
c.basis_rotation_instructions = basis_rotation_instructions
179180
return c

‎src/gate_kernels.jl

+13-15
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,14 @@ function matrix_rep(g::PRx)
9696
end
9797

9898
for G in (:Rx, :Ry, :Rz, :PhaseShift)
99-
@eval function matrix_rep(g::$G)
99+
@eval function matrix_rep(g::$G)::SMatrix{2,2,ComplexF64}
100100
n = g.pow_exponent::Float64
101101
θ = @inbounds g.angle[1]
102-
iszero(n) && return matrix_rep_raw(I())::SMatrix{2,2,ComplexF64}
103-
isone(n) && return matrix_rep_raw(g, θ)::SMatrix{2,2,ComplexF64}
104-
isinteger(n) && return matrix_rep_raw(g, θ*n)::SMatrix{2,2,ComplexF64}
105-
return SMatrix{2,2,ComplexF64}(matrix_rep_raw(g, θ) ^ n)
102+
one_mat = matrix_rep_raw(g, θ)
103+
iszero(n) && return matrix_rep_raw(I())
104+
isone(n) && return one_mat
105+
isinteger(n) && return matrix_rep_raw(g, θ*n)
106+
return SMatrix{2,2,ComplexF64}(one_mat ^ n)
106107
end
107108
end
108109

@@ -211,9 +212,9 @@ matrix_rep_raw(g::PRx) = SMatrix{2,2}(
211212
-im*exp(im*g.angle[2])*sin(g.angle[1]/2.0) cos(g.angle[1] / 2.0)
212213
],
213214
)
214-
matrix_rep_raw(g::Rz, ϕ) = (θ = ϕ/2.0; return SMatrix{2,2}(exp(-im*θ), 0.0, 0.0, exp(im*θ)))
215-
matrix_rep_raw(g::Rx, ϕ) = ((sθ, cθ) = sincos/2.0); return SMatrix{2,2}(cθ, -im*sθ, -im*sθ, cθ))
216-
matrix_rep_raw(g::Ry, ϕ) = ((sθ, cθ) = sincos/2.0); return SMatrix{2,2}(complex(cθ), complex(sθ), -complex(sθ), complex(cθ)))
215+
matrix_rep_raw(g::Rz, ϕ)::SMatrix{2,2,ComplexF64} = ((sθ, cθ) = sincos(ϕ/2.0); return SMatrix{2,2}(- im*, 0.0, 0.0, + im*))
216+
matrix_rep_raw(g::Rx, ϕ)::SMatrix{2,2,ComplexF64} = ((sθ, cθ) = sincos/2.0); return SMatrix{2,2}(cθ, -im*sθ, -im*sθ, cθ))
217+
matrix_rep_raw(g::Ry, ϕ)::SMatrix{2,2,ComplexF64} = ((sθ, cθ) = sincos/2.0); return SMatrix{2,2}(complex(cθ), complex(sθ), -complex(sθ), complex(cθ)))
217218
matrix_rep_raw(g::GPi) =
218219
SMatrix{2,2}(complex([0 exp(-im * g.angle[1]); exp(im * g.angle[1]) 0]))
219220

@@ -311,14 +312,11 @@ function apply_gate!(
311312
g_00, g_10, g_01, g_11 = g_matrix
312313
Threads.@threads for chunk_index = 0:n_chunks-1
313314
# my_amps is the group of amplitude generators which this `Task` will process
314-
my_amps = if n_chunks > 1
315-
chunk_index*CHUNK_SIZE:((chunk_index+1)*CHUNK_SIZE-1)
316-
else
317-
0:n_tasks-1
318-
end
319-
lower_ix = pad_bit(my_amps[1], endian_qubit) + 1
315+
first_amp = n_chunks > 1 ? chunk_index*CHUNK_SIZE : 0
316+
amp_block = n_chunks > 1 ? CHUNK_SIZE : n_tasks
317+
lower_ix = pad_bit(first_amp, endian_qubit) + 1
320318
higher_ix = lower_ix + flipper
321-
for task_amp = 0:length(my_amps)-1
319+
for task_amp = 0:amp_block-1
322320
if is_small_target && div(task_amp, flipper) > 0 && mod(task_amp, flipper) == 0
323321
lower_ix = higher_ix
324322
higher_ix = lower_ix + flipper

‎src/gates.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ mutable struct Unitary <: Gate
112112
Unitary(matrix::Matrix{<:Number}, pow_exponent=1.0) = new(ComplexF64.(matrix), Float64(pow_exponent))
113113
end
114114
Base.:(==)(u1::Unitary, u2::Unitary) = u1.matrix == u2.matrix && u1.pow_exponent == u2.pow_exponent
115-
qubit_count(g::Unitary) = convert(Int, log2(size(g.matrix, 1)))
115+
qubit_count(g::Unitary) = qubit_count(g.matrix)
116116
StructTypes.constructfrom(::Type{Unitary}, nt::Quasar.CircuitInstruction) = Unitary(only(nt.arguments), nt.exponent)
117117

118118
Parametrizable(g::AngledGate) = Parametrized()

‎src/observables.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ HermitianObservable(v::Vector{Vector{Vector{T}}}) where {T<:Number} = HermitianO
7171
Base.copy(o::HermitianObservable) = HermitianObservable(copy(o.matrix))
7272
StructTypes.lower(x::HermitianObservable) = Union{String, Vector{Vector{Vector{Float64}}}}[complex_matrix_to_ir(ComplexF64.(x.matrix))]
7373
Base.:(==)(h1::HermitianObservable, h2::HermitianObservable) = (size(h1.matrix) == size(h2.matrix) && h1.matrix h2.matrix)
74-
qubit_count(o::HermitianObservable) = convert(Int, log2(size(o.matrix, 1)))
74+
qubit_count(o::HermitianObservable) = qubit_count(o.matrix)
7575
LinearAlgebra.eigvals(o::HermitianObservable) = eigvals(Hermitian(o.matrix))
7676
unscaled(o::HermitianObservable) = o
7777
Base.:(*)(o::HermitianObservable, n::Real) = HermitianObservable(Float64(n) .* o.matrix)

‎src/operators.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ StructTypes.StructType(::Type{QuantumOperator}) = StructTypes.AbstractType()
2121
StructTypes.subtypes(::Type{QuantumOperator}) = (h=H, i=I, x=X, y=Y, z=Z, s=S, si=Si, t=T, ti=Ti, v=V, vi=Vi, cnot=CNot, swap=Swap, iswap=ISwap, cv=CV, cy=CY, cz=CZ, ecr=ECR, ccnot=CCNot, cswap=CSwap, unitary=Unitary, rx=Rx, ry=Ry, rz=Rz, phaseshift=PhaseShift, pswap=PSwap, xy=XY, cphaseshift=CPhaseShift, cphaseshift00=CPhaseShift00, cphaseshift01=CPhaseShift01, cphaseshift10=CPhaseShift10, xx=XX, yy=YY, zz=ZZ, gpi=GPi, gpi2=GPi2, ms=MS, prx=PRx, u=U, gphase=GPhase, kraus=Kraus, bit_flip=BitFlip, phase_flip=PhaseFlip, pauli_channel=PauliChannel, amplitude_damping=AmplitudeDamping, phase_damping=PhaseDamping, depolarizing=Depolarizing, two_qubit_dephasing=TwoQubitDephasing, two_qubit_depolarizing=TwoQubitDepolarizing, generalized_amplitude_damping=GeneralizedAmplitudeDamping, multi_qubit_pauli_channel=MultiQubitPauliChannel, measure=Measure, reset=Reset, barrier=Barrier, delay=Delay)
2222
parameters(::QuantumOperator) = FreeParameter[]
2323

24+
qubit_count(o::Matrix) = Int(log2(size(o, 1)))
25+
2426
struct PauliEigenvalues{N}
2527
coeff::Float64
2628
PauliEigenvalues{N}(coeff::Float64=1.0) where {N} = new(coeff)
@@ -99,4 +101,4 @@ for T in (:Barrier, :Reset, :Delay, :Measure)
99101
qubit_count(o::$T) = qubit_count($T)
100102
Parametrizable(::$T) = NonParametrized()
101103
end
102-
end
104+
end

‎src/pragmas.jl

+11-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
Quasar.qubit_count(o::String) = length(o)
2-
Quasar.qubit_count(o::Matrix) = Int(log2(size(o, 1)))
1+
qubit_count(o::String) = length(o)
32

43
function _observable_targets_error(observable::Matrix{ComplexF64}, targets)
54
mat = Vector{Vector{Vector{Float64}}}(undef, size(observable, 1))
@@ -14,7 +13,7 @@ end
1413
_observable_targets_error(::String, targets) = throw(Quasar.QasmVisitorError("Standard observable target must be exactly 1 qubit.", "ValueError"))
1514

1615
function _check_observable_targets(observable::Union{Matrix{ComplexF64}, String}, targets)
17-
qc = Quasar.qubit_count(observable)
16+
qc = qubit_count(observable)
1817
qc == 1 && (isempty(targets) || length(targets) == 1) && return
1918
qc == length(targets) && return
2019
_observable_targets_error(observable, targets)
@@ -42,7 +41,7 @@ function visit_observable(v, expr)
4241
end
4342
end
4443

45-
function Quasar.visit_pragma(v, program_expr)
44+
function visit_pragma(v, program_expr)
4645
pragma_type::Symbol = program_expr.args[1]
4746
if pragma_type == :result
4847
result_type = program_expr.args[2]
@@ -95,12 +94,12 @@ function Quasar.visit_pragma(v, program_expr)
9594
end
9695

9796
function parse_matrix(tokens::Vector{Tuple{Int64, Int32, Quasar.Token}}, stack, start, qasm)
98-
inner = Quasar.extract_braced_block(tokens, stack, start, qasm)
97+
inner = Quasar.extract_expression(tokens, Quasar.lbracket, Quasar.rbracket, stack, start, qasm)
9998
n_rows = count(triplet->triplet[end] == Quasar.lbracket, inner)
10099
matrix = Matrix{Quasar.QasmExpression}(undef, n_rows, n_rows)
101100
row = 1
102101
while !isempty(inner)
103-
row_tokens = Quasar.extract_braced_block(inner, stack, start, qasm)
102+
row_tokens = Quasar.extract_expression(inner, Quasar.lbracket, Quasar.rbracket, stack, start, qasm)
104103
push!(row_tokens, (-1, Int32(-1), Quasar.semicolon))
105104
col = 1
106105
while !isempty(row_tokens)
@@ -126,7 +125,7 @@ function parse_pragma_observables(tokens::Vector{Tuple{Int64, Int32, Quasar.Toke
126125
observable_token = popfirst!(tokens)
127126
observable_id = Quasar.parse_identifier(observable_token, qasm)
128127
if observable_id.args[1] == "hermitian"
129-
matrix_tokens = Quasar.extract_parensed(tokens, stack, start, qasm)
128+
matrix_tokens = Quasar.extract_expression(tokens, Quasar.lparen, Quasar.rparen, stack, start, qasm)
130129
# next token is targets
131130
h_mat = parse_matrix(matrix_tokens, stack, start, qasm)
132131
# next token is targets
@@ -146,7 +145,7 @@ function parse_pragma_observables(tokens::Vector{Tuple{Int64, Int32, Quasar.Toke
146145
break
147146
else
148147
if !isempty(tokens) && first(tokens)[end] == Quasar.lparen
149-
arg_tokens = Quasar.extract_parensed(tokens, stack, start, qasm)
148+
arg_tokens = Quasar.extract_expression(tokens, Quasar.lparen, Quasar.rparen, stack, start, qasm)
150149
push!(arg_tokens, (-1, Int32(-1), Quasar.semicolon))
151150
target_expr = Quasar.parse_expression(arg_tokens, stack, start, qasm)
152151
push!(obs_targets, target_expr)
@@ -175,7 +174,7 @@ function parse_pragma_targets(tokens::Vector{Tuple{Int64, Int32, Quasar.Token}},
175174
end
176175

177176

178-
function Quasar.parse_pragma(tokens, stack, start, qasm)
177+
function parse_pragma(tokens, stack, start, qasm)
179178
prefix = popfirst!(tokens)
180179
prefix_id = Quasar.parse_identifier(prefix, qasm)
181180
prefix_id.args[1] == "braket" || throw(Quasar.QasmParseError("pragma expression must begin with `#pragma braket`", stack, start, qasm))
@@ -204,7 +203,7 @@ function Quasar.parse_pragma(tokens, stack, start, qasm)
204203
end
205204
elseif pragma_type == "unitary"
206205
push!(expr, :unitary)
207-
matrix_tokens = Quasar.extract_parensed(tokens, stack, start, qasm)
206+
matrix_tokens = Quasar.extract_expression(tokens, Quasar.lparen, Quasar.rparen, stack, start, qasm)
208207
unitary_matrix = parse_matrix(matrix_tokens, stack, start, qasm)
209208
push!(expr, unitary_matrix)
210209
target_expr = parse_pragma_targets(tokens, stack, start, qasm)
@@ -213,8 +212,8 @@ function Quasar.parse_pragma(tokens, stack, start, qasm)
213212
push!(expr, :noise)
214213
noise_type = Quasar.parse_identifier(popfirst!(tokens), qasm)::Quasar.QasmExpression
215214
if noise_type.args[1] == "kraus"
216-
matrix_tokens = Quasar.extract_parensed(tokens, stack, start, qasm)
217-
all(triplet->triplet[end] == Quasar.lbracket, matrix_tokens[1:3]) && (matrix_tokens = Quasar.extract_braced_block(matrix_tokens, stack, start, qasm))
215+
matrix_tokens = Quasar.extract_expression(tokens, Quasar.lparen, Quasar.rparen, stack, start, qasm)
216+
all(triplet->triplet[end] == Quasar.lbracket, matrix_tokens[1:3]) && (matrix_tokens = Quasar.extract_expression(matrix_tokens, Quasar.lbracket, Quasar.rbracket, stack, start, qasm))
218217
mats = Matrix{Quasar.QasmExpression}[]
219218
while !isempty(matrix_tokens)
220219
push!(mats, parse_matrix(matrix_tokens, stack, start, qasm))

‎src/validation.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ function _check_observable(observable_map, observable, qubits)
164164
observable_map[qubits] = observable
165165
return observable_map
166166
end
167+
_check_observable(observable_map, observable, qubits::Int) = _check_observable(observable_map, observable, [qubits])
167168

168169
function _combine_obs_and_targets(observable::Observables.HermitianObservable, result_targets::Vector{Int})
169170
obs_qc = qubit_count(observable)
@@ -175,12 +176,14 @@ _combine_obs_and_targets(observable::Observables.TensorProduct, result_targets::
175176
_combine_obs_and_targets(observable, result_targets::Vector{Int}) = length(result_targets) == 1 ? [(observable, result_targets)] : [(copy(observable), t) for t in result_targets]
176177

177178
function _verify_openqasm_shots_observables(circuit::Circuit, n_qubits::Int)
178-
observable_map = Dict()
179+
observable_map = LittleDict{Vector{Int}, Observables.Observable}()
179180
for result in filter(rt->rt isa ObservableResult, circuit.result_types)
180181
result.observable isa Observables.I && continue
181182
result_targets = isempty(result.targets) ? collect(0:n_qubits-1) : collect(result.targets)
182183
for obs_and_target in _combine_obs_and_targets(result.observable, result_targets)
183-
observable_map = _check_observable(observable_map, obs_and_target...)
184+
obs = obs_and_target[1]
185+
targ = obs_and_target[2]
186+
observable_map = _check_observable(observable_map, obs, targ)
184187
end
185188
end
186189
return

‎test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test, Aqua, Documenter, BraketSimulator
22

3-
Aqua.test_all(BraketSimulator, ambiguities=false, piracies=false)
3+
Aqua.test_all(BraketSimulator, ambiguities=false)
44
Aqua.test_ambiguities(BraketSimulator)
55
dir_list = filter(x-> startswith(x, "test_") && endswith(x, ".jl"), readdir(@__DIR__))
66

‎test/test_openqasm3.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -762,9 +762,7 @@ get_tol(shots::Int) = return (
762762
qubit[3] q;
763763
i q;
764764
#pragma braket result expectation x(q[2])
765-
// # noqa: E501
766765
#pragma braket result expectation hermitian([[-6+0im, 2+1im, -3+0im, -5+2im], [2-1im, 0im, 2-1im, -5+4im], [-3+0im, 2+1im, 0im, -4+3im], [-5-2im, -5-4im, -4-3im, -6+0im]]) q[0:1]
767-
// # noqa: E501
768766
#pragma braket result expectation x(q[2]) @ hermitian([[-6+0im, 2+1im, -3+0im, -5+2im], [2-1im, 0im, 2-1im, -5+4im], [-3+0im, 2+1im, 0im, -4+3im], [-5-2im, -5-4im, -4-3im, -6+0im]]) q[0:1]
769767
"""
770768
circuit = BraketSimulator.to_circuit(qasm)
@@ -773,7 +771,8 @@ get_tol(shots::Int) = return (
773771
2-1im 0 2-1im -5+4im;
774772
-3 2+1im 0 -4+3im;
775773
-5-2im -5-4im -4-3im -6]
776-
h = BraketSimulator.Observables.HermitianObservable(arr)
774+
h = BraketSimulator.Observables.HermitianObservable(arr)
775+
@test circuit.result_types[2].observable.matrix == arr
777776
bris = vcat(BraketSimulator.diagonalizing_gates(h, [0, 1]), BraketSimulator.Instruction(BraketSimulator.H(), [2]))
778777
for (ix, bix) in zip(circuit.basis_rotation_instructions, bris)
779778
@test Matrix(BraketSimulator.matrix_rep(ix.operator)) adjoint(BraketSimulator.fix_endianness(Matrix(BraketSimulator.matrix_rep(bix.operator))))

0 commit comments

Comments
 (0)