Skip to content

Commit

Permalink
fixing mypy hint for identify_distribution_indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Dec 13, 2024
1 parent c151df6 commit 49078a2
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/dynode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def sample_if_distribution(parameters):

def identify_distribution_indexes(
parameters: dict[str, Any],
) -> dict[str, dict[str, str | tuple[int] | None]]:
) -> dict[str, dict[str, str | tuple | None]]:
"""
A inverse of the `sample_if_distribution()` which allows users to identify the locations
of numpyro samples. Given a dictionary of parameters, identifies which parameters
Expand All @@ -115,17 +115,19 @@ def identify_distribution_indexes(
Returns
------------
a dictionary mapping the sample name the parameter name within `parameters`.
`dict[str, dict[str, str | tuple[int] | None]]`
a dictionary mapping the sample name to the parameter name within `parameters`.
(if the sampled parameter is within a larger list, returns a tuple of indexes as well, otherwise None)
key: str -> sampled parameter name as produced by `sample_if_distribution()`
value: `dict[str:str, str:tuple]` -> "sample_name" = sample name within input `parameters`
value: `dict[str, str | tuple | None]` -> "sample_name" = sample name within input `parameters`
-> "sample_idx" = sample index if within list, else None
"""

def get_index(indexes):
return tuple(indexes)

index_locations = {}
index_locations: dict[str, dict[str, str | tuple | None]] = {}
for key, param in parameters.items():
# if distribution, it does not have an index, so None
if issubclass(type(param), Dist.Distribution):
Expand Down

0 comments on commit 49078a2

Please sign in to comment.