-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
27 lines (22 loc) · 914 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from utils import load_data, introduce_nan, save_inference
from imputer.models import mice_impute, knn_impute, gain_impute
from evaluate import evaluate
########################
#########CONFIG#########
########################
DATASET_PATH = './data/pressure-data.xlsx' #Dataset to Impute
SHEET_NAME = 'Back'
MODEL = 'MICE' #Specify the Imputation Model to Impute
K = 3 #If the model is KNN, sepicify the K value
SAVE_PATH = './results/'
NUM_ROWS = 1000 #Num of Rows ie. Time Series
actual_df = load_data(data_path=DATASET_PATH, sheet_name=SHEET_NAME, n_rows=1000)
nan_df = introduce_nan(actual_df)
if MODEL=='KNN':
imputed_df = knn_impute(df=nan_df, k=K)
elif MODEL=='MICE':
imputed_df = mice_impute(df=nan_df)
elif MODEL=='GAIN':
imputed_df = gain_impute(df=nan_df)
error_df = evaluate(actual_df=actual_df, imputed_df=imputed_df)
save_inference(actual_df, error_df, model=MODEL, path=SAVE_PATH)