Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Linear Interpolation (Lerp) Imputation Method #459

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .locf import LOCF
from .mean import Mean
from .median import Median
from .lerp import Lerp

__all__ = [
# neural network imputation methods
Expand Down Expand Up @@ -76,4 +77,5 @@
"LOCF",
"Mean",
"Median",
"Lerp",
]
12 changes: 12 additions & 0 deletions pypots/imputation/lerp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The package of the partially-observed time-series imputation method linear interpolation.
"""

# Created by Cole Sussmeier <colesussmeier@gmail.com>
# License: BSD-3-Clause

from .model import Lerp

__all__ = [
"Lerp",
]
160 changes: 160 additions & 0 deletions pypots/imputation/lerp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
The implementation of linear interpolation for the partially-observed time-series imputation task.
"""

# Created by Cole Sussmeier <colesussmeier@gmail.com>
# License: BSD-3-Clause

import warnings
from typing import Union, Optional

import h5py
import numpy as np
import torch

from ..base import BaseImputer


class Lerp(BaseImputer):
"""Linear interpolation (Lerp) imputation method.

Lerp will linearly interpolate missing values between the nearest non-missing values.
If there are missing values at the beginning or end of the series, they will be back-filled or forward-filled with the nearest non-missing value, respectively.
If an entire series is empty, all 'nan' values will be filled with zeros.
"""

def __init__(
self,
):
super().__init__()

def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "hdf5",
) -> None:
"""Train the imputer on the given data.

Warnings
--------
Linear interpolation class does not need to run fit().
Please run func ``predict()`` directly.
"""
warnings.warn(
"Linear interpolation class has no parameter to train. "
"Please run func `predict()` directly."
)

def predict(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
"""Make predictions for the input data with the trained model.

Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

file_type :
The type of the given file if test_set is a path string.

Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
"""
if isinstance(test_set, str):
with h5py.File(test_set, "r") as f:
X = f["X"][:]
else:
X = test_set["X"]

assert len(X.shape) == 3, (
f"Input X should have 3 dimensions [n_samples, n_steps, n_features], "
f"but the actual shape of X: {X.shape}"
)
if isinstance(X, list):
X = np.asarray(X)

def _interpolate_missing_values(X: np.ndarray):
nans = np.isnan(X)
nan_index = np.where(nans)[0]
index = np.where(~nans)[0]
if np.any(nans) and index.size > 1:
X[nans] = np.interp(nan_index, index, X[~nans])
elif np.any(nans):
X[nans] = 0

if isinstance(X, np.ndarray):

trans_X = X.transpose((0, 2, 1))
n_samples, n_features, n_steps = trans_X.shape
reshaped_X = np.reshape(trans_X, (-1, n_steps))
imputed_X = np.ones(reshaped_X.shape)

for i, univariate_series in enumerate(reshaped_X):
t = np.copy(univariate_series)
_interpolate_missing_values(t)
imputed_X[i] = t

imputed_trans_X = np.reshape(imputed_X, (n_samples, n_features, -1))
imputed_data = imputed_trans_X.transpose((0, 2, 1))

elif isinstance(X, torch.Tensor):

trans_X = X.permute(0, 2, 1)
n_samples, n_features, n_steps = trans_X.shape
reshaped_X = trans_X.reshape(-1, n_steps)
imputed_X = torch.ones_like(reshaped_X)

for i, univariate_series in enumerate(reshaped_X):
t = univariate_series.clone().cpu().detach().numpy()
_interpolate_missing_values(t)
imputed_X[i] = torch.from_numpy(t)

imputed_trans_X = imputed_X.reshape(n_samples, n_features, -1)
imputed_data = imputed_trans_X.permute(0, 2, 1)

else:
raise ValueError()

result_dict = {
"imputation": imputed_data,
}
return result_dict

def impute(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.

Parameters
----------
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
The type of the given file if X is a path string.

Returns
-------
array-like, shape [n_samples, sequence length (n_steps), n_features],
Imputed data.
"""

result_dict = self.predict(test_set, file_type=file_type)
return result_dict["imputation"]
74 changes: 74 additions & 0 deletions tests/imputation/lerp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Test cases for Linear Interpolation(Lerp) imputation method.
"""

# Created by Cole Sussmeier <colesussmeier@gmail.com>
# License: BSD-3-Clause


import unittest

import numpy as np
import pytest
import torch

from pypots.imputation import Lerp
from pypots.utils.logging import logger
from pypots.utils.metrics import calc_mse
from tests.global_test_config import (
DATA,
TEST_SET,
GENERAL_H5_TRAIN_SET_PATH,
GENERAL_H5_VAL_SET_PATH,
GENERAL_H5_TEST_SET_PATH,
)


class TestLerp(unittest.TestCase):
logger.info("Running tests for an imputation model Lerp...")
lerp = Lerp()

@pytest.mark.xdist_group(name="imputation-lerp")
def test_0_impute(self):
# if input data is numpy ndarray
test_X_imputed = self.lerp.predict(TEST_SET)["imputation"]
assert not np.isnan(
test_X_imputed
).any(), "Output still has missing values after running impute()."
test_MSE = calc_mse(
test_X_imputed, DATA["test_X_ori"], DATA["test_X_indicating_mask"]
)
logger.info(f"Lerp test_MSE: {test_MSE}")

# if input data is torch tensor
X = torch.from_numpy(np.copy(TEST_SET["X"]))
test_X_ori = torch.from_numpy(np.copy(DATA["test_X_ori"]))
test_X_indicating_mask = torch.from_numpy(
np.copy(DATA["test_X_indicating_mask"])
)

test_X_imputed = self.lerp.predict({"X": X})["imputation"]
assert not torch.isnan(
test_X_imputed
).any(), "Output still has missing values after running impute()."
test_MSE = calc_mse(test_X_imputed, test_X_ori, test_X_indicating_mask)
logger.info(f"Lerp test_MSE: {test_MSE}")

@pytest.mark.xdist_group(name="imputation-lerp")
def test_4_lazy_loading(self):
self.lerp.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)
imputation_results = self.lerp.predict(GENERAL_H5_TEST_SET_PATH)
assert not np.isnan(
imputation_results["imputation"]
).any(), "Output still has missing values after running impute()."

test_MSE = calc_mse(
imputation_results["imputation"],
DATA["test_X_ori"],
DATA["test_X_indicating_mask"],
)
logger.info(f"Lazy-loading Lerp test_MSE: {test_MSE}")


if __name__ == "__main__":
unittest.main()
Loading