Skip to content

Commit

Permalink
Merge branch 'dev' into mypy-dist
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Aug 2, 2024
2 parents 7acec8d + 871abb8 commit 158f297
Show file tree
Hide file tree
Showing 22 changed files with 187 additions and 23 deletions.
2 changes: 1 addition & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ This Code of Conduct applies both within project spaces and in public spaces whe

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at fritzo@uber.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at fritz.obermeyer@gmail.com or fehiepsi@gmail.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.

Expand Down
6 changes: 3 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,10 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"torch": ("https://pytorch.org/docs/master/", None),
"funsor": ("http://funsor.pyro.ai/en/stable/", None),
"funsor": ("https://funsor.pyro.ai/en/stable/", None),
"opt_einsum": ("https://optimized-einsum.readthedocs.io/en/stable/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
"Bio": ("https://biopython.org/docs/latest/api/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"Bio": ("https://biopython.org/docs/latest/", None),
"horovod": ("https://horovod.readthedocs.io/en/stable/", None),
"graphviz": ("https://graphviz.readthedocs.io/en/stable/", None),
}
Expand Down
8 changes: 8 additions & 0 deletions docs/source/pyro.poutine.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ ________________
:undoc-members:
:show-inheritance:

EqualizeMessenger
____________________

.. automodule:: pyro.poutine.equalize_messenger
:members:
:undoc-members:
:show-inheritance:

EscapeMessenger
________________

Expand Down
2 changes: 1 addition & 1 deletion examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def z_pres_prior_p(opt_step, time_step):

if "load" in args:
print("Loading parameters...")
air.load_state_dict(torch.load(args.load))
air.load_state_dict(torch.load(args.load, weights_only=False))

# Viz sample from prior.
if args.viz:
Expand Down
2 changes: 1 addition & 1 deletion examples/cvae/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,6 @@ def train(
break

# Save model weights
cvae_net.load_state_dict(torch.load(model_path))
cvae_net.load_state_dict(torch.load(model_path, weights_only=False))
cvae_net.eval()
return cvae_net
2 changes: 1 addition & 1 deletion examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def load_checkpoint():
args.load_model
), "--load-model and/or --load-opt misspecified"
logging.info("loading model from %s..." % args.load_model)
dmm.load_state_dict(torch.load(args.load_model))
dmm.load_state_dict(torch.load(args.load_model, weights_only=False))
logging.info("loading optimizer states from %s..." % args.load_opt)
adam.load(args.load_opt)
logging.info("done loading model and optimizer states.")
Expand Down
5 changes: 3 additions & 2 deletions pyro/contrib/examples/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import subprocess
import sys
import urllib
from functools import partial

import torch

Expand Down Expand Up @@ -120,12 +121,12 @@ def load_bart_od():
except urllib.error.HTTPError:
logging.debug("cache miss, preprocessing from scratch")
if os.path.exists(pkl_file):
return torch.load(pkl_file)
return torch.load(pkl_file, weights_only=False)

filenames = multiprocessing.Pool(len(SOURCE_FILES)).map(
_load_hourly_od, SOURCE_FILES
)
datasets = list(map(torch.load, filenames))
datasets = list(map(partial(torch.load, weights_only=False), filenames))

stations = sorted(set().union(*(d["stations"].keys() for d in datasets)))
min_time = min(int(d["rows"][:, 0].min()) for d in datasets)
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/examples/nextstrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def load_nextstrain_counts(map_location=None) -> dict:
# Load tensors to the default location.
if map_location is None:
map_location = torch.tensor(0.0).device
return torch.load(filename, map_location=map_location)
return torch.load(filename, map_location=map_location, weights_only=False)
2 changes: 1 addition & 1 deletion pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __get__(
if name not in obj.__dict__["_pyro_params"]:
init_value, constraint, event_dim = self
# bind method's self arg
init_value = functools.partial(init_value, obj) # type: ignore[arg-type]
init_value = functools.partial(init_value, obj) # type: ignore[arg-type,misc,operator]
setattr(obj, name, PyroParam(init_value, constraint, event_dim))
value: PyroParam = obj.__getattr__(name)
return value
Expand Down
4 changes: 3 additions & 1 deletion pyro/optim/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def load(self, filename: str, map_location=None) -> None:
Load optimizer state from disk
"""
with open(filename, "rb") as input_file:
state = torch.load(input_file, map_location=map_location)
state = torch.load(
input_file, map_location=map_location, weights_only=False
)
self.set_state(state)

def _get_optim(self, param: Union[Iterable[Tensor], Iterable[Dict[Any, Any]]]):
Expand Down
2 changes: 1 addition & 1 deletion pyro/params/param_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def load(self, filename: str, map_location: MAP_LOCATION = None) -> None:
:type map_location: function, torch.device, string or a dict
"""
with open(filename, "rb") as input_file:
state = torch.load(input_file, map_location)
state = torch.load(input_file, map_location, weights_only=False)
self.set_state(state)

@contextmanager
Expand Down
2 changes: 2 additions & 0 deletions pyro/poutine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
condition,
do,
enum,
equalize,
escape,
infer_config,
lift,
Expand Down Expand Up @@ -36,6 +37,7 @@
"enable_validation",
"enum",
"escape",
"equalize",
"get_mask",
"infer_config",
"is_validation_enabled",
Expand Down
77 changes: 77 additions & 0 deletions pyro/poutine/equalize_messenger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import re
from typing import List, Optional, Union

from typing_extensions import Self

from pyro.distributions import Delta
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message


class EqualizeMessenger(Messenger):
"""
Given a stochastic function with some primitive statements and a list of names,
force the primitive statements at those names to have the same value,
with that value being the result of the first primitive statement matching those names.
Consider the following Pyro program:
>>> def per_category_model(category):
... shift = pyro.param(f'{category}_shift', torch.randn(1))
... mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1))
... std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1))
... return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std))
Running the program for multiple categories can be done by
>>> def model(categories):
... return {category:per_category_model(category) for category in categories}
To make the `std` sample sites have the same value, we can write
>>> equal_std_model = pyro.poutine.equalize(model, '.+_std')
If on top of the above we would like to make the 'shift' parameters identical, we can write
>>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')
:param fn: a stochastic function (callable containing Pyro primitive calls)
:param sites: a string or list of strings to match site names (the strings can be regular expressions)
:param type: a string specifying the site type (default is 'sample')
:returns: stochastic function decorated with a :class:`~pyro.poutine.equalize_messenger.EqualizeMessenger`
"""

def __init__(
self, sites: Union[str, List[str]], type: Optional[str] = "sample"
) -> None:
super().__init__()
self.sites = [sites] if isinstance(sites, str) else sites
self.type = type

def __enter__(self) -> Self:
self.value = None
return super().__enter__()

def _is_matching(self, msg: Message) -> bool:
if msg["type"] == self.type:
for site in self.sites:
if re.compile(site).fullmatch(msg["name"]) is not None: # type: ignore[arg-type]
return True
return False

def _postprocess_message(self, msg: Message) -> None:
if self.value is None and self._is_matching(msg):
value = msg["value"]
assert value is not None
self.value = value

def _process_message(self, msg: Message) -> None:
if self.value is not None and self._is_matching(msg): # type: ignore[unreachable]
msg["value"] = self.value # type: ignore[unreachable]
if msg["type"] == "sample":
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False)
msg["infer"] = {"_deterministic": True}
msg["is_observed"] = True
24 changes: 24 additions & 0 deletions pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from pyro.poutine.condition_messenger import ConditionMessenger
from pyro.poutine.do_messenger import DoMessenger
from pyro.poutine.enum_messenger import EnumMessenger
from pyro.poutine.equalize_messenger import EqualizeMessenger
from pyro.poutine.escape_messenger import EscapeMessenger
from pyro.poutine.infer_config_messenger import InferConfigMessenger
from pyro.poutine.lift_messenger import LiftMessenger
Expand Down Expand Up @@ -301,6 +302,29 @@ def escape( # type: ignore[empty-body]
) -> Union[EscapeMessenger, Callable[_P, _T]]: ...


@overload
def equalize(
sites: Union[str, List[str]],
type: Optional[str],
) -> ConditionMessenger: ...


@overload
def equalize(
fn: Callable[_P, _T],
sites: Union[str, List[str]],
type: Optional[str],
) -> Callable[_P, _T]: ...


@_make_handler(EqualizeMessenger)
def equalize( # type: ignore[empty-body]
fn: Callable[_P, _T],
sites: Union[str, List[str]],
type: Optional[str],
) -> Union[EqualizeMessenger, Callable[_P, _T]]: ...


@overload
def infer_config(
config_fn: Callable[["Message"], "InferDict"],
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/cevae/test_cevae.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_serialization(jit, feature_dim, outcome_dist):
warnings.filterwarnings("ignore", category=UserWarning)
torch.save(cevae, f)
f.seek(0)
loaded_cevae = torch.load(f)
loaded_cevae = torch.load(f, weights_only=False)

pyro.set_rng_seed(0)
actual_ite = loaded_cevae.ite(x)
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/easyguide/test_easyguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_serialize():
f = io.BytesIO()
torch.save(guide, f)
f.seek(0)
actual = torch.load(f)
actual = torch.load(f, weights_only=False)

assert type(actual) == type(guide)
assert dir(actual) == dir(guide)
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,5 @@ def test_pickle(Dist):
# Note that pickling torch.Size() requires protocol >= 2
torch.save(dist, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL)
buffer.seek(0)
deserialized = torch.load(buffer)
deserialized = torch.load(buffer, weights_only=False)
assert isinstance(deserialized, Dist)
2 changes: 1 addition & 1 deletion tests/infer/mcmc/test_valid_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def test_potential_fn_pickling(jit):
buffer = io.BytesIO()
torch.save(potential_fn, buffer)
buffer.seek(0)
deser_potential_fn = torch.load(buffer)
deser_potential_fn = torch.load(buffer, weights_only=False)
assert_close(deser_potential_fn(test_data), potential_fn(test_data))


Expand Down
2 changes: 1 addition & 1 deletion tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def test_serialization(auto_class, jit):
f = io.BytesIO()
torch.save(guide, f)
f.seek(0)
guide_deser = torch.load(f)
guide_deser = torch.load(f, weights_only=False)

# Check .call() result.
pyro.set_rng_seed(0)
Expand Down
6 changes: 3 additions & 3 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def test_mixin_factory():
del module
pyro.clear_param_store()
f.seek(0)
module = torch.load(f)
module = torch.load(f, weights_only=False)
assert type(module).__name__ == "PyroSequential"
actual = module(data)
assert_equal(actual, expected)
Expand Down Expand Up @@ -680,7 +680,7 @@ def test_torch_serialize_attributes(local_params):
torch.save(module, f)
pyro.clear_param_store()
f.seek(0)
actual = torch.load(f)
actual = torch.load(f, weights_only=False)

assert_equal(actual.x, module.x)
actual_names = {name for name, _ in actual.named_parameters()}
Expand All @@ -704,7 +704,7 @@ def test_torch_serialize_decorators(local_params):
torch.save(module, f)
pyro.clear_param_store()
f.seek(0)
actual = torch.load(f)
actual = torch.load(f, weights_only=False)

assert_equal(actual.x, module.x)
assert_equal(actual.y, module.y)
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/einsum/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_marginal(equation):
assert_equal(expected, actual)


@pytest.mark.filterwarnings("ignore:.*reduce_op is deprecated")
@pytest.mark.filterwarnings("ignore:.*reduce_op`? is deprecated")
def test_require_backward_memory_leak():
tensors = [o for o in gc.get_objects() if torch.is_tensor(o)]
num_global_tensors = len(tensors)
Expand Down
52 changes: 51 additions & 1 deletion tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,56 @@ def test_infer_config_sample(self):
assert tr.nodes["p"]["infer"] == {}


class EqualizeHandlerTests(TestCase):
def setUp(self):
def per_category_model(category):
shift = pyro.param(f"{category}_shift", torch.randn(1))
mean = pyro.sample(f"{category}_mean", pyro.distributions.Normal(0, 1))
std = pyro.sample(f"{category}_std", pyro.distributions.LogNormal(0, 1))
with pyro.plate(f"{category}_num_samples", 5):
return pyro.sample(
f"{category}_values", pyro.distributions.Normal(mean + shift, std)
)

def model(categories=["dogs", "cats"]):
return {category: per_category_model(category) for category in categories}

self.model = model

def test_sample_site_equalization(self):
pyro.set_rng_seed(20240616)
pyro.clear_param_store()
model = poutine.equalize(self.model, ".+_std")
tr = pyro.poutine.trace(model).get_trace()
assert_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"])
assert_not_equal(
tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"]
)
guide = pyro.infer.autoguide.AutoNormal(model)
guide_sites = [*guide()]
assert guide_sites == [
"dogs_mean",
"dogs_std",
"dogs_values",
"cats_mean",
"cats_values",
]

def test_param_equalization(self):
pyro.set_rng_seed(20240616)
pyro.clear_param_store()
model = poutine.equalize(self.model, ".+_shift", "param")
tr = pyro.poutine.trace(model).get_trace()
assert_equal(tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"])
assert_not_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"])

def test_render_model(self):
pyro.set_rng_seed(20240616)
pyro.clear_param_store()
model = poutine.equalize(self.model, ".+_std")
pyro.render_model(model)


@pytest.mark.parametrize("first_available_dim", [-1, -2, -3])
@pytest.mark.parametrize("depth", [0, 1, 2])
def test_enumerate_poutine(depth, first_available_dim):
Expand Down Expand Up @@ -977,7 +1027,7 @@ def test_pickling(wrapper):
# default protocol cannot serialize torch.Size objects (see https://github.com/pytorch/pytorch/issues/20823)
torch.save(wrapped, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL)
buffer.seek(0)
deserialized = torch.load(buffer)
deserialized = torch.load(buffer, weights_only=False)
obs = torch.tensor(0.5)
pyro.set_rng_seed(0)
actual_trace = poutine.trace(deserialized).get_trace(obs)
Expand Down

0 comments on commit 158f297

Please sign in to comment.