Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support observations that require broadcasting in SplitReparam #3392

Merged
merged 3 commits into from
Aug 4, 2024

Conversation

BenZickel
Copy link
Contributor

@BenZickel BenZickel commented Aug 3, 2024

The Problem

When splitting a parameter into sections with SplitReparam, and observing a single section of the split parameter, you would get an error when trying to vectorize the model and the other sections of the split parameter (as you would when doing SVI with vectorized particles for instance).

Consider the below code

import pyro
from pyro import distributions as dist
from pyro.infer.reparam import SplitReparam
from pyro import poutine
from pyro.infer.autoguide.initialization import InitMessenger, init_to_median

def model():
    x_dist = dist.TransformedDistribution(
        dist.Normal(0, 1).expand((8,)).to_event(1), dist.transforms.HaarTransform())
    return pyro.sample("x", x_dist)

# Build reparameterized model
rep = SplitReparam([6, 2], -1)
reparam_model = poutine.reparam(model, {"x": rep})

# Sample from the reparameterized model to create an observation
initialized_reparam_model = InitMessenger(init_to_median)(reparam_model)
trace = poutine.trace(initialized_reparam_model).get_trace()
observation = {"x_split_1": trace.nodes["x_split_1"]["value"]}

# Create a model conditioned on the observation
conditioned_reparam_model = poutine.condition(reparam_model, observation)

# Fit a guide for the conditioned model
guide = pyro.infer.autoguide.AutoMultivariateNormal(conditioned_reparam_model)
optim = pyro.optim.Adam(dict(lr=0.1))
loss = pyro.infer.Trace_ELBO(num_particles=20, vectorize_particles=True)
svi = pyro.infer.SVI(conditioned_reparam_model, guide, optim, loss)
for iter_count in range(10):
    svi.step()

which would raise the error

Traceback (most recent call last):
  File "C:\SW\pyro-ppl\pyro\poutine\trace_messenger.py", line 191, in __call__
    ret = self.fn(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\reparam_messenger.py", line 163, in __call__
    return self.fn(*args, **kwargs)
  File "<stdin>", line 4, in model
  File "C:\SW\pyro-ppl\pyro\primitives.py", line 189, in sample
    apply_stack(msg)
  File "C:\SW\pyro-ppl\pyro\poutine\runtime.py", line 378, in apply_stack
    frame._process_message(msg)
  File "C:\SW\pyro-ppl\pyro\poutine\messenger.py", line 189, in _process_message
    method(msg)
  File "C:\SW\pyro-ppl\pyro\poutine\reparam_messenger.py", line 112, in _pyro_sample
    new_msg = reparam.apply(
  File "C:\SW\pyro-ppl\pyro\infer\reparam\split.py", line 133, in apply
    value = torch.cat(value_split, dim=-self.event_dim)
RuntimeError: Tensors must have same number of dimensions: got 2 and 1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "C:\SW\pyro-ppl\pyro\infer\svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\infer\trace_elbo.py", line 140, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "C:\SW\pyro-ppl\pyro\infer\elbo.py", line 234, in _get_traces
    yield self._get_vectorized_trace(model, guide, args, kwargs)
  File "C:\SW\pyro-ppl\pyro\infer\elbo.py", line 211, in _get_vectorized_trace
    return self._get_trace(
  File "C:\SW\pyro-ppl\pyro\infer\trace_elbo.py", line 57, in _get_trace
    model_trace, guide_trace = get_importance_trace(
  File "C:\SW\pyro-ppl\pyro\infer\enum.py", line 65, in get_importance_trace
    model_trace = poutine.trace(
  File "C:\SW\pyro-ppl\pyro\poutine\trace_messenger.py", line 216, in get_trace
    self(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\trace_messenger.py", line 198, in __call__
    raise exc from e
  File "C:\SW\pyro-ppl\pyro\poutine\trace_messenger.py", line 191, in __call__
    ret = self.fn(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\poutine\reparam_messenger.py", line 163, in __call__
    return self.fn(*args, **kwargs)
  File "<stdin>", line 4, in model
  File "C:\SW\pyro-ppl\pyro\primitives.py", line 189, in sample
    apply_stack(msg)
  File "C:\SW\pyro-ppl\pyro\poutine\runtime.py", line 378, in apply_stack
    frame._process_message(msg)
  File "C:\SW\pyro-ppl\pyro\poutine\messenger.py", line 189, in _process_message
    method(msg)
  File "C:\SW\pyro-ppl\pyro\poutine\reparam_messenger.py", line 112, in _pyro_sample
    new_msg = reparam.apply(
  File "C:\SW\pyro-ppl\pyro\infer\reparam\split.py", line 133, in apply
    value = torch.cat(value_split, dim=-self.event_dim)
RuntimeError: Tensors must have same number of dimensions: got 2 and 1
 Trace Shapes:
  Param Sites:
 Sample Sites:
x_split_0 dist 20 | 6
         value 20 | 6
x_split_1 dist 20 | 2
         value    | 2

due to the tensor sections of value_split not having the same dimensions, except for the concatenated dimension.

The Solution

Broadcast the tensor sections of value_split to have the same shape, except for the concatenated dimension, so that the sections can be concatenated together along the required dimension.

@fritzo fritzo added the bug label Aug 4, 2024
@fritzo fritzo changed the title Support observations that require broadcasting in SplitReparam [bugfix] Support observations that require broadcasting in SplitReparam Aug 4, 2024
fritzo
fritzo previously approved these changes Aug 4, 2024
@BenZickel BenZickel requested a review from fritzo August 4, 2024 16:38
@BenZickel BenZickel requested a review from fritzo August 4, 2024 18:05
@fritzo fritzo merged commit 414a4d5 into pyro-ppl:dev Aug 4, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants