Skip to content

Commit e6298fb

Browse files
committed
pr: Add option to use interactive server in parser conversion pass.
1 parent 34718b2 commit e6298fb

File tree

4 files changed

+244
-19
lines changed

4 files changed

+244
-19
lines changed

include/vast/Conversion/Parser/Passes.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "vast/Util/Warnings.hpp"
66

7+
#include "vast/server/server.hpp"
8+
79
VAST_RELAX_WARNINGS
810
#include <mlir/Pass/Pass.h>
911
#include <mlir/Pass/PassManager.h>

include/vast/Conversion/Parser/Passes.td

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> {
1010
let options = [
1111
Option< "config", "config", "std::string", "",
1212
"Configuration file for parser transformation."
13+
>,
14+
Option< "socket", "socket", "std::string", "",
15+
"Unix socket path to use for server"
1316
>
1417
];
1518

include/vast/server/types.hpp

+43-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,34 @@
44

55
#include <concepts>
66
#include <cstdint>
7+
#include <optional>
78
#include <string>
89
#include <variant>
910

1011
#include <nlohmann/json.hpp>
1112

13+
namespace nlohmann {
14+
template< typename T >
15+
struct adl_serializer< std::optional< T > >
16+
{
17+
static void to_json(json &j, const std::optional< T > &opt) {
18+
if (!opt.has_value()) {
19+
j = nullptr;
20+
} else {
21+
j = *opt;
22+
}
23+
}
24+
25+
static void from_json(const json &j, std::optional< T > &opt) {
26+
if (j.is_null()) {
27+
opt = std::nullopt;
28+
} else {
29+
opt = j.template get< T >();
30+
}
31+
}
32+
};
33+
} // namespace nlohmann
34+
1235
namespace vast::server {
1336
template< typename T >
1437
concept json_convertible = requires(T obj, nlohmann::json &json) {
@@ -88,14 +111,33 @@ namespace vast::server {
88111
template< request_like request >
89112
using result_type = std::variant< typename request::response_type, error< request > >;
90113

114+
struct position
115+
{
116+
unsigned int line;
117+
unsigned int character;
118+
119+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(position, line, character)
120+
};
121+
122+
struct range
123+
{
124+
position start;
125+
position end;
126+
127+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(range, start, end)
128+
};
129+
91130
struct input_request
92131
{
93132
static constexpr const char *method = "input";
94133
static constexpr bool is_notification = false;
95134

96135
nlohmann::json type;
97136
std::string text;
98-
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text)
137+
std::optional< std::string > filePath;
138+
std::optional< range > range;
139+
140+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text, filePath, range)
99141

100142
struct response_type
101143
{

lib/vast/Conversion/Parser/ToParser.cpp

+196-18
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
2929

3030
#include "vast/Conversion/Parser/Config.hpp"
3131

32+
#include "vast/server/server.hpp"
33+
#include "vast/server/types.hpp"
34+
3235
#include <ranges>
3336

3437
namespace vast::conv {
@@ -75,6 +78,154 @@ namespace vast::conv {
7578

7679
using function_models = llvm::StringMap< function_model >;
7780

81+
struct location
82+
{
83+
std::string filePath;
84+
server::range range;
85+
};
86+
87+
location get_location(file_loc_t loc) {
88+
return {
89+
.filePath = loc.getFilename().str(),
90+
.range = {
91+
.start = { loc.getLine(), loc.getColumn(), },
92+
.end = { loc.getLine(), loc.getColumn(), },
93+
},
94+
};
95+
}
96+
97+
location get_location(name_loc_t loc) {
98+
return get_location(mlir::cast< file_loc_t >(loc.getChildLoc()));
99+
}
100+
101+
std::optional< location > get_location(loc_t loc) {
102+
if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103+
return get_location(file_loc);
104+
} else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105+
return get_location(name_loc);
106+
}
107+
108+
return std::nullopt;
109+
}
110+
111+
pr::data_type parse_type_name(const std::string &name) {
112+
if (name == "data") {
113+
return pr::data_type::data;
114+
} else if (name == "nodata") {
115+
return pr::data_type::nodata;
116+
} else {
117+
return pr::data_type::maybedata;
118+
}
119+
}
120+
121+
function_category
122+
ask_user_for_category(vast::server::server_base &server, core::function_op_interface op) {
123+
auto loc = op.getLoc();
124+
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
125+
VAST_ASSERT(sym);
126+
auto name = sym.getSymbolName().str();
127+
128+
vast::server::input_request req{
129+
.type = {"nonparser", "sink", "source", "parser",},
130+
.text = "Please choose category for function `" + name + '`',
131+
.filePath = std::nullopt,
132+
.range = std::nullopt,
133+
};
134+
135+
if (auto req_loc = get_location(loc)) {
136+
req.filePath = req_loc->filePath;
137+
req.range = req_loc->range;
138+
}
139+
140+
auto response = server.send_request(req);
141+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
142+
{
143+
if (result->value == "nonparser") {
144+
return function_category::nonparser;
145+
} else if (result->value == "sink") {
146+
return function_category::sink;
147+
} else if (result->value == "source") {
148+
return function_category::source;
149+
} else if (result->value == "parser") {
150+
return function_category::parser;
151+
}
152+
}
153+
return function_category::nonparser;
154+
}
155+
156+
pr::data_type ask_user_for_return_type(
157+
vast::server::server_base &server, core::function_op_interface op
158+
) {
159+
auto loc = op.getLoc();
160+
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
161+
VAST_ASSERT(sym);
162+
auto name = sym.getSymbolName().str();
163+
164+
vast::server::input_request req{
165+
.type = { "maybedata", "nodata", "data" },
166+
.text = "Please choose return type for function `" + name + '`',
167+
.filePath = std::nullopt,
168+
.range = std::nullopt,
169+
};
170+
171+
if (auto req_loc = get_location(loc)) {
172+
req.filePath = req_loc->filePath;
173+
req.range = req_loc->range;
174+
}
175+
176+
auto response = server.send_request(req);
177+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
178+
{
179+
return parse_type_name(result->value);
180+
}
181+
return pr::data_type::maybedata;
182+
}
183+
184+
pr::data_type ask_user_for_argument_type(
185+
vast::server::server_base &server, core::function_op_interface op, unsigned int idx
186+
) {
187+
auto num_body_args = op.getFunctionBody().getNumArguments();
188+
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
189+
VAST_ASSERT(sym);
190+
auto name = sym.getSymbolName().str();
191+
192+
vast::server::input_request req{
193+
.type = { "maybedata", "nodata", "data" },
194+
.text = "Please choose a type for argument " + std::to_string(idx)
195+
+ " of function `" + name + '`',
196+
.filePath = std::nullopt,
197+
.range = std::nullopt,
198+
};
199+
200+
if (idx < num_body_args) {
201+
auto arg = op.getArgument(idx);
202+
auto loc = arg.getLoc();
203+
if (auto req_loc = get_location(loc)) {
204+
req.filePath = req_loc->filePath;
205+
req.range = req_loc->range;
206+
}
207+
}
208+
209+
auto response = server.send_request(req);
210+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
211+
{
212+
return parse_type_name(result->value);
213+
}
214+
return pr::data_type::maybedata;
215+
}
216+
217+
function_model ask_user_for_function_model(
218+
vast::server::server_base &server, core::function_op_interface op
219+
) {
220+
function_model model;
221+
model.return_type = ask_user_for_return_type(server, op);
222+
for (unsigned int i = 0; i < op.getNumArguments(); ++i) {
223+
model.arguments.push_back(ask_user_for_argument_type(server, op, i));
224+
}
225+
model.category = ask_user_for_category(server, op);
226+
return model;
227+
}
228+
78229
} // namespace vast::conv
79230

80231
LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type);
@@ -130,25 +281,28 @@ namespace vast::conv {
130281
using base = base_conversion_config;
131282

132283
parser_conversion_config(
133-
rewrite_pattern_set patterns, conversion_target target,
134-
const function_models &models
284+
rewrite_pattern_set patterns, conversion_target target, function_models &models,
285+
vast::server::server_base *server
135286
)
136-
: base(std::move(patterns), std::move(target)), models(models)
137-
{}
287+
: base(std::move(patterns), std::move(target)), models(models), server(server) {}
138288

139289
template< typename pattern >
140290
void add_pattern() {
141291
auto ctx = patterns.getContext();
142292
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143293
patterns.template add< pattern >(ctx);
144-
} else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145-
patterns.template add< pattern >(ctx, models);
294+
} else if constexpr (std::is_constructible_v<
295+
pattern, mcontext_t *, function_models &,
296+
vast::server::server_base * >)
297+
{
298+
patterns.template add< pattern >(ctx, models, server);
146299
} else {
147300
static_assert(false, "pattern does not have a valid constructor");
148301
}
149302
}
150303

151-
const function_models &models;
304+
function_models &models;
305+
vast::server::server_base *server;
152306
};
153307

154308
struct function_type_converter
@@ -277,26 +431,36 @@ namespace vast::conv {
277431
{
278432
using base = mlir::OpConversionPattern< op_t >;
279433

280-
parser_conversion_pattern_base(mcontext_t *mctx, const function_models &models)
281-
: base(mctx), models(models)
282-
{}
434+
parser_conversion_pattern_base(
435+
mcontext_t *mctx, function_models &models, vast::server::server_base *server
436+
)
437+
: base(mctx), models(models), server(server) {}
283438

284-
static std::optional< function_model >
285-
get_model(const function_models &models, core::function_op_interface op) {
439+
static std::optional< function_model > get_model(
440+
function_models &models, core::function_op_interface op,
441+
vast::server::server_base *server
442+
) {
286443
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
287444
VAST_ASSERT(sym);
288445
if (auto kv = models.find(sym.getSymbolName()); kv != models.end()) {
289446
return kv->second;
290447
}
291448

449+
if (server) {
450+
auto model = ask_user_for_function_model(*server, op);
451+
models[sym.getSymbolName()] = model;
452+
return model;
453+
}
454+
292455
return std::nullopt;
293456
}
294457

295458
std::optional< function_model > get_model(core::function_op_interface op) const {
296-
return get_model(models, op);
459+
return get_model(models, op, server);
297460
}
298461

299-
const function_models &models;
462+
function_models &models;
463+
vast::server::server_base *server;
300464
};
301465

302466
//
@@ -543,10 +707,13 @@ namespace vast::conv {
543707
return mlir::failure();
544708
}
545709

546-
static void legalize(parser_conversion_config &cfg) {
710+
static void
711+
legalize(parser_conversion_config &cfg, vast::server::server_base *server) {
547712
cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >();
548-
cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) {
549-
return function_type_converter(*op.getContext(), get_model(models, op))
713+
cfg.target.addDynamicallyLegalOp< op_t >([&cfg, server](op_t op) {
714+
return function_type_converter(
715+
*op.getContext(), get_model(cfg.models, op, server)
716+
)
550717
.isLegal(op.getFunctionType());
551718
});
552719
}
@@ -724,6 +891,9 @@ namespace vast::conv {
724891
{
725892
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
726893

894+
struct server_handler
895+
{};
896+
727897
static conversion_target create_conversion_target(mcontext_t &mctx) {
728898
return conversion_target(mctx);
729899
}
@@ -738,6 +908,12 @@ namespace vast::conv {
738908
if (!config.empty()) {
739909
load_and_parse(config);
740910
}
911+
912+
if (!socket.empty()) {
913+
server = std::make_shared< vast::server::server< server_handler > >(
914+
vast::server::sock_adapter::create_unix_socket(socket)
915+
);
916+
}
741917
}
742918

743919
void load_and_parse(string_ref config) {
@@ -764,10 +940,12 @@ namespace vast::conv {
764940

765941
parser_conversion_config make_config() {
766942
auto &ctx = getContext();
767-
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models };
943+
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models,
944+
server.get() };
768945
}
769946

770947
function_models models;
948+
std::shared_ptr< vast::server::server< server_handler > > server;
771949
};
772950

773951
} // namespace vast::conv

0 commit comments

Comments
 (0)