@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
29
29
30
30
#include " vast/Conversion/Parser/Config.hpp"
31
31
32
+ #include " vast/server/server.hpp"
33
+ #include " vast/server/types.hpp"
34
+
32
35
#include < ranges>
33
36
34
37
namespace vast ::conv {
@@ -75,6 +78,154 @@ namespace vast::conv {
75
78
76
79
using function_models = llvm::StringMap< function_model >;
77
80
81
+ struct location
82
+ {
83
+ std::string filePath;
84
+ server::range range;
85
+ };
86
+
87
+ location get_location (file_loc_t loc) {
88
+ return {
89
+ .filePath = loc.getFilename ().str (),
90
+ .range = {
91
+ .start = { loc.getLine (), loc.getColumn (), },
92
+ .end = { loc.getLine (), loc.getColumn (), },
93
+ },
94
+ };
95
+ }
96
+
97
+ location get_location (name_loc_t loc) {
98
+ return get_location (mlir::cast< file_loc_t >(loc.getChildLoc ()));
99
+ }
100
+
101
+ std::optional< location > get_location (loc_t loc) {
102
+ if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103
+ return get_location (file_loc);
104
+ } else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105
+ return get_location (name_loc);
106
+ }
107
+
108
+ return std::nullopt;
109
+ }
110
+
111
+ pr::data_type parse_type_name (const std::string &name) {
112
+ if (name == " data" ) {
113
+ return pr::data_type::data;
114
+ } else if (name == " nodata" ) {
115
+ return pr::data_type::nodata;
116
+ } else {
117
+ return pr::data_type::maybedata;
118
+ }
119
+ }
120
+
121
+ function_category
122
+ ask_user_for_category (vast::server::server_base &server, core::function_op_interface op) {
123
+ auto loc = op.getLoc ();
124
+ auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation ());
125
+ VAST_ASSERT (sym);
126
+ auto name = sym.getSymbolName ().str ();
127
+
128
+ vast::server::input_request req{
129
+ .type = {" nonparser" , " sink" , " source" , " parser" ,},
130
+ .text = " Please choose category for function `" + name + ' `' ,
131
+ .filePath = std::nullopt,
132
+ .range = std::nullopt,
133
+ };
134
+
135
+ if (auto req_loc = get_location (loc)) {
136
+ req.filePath = req_loc->filePath ;
137
+ req.range = req_loc->range ;
138
+ }
139
+
140
+ auto response = server.send_request (req);
141
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
142
+ {
143
+ if (result->value == " nonparser" ) {
144
+ return function_category::nonparser;
145
+ } else if (result->value == " sink" ) {
146
+ return function_category::sink;
147
+ } else if (result->value == " source" ) {
148
+ return function_category::source;
149
+ } else if (result->value == " parser" ) {
150
+ return function_category::parser;
151
+ }
152
+ }
153
+ return function_category::nonparser;
154
+ }
155
+
156
+ pr::data_type ask_user_for_return_type (
157
+ vast::server::server_base &server, core::function_op_interface op
158
+ ) {
159
+ auto loc = op.getLoc ();
160
+ auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation ());
161
+ VAST_ASSERT (sym);
162
+ auto name = sym.getSymbolName ().str ();
163
+
164
+ vast::server::input_request req{
165
+ .type = { " maybedata" , " nodata" , " data" },
166
+ .text = " Please choose return type for function `" + name + ' `' ,
167
+ .filePath = std::nullopt,
168
+ .range = std::nullopt,
169
+ };
170
+
171
+ if (auto req_loc = get_location (loc)) {
172
+ req.filePath = req_loc->filePath ;
173
+ req.range = req_loc->range ;
174
+ }
175
+
176
+ auto response = server.send_request (req);
177
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
178
+ {
179
+ return parse_type_name (result->value );
180
+ }
181
+ return pr::data_type::maybedata;
182
+ }
183
+
184
+ pr::data_type ask_user_for_argument_type (
185
+ vast::server::server_base &server, core::function_op_interface op, unsigned int idx
186
+ ) {
187
+ auto num_body_args = op.getFunctionBody ().getNumArguments ();
188
+ auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation ());
189
+ VAST_ASSERT (sym);
190
+ auto name = sym.getSymbolName ().str ();
191
+
192
+ vast::server::input_request req{
193
+ .type = { " maybedata" , " nodata" , " data" },
194
+ .text = " Please choose a type for argument " + std::to_string (idx)
195
+ + " of function `" + name + ' `' ,
196
+ .filePath = std::nullopt,
197
+ .range = std::nullopt,
198
+ };
199
+
200
+ if (idx < num_body_args) {
201
+ auto arg = op.getArgument (idx);
202
+ auto loc = arg.getLoc ();
203
+ if (auto req_loc = get_location (loc)) {
204
+ req.filePath = req_loc->filePath ;
205
+ req.range = req_loc->range ;
206
+ }
207
+ }
208
+
209
+ auto response = server.send_request (req);
210
+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
211
+ {
212
+ return parse_type_name (result->value );
213
+ }
214
+ return pr::data_type::maybedata;
215
+ }
216
+
217
+ function_model ask_user_for_function_model (
218
+ vast::server::server_base &server, core::function_op_interface op
219
+ ) {
220
+ function_model model;
221
+ model.return_type = ask_user_for_return_type (server, op);
222
+ for (unsigned int i = 0 ; i < op.getNumArguments (); ++i) {
223
+ model.arguments .push_back (ask_user_for_argument_type (server, op, i));
224
+ }
225
+ model.category = ask_user_for_category (server, op);
226
+ return model;
227
+ }
228
+
78
229
} // namespace vast::conv
79
230
80
231
LLVM_YAML_IS_SEQUENCE_VECTOR (vast::pr::data_type);
@@ -130,25 +281,28 @@ namespace vast::conv {
130
281
using base = base_conversion_config;
131
282
132
283
parser_conversion_config (
133
- rewrite_pattern_set patterns, conversion_target target,
134
- const function_models &models
284
+ rewrite_pattern_set patterns, conversion_target target, function_models &models,
285
+ vast::server::server_base *server
135
286
)
136
- : base(std::move(patterns), std::move(target)), models(models)
137
- {}
287
+ : base(std::move(patterns), std::move(target)), models(models), server(server) {}
138
288
139
289
template < typename pattern >
140
290
void add_pattern () {
141
291
auto ctx = patterns.getContext ();
142
292
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143
293
patterns.template add < pattern >(ctx);
144
- } else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145
- patterns.template add < pattern >(ctx, models);
294
+ } else if constexpr (std::is_constructible_v<
295
+ pattern, mcontext_t *, function_models &,
296
+ vast::server::server_base * >)
297
+ {
298
+ patterns.template add < pattern >(ctx, models, server);
146
299
} else {
147
300
static_assert (false , " pattern does not have a valid constructor" );
148
301
}
149
302
}
150
303
151
- const function_models ⊧
304
+ function_models ⊧
305
+ vast::server::server_base *server;
152
306
};
153
307
154
308
struct function_type_converter
@@ -277,26 +431,36 @@ namespace vast::conv {
277
431
{
278
432
using base = mlir::OpConversionPattern< op_t >;
279
433
280
- parser_conversion_pattern_base (mcontext_t *mctx, const function_models &models)
281
- : base(mctx), models(models)
282
- {}
434
+ parser_conversion_pattern_base (
435
+ mcontext_t *mctx, function_models &models, vast::server::server_base *server
436
+ )
437
+ : base(mctx), models(models), server(server) {}
283
438
284
- static std::optional< function_model >
285
- get_model (const function_models &models, core::function_op_interface op) {
439
+ static std::optional< function_model > get_model (
440
+ function_models &models, core::function_op_interface op,
441
+ vast::server::server_base *server
442
+ ) {
286
443
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation ());
287
444
VAST_ASSERT (sym);
288
445
if (auto kv = models.find (sym.getSymbolName ()); kv != models.end ()) {
289
446
return kv->second ;
290
447
}
291
448
449
+ if (server) {
450
+ auto model = ask_user_for_function_model (*server, op);
451
+ models[sym.getSymbolName ()] = model;
452
+ return model;
453
+ }
454
+
292
455
return std::nullopt;
293
456
}
294
457
295
458
std::optional< function_model > get_model (core::function_op_interface op) const {
296
- return get_model (models, op);
459
+ return get_model (models, op, server );
297
460
}
298
461
299
- const function_models ⊧
462
+ function_models ⊧
463
+ vast::server::server_base *server;
300
464
};
301
465
302
466
//
@@ -543,10 +707,13 @@ namespace vast::conv {
543
707
return mlir::failure ();
544
708
}
545
709
546
- static void legalize (parser_conversion_config &cfg) {
710
+ static void
711
+ legalize (parser_conversion_config &cfg, vast::server::server_base *server) {
547
712
cfg.target .addLegalOp < mlir::UnrealizedConversionCastOp >();
548
- cfg.target .addDynamicallyLegalOp < op_t >([models = cfg.models ](op_t op) {
549
- return function_type_converter (*op.getContext (), get_model (models, op))
713
+ cfg.target .addDynamicallyLegalOp < op_t >([&cfg, server](op_t op) {
714
+ return function_type_converter (
715
+ *op.getContext (), get_model (cfg.models , op, server)
716
+ )
550
717
.isLegal (op.getFunctionType ());
551
718
});
552
719
}
@@ -724,6 +891,9 @@ namespace vast::conv {
724
891
{
725
892
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
726
893
894
+ struct server_handler
895
+ {};
896
+
727
897
static conversion_target create_conversion_target (mcontext_t &mctx) {
728
898
return conversion_target (mctx);
729
899
}
@@ -738,6 +908,12 @@ namespace vast::conv {
738
908
if (!config.empty ()) {
739
909
load_and_parse (config);
740
910
}
911
+
912
+ if (!socket.empty ()) {
913
+ server = std::make_shared< vast::server::server< server_handler > >(
914
+ vast::server::sock_adapter::create_unix_socket (socket)
915
+ );
916
+ }
741
917
}
742
918
743
919
void load_and_parse (string_ref config) {
@@ -764,10 +940,12 @@ namespace vast::conv {
764
940
765
941
parser_conversion_config make_config () {
766
942
auto &ctx = getContext ();
767
- return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models };
943
+ return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models,
944
+ server.get () };
768
945
}
769
946
770
947
function_models models;
948
+ std::shared_ptr< vast::server::server< server_handler > > server;
771
949
};
772
950
773
951
} // namespace vast::conv
0 commit comments