Skip to content

Commit 3d2f75d

Browse files
committed
pr: Add option to use interactive server in parser conversion pass.
1 parent 1a65ed4 commit 3d2f75d

File tree

4 files changed

+232
-19
lines changed

4 files changed

+232
-19
lines changed

include/vast/Conversion/Parser/Passes.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "vast/Util/Warnings.hpp"
66

7+
#include "vast/server/server.hpp"
8+
79
VAST_RELAX_WARNINGS
810
#include <mlir/Pass/Pass.h>
911
#include <mlir/Pass/PassManager.h>

include/vast/Conversion/Parser/Passes.td

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> {
1010
let options = [
1111
Option< "config", "config", "std::string", "",
1212
"Configuration file for parser transformation."
13+
>,
14+
Option< "socket", "socket", "std::string", "",
15+
"Unix socket path to use for server"
1316
>
1417
];
1518

include/vast/server/types.hpp

+45-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,36 @@
44

55
#include <concepts>
66
#include <cstdint>
7+
#include <optional>
78
#include <string>
89
#include <variant>
910

1011
#include <nlohmann/json.hpp>
1112

13+
namespace nlohmann {
14+
template< typename T >
15+
struct adl_serializer< std::optional< T > >
16+
{
17+
static void to_json(json &j, const std::optional< T > &opt) {
18+
if (!opt.has_value()) {
19+
j = nullptr;
20+
} else {
21+
j = *opt; // this will call adl_serializer<T>::to_json which will
22+
// find the free function to_json in T's namespace!
23+
}
24+
}
25+
26+
static void from_json(const json &j, std::optional< T > &opt) {
27+
if (j.is_null()) {
28+
opt = std::nullopt;
29+
} else {
30+
opt = j.template get< T >(); // same as above, but with
31+
// adl_serializer<T>::from_json
32+
}
33+
}
34+
};
35+
} // namespace nlohmann
36+
1237
namespace vast::server {
1338
template< typename T >
1439
concept json_convertible = requires(T obj, nlohmann::json &json) {
@@ -88,14 +113,33 @@ namespace vast::server {
88113
template< request_like request >
89114
using result_type = std::variant< typename request::response_type, error< request > >;
90115

116+
struct position
117+
{
118+
unsigned int line;
119+
unsigned int character;
120+
121+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(position, line, character)
122+
};
123+
124+
struct range
125+
{
126+
position start;
127+
position end;
128+
129+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(range, start, end)
130+
};
131+
91132
struct input_request
92133
{
93134
static constexpr const char *method = "input";
94135
static constexpr bool is_notification = false;
95136

96137
nlohmann::json type;
97138
std::string text;
98-
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text)
139+
std::optional< std::string > filePath;
140+
std::optional< range > range;
141+
142+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text, filePath, range)
99143

100144
struct response_type
101145
{

lib/vast/Conversion/Parser/ToParser.cpp

+182-18
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
2929

3030
#include "vast/Conversion/Parser/Config.hpp"
3131

32+
#include "vast/server/server.hpp"
33+
#include "vast/server/types.hpp"
34+
3235
#include <ranges>
3336

3437
namespace vast::conv {
@@ -75,6 +78,141 @@ namespace vast::conv {
7578

7679
using function_models = llvm::StringMap< function_model >;
7780

81+
struct location
82+
{
83+
std::string filePath;
84+
server::range range;
85+
};
86+
87+
location get_location(file_loc_t loc) {
88+
return {
89+
.filePath = loc.getFilename().str(),
90+
.range = {
91+
.start = { loc.getLine(), loc.getColumn(), },
92+
.end = { loc.getLine(), loc.getColumn(), },
93+
},
94+
};
95+
}
96+
97+
location get_location(name_loc_t loc) {
98+
return get_location(mlir::cast< file_loc_t >(loc.getChildLoc()));
99+
}
100+
101+
std::optional< location > get_location(loc_t loc) {
102+
if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103+
return get_location(file_loc);
104+
} else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105+
return get_location(name_loc);
106+
}
107+
108+
return std::nullopt;
109+
}
110+
111+
pr::data_type parse_type_name(const std::string &name) {
112+
if (name == "data") {
113+
return pr::data_type::data;
114+
} else if (name == "nodata") {
115+
return pr::data_type::nodata;
116+
} else {
117+
return pr::data_type::maybedata;
118+
}
119+
}
120+
121+
function_category ask_user_for_category(vast::server::server_base &server, hl::FuncOp op) {
122+
auto loc = op.getLoc();
123+
124+
vast::server::input_request req{
125+
.type = {"nonparser", "sink", "source", "parser",},
126+
.text = "Please choose category for function `" + op.getSymName().str() + '`',
127+
.filePath = std::nullopt,
128+
.range = std::nullopt,
129+
};
130+
131+
if (auto req_loc = get_location(loc)) {
132+
req.filePath = req_loc->filePath;
133+
req.range = req_loc->range;
134+
}
135+
136+
auto response = server.send_request(req);
137+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
138+
{
139+
if (result->value == "nonparser") {
140+
return function_category::nonparser;
141+
} else if (result->value == "sink") {
142+
return function_category::sink;
143+
} else if (result->value == "source") {
144+
return function_category::source;
145+
} else if (result->value == "parser") {
146+
return function_category::parser;
147+
}
148+
}
149+
return function_category::nonparser;
150+
}
151+
152+
pr::data_type ask_user_for_return_type(vast::server::server_base &server, hl::FuncOp op) {
153+
auto loc = op.getLoc();
154+
155+
vast::server::input_request req{
156+
.type = { "maybedata", "nodata", "data" },
157+
.text = "Please choose return type for function `" + op.getSymName().str() + '`',
158+
.filePath = std::nullopt,
159+
.range = std::nullopt,
160+
};
161+
162+
if (auto req_loc = get_location(loc)) {
163+
req.filePath = req_loc->filePath;
164+
req.range = req_loc->range;
165+
}
166+
167+
auto response = server.send_request(req);
168+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
169+
{
170+
return parse_type_name(result->value);
171+
}
172+
return pr::data_type::maybedata;
173+
}
174+
175+
pr::data_type ask_user_for_argument_type(
176+
vast::server::server_base &server, hl::FuncOp op, unsigned int idx
177+
) {
178+
auto num_body_args = op.getFunctionBody().getNumArguments();
179+
180+
vast::server::input_request req{
181+
.type = { "maybedata", "nodata", "data" },
182+
.text = "Please choose a type for argument " + std::to_string(idx)
183+
+ " of function `" + op.getSymName().str() + '`',
184+
.filePath = std::nullopt,
185+
.range = std::nullopt,
186+
};
187+
188+
if (idx < num_body_args) {
189+
auto arg = op.getArgument(idx);
190+
auto loc = arg.getLoc();
191+
if (auto req_loc = get_location(loc)) {
192+
req.filePath = req_loc->filePath;
193+
req.range = req_loc->range;
194+
}
195+
}
196+
197+
auto response = server.send_request(req);
198+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
199+
{
200+
return parse_type_name(result->value);
201+
}
202+
return pr::data_type::maybedata;
203+
}
204+
205+
function_model
206+
ask_user_for_function_model(vast::server::server_base &server, hl::FuncOp op) {
207+
function_model model;
208+
model.return_type = ask_user_for_return_type(server, op);
209+
for (unsigned int i = 0; i < op.getNumArguments(); ++i) {
210+
model.arguments.push_back(ask_user_for_argument_type(server, op, i));
211+
}
212+
model.category = ask_user_for_category(server, op);
213+
return model;
214+
}
215+
78216
} // namespace vast::conv
79217

80218
LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type);
@@ -130,25 +268,28 @@ namespace vast::conv {
130268
using base = base_conversion_config;
131269

132270
parser_conversion_config(
133-
rewrite_pattern_set patterns, conversion_target target,
134-
const function_models &models
271+
rewrite_pattern_set patterns, conversion_target target, function_models &models,
272+
vast::server::server_base *server
135273
)
136-
: base(std::move(patterns), std::move(target)), models(models)
137-
{}
274+
: base(std::move(patterns), std::move(target)), models(models), server(server) {}
138275

139276
template< typename pattern >
140277
void add_pattern() {
141278
auto ctx = patterns.getContext();
142279
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143280
patterns.template add< pattern >(ctx);
144-
} else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145-
patterns.template add< pattern >(ctx, models);
281+
} else if constexpr (std::is_constructible_v<
282+
pattern, mcontext_t *, function_models &,
283+
vast::server::server_base * >)
284+
{
285+
patterns.template add< pattern >(ctx, models, server);
146286
} else {
147287
static_assert(false, "pattern does not have a valid constructor");
148288
}
149289
}
150290

151-
const function_models &models;
291+
function_models &models;
292+
vast::server::server_base *server;
152293
};
153294

154295
struct function_type_converter
@@ -277,24 +418,33 @@ namespace vast::conv {
277418
{
278419
using base = mlir::OpConversionPattern< op_t >;
279420

280-
parser_conversion_pattern_base(mcontext_t *mctx, const function_models &models)
281-
: base(mctx), models(models)
282-
{}
421+
parser_conversion_pattern_base(
422+
mcontext_t *mctx, function_models &models, vast::server::server_base *server
423+
)
424+
: base(mctx), models(models), server(server) {}
283425

284-
static std::optional< function_model >
285-
get_model(const function_models &models, hl::FuncOp func) {
426+
static std::optional< function_model > get_model(
427+
function_models &models, hl::FuncOp func, vast::server::server_base *server
428+
) {
286429
if (auto kv = models.find(func.getSymName()); kv != models.end()) {
287430
return kv->second;
288431
}
289432

433+
if (server) {
434+
auto model = ask_user_for_function_model(*server, func);
435+
models[func.getSymName()] = model;
436+
return model;
437+
}
438+
290439
return std::nullopt;
291440
}
292441

293442
std::optional< function_model > get_model(hl::FuncOp func) const {
294-
return get_model(models, func);
443+
return get_model(models, func, server);
295444
}
296445

297-
const function_models &models;
446+
function_models &models;
447+
vast::server::server_base *server;
298448
};
299449

300450
//
@@ -541,10 +691,13 @@ namespace vast::conv {
541691
return mlir::failure();
542692
}
543693

544-
static void legalize(parser_conversion_config &cfg) {
694+
static void
695+
legalize(parser_conversion_config &cfg, vast::server::server_base *server) {
545696
cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >();
546-
cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) {
547-
return function_type_converter(*op.getContext(), get_model(models, op))
697+
cfg.target.addDynamicallyLegalOp< op_t >([&cfg, server](op_t op) {
698+
return function_type_converter(
699+
*op.getContext(), get_model(cfg.models, op, server)
700+
)
548701
.isLegal(op.getFunctionType());
549702
});
550703
}
@@ -708,6 +861,9 @@ namespace vast::conv {
708861
{
709862
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
710863

864+
struct server_handler
865+
{};
866+
711867
static conversion_target create_conversion_target(mcontext_t &mctx) {
712868
return conversion_target(mctx);
713869
}
@@ -722,6 +878,12 @@ namespace vast::conv {
722878
if (!config.empty()) {
723879
load_and_parse(config);
724880
}
881+
882+
if (!socket.empty()) {
883+
server = std::make_shared< vast::server::server< server_handler > >(
884+
vast::server::sock_adapter::create_unix_socket(socket)
885+
);
886+
}
725887
}
726888

727889
void load_and_parse(string_ref config) {
@@ -748,10 +910,12 @@ namespace vast::conv {
748910

749911
parser_conversion_config make_config() {
750912
auto &ctx = getContext();
751-
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models };
913+
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models,
914+
server.get() };
752915
}
753916

754917
function_models models;
918+
std::shared_ptr< vast::server::server< server_handler > > server;
755919
};
756920

757921
} // namespace vast::conv

0 commit comments

Comments
 (0)