Skip to content

Commit

Permalink
Switch to sqrt(precision) representation in Gaussian (#568)
Browse files Browse the repository at this point in the history
* Switch to sqrt(precision) representation in Gaussian

* Fix some bugs

* Fix more math

* Add GaussianMeta conversions; fix broadcasting bug

* Fix some distribution tests

* Refactor from info_vec to white_vec

* Fix more tests

* Flesh our matrix_and_mvn_to_funsor()

* Work our marginalization

* fix more tests

* Fix more tests

* Fix test_gaussian.py

* Fix distribution patterns

* Fix argmax approximation

* Remove Gaussian.negate attribute

* Fix matrix_and_mvn_to_funsor diag (full still broken)

* Fix old uses of info_vec

* Add a test

* Fix shape bug in matrix_and_mvn_to_funsor()

* Enable pprint for funsors

* Revert pp property

* Fix matrix_and_mvn_to_funsor()

* Relax rank condition

* Fix ._sample()

* Fix eager_contraction_to_binary

* Fix test_joint.py

* Fix comparisons in sequential sum product

* Fix saarka bilmes test

* Add and xfail tests of singular matrices

* Fix rank deficiency issues

* Add gaussian integrate patterns

* Fix comment

* Add a set_compression_threshold context manager

* Update docstring

* Fix backward sampling support bug

* Xfail test_elbo.py::test_complex

* Relax test thresholds

* Fix ops.qr numpy backend

* Fix jax tests

* Fix bugs

* Tweak sensor example

* Address review comments
  • Loading branch information
fritzo authored Oct 18, 2021
1 parent 0cf9ac2 commit c94a9bd
Show file tree
Hide file tree
Showing 25 changed files with 1,155 additions and 1,082 deletions.
21 changes: 11 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ jobs:
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]

env:
CI: 1
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -29,8 +30,7 @@ jobs:
pip install .[test]
pip freeze
- name: Run test
run: |
make test
run: make test


torch:
Expand All @@ -39,7 +39,9 @@ jobs:
strategy:
matrix:
python-version: [3.6]

env:
CI: 1
FUNSOR_BACKEND: torch
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -57,9 +59,7 @@ jobs:
pip install .[test,torch]
pip freeze
- name: Run test
run: |
make test
FUNSOR_BACKEND=torch make test
run: make test


jax:
Expand All @@ -68,7 +68,9 @@ jobs:
strategy:
matrix:
python-version: [3.6]

env:
CI: 1
FUNSOR_BACKEND: jax
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -85,5 +87,4 @@ jobs:
pip install .[test,jax]
pip freeze
- name: Run test
run: |
CI=1 FUNSOR_BACKEND=jax make test
run: make test
9 changes: 6 additions & 3 deletions examples/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def generate_data(num_frames, num_sensors):
]
)
trans_dist = dist.MultivariateNormal(
torch.zeros(4), scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky()
torch.zeros(4),
scale_tril=trans_noise * torch.linalg.cholesky(NCV_PROCESS_NOISE),
)

# define biased sensors
Expand Down Expand Up @@ -128,7 +129,7 @@ def forward(self, observations, add_bias=True):
curr = Variable("curr", Reals[4])
self.trans_dist = f_dist.MultivariateNormal(
loc=prev @ NCV_TRANSITION_MATRIX,
scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky(),
scale_tril=trans_noise * torch.linalg.cholesky(NCV_PROCESS_NOISE),
value=curr,
)

Expand Down Expand Up @@ -239,7 +240,9 @@ def main(args):
or not args.metrics_filename
or not os.path.exists(args.metrics_filename)
):
results = track(args)
# Increase compression threshold for numerical stability.
with funsor.gaussian.Gaussian.set_compression_threshold(3):
results = track(args)
else:
results = torch.load(args.metrics_filename)

Expand Down
5 changes: 1 addition & 4 deletions funsor/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,7 @@ def compute_argmax_gaussian(model, approx_vars):

approx_names = frozenset(v.name for v in approx_vars)
if approx_names == frozenset(real_inputs):
x = model.info_vec[..., None]
x = ops.triangular_solve(x, model._precision_chol)
x = ops.triangular_solve(x, model._precision_chol, transpose=True)
mode = x[..., 0]
mode = model._mean
offsets, _ = _compute_offsets(real_inputs)
result = {}
for key, domain in real_inputs.items():
Expand Down
33 changes: 16 additions & 17 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,19 +697,12 @@ def indep_to_data(funsor_dist, name_to_dim=None):


@to_data.register(Gaussian)
def gaussian_to_data(funsor_dist, name_to_dim=None, normalized=False):
if normalized:
return to_data(
funsor_dist.log_normalizer + funsor_dist, name_to_dim=name_to_dim
)
loc = ops.cholesky_solve(
ops.unsqueeze(funsor_dist.info_vec, -1), ops.cholesky(funsor_dist.precision)
).squeeze(-1)
def gaussian_to_data(funsor_dist, name_to_dim=None):
int_inputs = OrderedDict(
(k, d) for k, d in funsor_dist.inputs.items() if d.dtype != "real"
)
loc = to_data(Tensor(loc, int_inputs), name_to_dim)
precision = to_data(Tensor(funsor_dist.precision, int_inputs), name_to_dim)
loc = to_data(Tensor(funsor_dist._mean, int_inputs), name_to_dim)
precision = to_data(Tensor(funsor_dist._precision, int_inputs), name_to_dim)
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
return backend_dist.MultivariateNormal.dist_class(loc, precision_matrix=precision)

Expand Down Expand Up @@ -845,13 +838,17 @@ def eager_normal(loc, scale, value):
if not is_affine(loc) or not is_affine(value):
return None # lazy

info_vec = ops.new_zeros(scale.data, scale.data.shape + (1,))
precision = ops.pow(scale.data, -2).reshape(scale.data.shape + (1, 1))
log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale).sum()
white_vec = ops.new_zeros(scale.data, scale.data.shape + (1,))
prec_sqrt = (1 / scale.data)[..., None, None]
log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale)
inputs = scale.inputs.copy()
var = gensym("value")
inputs[var] = Real
gaussian = log_prob + Gaussian(info_vec, precision, inputs)
gaussian = log_prob + Gaussian(
white_vec=white_vec,
prec_sqrt=prec_sqrt,
inputs=inputs,
)
return gaussian(**{var: value - loc})


Expand All @@ -862,16 +859,18 @@ def eager_mvn(loc, scale_tril, value):
if not is_affine(loc) or not is_affine(value):
return None # lazy

info_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1])
precision = ops.cholesky_inverse(scale_tril.data)
white_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1])
prec_sqrt = ops.transpose(ops.triangular_inv(scale_tril.data), -1, -2)
scale_diag = Tensor(ops.diagonal(scale_tril.data, -1, -2), scale_tril.inputs)
log_prob = (
-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - ops.log(scale_diag).sum()
)
inputs = scale_tril.inputs.copy()
var = gensym("value")
inputs[var] = Reals[scale_diag.shape[0]]
gaussian = log_prob + Gaussian(info_vec, precision, inputs)
gaussian = log_prob + Gaussian(
white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs
)
return gaussian(**{var: value - loc})


Expand Down
Loading

0 comments on commit c94a9bd

Please sign in to comment.