@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
29
29
30
30
#include " vast/Conversion/Parser/Config.hpp"
31
31
32
+ #include " vast/server/server.hpp"
33
+ #include " vast/server/types.hpp"
34
+
32
35
#include < ranges>
33
36
34
37
namespace vast ::conv {
@@ -75,6 +78,141 @@ namespace vast::conv {
75
78
76
79
using function_models = llvm::StringMap< function_model >;
77
80
81
+ struct location
82
+ {
83
+ std::string filePath;
84
+ server::range range;
85
+ };
86
+
87
+ location get_location (file_loc_t loc) {
88
+ return {
89
+ .filePath = loc.getFilename ().str (),
90
+ .range = {
91
+ .start = { loc.getLine (), loc.getColumn (), },
92
+ .end = { loc.getLine (), loc.getColumn (), },
93
+ },
94
+ };
95
+ }
96
+
97
+ location get_location (name_loc_t loc) {
98
+ return get_location (mlir::cast< file_loc_t >(loc.getChildLoc ()));
99
+ }
100
+
101
+ std::optional< location > get_location (loc_t loc) {
102
+ if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103
+ return get_location (file_loc);
104
+ } else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105
+ return get_location (name_loc);
106
+ }
107
+
108
+ return std::nullopt;
109
+ }
110
+
111
+ pr::data_type parse_type_name (const std::string &name) {
112
+ if (name == " data" ) {
113
+ return pr::data_type::data;
114
+ } else if (name == " nodata" ) {
115
+ return pr::data_type::nodata;
116
+ } else {
117
+ return pr::data_type::maybedata;
118
+ }
119
+ }
120
+
121
+ function_category ask_user_for_category (vast::server::server_base &server, hl::FuncOp op) {
122
+ auto loc = op.getLoc ();
123
+
124
+ vast::server::input_request req{
125
+ .type = {" nonparser" , " sink" , " source" , " parser" ,},
126
+ .text = " Please choose category for function `" + op.getSymName ().str () + ' `' ,
127
+ .filePath = std::nullopt,
128
+ .range = std::nullopt,
129
+ };
130
+
131
+ if (auto req_loc = get_location (loc)) {
132
+ req.filePath = req_loc->filePath ;
133
+ req.range = req_loc->range ;
134
+ }
135
+
136
+ auto response = server.send_request (req);
137
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
138
+ {
139
+ if (result->value == " nonparser" ) {
140
+ return function_category::nonparser;
141
+ } else if (result->value == " sink" ) {
142
+ return function_category::sink;
143
+ } else if (result->value == " source" ) {
144
+ return function_category::source;
145
+ } else if (result->value == " parser" ) {
146
+ return function_category::parser;
147
+ }
148
+ }
149
+ return function_category::nonparser;
150
+ }
151
+
152
+ pr::data_type ask_user_for_return_type (vast::server::server_base &server, hl::FuncOp op) {
153
+ auto loc = op.getLoc ();
154
+
155
+ vast::server::input_request req{
156
+ .type = { " maybedata" , " nodata" , " data" },
157
+ .text = " Please choose return type for function `" + op.getSymName ().str () + ' `' ,
158
+ .filePath = std::nullopt,
159
+ .range = std::nullopt,
160
+ };
161
+
162
+ if (auto req_loc = get_location (loc)) {
163
+ req.filePath = req_loc->filePath ;
164
+ req.range = req_loc->range ;
165
+ }
166
+
167
+ auto response = server.send_request (req);
168
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
169
+ {
170
+ return parse_type_name (result->value );
171
+ }
172
+ return pr::data_type::maybedata;
173
+ }
174
+
175
+ pr::data_type ask_user_for_argument_type (
176
+ vast::server::server_base &server, hl::FuncOp op, unsigned int idx
177
+ ) {
178
+ auto num_body_args = op.getFunctionBody ().getNumArguments ();
179
+
180
+ vast::server::input_request req{
181
+ .type = { " maybedata" , " nodata" , " data" },
182
+ .text = " Please choose a type for argument " + std::to_string (idx)
183
+ + " of function `" + op.getSymName ().str () + ' `' ,
184
+ .filePath = std::nullopt,
185
+ .range = std::nullopt,
186
+ };
187
+
188
+ if (idx < num_body_args) {
189
+ auto arg = op.getArgument (idx);
190
+ auto loc = arg.getLoc ();
191
+ if (auto req_loc = get_location (loc)) {
192
+ req.filePath = req_loc->filePath ;
193
+ req.range = req_loc->range ;
194
+ }
195
+ }
196
+
197
+ auto response = server.send_request (req);
198
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
199
+ {
200
+ return parse_type_name (result->value );
201
+ }
202
+ return pr::data_type::maybedata;
203
+ }
204
+
205
+ function_model
206
+ ask_user_for_function_model (vast::server::server_base &server, hl::FuncOp op) {
207
+ function_model model;
208
+ model.return_type = ask_user_for_return_type (server, op);
209
+ for (unsigned int i = 0 ; i < op.getNumArguments (); ++i) {
210
+ model.arguments .push_back (ask_user_for_argument_type (server, op, i));
211
+ }
212
+ model.category = ask_user_for_category (server, op);
213
+ return model;
214
+ }
215
+
78
216
} // namespace vast::conv
79
217
80
218
LLVM_YAML_IS_SEQUENCE_VECTOR (vast::pr::data_type);
@@ -130,25 +268,28 @@ namespace vast::conv {
130
268
using base = base_conversion_config;
131
269
132
270
parser_conversion_config (
133
- rewrite_pattern_set patterns, conversion_target target,
134
- const function_models &models
271
+ rewrite_pattern_set patterns, conversion_target target, function_models &models,
272
+ vast::server::server_base *server
135
273
)
136
- : base(std::move(patterns), std::move(target)), models(models)
137
- {}
274
+ : base(std::move(patterns), std::move(target)), models(models), server(server) {}
138
275
139
276
template < typename pattern >
140
277
void add_pattern () {
141
278
auto ctx = patterns.getContext ();
142
279
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143
280
patterns.template add < pattern >(ctx);
144
- } else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145
- patterns.template add < pattern >(ctx, models);
281
+ } else if constexpr (std::is_constructible_v<
282
+ pattern, mcontext_t *, function_models &,
283
+ vast::server::server_base * >)
284
+ {
285
+ patterns.template add < pattern >(ctx, models, server);
146
286
} else {
147
287
static_assert (false , " pattern does not have a valid constructor" );
148
288
}
149
289
}
150
290
151
- const function_models ⊧
291
+ function_models ⊧
292
+ vast::server::server_base *server;
152
293
};
153
294
154
295
struct function_type_converter
@@ -277,24 +418,33 @@ namespace vast::conv {
277
418
{
278
419
using base = mlir::OpConversionPattern< op_t >;
279
420
280
- parser_conversion_pattern_base (mcontext_t *mctx, const function_models &models)
281
- : base(mctx), models(models)
282
- {}
421
+ parser_conversion_pattern_base (
422
+ mcontext_t *mctx, function_models &models, vast::server::server_base *server
423
+ )
424
+ : base(mctx), models(models), server(server) {}
283
425
284
- static std::optional< function_model >
285
- get_model (const function_models &models, hl::FuncOp func) {
426
+ static std::optional< function_model > get_model (
427
+ function_models &models, hl::FuncOp func, vast::server::server_base *server
428
+ ) {
286
429
if (auto kv = models.find (func.getSymName ()); kv != models.end ()) {
287
430
return kv->second ;
288
431
}
289
432
433
+ if (server) {
434
+ auto model = ask_user_for_function_model (*server, func);
435
+ models[func.getSymName ()] = model;
436
+ return model;
437
+ }
438
+
290
439
return std::nullopt;
291
440
}
292
441
293
442
std::optional< function_model > get_model (hl::FuncOp func) const {
294
- return get_model (models, func);
443
+ return get_model (models, func, server );
295
444
}
296
445
297
- const function_models ⊧
446
+ function_models ⊧
447
+ vast::server::server_base *server;
298
448
};
299
449
300
450
//
@@ -541,10 +691,13 @@ namespace vast::conv {
541
691
return mlir::failure ();
542
692
}
543
693
544
- static void legalize (parser_conversion_config &cfg) {
694
+ static void
695
+ legalize (parser_conversion_config &cfg, vast::server::server_base *server) {
545
696
cfg.target .addLegalOp < mlir::UnrealizedConversionCastOp >();
546
- cfg.target .addDynamicallyLegalOp < op_t >([models = cfg.models ](op_t op) {
547
- return function_type_converter (*op.getContext (), get_model (models, op))
697
+ cfg.target .addDynamicallyLegalOp < op_t >([&cfg, server](op_t op) {
698
+ return function_type_converter (
699
+ *op.getContext (), get_model (cfg.models , op, server)
700
+ )
548
701
.isLegal (op.getFunctionType ());
549
702
});
550
703
}
@@ -708,6 +861,9 @@ namespace vast::conv {
708
861
{
709
862
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
710
863
864
+ struct server_handler
865
+ {};
866
+
711
867
static conversion_target create_conversion_target (mcontext_t &mctx) {
712
868
return conversion_target (mctx);
713
869
}
@@ -722,6 +878,12 @@ namespace vast::conv {
722
878
if (!config.empty ()) {
723
879
load_and_parse (config);
724
880
}
881
+
882
+ if (!socket.empty ()) {
883
+ server = std::make_shared< vast::server::server< server_handler > >(
884
+ vast::server::sock_adapter::create_unix_socket (socket)
885
+ );
886
+ }
725
887
}
726
888
727
889
void load_and_parse (string_ref config) {
@@ -748,10 +910,12 @@ namespace vast::conv {
748
910
749
911
parser_conversion_config make_config () {
750
912
auto &ctx = getContext ();
751
- return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models };
913
+ return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models,
914
+ server.get () };
752
915
}
753
916
754
917
function_models models;
918
+ std::shared_ptr< vast::server::server< server_handler > > server;
755
919
};
756
920
757
921
} // namespace vast::conv
0 commit comments