Skip to content

Commit 7b23bee

Browse files
authored
docs: More docstrings and comments for gate kernels (#62)
1 parent 0261648 commit 7b23bee

File tree

2 files changed

+230
-12
lines changed

2 files changed

+230
-12
lines changed

docs/src/internals.md

+7
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,11 @@ BraketSimulator._combine_operations
1111
BraketSimulator._prepare_program
1212
BraketSimulator._get_measured_qubits
1313
BraketSimulator._compute_results
14+
BraketSimulator.flip_bit
15+
BraketSimulator.flip_bits
16+
BraketSimulator.pad_bit
17+
BraketSimulator.pad_bits
18+
BraketSimulator.matrix_rep
19+
BraketSimulator.endian_qubits
20+
BraketSimulator.get_amps_and_qubits
1421
```

src/gate_kernels.jl

+223-12
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,80 @@
1+
"""
2+
pad_bit(amp_index::Integer, bit::Integer)
3+
4+
Insert a `0` at location `bit` of `amp_index` (in its bits representation).
5+
The first valid value of `bit` is **zero**.
6+
7+
# Examples
8+
```jldoctest
9+
julia> amp_index = 10
10+
10
11+
12+
julia> digits(amp_index, base=2, pad=6)
13+
6-element Vector{Int64}:
14+
0
15+
1
16+
0
17+
1
18+
0
19+
0
20+
21+
julia> amp_index = BraketSimulator.pad_bit(amp_index, 2);
22+
23+
julia> digits(amp_index, base=2, pad=7)
24+
7-element Vector{Int64}:
25+
0
26+
1
27+
0
28+
0
29+
1
30+
0
31+
0
32+
```
33+
"""
134
@inline function pad_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer,Tj<:Integer}
235
left = (amp_index >> bit) << bit
336
right = amp_index - left
437
return (left << one(Ti)) right
538
end
39+
"""
40+
pad_bits(amp_index::Integer, to_pad)
41+
42+
Insert a `0` in `amp_index` at each location `bit` in the collection `to_pad`.
43+
The first valid value of any `bit` is **zero**.
44+
45+
# Examples
46+
```jldoctest
47+
julia> amp_index = 10
48+
10
49+
50+
julia> digits(amp_index, base=2, pad=6)
51+
6-element Vector{Int64}:
52+
0
53+
1
54+
0
55+
1
56+
0
57+
0
58+
59+
julia> amp_index = BraketSimulator.pad_bits(amp_index, (2, 4));
60+
61+
julia> digits(amp_index, base=2, pad=8)
62+
8-element Vector{Int64}:
63+
0
64+
1
65+
0
66+
0
67+
0
68+
1
69+
0
70+
0
71+
```
72+
73+
!!! note
74+
75+
The indices in `pad_bits` aren't adjusted based on previous indices -- this can be seen in the above example,
76+
where the bit at index 4 is different **before** and **after** inserting a bit at index 2.
77+
"""
678
function pad_bits(ix::Ti, to_pad)::Ti where {Ti<:Integer}
779
padded_ix = ix
880
for bit in to_pad
@@ -11,18 +83,139 @@ function pad_bits(ix::Ti, to_pad)::Ti where {Ti<:Integer}
1183
return padded_ix
1284
end
1385

14-
@inline function flip_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer,Tj<:Integer}
86+
"""
87+
flip_bit(amp_index::Integer, bit::Integer)
88+
89+
Flip the `bit`-th bit of `amp_index`, so that 0 becomes 1 and 1 becomes 0.
90+
The first valid value of `bit` is **zero**.
91+
92+
# Examples
93+
```jldoctest
94+
julia> amp_index = 10
95+
10
96+
97+
julia> digits(amp_index, base=2, pad=6)
98+
6-element Vector{Int64}:
99+
0
100+
1
101+
0
102+
1
103+
0
104+
0
105+
106+
julia> amp_index = BraketSimulator.flip_bit(amp_index, 1)
107+
8
108+
109+
julia> digits(amp_index, base=2, pad=6)
110+
6-element Vector{Int64}:
111+
0
112+
0
113+
0
114+
1
115+
0
116+
0
117+
```
118+
"""
119+
@inline function flip_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer, Tj<:Integer}
15120
return amp_index (one(Ti) << bit)
16121
end
122+
123+
"""
124+
flip_bits(amp_index::Integer, to_flip)
125+
126+
Flip the `bit`-th bit of `amp_index` for every `bit` in `to_flip`,
127+
so that 0 becomes 1 and 1 becomes 0.
128+
The first valid value of `bit` is **zero**.
129+
130+
# Examples
131+
```jldoctest
132+
julia> amp_index = 10
133+
10
134+
135+
julia> digits(amp_index, base=2, pad=6)
136+
6-element Vector{Int64}:
137+
0
138+
1
139+
0
140+
1
141+
0
142+
0
143+
144+
julia> amp_index = BraketSimulator.flip_bits(amp_index, (1, 3, 2))
145+
4
146+
147+
julia> digits(amp_index, base=2, pad=6)
148+
6-element Vector{Int64}:
149+
0
150+
0
151+
1
152+
0
153+
0
154+
0
155+
```
156+
"""
17157
function flip_bits(ix::Ti, to_flip)::Ti where {Ti<:Integer}
18158
flipped_ix = ix
19159
for bit in to_flip
20160
flipped_ix = flip_bit(flipped_ix, bit)
21161
end
22162
return flipped_ix
23163
end
164+
"""
165+
endian_qubits(n_qubits::Int, qubit::Int)
166+
167+
Rotate the qubit index `qubit` to match what Braket expects with the
168+
correct endianness. This has to be done because Braket and Julia have different
169+
[endianness](https://en.wikipedia.org/wiki/Endianness).
170+
171+
!!! note
172+
173+
The first valid value for `qubit` is **zero**, since qubits are zero-indexed.
174+
175+
# Examples
176+
```jldoctest
177+
julia> qubit = 2
178+
2
179+
180+
julia> n_qubits = 5
181+
5
182+
183+
julia> BraketSimulator.endian_qubits(n_qubits, qubit)
184+
2
185+
186+
julia> qubit = 3
187+
3
188+
189+
julia> BraketSimulator.endian_qubits(n_qubits, qubit)
190+
1
191+
```
192+
"""
24193
@inline endian_qubits(n_qubits::Int, qubit::Int) = n_qubits - 1 - qubit
194+
"""
195+
endian_qubits(n_qubits::Int, qubits::Int...)
196+
197+
Rotate each qubit index in `qubits` to match what Braket expects with the
198+
correct endianness. This has to be done because Braket and Julia have different
199+
[endianness](https://en.wikipedia.org/wiki/Endianness).
200+
201+
!!! note
202+
203+
The first valid value for any element of `qubits` is **zero**,
204+
since qubits are zero-indexed.
205+
"""
25206
@inline endian_qubits(n_qubits::Int, qubits::Int...) = n_qubits .- 1 .- qubits
207+
"""
208+
get_amps_and_qubits(state_vec::AbstractStateVector, qubits::Int...)
209+
210+
Get the total number of amplitudes of `state_vec` (its length) and use this
211+
to apply [`endian_qubits`](@ref) to `qubits`. This is a convenience function
212+
to automate several common operations.
213+
214+
!!! note
215+
216+
The first valid value for any element of `qubits` is **zero**,
217+
since qubits are zero-indexed.
218+
"""
26219
@inline function get_amps_and_qubits(state_vec::AbstractStateVector, qubits::Int...)
27220
n_amps = length(state_vec)
28221
n_qubits = Int(log2(n_amps))
@@ -269,6 +462,12 @@ matrix_rep_raw(::ZZ, ϕ) = (θ = ϕ/2.0; return Diagonal(SVector{4}(exp(-im * θ
269462
# 1/√2 * (IX - XY)
270463
matrix_rep_raw(g::ECR) = SMatrix{4,4}(1/√2 * [0 1 0 im; 1 0 -im 0; 0 im 0 1; -im 0 1 0])
271464
matrix_rep_raw(g::Unitary) = g.matrix
465+
"""
466+
matrix_rep(g::Gate)
467+
468+
Convert `g` into its matrix form, applying its argument values and any
469+
exponent it is raised to.
470+
"""
272471
function matrix_rep(g::Gate)
273472
n = g.pow_exponent
274473
iszero(n) && matrix_rep_raw(I(), qubit_count(g))
@@ -311,12 +510,16 @@ function apply_gate!(
311510
is_small_target = flipper < CHUNK_SIZE
312511
g_00, g_10, g_01, g_11 = g_matrix
313512
Threads.@threads for chunk_index = 0:n_chunks-1
314-
# my_amps is the group of amplitude generators which this `Task` will process
513+
# first_amp is the leading index in the group
514+
# of amplitude generators which this `Task` will process
315515
first_amp = n_chunks > 1 ? chunk_index*CHUNK_SIZE : 0
516+
# amp_block is the total size of the block this `Task` will process
316517
amp_block = n_chunks > 1 ? CHUNK_SIZE : n_tasks
317518
lower_ix = pad_bit(first_amp, endian_qubit) + 1
318519
higher_ix = lower_ix + flipper
319520
for task_amp = 0:amp_block-1
521+
# this avoids hitting an index pair already "touched" earlier in the block
522+
# if 2 ^ qubit_index is smaller than the block size
320523
if is_small_target && div(task_amp, flipper) > 0 && mod(task_amp, flipper) == 0
321524
lower_ix = higher_ix
322525
higher_ix = lower_ix + flipper
@@ -372,12 +575,13 @@ function apply_gate!(
372575
return
373576
end
374577

578+
# single controlled single target unitaries like CZ, CV, CPhaseShift
375579
function apply_controlled_gate!(
376580
g_matrix::Union{SMatrix{2,2,T}, Diagonal{T,SVector{2,T}}, Matrix{T}},
377-
c_bit::Bool,
581+
c_bit::Bool, # the bit-value to control on (0 or 1)
378582
state_vec::AbstractStateVector{T},
379-
control::Int,
380-
target::Int,
583+
control::Int, # the qubit to control on
584+
target::Int, # the qubit to target
381585
) where {T<:Complex}
382586
n_amps, (endian_control, endian_target) =
383587
get_amps_and_qubits(state_vec, control, target)
@@ -400,13 +604,14 @@ function apply_controlled_gate!(
400604
end
401605
return
402606
end
607+
# single controlled two target unitaries like CSWAP
403608
function apply_controlled_gate!(
404609
g_matrix::Union{SMatrix{4, 4, T}, Diagonal{T, SVector{4, T}}, Matrix{T}},
405-
c_bit::Bool,
610+
c_bit::Bool, # the bit-value to control on (0 or 1)
406611
state_vec::AbstractStateVector{T},
407-
control::Int,
408-
target_1::Int,
409-
target_2::Int,
612+
control::Int, # the qubit to control on
613+
target_1::Int, # the first qubit to target
614+
target_2::Int, # the second qubit to target
410615
) where {T<:Complex}
411616
n_amps, (endian_control, endian_t1, endian_t2) = get_amps_and_qubits(state_vec, control, target_1, target_2)
412617
small_t = min(endian_control, endian_t1, endian_t2)
@@ -429,11 +634,11 @@ function apply_controlled_gate!(
429634
end
430635
return
431636
end
432-
# doubly controlled unitaries
637+
# doubly controlled single target unitaries like CCNot
433638
function apply_controlled_gate!(
434639
g_matrix::Union{SMatrix{2, 2, T}, Diagonal{T, SVector{2, T}}, Matrix{T}},
435-
c1_bit::Bool,
436-
c2_bit::Bool,
640+
c1_bit::Bool, # the bit-value to control on (0 or 1) for the first control qubit
641+
c2_bit::Bool, # the bit-value to control on (0 or 1) for the second control qubit
437642
state_vec::AbstractStateVector{T},
438643
control_1::Int,
439644
control_2::Int,
@@ -463,6 +668,10 @@ function apply_controlled_gate!(
463668
end
464669
return
465670
end
671+
# these are "intermediate" dispatch methods which turn a `Gate` into the appropriate
672+
# static matrix and dispatch to the appropriate kernel to apply it
673+
# the `:conj` versions are there for *density matrices*, and apply the conjugated
674+
# (but *not* transposed) version of the gate matrix.
466675
for (V, f) in ((true, :conj), (false, :identity))
467676
@eval begin
468677
apply_gate!(::Val{$V}, gate::Control{G, B}, state_vec::AbstractStateVector{T}, qubits::Int...) where {T<:Complex, G<:Gate, B} = apply_controlled_gate!(Val($V), Val(B), gate, gate.g ^ gate.pow_exponent, state_vec, gate.bitvals, qubits...)
@@ -539,6 +748,8 @@ function apply_gate!(
539748
apply_gate!(Diagonal(SVector{2^N, ComplexF64}(g_matrix)), state_vec, qubits...)
540749
end
541750

751+
# fallback method for arbitrary unitaries with `NQ` targets
752+
# such as a Unitary on 5 qubits
542753
function apply_gate!(
543754
g_matrix::Union{SMatrix{N, N, T}, Diagonal{T, SVector{N, T}}},
544755
state_vec::AbstractStateVector{T},

0 commit comments

Comments
 (0)