Flowfusion.jl is a Julia package for training and sampling from diffusion and flow matching models (and some things in between), across continuous, discrete, and manifold spaces, all in a single unified framework and interface.
The animated logo shows samples from a model trained to jointly transport a 2D point and an angular hue between two distributions. For the 2D point, the left side uses "Flow matching" with deterministic trajectories, and the right uses a Brownian bridge. For both sides, the angular hue is diffused via an angular Brownian bridge. The hue endpoints are antipodal, and you can see both paths, in opposite angular directions, are sampled.
- Flexible initial
$X_0$ distribution - Conditioning via masking
- States: Continuous, discrete, and a wide variety of manifolds supported (via Manifolds.jl)
- Compound states supported (e.g. jointly sampling from both continuous and discrete variables)
- Controllable noise (or fully deterministic for flow matching)
- Time-scaling schedules (see
examples/logo_example.jl
)
- Generate
X0
andX1
states from your favorite distribution, and a randomt
between 0 and 1 Xt = bridge(P, X0, X1, t)
: Sample intermediate states conditioned on start and end states- Train model to predict how to get to
X1
fromXt
gen(P, X0, model, steps)
: Generate sequences using a learned model
The package includes several examples demonstrating different use cases:
continuous.jl
: Learning a continuous distributiontorus.jl
: Continous distributions on a manifolddiscrete.jl
: Discrete distributions with discrete processesprobabilitysimplex.jl
: Discrete distributions with continuous processes via a probability simplex manifoldcontinuous_masked.jl
: Conditioning on partial observationsmasked_tuple.jl
: Jointly sampling from continuous and discrete variables, with conditioning
]add https://github.com/MurrellGroup/Flowfusion.jl
using ForwardBackward, Flowfusion, Flux, RandomFeatureMaps, Optimisers, Plots
#Set up a Flux model: X̂1 = model(t,Xt)
struct FModel{A}
layers::A
end
Flux.@layer FModel
function FModel(; embeddim = 128, spacedim = 2, layers = 3)
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
embed_state = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
decode = Dense(embeddim => spacedim)
layers = (; embed_time, embed_state, ffs, decode)
FModel(layers)
end
function (f::FModel)(t, Xt)
l = f.layers
tXt = tensor(Xt)
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
x = l.embed_time(tv) .+ l.embed_state(tXt)
for ff in l.ffs
x = x .+ ff(x)
end
tXt .+ l.decode(x) .* (1.05f0 .- expand(t, ndims(tXt)))
end
model = FModel(embeddim = 256, layers = 3, spacedim = 2)
#Distributions for training:
T = Float32
sampleX0(n_samples) = rand(T, 2, n_samples) .+ 2
sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))
n_samples = 400
#The process:
P = BrownianMotion(0.15f0)
#P = Deterministic()
#Optimizer:
eta = 0.001
opt_state = Flux.setup(AdamW(eta = eta), model)
iters = 4000
for i in 1:iters
#Set up a batch of training pairs, and t:
X0 = ContinuousState(sampleX0(n_samples))
X1 = ContinuousState(sampleX1(n_samples))
t = rand(T, n_samples)
#Construct the bridge:
Xt = bridge(P, X0, X1, t)
#Gradient & update:
l,g = Flux.withgradient(model) do m
floss(P, m(t,Xt), X1, scalefloss(P, t))
end
Flux.update!(opt_state, model, g[1])
(i % 10 == 0) && println("i: $i; Loss: $l")
end
#Generate samples by stepping from X0
n_inference_samples = 5000
X0 = ContinuousState(sampleX0(n_inference_samples))
samples = gen(P, X0, model, 0f0:0.005f0:1f0)
#Plotting
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
X1true = sampleX1(n_inference_samples)
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
savefig("readmeexamplecat.svg")
#Generate samples by stepping from X0
n_inference_samples = 5000
X0 = ContinuousState(sampleX0(n_inference_samples))
paths = Tracker() #<- A tracker to record the trajectory
samples = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
#Plotting:
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
tvec = stack_tracker(paths, :t)
xttraj = stack_tracker(paths, :xt)
for i in 1:50:1000
plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = i==1 ? "Trajectory" : :none, alpha = 0.4)
end
X1true = sampleX1(n_inference_samples)
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
These can be found in examples.
with P = Deterministic()
with P = Deterministic()
and X0 = ManifoldState(Torus(2), ...)
with P = ManifoldProcess(0.2)
and X0 = ManifoldState(Torus(2), ...)
:
with P = NoisyInterpolatingDiscreteFlow(0.1)
and X0 = DiscreteState(...)
:
with X0 = MaskedState(state, cmask, lmask)
with P = ManifoldProcess(0.5)
and X0 = ManifoldState(ProbabilitySimplex(32), ...)
:
probsimplex_ManifoldProcess.Float32.0.5f0.mp4
For background reading please see:
- Denoising Diffusion Probabilistic Models
- Flow Matching for Generative Modeling
- Denoising Diffusion Bridge Models
- Flow matching on general geometries
- Diffusion on Manifolds
- Flow Matching (a review/guide)
- Discrete Flow Matching
Except where noted in the code, this package mostly doesn't try and achieve faithful reproductions of approaches described in the literature, and prefers to be inspired by, rather than constrained by, precise mathematical correctness. The main goals are:
- Bringing a variety of different processes under a single unified and flexible framework
- Providing processes that work, practically speaking