Skip to content

Commit

Permalink
Update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Oct 20, 2024
1 parent e06d1ef commit 97e4170
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion neurobayes/models/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BNN:
outcomes, thus quantifying the inherent uncertainty.
Args:
target_dim (int): Dimensionality of the outputs/targets. For example, if predicting a
target_dim (int): Dimensionality of the target variable. For example, if predicting a
single scalar property, set target_dim=1.
hidden_dim (List[int], optional): List specifying the number of hidden units in each layer
of the neural network architecture. Defaults to [32, 16, 8].
Expand Down
2 changes: 1 addition & 1 deletion neurobayes/models/bnn_heteroskedastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class HeteroskedasticBNN(BNN):
Heteroskedastic Bayesian Neural Network for input-dependent observational noise
Args:
target_dim (int): Dimensionality of the outputs/targets. For example, if predicting a
target_dim (int): Dimensionality of the target variable. For example, if predicting a
single scalar property, set target_dim=1.
hidden_dim (List[int], optional): List specifying the number of hidden units in each layer
of the neural network architecture. Defaults to [32, 16, 8].
Expand Down
23 changes: 18 additions & 5 deletions neurobayes/models/bnn_heteroskedastic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Callable
from typing import List, Callable, Optional, Dict
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
Expand All @@ -10,12 +10,25 @@

class VarianceModelHeteroskedasticBNN(HeteroskedasticBNN):
"""
Variance model based heteroskedastic Bayesian Neural Net
Variance model based heteroskedastic Bayesian Neural Network
Args:
target_dim (int): Dimensionality of the target variable.
variance_model (Callable): Function to compute the variance given inputs and parameters.
variance_model_prior (Callable): Function to sample prior parameters for the variance model.
hidden_dim (List[int], optional): List specifying the number of hidden units in each layer
of the neural network architecture. Defaults to [32, 16, 8].
conv_layers (List[int], optional): List specifying the number of filters in each
convolutional layer. If provided, enables a ConvNet architecture with max pooling
between each conv layer.
input_dim (int, optional): Input dimensionality (between 1 and 3). Required only for
ConvNet architecture.
activation (str, optional): Non-linear activation function to use. Defaults to 'tanh'.
"""
def __init__(self,
target_dim: int,
variance_model: Callable,
variance_model_prior: Callable,
variance_model: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]],
variance_model_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]],
hidden_dim: List[int] = None,
conv_layers: List[int] = None,
input_dim: int = None,
Expand All @@ -32,7 +45,7 @@ def __init__(self,
self.variance_model_prior = variance_model_prior

def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None:
"""BNN probabilistic model"""
"""Heteroskedastic BNN model"""

input_shape = X.shape[1:] if X.ndim > 2 else (X.shape[-1],)

Expand Down

0 comments on commit 97e4170

Please sign in to comment.