Skip to content

Commit 63106ed

Browse files
committed
pr: Add option to use interactive server in parser conversion pass.
1 parent 1a65ed4 commit 63106ed

File tree

4 files changed

+226
-19
lines changed

4 files changed

+226
-19
lines changed

include/vast/Conversion/Parser/Passes.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "vast/Util/Warnings.hpp"
66

7+
#include "vast/server/server.hpp"
8+
79
VAST_RELAX_WARNINGS
810
#include <mlir/Pass/Pass.h>
911
#include <mlir/Pass/PassManager.h>

include/vast/Conversion/Parser/Passes.td

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> {
1010
let options = [
1111
Option< "config", "config", "std::string", "",
1212
"Configuration file for parser transformation."
13+
>,
14+
Option< "socket", "socket", "std::string", "",
15+
"Unix socket path to use for server"
1316
>
1417
];
1518

include/vast/server/types.hpp

+45-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,36 @@
44

55
#include <concepts>
66
#include <cstdint>
7+
#include <optional>
78
#include <string>
89
#include <variant>
910

1011
#include <nlohmann/json.hpp>
1112

13+
namespace nlohmann {
14+
template< typename T >
15+
struct adl_serializer< std::optional< T > >
16+
{
17+
static void to_json(json &j, const std::optional< T > &opt) {
18+
if (!opt.has_value()) {
19+
j = nullptr;
20+
} else {
21+
j = *opt; // this will call adl_serializer<T>::to_json which will
22+
// find the free function to_json in T's namespace!
23+
}
24+
}
25+
26+
static void from_json(const json &j, std::optional< T > &opt) {
27+
if (j.is_null()) {
28+
opt = std::nullopt;
29+
} else {
30+
opt = j.template get< T >(); // same as above, but with
31+
// adl_serializer<T>::from_json
32+
}
33+
}
34+
};
35+
} // namespace nlohmann
36+
1237
namespace vast::server {
1338
template< typename T >
1439
concept json_convertible = requires(T obj, nlohmann::json &json) {
@@ -88,14 +113,33 @@ namespace vast::server {
88113
template< request_like request >
89114
using result_type = std::variant< typename request::response_type, error< request > >;
90115

116+
struct position
117+
{
118+
unsigned int line;
119+
unsigned int character;
120+
121+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(position, line, character)
122+
};
123+
124+
struct range
125+
{
126+
position start;
127+
position end;
128+
129+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(range, start, end)
130+
};
131+
91132
struct input_request
92133
{
93134
static constexpr const char *method = "input";
94135
static constexpr bool is_notification = false;
95136

96137
nlohmann::json type;
97138
std::string text;
98-
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text)
139+
std::optional< std::string > filePath;
140+
std::optional< range > range;
141+
142+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text, filePath, range)
99143

100144
struct response_type
101145
{

lib/vast/Conversion/Parser/ToParser.cpp

+176-18
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
2929

3030
#include "vast/Conversion/Parser/Config.hpp"
3131

32+
#include "vast/server/server.hpp"
33+
#include "vast/server/types.hpp"
34+
3235
#include <ranges>
3336

3437
namespace vast::conv {
@@ -75,6 +78,135 @@ namespace vast::conv {
7578

7679
using function_models = llvm::StringMap< function_model >;
7780

81+
struct location
82+
{
83+
std::string filePath;
84+
server::range range;
85+
};
86+
87+
location get_location(file_loc_t loc) {
88+
return {
89+
.filePath = loc.getFilename().str(),
90+
.range = {
91+
.start = { loc.getLine(), loc.getColumn(), },
92+
.end = { loc.getLine(), loc.getColumn(), },
93+
},
94+
};
95+
}
96+
97+
location get_location(name_loc_t loc) {
98+
return get_location(mlir::cast< file_loc_t >(loc.getChildLoc()));
99+
}
100+
101+
std::optional< location > get_location(loc_t loc) {
102+
if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103+
return get_location(file_loc);
104+
} else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105+
return get_location(name_loc);
106+
}
107+
108+
return std::nullopt;
109+
}
110+
111+
pr::data_type parse_type_name(const std::string &name) {
112+
if (name == "data") {
113+
return pr::data_type::data;
114+
} else if (name == "nodata") {
115+
return pr::data_type::nodata;
116+
} else {
117+
return pr::data_type::maybedata;
118+
}
119+
}
120+
121+
function_category ask_user_for_category(vast::server::server_base &server, hl::FuncOp op) {
122+
auto loc = op.getLoc();
123+
124+
vast::server::input_request req{
125+
.type = {"nonparser", "sink", "source", "parser",},
126+
.text = "Please choose category for function `" + op.getSymName().str() + '`',
127+
};
128+
129+
if (auto req_loc = get_location(loc)) {
130+
req.filePath = req_loc->filePath;
131+
req.range = req_loc->range;
132+
}
133+
134+
auto response = server.send_request(req);
135+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
136+
{
137+
if (result->value == "nonparser") {
138+
return function_category::nonparser;
139+
} else if (result->value == "sink") {
140+
return function_category::sink;
141+
} else if (result->value == "source") {
142+
return function_category::source;
143+
} else if (result->value == "parser") {
144+
return function_category::parser;
145+
}
146+
}
147+
return function_category::nonparser;
148+
}
149+
150+
pr::data_type ask_user_for_return_type(vast::server::server_base &server, hl::FuncOp op) {
151+
auto loc = op.getLoc();
152+
153+
vast::server::input_request req{
154+
.type = { "maybedata", "nodata", "data" },
155+
.text = "Please choose return type for function `" + op.getSymName().str() + '`',
156+
};
157+
158+
if (auto req_loc = get_location(loc)) {
159+
req.filePath = req_loc->filePath;
160+
req.range = req_loc->range;
161+
}
162+
163+
auto response = server.send_request(req);
164+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
165+
{
166+
return parse_type_name(result->value);
167+
}
168+
return pr::data_type::maybedata;
169+
}
170+
171+
pr::data_type ask_user_for_argument_type(
172+
vast::server::server_base &server, hl::FuncOp op, unsigned int idx
173+
) {
174+
auto num_body_args = op.getFunctionBody().getNumArguments();
175+
176+
vast::server::input_request req{
177+
.type = { "maybedata", "nodata", "data" },
178+
.text = "Please choose a type for argument " + std::to_string(idx)
179+
+ " of function `" + op.getSymName().str() + '`',
180+
};
181+
182+
if (idx < num_body_args) {
183+
auto arg = op.getArgument(idx);
184+
auto loc = arg.getLoc();
185+
if (auto req_loc = get_location(loc)) {
186+
req.filePath = req_loc->filePath;
187+
req.range = req_loc->range;
188+
}
189+
}
190+
191+
auto response = server.send_request(req);
192+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
193+
{
194+
return parse_type_name(result->value);
195+
}
196+
return pr::data_type::maybedata;
197+
}
198+
199+
function_model
200+
ask_user_for_function_model(vast::server::server_base &server, hl::FuncOp op) {
201+
function_model model;
202+
model.return_type = ask_user_for_return_type(server, op);
203+
for (unsigned int i = 0; i < op.getNumArguments(); ++i) {
204+
model.arguments.push_back(ask_user_for_argument_type(server, op, i));
205+
}
206+
model.category = ask_user_for_category(server, op);
207+
return model;
208+
}
209+
78210
} // namespace vast::conv
79211

80212
LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type);
@@ -130,25 +262,28 @@ namespace vast::conv {
130262
using base = base_conversion_config;
131263

132264
parser_conversion_config(
133-
rewrite_pattern_set patterns, conversion_target target,
134-
const function_models &models
265+
rewrite_pattern_set patterns, conversion_target target, function_models &models,
266+
vast::server::server_base *server
135267
)
136-
: base(std::move(patterns), std::move(target)), models(models)
137-
{}
268+
: base(std::move(patterns), std::move(target)), models(models), server(server) {}
138269

139270
template< typename pattern >
140271
void add_pattern() {
141272
auto ctx = patterns.getContext();
142273
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143274
patterns.template add< pattern >(ctx);
144-
} else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145-
patterns.template add< pattern >(ctx, models);
275+
} else if constexpr (std::is_constructible_v<
276+
pattern, mcontext_t *, function_models &,
277+
vast::server::server_base * >)
278+
{
279+
patterns.template add< pattern >(ctx, models, server);
146280
} else {
147281
static_assert(false, "pattern does not have a valid constructor");
148282
}
149283
}
150284

151-
const function_models &models;
285+
function_models &models;
286+
vast::server::server_base *server;
152287
};
153288

154289
struct function_type_converter
@@ -277,24 +412,33 @@ namespace vast::conv {
277412
{
278413
using base = mlir::OpConversionPattern< op_t >;
279414

280-
parser_conversion_pattern_base(mcontext_t *mctx, const function_models &models)
281-
: base(mctx), models(models)
282-
{}
415+
parser_conversion_pattern_base(
416+
mcontext_t *mctx, function_models &models, vast::server::server_base *server
417+
)
418+
: base(mctx), models(models), server(server) {}
283419

284-
static std::optional< function_model >
285-
get_model(const function_models &models, hl::FuncOp func) {
420+
static std::optional< function_model > get_model(
421+
function_models &models, hl::FuncOp func, vast::server::server_base *server
422+
) {
286423
if (auto kv = models.find(func.getSymName()); kv != models.end()) {
287424
return kv->second;
288425
}
289426

427+
if (server) {
428+
auto model = ask_user_for_function_model(*server, func);
429+
models[func.getSymName()] = model;
430+
return model;
431+
}
432+
290433
return std::nullopt;
291434
}
292435

293436
std::optional< function_model > get_model(hl::FuncOp func) const {
294-
return get_model(models, func);
437+
return get_model(models, func, server);
295438
}
296439

297-
const function_models &models;
440+
function_models &models;
441+
vast::server::server_base *server;
298442
};
299443

300444
//
@@ -541,10 +685,13 @@ namespace vast::conv {
541685
return mlir::failure();
542686
}
543687

544-
static void legalize(parser_conversion_config &cfg) {
688+
static void
689+
legalize(parser_conversion_config &cfg, vast::server::server_base *server) {
545690
cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >();
546-
cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) {
547-
return function_type_converter(*op.getContext(), get_model(models, op))
691+
cfg.target.addDynamicallyLegalOp< op_t >([&cfg, server](op_t op) {
692+
return function_type_converter(
693+
*op.getContext(), get_model(cfg.models, op, server)
694+
)
548695
.isLegal(op.getFunctionType());
549696
});
550697
}
@@ -708,6 +855,9 @@ namespace vast::conv {
708855
{
709856
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
710857

858+
struct server_handler
859+
{};
860+
711861
static conversion_target create_conversion_target(mcontext_t &mctx) {
712862
return conversion_target(mctx);
713863
}
@@ -722,6 +872,12 @@ namespace vast::conv {
722872
if (!config.empty()) {
723873
load_and_parse(config);
724874
}
875+
876+
if (!socket.empty()) {
877+
server = std::make_shared< vast::server::server< server_handler > >(
878+
vast::server::sock_adapter::create_unix_socket(socket)
879+
);
880+
}
725881
}
726882

727883
void load_and_parse(string_ref config) {
@@ -748,10 +904,12 @@ namespace vast::conv {
748904

749905
parser_conversion_config make_config() {
750906
auto &ctx = getContext();
751-
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models };
907+
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models,
908+
server.get() };
752909
}
753910

754911
function_models models;
912+
std::shared_ptr< vast::server::server< server_handler > > server;
755913
};
756914

757915
} // namespace vast::conv

0 commit comments

Comments
 (0)