forked from Kolkir/code2seq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcode2seq.py
44 lines (39 loc) · 1.29 KB
/
code2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import tensorflow as tf
from config import Config
from interactive_predict import InteractivePredictor
from modelrunner import ModelRunner
from args import read_args
if __name__ == "__main__":
physical_devices = tf.config.list_physical_devices("GPU")
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
args = read_args()
np.random.seed(args.seed)
tf.random.set_seed(args.seed)
if args.debug:
config = Config.get_debug_config(args)
tf.config.experimental_run_functions_eagerly(True)
else:
config = Config.get_default_config(args)
print("Created model")
if config.TRAIN_PATH:
model = ModelRunner(config)
model.train()
if config.TEST_PATH and not args.data_path:
model = ModelRunner(config)
results, precision, recall, f1, rouge = model.evaluate()
print("Accuracy: " + str(results))
print(
"Precision: "
+ str(precision)
+ ", recall: "
+ str(recall)
+ ", F1: "
+ str(f1)
)
print("Rouge: ", rouge)
if args.predict:
model = ModelRunner(config)
predictor = InteractivePredictor(config, model, args.predict)
predictor.predict()