Skip to content

Commit 93cd954

Browse files
authored
infra: format all .py files with black (aws#2223)
1 parent 15f53ea commit 93cd954

File tree

746 files changed

+36993
-27194
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

746 files changed

+36993
-27194
lines changed

advanced_functionality/autogluon-tabular/container-training/inference.py

+37-29
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import subprocess
99
import copy
1010

11-
warnings.filterwarnings('ignore', category=FutureWarning)
11+
warnings.filterwarnings("ignore", category=FutureWarning)
1212

1313
import numpy as np
1414
import pandas as pd
@@ -19,48 +19,53 @@
1919
from collections import Counter
2020

2121
with warnings.catch_warnings():
22-
warnings.filterwarnings('ignore', category=DeprecationWarning)
22+
warnings.filterwarnings("ignore", category=DeprecationWarning)
2323
from prettytable import PrettyTable
2424
from autogluon.tabular import TabularPredictor, TabularDataset
2525

26+
2627
def make_str_table(df):
27-
table = PrettyTable(['index']+list(df.columns))
28+
table = PrettyTable(["index"] + list(df.columns))
2829
for row in df.itertuples():
2930
table.add_row(row)
3031
return str(table)
3132

33+
3234
def take(n, iterable):
3335
"Return first n items of the iterable as a list"
3436
return list(islice(iterable, n))
3537

38+
3639
def preprocess(df, columns, target):
3740
features = copy.deepcopy(columns)
3841
features.remove(target)
39-
first_row_list = df.iloc[0].tolist()
42+
first_row_list = df.iloc[0].tolist()
4043

4144
if set(first_row_list) >= set(features):
4245
df.drop(0, inplace=True)
4346
if len(first_row_list) == len(columns):
4447
df.columns = columns
4548
if len(first_row_list) == len(features):
4649
df.columns = features
47-
50+
4851
return df
4952

53+
5054
# ------------------------------------------------------------ #
5155
# Hosting methods #
5256
# ------------------------------------------------------------ #
5357

58+
5459
def model_fn(model_dir):
5560
"""
5661
Load the gluon model. Called once when hosting service starts.
5762
:param: model_dir The directory where model files are stored.
5863
:return: a model (in this case a Gluon network) and the column info.
5964
"""
60-
print(f'Loading model from {model_dir} with contents {os.listdir(model_dir)}')
65+
print(f"Loading model from {model_dir} with contents {os.listdir(model_dir)}")
6166

6267
net = TabularPredictor.load(model_dir, verbosity=True)
63-
with open(f'{model_dir}/code/columns.pkl', 'rb') as f:
68+
with open(f"{model_dir}/code/columns.pkl", "rb") as f:
6469
column_dict = pickle.load(f)
6570
return net, column_dict
6671

@@ -77,72 +82,75 @@ def transform_fn(models, data, input_content_type, output_content_type):
7782
start = timer()
7883
net = models[0]
7984
column_dict = models[1]
80-
label_map = net.class_labels_internal_map ###
85+
label_map = net.class_labels_internal_map ###
8186

8287
# text/csv
83-
if 'text/csv' in input_content_type:
88+
if "text/csv" in input_content_type:
8489
# Load dataset
85-
columns = column_dict['columns']
90+
columns = column_dict["columns"]
8691

8792
if type(data) == str:
88-
# Load dataset
93+
# Load dataset
8994
df = pd.read_csv(StringIO(data), header=None)
9095
else:
9196
df = pd.read_csv(StringIO(data.decode()), header=None)
9297

9398
df_preprosessed = preprocess(df, columns, net.label)
9499

95100
ds = TabularDataset(data=df_preprosessed)
96-
101+
97102
try:
98103
predictions = net.predict_proba(ds)
99104
predictions_ = net.predict(ds)
100105
except:
101106
try:
102107
predictions = net.predict_proba(ds.fillna(0.0))
103108
predictions_ = net.predict(ds.fillna(0.0))
104-
warnings.warn('Filled NaN\'s with 0.0 in order to predict.')
109+
warnings.warn("Filled NaN's with 0.0 in order to predict.")
105110
except Exception as e:
106111
response_body = e
107112
return response_body, output_content_type
108113

109-
#threshold = 0.5
110-
#predictions_label = [[k for k, v in label_map.items() if v == 1][0] if i > threshold else [k for k, v in label_map.items() if v == 0][0] for i in predictions]
114+
# threshold = 0.5
115+
# predictions_label = [[k for k, v in label_map.items() if v == 1][0] if i > threshold else [k for k, v in label_map.items() if v == 0][0] for i in predictions]
111116
predictions_label = predictions_.tolist()
112-
113117

114118
# Print prediction counts, limit in case of regression problem
115119
pred_counts = Counter(predictions_label)
116120
n_display_items = 30
117121
if len(pred_counts) > n_display_items:
118-
print(f'Top {n_display_items} prediction counts: '
119-
f'{dict(take(n_display_items, pred_counts.items()))}')
122+
print(
123+
f"Top {n_display_items} prediction counts: "
124+
f"{dict(take(n_display_items, pred_counts.items()))}"
125+
)
120126
else:
121-
print(f'Prediction counts: {pred_counts}')
127+
print(f"Prediction counts: {pred_counts}")
122128

123129
# Form response
124130
output = StringIO()
125131
pd.DataFrame(predictions).to_csv(output, header=False, index=False)
126-
response_body = output.getvalue()
132+
response_body = output.getvalue()
127133

128134
# If target column passed, evaluate predictions performance
129135
target = net.label
130136
if target in ds:
131-
print(f'Label column ({target}) found in input data. '
132-
'Therefore, evaluating prediction performance...')
137+
print(
138+
f"Label column ({target}) found in input data. "
139+
"Therefore, evaluating prediction performance..."
140+
)
133141
try:
134-
performance = net.evaluate_predictions(y_true=ds[target],
135-
y_pred=np.array(predictions_label),
136-
auxiliary_metrics=True)
142+
performance = net.evaluate_predictions(
143+
y_true=ds[target], y_pred=np.array(predictions_label), auxiliary_metrics=True
144+
)
137145
print(json.dumps(performance, indent=4, default=pd.DataFrame.to_json))
138146
time.sleep(0.1)
139147
except Exception as e:
140148
# Print exceptions on evaluate, continue to return predictions
141-
print(f'Exception: {e}')
149+
print(f"Exception: {e}")
142150
else:
143151
raise NotImplementedError("content_type must be 'text/csv'")
144152

145-
elapsed_time = round(timer()-start,3)
146-
print(f'Elapsed time: {round(timer()-start,3)} seconds')
147-
153+
elapsed_time = round(timer() - start, 3)
154+
print(f"Elapsed time: {round(timer()-start,3)} seconds")
155+
148156
return response_body, output_content_type

0 commit comments

Comments
 (0)