forked from Kolkir/code2seq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinteractive_predict.py
86 lines (79 loc) · 3.23 KB
/
interactive_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from common import Common
from java_extractor import JavaExtractor
from cpp_extractor import CppExtractor
SHOW_TOP_CONTEXTS = 10
MAX_PATH_LENGTH = 8
MAX_PATH_WIDTH = 2
EXTRACTION_API = (
"https://po3g2dx2qa.execute-api.us-east-1.amazonaws.com/production/extractmethods"
)
class InteractivePredictor:
exit_keywords = ["exit", "quit", "q"]
def __init__(self, config, model, language):
self.model = model
self.config = config
if language == "java":
self.path_extractor = JavaExtractor(
config, EXTRACTION_API, self.config.MAX_PATH_LENGTH, max_path_width=2
)
elif language == "cpp":
self.path_extractor = CppExtractor(config)
else:
assert False, "Unsupported language model"
@staticmethod
def read_file(input_filename):
with open(input_filename, "r") as file:
return file.readlines()
def predict(self):
input_filename = "Input.source"
print("Serving")
while True:
print(
'Modify the file: "'
+ input_filename
+ '" and press any key when ready, or "q" / "exit" to exit'
)
user_input = input()
if user_input.lower() in self.exit_keywords:
print("Exiting...")
return
user_input = " ".join(self.read_file(input_filename))
try:
predict_lines, pc_info_dict = self.path_extractor.extract_paths(
user_input
)
except ValueError:
continue
model_results = self.model.predict(predict_lines)
prediction_results = Common.parse_results(
model_results, pc_info_dict, topk=SHOW_TOP_CONTEXTS
)
for index, method_prediction in prediction_results.items():
print("Original name:\t" + method_prediction.original_name)
if self.config.BEAM_WIDTH == 0:
print(
"Predicted:\t%s"
% [step.prediction for step in method_prediction.predictions]
)
for timestep, single_timestep_prediction in enumerate(
method_prediction.predictions
):
print("Attention:")
print(
"TIMESTEP: %d\t: %s"
% (timestep, single_timestep_prediction.prediction)
)
for attention_obj in single_timestep_prediction.attention_paths:
print(
"%f\tcontext: %s,%s,%s"
% (
attention_obj["score"],
attention_obj["token1"],
attention_obj["path"],
attention_obj["token2"],
)
)
else:
print("Predicted:")
for predicted_seq in method_prediction.predictions:
print("\t%s" % predicted_seq.prediction)