From 65c607b632de2ed3ecc1bc9788d94ecf1e32e3ea Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Wed, 12 Jun 2024 20:18:31 +0800 Subject: [PATCH 1/3] add new interp --- appletree/config.py | 19 ++++++++++-- appletree/interpolation.py | 59 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/appletree/config.py b/appletree/config.py index 3cecd642..b0bff12e 100644 --- a/appletree/config.py +++ b/appletree/config.py @@ -148,6 +148,9 @@ class Map(Config): When using log-binning, we will first convert the positions to log space. """ + def __init__(self, method="IDW", **kwargs): + super().__init__(**kwargs) + self.method = method def build(self, llh_name: Optional[str] = None): """Cache the map to jnp.array.""" @@ -248,9 +251,19 @@ def build_regbin(self, data): if len(self.coordinate_lowers) == 1: setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_1d) elif len(self.coordinate_lowers) == 2: - setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_2d) - elif len(self.coordinate_lowers) == 3: - setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_3d) + if self.method == "IDW": + setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_2d) + elif self.method == "NN": + setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_nearest_neighbor_2d) + else: + raise ValueError(f"Unknown method {self.method} for 2D regular binning.") + elif len(self.coordinate_lowers) == 3 and self.method == "IDW": + if self.method == "IDW": + setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_3d) + elif self.method == "NN": + setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_nearest_neighbor_3d) + else: + raise ValueError(f"Unknown method {self.method} for 3D regular binning.") if self.coordinate_type == "log_regbin": if jnp.any(self.coordinate_lowers <= 0) or jnp.any(self.coordinate_uppers <= 0): raise ValueError( diff --git a/appletree/interpolation.py b/appletree/interpolation.py index 71db3612..f9474fd9 100644 --- a/appletree/interpolation.py +++ b/appletree/interpolation.py @@ -257,3 +257,62 @@ def map_interpolator_regular_binning_3d(pos, ref_pos_lowers, ref_pos_uppers, ref ) return val + + +@jit +def find_nearest_indices(x, y): + x = x[:, jnp.newaxis] + differences = jnp.abs(x - y) + indices = jnp.argmin(differences, axis=1) + return indices + + + +@export +@jit +def map_interpolator_regular_binning_nearest_neighbor_2d(pos, ref_pos_lowers, ref_pos_uppers, ref_val): + """Nearest neighbor 2D interpolation. A uniform mesh grid binning is assumed. + + Args: + pos: array with shape (N, 2), positions at which the interp is calculated. + ref_pos_lowers: array with shape (2,), the lower edges of the binning on each dimension. + ref_pos_uppers: array with shape (2,), the upper edges of the binning on each dimension. + ref_val: array with shape (M1, M2), map values. + + """ + n0, n1 = ref_val.shape + + bins0 = jnp.linspace(ref_pos_lowers[0], ref_pos_uppers[0], n0) + ind0 = find_nearest_indices(pos[:, 0], bins0) + + bins1 = jnp.linspace(ref_pos_lowers[1], ref_pos_uppers[1], n1) + ind1 = find_nearest_indices(pos[:, 1], bins1) + + val = ref_val[ind0, ind1] + return val + +@export +@jit +def map_interpolator_regular_binning_nearest_neighbor_3d(pos, ref_pos_lowers, ref_pos_uppers, ref_val): + """Nearest neighbor 3D interpolation. A uniform mesh grid binning is assumed. + + Args: + pos: array with shape (N, 3), positions at which the interp is calculated. + ref_pos_lowers: array with shape (3,), the lower edges of the binning on each dimension. + ref_pos_uppers: array with shape (3,), the upper edges of the binning on each dimension. + ref_val: array with shape (M1, M2, M3), map values. + + """ + n0, n1, n2 = ref_val.shape + + bins0 = jnp.linspace(ref_pos_lowers[0], ref_pos_uppers[0], n0) + ind0 = find_nearest_indices(pos[:, 0], bins0) + + bins1 = jnp.linspace(ref_pos_lowers[1], ref_pos_uppers[1], n1) + ind1 = find_nearest_indices(pos[:, 1], bins1) + + bins2 = jnp.linspace(ref_pos_lowers[2], ref_pos_uppers[2], n2) + ind2 = find_nearest_indices(pos[:, 2], bins2) + + val = ref_val[ind0, ind1, ind2] + return val From e48cbb8e876684bf1e022bd7d37727c405558ab3 Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Wed, 12 Jun 2024 20:20:32 +0800 Subject: [PATCH 2/3] update sigma map as well --- appletree/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/appletree/config.py b/appletree/config.py index b0bff12e..7f294f6e 100644 --- a/appletree/config.py +++ b/appletree/config.py @@ -315,6 +315,9 @@ class SigmaMap(Config): in Component.needed_parameters. """ + def __init__(self, method="IDW", **kwargs): + super().__init__(**kwargs) + self.method = method def build(self, llh_name: Optional[str] = None): """Read maps.""" @@ -343,7 +346,7 @@ def build(self, llh_name: Optional[str] = None): ) # If only one file is given, then use the same file for all sigmas default = _configs_default - maps[sigma] = Map(name=self.name + f"_{sigma}", default=default) + maps[sigma] = Map(method=self.method, name=self.name + f"_{sigma}", default=default) setattr(self, sigma, maps[sigma]) From 7ed47f9df00449f46a31f2fd1319d9caad27c966 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 13:04:00 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- appletree/config.py | 14 ++++++++++++-- appletree/interpolation.py | 10 +++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/appletree/config.py b/appletree/config.py index 7f294f6e..2fc76145 100644 --- a/appletree/config.py +++ b/appletree/config.py @@ -148,6 +148,7 @@ class Map(Config): When using log-binning, we will first convert the positions to log space. """ + def __init__(self, method="IDW", **kwargs): super().__init__(**kwargs) self.method = method @@ -254,14 +255,22 @@ def build_regbin(self, data): if self.method == "IDW": setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_2d) elif self.method == "NN": - setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_nearest_neighbor_2d) + setattr( + self, + "interpolator", + interpolation.map_interpolator_regular_binning_nearest_neighbor_2d, + ) else: raise ValueError(f"Unknown method {self.method} for 2D regular binning.") elif len(self.coordinate_lowers) == 3 and self.method == "IDW": if self.method == "IDW": setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_3d) elif self.method == "NN": - setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_nearest_neighbor_3d) + setattr( + self, + "interpolator", + interpolation.map_interpolator_regular_binning_nearest_neighbor_3d, + ) else: raise ValueError(f"Unknown method {self.method} for 3D regular binning.") if self.coordinate_type == "log_regbin": @@ -315,6 +324,7 @@ class SigmaMap(Config): in Component.needed_parameters. """ + def __init__(self, method="IDW", **kwargs): super().__init__(**kwargs) self.method = method diff --git a/appletree/interpolation.py b/appletree/interpolation.py index f9474fd9..d11103ce 100644 --- a/appletree/interpolation.py +++ b/appletree/interpolation.py @@ -267,10 +267,11 @@ def find_nearest_indices(x, y): return indices - @export @jit -def map_interpolator_regular_binning_nearest_neighbor_2d(pos, ref_pos_lowers, ref_pos_uppers, ref_val): +def map_interpolator_regular_binning_nearest_neighbor_2d( + pos, ref_pos_lowers, ref_pos_uppers, ref_val +): """Nearest neighbor 2D interpolation. A uniform mesh grid binning is assumed. Args: @@ -291,9 +292,12 @@ def map_interpolator_regular_binning_nearest_neighbor_2d(pos, ref_pos_lowers, re val = ref_val[ind0, ind1] return val + @export @jit -def map_interpolator_regular_binning_nearest_neighbor_3d(pos, ref_pos_lowers, ref_pos_uppers, ref_val): +def map_interpolator_regular_binning_nearest_neighbor_3d( + pos, ref_pos_lowers, ref_pos_uppers, ref_val +): """Nearest neighbor 3D interpolation. A uniform mesh grid binning is assumed. Args: