Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding PosteriorSample() helper type. #357

Open
wants to merge 4 commits into
base: development
Choose a base branch
from

Conversation

arik-shurygin
Copy link
Collaborator

Often it is the case that modelers will want to run two separate models, one to fit to the data, and than another to project forward in time or run some scenarios. In these cases they need a way within their configuration files to mark that a particular parameter will depend on a set of posterior samples, usually from a previous run.

The PosteriorSample class offers an easy way for users to mark that a particular parameter's values will be deterministically picked from a set of posterior samples, while also quickly erroring if a user attempts to treat it as if it is a distribution they may randomly sample from.

Here is a basic example of the functionality:

import jax
import numpyro.distributions as dist
from numpyro.handlers import substitute
import numpyro
from dynode.model_configuration.types import PosteriorSample


def tester(distribution):
    """Sample a distribution"""
    rng_key = jax.random.PRNGKey(0)
    sample = numpyro.sample(
        "x", distribution, rng_key=rng_key
    )
    return sample

regular_dist = dist.Normal()
posterior_sample= PosteriorSample()

x_sample = tester(regular_dist)
>>> -0.7847657764467411

# wrap our tester method in a substitute call 
x_posterior = substitute(fn=lambda: tester(posterior_sample), data={"x": 1000})()
>>> 1000

# now no posterior context
x_error = model(posterior_sample) 
>>> SamplePosteriorError: Attempted to sample a PosteriorSample parameter outside of a
            Predictive() context. This likely means you did not provide
            posterior samples to the context via numpyro.infer.Predictive() or
            numpyro.handlers.substitute().

in the backend the PosteriorSample class is treated just like a numpyro.distributions.Distribution object, however, when attempting to sample it outside of a Predictive() or numpyro.handlers.substitute context, the error will be raised immediately.

In this example we use a site name x to identify the sampled parameter, however as long as we programmatically pick site names based on the parameter name, and users do not modify their data dictionary to mess with site names, this should properly match all parameters.

FYSA @kokbent @tjhladish @cdc-ap66 feel free to start a discussion below.

CLOSES #355

Copy link

@cdc-ap66 cdc-ap66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely an important addition. Added a few comments, but feel free to go ahead after you review.

pass


class PosteriorSample(dist.Distribution):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me. I just want to make sure that "PosteriorSample" is the right name for this. Feels like it can also be called "PlaceholderSample" given it might have functionality that pushes the boundaries of what a Bayesian modeler might specifically consider belongs to the moniker "posterior."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty green on all this Bayesian inference stuff so feel free to ignore if you think this is a dumb comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep I think @kokbent will have insight on this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do like PlaceholderSample a little better since it doesn't need to be substituted by a posterior sample, e.g., initial state, or some other number that we substitute in run time for whatever reason... PosteriorSample is kind of like a subset of PlaceholderSample that might help readability, so I'm a bit torn 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants