Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ImputeFormer, fix RevIN, and update docs #454

Merged
merged 7 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ The paper references and links are all listed at the bottom of this file.

| **Type** | **Algo** | **IMPU** | **FORE** | **CLAS** | **CLUS** | **ANOD** | **Year - Venue** |
|:--------------|:----------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------------|
| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` |
| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` |
| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` |
Expand Down Expand Up @@ -161,7 +162,7 @@ And what else? Please read on ;-)
👈 Time series datasets are taken as coffee beans at PyPOTS, and POTS datasets are incomplete coffee beans with missing parts that have their own meanings.
To make various public time-series datasets readily available to users,
<i>Time Series Data Beans (TSDB)</i> is created to make loading time-series datasets super easy!
Visit [TSDB](https://github.com/WenjieDu/TSDB) right now to know more about this handy tool 🛠, and it now supports a total of 170 open-source datasets!
Visit [TSDB](https://github.com/WenjieDu/TSDB) right now to know more about this handy tool 🛠, and it now supports a total of 172 open-source datasets!

<a href="https://github.com/WenjieDu/PyGrinder">
<img src="https://pypots.com/figs/pypots_logos/PyGrinder/logo_FFBG.svg" align="right" width="140" alt="PyGrinder logo"/>
Expand Down Expand Up @@ -293,9 +294,9 @@ year={2023},
}
```
or
> Wenjie Du. (2023).
> Wenjie Du.
> PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series.
> arXiv, abs/2305.18811. https://arxiv.org/abs/2305.18811
> arXiv, abs/2305.18811, 2023.


## ❖ Contribution
Expand Down Expand Up @@ -380,7 +381,7 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together
[^31]: Kim, T., Kim, J., Tae, Y., Park, C., Choi, J. H., & Choo, J. (2022). [Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift](https://openreview.net/forum?id=cGDAkQo1C0p). *ICLR 2022*.
[^32]: Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). [Reformer: The Efficient Transformer](https://openreview.net/forum?id=0EXmFzUn5I). *ICLR 2020*.
[^33]: Cao, D., Wang, Y., Duan, J., Zhang, C., Zhu, X., Huang, C., Tong, Y., Xu, B., Bai, J., Tong, J., & Zhang, Q. (2020). [Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting](https://proceedings.neurips.cc/paper/2020/hash/cdf6581cb7aca4b7e19ef136c6e601a5-Abstract.html). *NeurIPS 2020*.

[^34]: Nie, T., Qin, G., Mei, Y., & Sun, J. (2024). [ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation](https://arxiv.org/abs/2312.01728). *KDD 2024*.

<details>
<summary>🏠 Visits</summary>
Expand Down
4 changes: 3 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及

| **类型** | **算法** | **插补** | **预测** | **分类** | **聚类** | **异常检测** | **年份 - 刊物** |
|:--------------|:----------------------------|:------:|:------:|:------:|:------:|:--------:|:-----------------|
| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` |
| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` |
| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` |
Expand Down Expand Up @@ -145,7 +146,7 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及

👈 在PyPOTS中,数据可以被看作是咖啡豆,而写的携带缺失值的POTS数据则是不完整的咖啡豆。
为了让用户能够轻松使用各种开源的时间序列数据集,我们创建了开源时间序列数据集的仓库 Time Series Data Beans (TSDB)(可以将其视为咖啡豆仓库),
TSDB让加载开源时序数据集变得超级简单!访问 [TSDB](https://github.com/WenjieDu/TSDB),了解更多关于TSDB的信息,目前总共支持170个开源数据集
TSDB让加载开源时序数据集变得超级简单!访问 [TSDB](https://github.com/WenjieDu/TSDB),了解更多关于TSDB的信息,目前总共支持172个开源数据集

<a href="https://github.com/WenjieDu/PyGrinder">
<img src="https://pypots.com/figs/pypots_logos/PyGrinder/logo_FFBG.svg" align="right" width="140" alt="PyGrinder logo"/>
Expand Down Expand Up @@ -351,6 +352,7 @@ PyPOTS社区是一个开放、透明、友好的社区,让我们共同努力
[^31]: Kim, T., Kim, J., Tae, Y., Park, C., Choi, J. H., & Choo, J. (2022). [Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift](https://openreview.net/forum?id=cGDAkQo1C0p). *ICLR 2022*.
[^32]: Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). [Reformer: The Efficient Transformer](https://openreview.net/forum?id=0EXmFzUn5I). *ICLR 2020*.
[^33]: Cao, D., Wang, Y., Duan, J., Zhang, C., Zhu, X., Huang, C., Tong, Y., Xu, B., Bai, J., Tong, J., & Zhang, Q. (2020). [Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting](https://proceedings.neurips.cc/paper/2020/hash/cdf6581cb7aca4b7e19ef136c6e601a5-Abstract.html). *NeurIPS 2020*.
[^34]: Nie, T., Qin, G., Mei, Y., & Sun, J. (2024). [ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation](https://arxiv.org/abs/2312.01728). *KDD 2024*.


<details>
Expand Down
14 changes: 5 additions & 9 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Welcome to PyPOTS docs!

**A Python Toolbox for Machine Learning on Partially-Observed Time Series**

.. image:: https://img.shields.io/badge/Python-v3.7+-E97040?logo=python&logoColor=white
.. image:: https://img.shields.io/badge/Python-v3.8+-E97040?logo=python&logoColor=white
:alt: Python version
:target: https://docs.pypots.com/en/latest/install.html#reasons-of-version-limitations-on-dependencies

Expand Down Expand Up @@ -220,7 +220,7 @@ And what else? Please read on ;-)
👈 Time series datasets are taken as coffee beans at PyPOTS, and POTS datasets are incomplete coffee beans with missing parts that have their own meanings.
To make various public time-series datasets readily available to users,
*Time Series Data Beans (TSDB)* is created to make loading time-series datasets super easy!
Visit `TSDB <https://github.com/WenjieDu/TSDB>`_ right now to know more about this handy tool 🛠, and it now supports a total of 170 open-source datasets!
Visit `TSDB <https://github.com/WenjieDu/TSDB>`_ right now to know more about this handy tool 🛠, and it now supports a total of 172 open-source datasets!

.. image:: https://pypots.com/figs/pypots_logos/PyGrinder/logo_FFBG.svg
:width: 150
Expand Down Expand Up @@ -298,21 +298,17 @@ please cite it as below and 🌟star `PyPOTS repository <https://github.com/Wenj
@article{du2023pypots,
title={{PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series}},
author={Wenjie Du},
journal={arXiv preprint arXiv:2305.18811},
year={2023},
eprint={2305.18811},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2305.18811},
doi={10.48550/arXiv.2305.18811},
}

or

..

Wenjie Du. (2023).
Wenjie Du.
PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series.
arXiv, abs/2305.18811. https://doi.org/10.48550/arXiv.2305.18811
arXiv, abs/2305.18811, 2023.


❖ Contribution
Expand Down
11 changes: 11 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,14 @@ @inproceedings{xu2024fits
year={2024},
url={https://openreview.net/forum?id=bWcnvZ3qMb}
}

@article{nie2024imputeformer,
title={ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation},
author={Nie, Tong and Qin, Guoyang and Ma, Wei and Mei, Yuewen and Sun, Jian},
booktitle = {Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
publisher = {Association for Computing Machinery},
year={2024},
series = {KDD '24},
doi = {10.1145/3637528.3671751},
url = {https://doi.org/10.1145/3637528.3671751},
}
2 changes: 2 additions & 0 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(
self.empirical_mean = torch.sum(
self.missing_mask * self.X, dim=[0, 1]
) / torch.sum(self.missing_mask, dim=[0, 1])
# fill nan with 0, in case some features have no observations
self.empirical_mean = torch.nan_to_num(self.empirical_mean, 0)

def _fetch_data_from_array(self, idx: int) -> Iterable:
"""Fetch data according to index.
Expand Down
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .tide import TiDE
from .grud import GRUD
from .stemgnn import StemGNN
from .imputeformer import ImputeFormer

# naive imputation methods
from .locf import LOCF
Expand Down Expand Up @@ -70,6 +71,7 @@
"TiDE",
"GRUD",
"StemGNN",
"ImputeFormer",
# naive imputation methods
"LOCF",
"Mean",
Expand Down
20 changes: 20 additions & 0 deletions pypots/imputation/imputeformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
The package of the partially-observed time-series imputation model ImputeFormer.

Refer to the papers
`Tong Nie, Guoyang Qin, Wei Ma, Yuewen Mei, Jian Sun.
"ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation"
KDD 2024.
<https://doi.org/10.48550/arXiv.2312.01728>`_

"""

# Created by Tong Nie <nietong@tongji.edu.cn> and Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import ImputeFormer

__all__ = [
"ImputeFormer",
]
151 changes: 151 additions & 0 deletions pypots/imputation/imputeformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
The core wrapper assembles the submodules of ImputeFormer imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Tong Nie <nietong@tongji.edu.cn> and Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.modules.imputeformer import (
EmbeddedAttentionLayer,
ProjectedAttentionLayer,
MLP,
)
from ...nn.modules.saits import SaitsLoss


class _ImputeFormer(nn.Module):
"""
Spatiotemporal Imputation Transformer induced by low-rank factorization, KDD'24.
Note:
This is a simplified implementation under the SAITS framework (ORT+MIT).
The timestamp encoding is also removed for ease of implementation.
"""

def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_input_embed: int,
d_learnable_embed: int,
d_proj: int,
d_ffn: int,
n_temporal_heads: int,
dropout: float = 0.0,
input_dim: int = 1,
output_dim: int = 1,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.n_nodes = n_features
self.in_steps = n_steps
self.out_steps = n_steps
self.input_dim = input_dim
self.output_dim = output_dim
self.input_embedding_dim = d_input_embed
self.learnable_embedding_dim = d_learnable_embed
self.model_dim = d_input_embed + d_learnable_embed

self.n_temporal_heads = n_temporal_heads
self.num_layers = n_layers
self.input_proj = nn.Linear(input_dim, self.input_embedding_dim)
self.d_proj = d_proj
self.d_ffn = d_ffn

self.learnable_embedding = nn.init.xavier_uniform_(
nn.Parameter(
torch.empty(self.in_steps, self.n_nodes, self.learnable_embedding_dim)
)
)

self.readout = MLP(self.model_dim, self.model_dim, output_dim, n_layers=2)

self.attn_layers_t = nn.ModuleList(
[
ProjectedAttentionLayer(
self.n_nodes,
self.d_proj,
self.model_dim,
self.n_temporal_heads,
self.model_dim,
dropout,
)
for _ in range(self.num_layers)
]
)

self.attn_layers_s = nn.ModuleList(
[
EmbeddedAttentionLayer(
self.model_dim,
self.learnable_embedding_dim,
self.d_ffn,
)
for _ in range(self.num_layers)
]
)

# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
x, missing_mask = inputs["X"], inputs["missing_mask"]

# x: (batch_size, in_steps, num_nodes)
# Note that ImputeFormer is designed for Spatial-Temporal data that has the format [B, S, N, C],
# where N is the number of nodes and C is an additional feature dimension,
# We simply add an extra axis here for implementation.
x = x.unsqueeze(-1) # [b s n c]
missing_mask = missing_mask.unsqueeze(-1) # [b s n c]
batch_size = x.shape[0]
# Whiten missing values
x = x * missing_mask
x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)

# Learnable node embedding
node_emb = self.learnable_embedding.expand(
batch_size, *self.learnable_embedding.shape
)
x = torch.cat(
[x, node_emb], dim=-1
) # (batch_size, in_steps, num_nodes, model_dim)

# Spatial and temporal processing with customized attention layers
x = x.permute(0, 2, 1, 3) # [b n s c]
for att_t, att_s in zip(self.attn_layers_t, self.attn_layers_s):
x = att_t(x)
x = att_s(x, self.learnable_embedding, dim=1)

# Readout
x = x.permute(0, 2, 1, 3) # [b s n c]
reconstruction = self.readout(x)
reconstruction = reconstruction.squeeze(-1) # [b s n]
missing_mask = missing_mask.squeeze(-1) # [b s n]

# Below is the SAITS processing pipeline:
# replace the observed part with values from X
imputed_data = missing_mask * inputs["X"] + (1 - missing_mask) * reconstruction

# ensemble the results as a dictionary for return
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
22 changes: 22 additions & 0 deletions pypots/imputation/imputeformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Dataset class for the imputation model ImputeFormer.
"""

# Created by Tong Nie <nietong@tongji.edu.cn> and Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForImputeFormer(DatasetForSAITS):
def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading
Loading