Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix when metric params is none #7

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ This repository contains a project for visualizing warping distortions such as t

## Usage

### PyPi Package

Simply install the package as such `pip install elastic_warping_vis` and then use it as follows:
```Python
from elastic_warping_vis.utils import load_data
from elastic_warping_vis.draw_functions import draw_elastic_gif, draw_elastic

# ignore third output, used for dev version
X, y, _ = load_data(name="ECG200", split="train", znormalize=True)
draw_elastic(x, y, metric="dtw")
draw_elastic_gif(x, y, metric="dtw")

```

### Prerequisites

- Python >= 3.10
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aeon.distances import cost_matrix
from aeon.distances._alignment_paths import compute_min_return_path

from elastic_warp_vis.utils import alignment_path_to_plot
from elastic_warping_vis.utils import alignment_path_to_plot


def draw_elastic(
Expand Down Expand Up @@ -66,7 +66,10 @@ def draw_elastic(
_x = np.copy(x)
_y = np.copy(y)

_cost_matrix = cost_matrix(x=_x, y=_y, metric=metric, **metric_params)
if metric_params is not None:
_cost_matrix = cost_matrix(x=_x, y=_y, metric=metric, **metric_params)
else:
_cost_matrix = cost_matrix(x=_x, y=_y, metric=metric)
optimal_path = compute_min_return_path(cost_matrix=_cost_matrix)
path_dtw_x, path_dtw_y = alignment_path_to_plot(path_dtw=optimal_path)

Expand Down Expand Up @@ -261,7 +264,10 @@ def draw_elastic_gif(
_x = np.copy(x)
_y = np.copy(y)

_cost_matrix = cost_matrix(x=_x, y=_y, metric=metric, **metric_params)
if metric_params is not None:
_cost_matrix = cost_matrix(x=_x, y=_y, metric=metric, **metric_params)
else:
_cost_matrix = cost_matrix(x=_x, y=_y, metric=metric)
optimal_path = compute_min_return_path(cost_matrix=_cost_matrix)
path_dtw_x, path_dtw_y = alignment_path_to_plot(path_dtw=optimal_path)

Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import hydra
from omegaconf import DictConfig, OmegaConf

from elastic_warp_vis.utils import create_directory, load_data
from elastic_warp_vis.draw_functions import draw_elastic, draw_elastic_gif
from elastic_warping_vis.utils import create_directory, load_data
from elastic_warping_vis.draw_functions import draw_elastic, draw_elastic_gif


@hydra.main(config_name="config_hydra.yaml", config_path="config")
Expand Down
Loading