diff --git a/README.md b/README.md index 987f0b6..0f91375 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,20 @@ This repository contains a project for visualizing warping distortions such as t ## Usage +### PyPi Package + +Simply install the package as such `pip install elastic_warping_vis` and then use it as follows: +```Python +from elastic_warping_vis.utils import load_data +from elastic_warping_vis.draw_functions import draw_elastic_gif, draw_elastic + +# ignore third output, used for dev version +X, y, _ = load_data(name="ECG200", split="train", znormalize=True) +draw_elastic(x, y, metric="dtw") +draw_elastic_gif(x, y, metric="dtw") + +``` + ### Prerequisites - Python >= 3.10 diff --git a/elastic_warp_vis/__init__.py b/elastic_warping_vis/__init__.py similarity index 100% rename from elastic_warp_vis/__init__.py rename to elastic_warping_vis/__init__.py diff --git a/elastic_warp_vis/draw_functions.py b/elastic_warping_vis/draw_functions.py similarity index 97% rename from elastic_warp_vis/draw_functions.py rename to elastic_warping_vis/draw_functions.py index 77404f8..ad4a90d 100644 --- a/elastic_warp_vis/draw_functions.py +++ b/elastic_warping_vis/draw_functions.py @@ -13,7 +13,7 @@ from aeon.distances import cost_matrix from aeon.distances._alignment_paths import compute_min_return_path -from elastic_warp_vis.utils import alignment_path_to_plot +from elastic_warping_vis.utils import alignment_path_to_plot def draw_elastic( @@ -66,7 +66,10 @@ def draw_elastic( _x = np.copy(x) _y = np.copy(y) - _cost_matrix = cost_matrix(x=_x, y=_y, metric=metric, **metric_params) + if metric_params is not None: + _cost_matrix = cost_matrix(x=_x, y=_y, metric=metric, **metric_params) + else: + _cost_matrix = cost_matrix(x=_x, y=_y, metric=metric) optimal_path = compute_min_return_path(cost_matrix=_cost_matrix) path_dtw_x, path_dtw_y = alignment_path_to_plot(path_dtw=optimal_path) @@ -261,7 +264,10 @@ def draw_elastic_gif( _x = np.copy(x) _y = np.copy(y) - _cost_matrix = cost_matrix(x=_x, y=_y, metric=metric, **metric_params) + if metric_params is not None: + _cost_matrix = cost_matrix(x=_x, y=_y, metric=metric, **metric_params) + else: + _cost_matrix = cost_matrix(x=_x, y=_y, metric=metric) optimal_path = compute_min_return_path(cost_matrix=_cost_matrix) path_dtw_x, path_dtw_y = alignment_path_to_plot(path_dtw=optimal_path) diff --git a/elastic_warp_vis/utils.py b/elastic_warping_vis/utils.py similarity index 100% rename from elastic_warp_vis/utils.py rename to elastic_warping_vis/utils.py diff --git a/main.py b/main.py index e67d76a..930d982 100644 --- a/main.py +++ b/main.py @@ -4,8 +4,8 @@ import hydra from omegaconf import DictConfig, OmegaConf -from elastic_warp_vis.utils import create_directory, load_data -from elastic_warp_vis.draw_functions import draw_elastic, draw_elastic_gif +from elastic_warping_vis.utils import create_directory, load_data +from elastic_warping_vis.draw_functions import draw_elastic, draw_elastic_gif @hydra.main(config_name="config_hydra.yaml", config_path="config")