@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
29
29
30
30
#include " vast/Conversion/Parser/Config.hpp"
31
31
32
+ #include " vast/server/server.hpp"
33
+ #include " vast/server/types.hpp"
34
+
32
35
#include < ranges>
33
36
34
37
namespace vast ::conv {
@@ -75,6 +78,135 @@ namespace vast::conv {
75
78
76
79
using function_models = llvm::StringMap< function_model >;
77
80
81
+ struct location
82
+ {
83
+ std::string filePath;
84
+ server::range range;
85
+ };
86
+
87
+ location get_location (file_loc_t loc) {
88
+ return {
89
+ .filePath = loc.getFilename ().str (),
90
+ .range = {
91
+ .start = { loc.getLine (), loc.getColumn (), },
92
+ .end = { loc.getLine (), loc.getColumn (), },
93
+ },
94
+ };
95
+ }
96
+
97
+ location get_location (name_loc_t loc) {
98
+ return get_location (mlir::cast< file_loc_t >(loc.getChildLoc ()));
99
+ }
100
+
101
+ std::optional< location > get_location (loc_t loc) {
102
+ if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103
+ return get_location (file_loc);
104
+ } else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105
+ return get_location (name_loc);
106
+ }
107
+
108
+ return std::nullopt;
109
+ }
110
+
111
+ pr::data_type parse_type_name (const std::string &name) {
112
+ if (name == " data" ) {
113
+ return pr::data_type::data;
114
+ } else if (name == " nodata" ) {
115
+ return pr::data_type::nodata;
116
+ } else {
117
+ return pr::data_type::maybedata;
118
+ }
119
+ }
120
+
121
+ function_category ask_user_for_category (vast::server::server_base &server, hl::FuncOp op) {
122
+ auto loc = op.getLoc ();
123
+
124
+ vast::server::input_request req{
125
+ .type = {" nonparser" , " sink" , " source" , " parser" ,},
126
+ .text = " Please choose category for function `" + op.getSymName ().str () + ' `' ,
127
+ };
128
+
129
+ if (auto req_loc = get_location (loc)) {
130
+ req.filePath = req_loc->filePath ;
131
+ req.range = req_loc->range ;
132
+ }
133
+
134
+ auto response = server.send_request (req);
135
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
136
+ {
137
+ if (result->value == " nonparser" ) {
138
+ return function_category::nonparser;
139
+ } else if (result->value == " sink" ) {
140
+ return function_category::sink;
141
+ } else if (result->value == " source" ) {
142
+ return function_category::source;
143
+ } else if (result->value == " parser" ) {
144
+ return function_category::parser;
145
+ }
146
+ }
147
+ return function_category::nonparser;
148
+ }
149
+
150
+ pr::data_type ask_user_for_return_type (vast::server::server_base &server, hl::FuncOp op) {
151
+ auto loc = op.getLoc ();
152
+
153
+ vast::server::input_request req{
154
+ .type = { " maybedata" , " nodata" , " data" },
155
+ .text = " Please choose return type for function `" + op.getSymName ().str () + ' `' ,
156
+ };
157
+
158
+ if (auto req_loc = get_location (loc)) {
159
+ req.filePath = req_loc->filePath ;
160
+ req.range = req_loc->range ;
161
+ }
162
+
163
+ auto response = server.send_request (req);
164
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
165
+ {
166
+ return parse_type_name (result->value );
167
+ }
168
+ return pr::data_type::maybedata;
169
+ }
170
+
171
+ pr::data_type ask_user_for_argument_type (
172
+ vast::server::server_base &server, hl::FuncOp op, unsigned int idx
173
+ ) {
174
+ auto num_body_args = op.getFunctionBody ().getNumArguments ();
175
+
176
+ vast::server::input_request req{
177
+ .type = { " maybedata" , " nodata" , " data" },
178
+ .text = " Please choose a type for argument " + std::to_string (idx)
179
+ + " of function `" + op.getSymName ().str () + ' `' ,
180
+ };
181
+
182
+ if (idx < num_body_args) {
183
+ auto arg = op.getArgument (idx);
184
+ auto loc = arg.getLoc ();
185
+ if (auto req_loc = get_location (loc)) {
186
+ req.filePath = req_loc->filePath ;
187
+ req.range = req_loc->range ;
188
+ }
189
+ }
190
+
191
+ auto response = server.send_request (req);
192
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
193
+ {
194
+ return parse_type_name (result->value );
195
+ }
196
+ return pr::data_type::maybedata;
197
+ }
198
+
199
+ function_model
200
+ ask_user_for_function_model (vast::server::server_base &server, hl::FuncOp op) {
201
+ function_model model;
202
+ model.return_type = ask_user_for_return_type (server, op);
203
+ for (unsigned int i = 0 ; i < op.getNumArguments (); ++i) {
204
+ model.arguments .push_back (ask_user_for_argument_type (server, op, i));
205
+ }
206
+ model.category = ask_user_for_category (server, op);
207
+ return model;
208
+ }
209
+
78
210
} // namespace vast::conv
79
211
80
212
LLVM_YAML_IS_SEQUENCE_VECTOR (vast::pr::data_type);
@@ -130,25 +262,28 @@ namespace vast::conv {
130
262
using base = base_conversion_config;
131
263
132
264
parser_conversion_config (
133
- rewrite_pattern_set patterns, conversion_target target,
134
- const function_models &models
265
+ rewrite_pattern_set patterns, conversion_target target, function_models &models,
266
+ vast::server::server_base *server
135
267
)
136
- : base(std::move(patterns), std::move(target)), models(models)
137
- {}
268
+ : base(std::move(patterns), std::move(target)), models(models), server(server) {}
138
269
139
270
template < typename pattern >
140
271
void add_pattern () {
141
272
auto ctx = patterns.getContext ();
142
273
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143
274
patterns.template add < pattern >(ctx);
144
- } else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145
- patterns.template add < pattern >(ctx, models);
275
+ } else if constexpr (std::is_constructible_v<
276
+ pattern, mcontext_t *, function_models &,
277
+ vast::server::server_base * >)
278
+ {
279
+ patterns.template add < pattern >(ctx, models, server);
146
280
} else {
147
281
static_assert (false , " pattern does not have a valid constructor" );
148
282
}
149
283
}
150
284
151
- const function_models ⊧
285
+ function_models ⊧
286
+ vast::server::server_base *server;
152
287
};
153
288
154
289
struct function_type_converter
@@ -277,24 +412,33 @@ namespace vast::conv {
277
412
{
278
413
using base = mlir::OpConversionPattern< op_t >;
279
414
280
- parser_conversion_pattern_base (mcontext_t *mctx, const function_models &models)
281
- : base(mctx), models(models)
282
- {}
415
+ parser_conversion_pattern_base (
416
+ mcontext_t *mctx, function_models &models, vast::server::server_base *server
417
+ )
418
+ : base(mctx), models(models), server(server) {}
283
419
284
- static std::optional< function_model >
285
- get_model (const function_models &models, hl::FuncOp func) {
420
+ static std::optional< function_model > get_model (
421
+ function_models &models, hl::FuncOp func, vast::server::server_base *server
422
+ ) {
286
423
if (auto kv = models.find (func.getSymName ()); kv != models.end ()) {
287
424
return kv->second ;
288
425
}
289
426
427
+ if (server) {
428
+ auto model = ask_user_for_function_model (*server, func);
429
+ models[func.getSymName ()] = model;
430
+ return model;
431
+ }
432
+
290
433
return std::nullopt;
291
434
}
292
435
293
436
std::optional< function_model > get_model (hl::FuncOp func) const {
294
- return get_model (models, func);
437
+ return get_model (models, func, server );
295
438
}
296
439
297
- const function_models ⊧
440
+ function_models ⊧
441
+ vast::server::server_base *server;
298
442
};
299
443
300
444
//
@@ -541,10 +685,13 @@ namespace vast::conv {
541
685
return mlir::failure ();
542
686
}
543
687
544
- static void legalize (parser_conversion_config &cfg) {
688
+ static void
689
+ legalize (parser_conversion_config &cfg, vast::server::server_base *server) {
545
690
cfg.target .addLegalOp < mlir::UnrealizedConversionCastOp >();
546
- cfg.target .addDynamicallyLegalOp < op_t >([models = cfg.models ](op_t op) {
547
- return function_type_converter (*op.getContext (), get_model (models, op))
691
+ cfg.target .addDynamicallyLegalOp < op_t >([&cfg, server](op_t op) {
692
+ return function_type_converter (
693
+ *op.getContext (), get_model (cfg.models , op, server)
694
+ )
548
695
.isLegal (op.getFunctionType ());
549
696
});
550
697
}
@@ -708,6 +855,9 @@ namespace vast::conv {
708
855
{
709
856
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
710
857
858
+ struct server_handler
859
+ {};
860
+
711
861
static conversion_target create_conversion_target (mcontext_t &mctx) {
712
862
return conversion_target (mctx);
713
863
}
@@ -722,6 +872,12 @@ namespace vast::conv {
722
872
if (!config.empty ()) {
723
873
load_and_parse (config);
724
874
}
875
+
876
+ if (!socket.empty ()) {
877
+ server = std::make_shared< vast::server::server< server_handler > >(
878
+ vast::server::sock_adapter::create_unix_socket (socket)
879
+ );
880
+ }
725
881
}
726
882
727
883
void load_and_parse (string_ref config) {
@@ -748,10 +904,12 @@ namespace vast::conv {
748
904
749
905
parser_conversion_config make_config () {
750
906
auto &ctx = getContext ();
751
- return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models };
907
+ return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models,
908
+ server.get () };
752
909
}
753
910
754
911
function_models models;
912
+ std::shared_ptr< vast::server::server< server_handler > > server;
755
913
};
756
914
757
915
} // namespace vast::conv
0 commit comments