From 246c2c43f82e7cc4e25f9234193e4df9410daabe Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 27 Jan 2025 11:56:50 +0100 Subject: [PATCH 1/3] server: Add JSON-RPC scaffolding code. --- include/vast/server/io.hpp | 103 ++++++ include/vast/server/server.hpp | 424 +++++++++++++++++++++++ include/vast/server/sync_collections.hpp | 106 ++++++ include/vast/server/types.hpp | 155 +++++++++ include/vast/server/util.hpp | 49 +++ lib/vast/CMakeLists.txt | 2 + lib/vast/server/CMakeLists.txt | 5 + lib/vast/server/io.cpp | 114 ++++++ tools/CMakeLists.txt | 2 +- vcpkg.json | 3 +- 10 files changed, 961 insertions(+), 2 deletions(-) create mode 100644 include/vast/server/io.hpp create mode 100644 include/vast/server/server.hpp create mode 100644 include/vast/server/sync_collections.hpp create mode 100644 include/vast/server/types.hpp create mode 100644 include/vast/server/util.hpp create mode 100644 lib/vast/server/CMakeLists.txt create mode 100644 lib/vast/server/io.cpp diff --git a/include/vast/server/io.hpp b/include/vast/server/io.hpp new file mode 100644 index 0000000000..b860502103 --- /dev/null +++ b/include/vast/server/io.hpp @@ -0,0 +1,103 @@ +// Copyright (c) 2024-present, Trail of Bits, Inc. + +#pragma once + +#include +#include +#include +#include + +#include + +#include + +#include "vast/Util/Warnings.hpp" + +namespace vast::server { + class connection_closed : public std::runtime_error + { + public: + connection_closed() : std::runtime_error("Connection closed") {} + + connection_closed(const char *what) : std::runtime_error(what) {} + }; + + struct io_adapter + { + virtual ~io_adapter() = default; + + virtual void close() {} + + virtual size_t read_some(std::span< char > dst) = 0; + virtual size_t write_some(std::span< const char > dst) = 0; + + // Upon completion, `dst` is filled with data. + void read_all(std::span< char > dst) { + while (!dst.empty()) { + size_t nread = read_some(dst); + dst = dst.subspan(nread); + } + } + + // Upon completion, all of the data in `src` is written to the client.. + void write_all(std::span< const char > src) { + while (!src.empty()) { + size_t nwritten = write_some(src); + src = src.subspan(nwritten); + } + } + + char read() { + char res[1]; + read_all(res); + return res[0]; + } + }; + + class file_adapter final : public io_adapter + { + FILE *ifd; + FILE *ofd; + + public: + file_adapter(FILE *ifd = stdin, FILE *ofd = stdout) : ifd(ifd), ofd(ofd) { + VAST_ASSERT(ifd != nullptr); + VAST_ASSERT(ofd != nullptr); + + setvbuf(ofd, NULL, _IONBF, 0); + } + + size_t read_some(std::span< char > dst) override { + size_t nread = fread(dst.data(), 1, dst.size_bytes(), ifd); + if (nread == 0 && (feof(ifd) || ferror(ifd))) { + throw connection_closed{}; + } + return nread; + } + + size_t write_some(std::span< const char > src) override { + size_t nwritten = fwrite(src.data(), 1, src.size_bytes(), ofd); + if (src.size() != 0 && nwritten == 0) { + throw connection_closed{}; + } + return nwritten; + } + }; + + class sock_adapter final : public io_adapter + { + public: + struct impl; + + size_t read_some(std::span< char > dst) override; + size_t write_some(std::span< const char > src) override; + ~sock_adapter(); + void close() override; + + static std::unique_ptr< sock_adapter > create_unix_socket(const std::string &path); + + private: + std::unique_ptr< struct impl > pimpl; + sock_adapter(std::unique_ptr< impl > pimpl); + }; +} // namespace vast::server diff --git a/include/vast/server/server.hpp b/include/vast/server/server.hpp new file mode 100644 index 0000000000..620593e422 --- /dev/null +++ b/include/vast/server/server.hpp @@ -0,0 +1,424 @@ +// Copyright (c) 2024-present, Trail of Bits, Inc. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "io.hpp" +#include "sync_collections.hpp" +#include "types.hpp" +#include "util.hpp" + +namespace vast::server { + class protocol_error : public std::runtime_error + { + public: + protocol_error(const char *what) : std::runtime_error(what) {} + }; + + enum JSONRPC_ERRORS { + JSONRPC_PARSE_ERROR = -32700, + JSONRPC_INVALID_REQUEST = -32600, + JSONRPC_METHOD_NOT_FOUND = -32601, + JSONRPC_INVALID_PARAMS = -32602, + JSONRPC_INTERNAL_ERROR = -32603, + }; + + template< request_like request > + class ticket + { + friend class server_base; + size_t id; + bool is_valid = true; + bool has_been_waited = false; + + ticket(size_t id) : id(id) {} + + public: + ticket(const ticket &) = delete; + + ticket(ticket &&other) + : id(other.id), is_valid(other.is_valid), has_been_waited(other.has_been_waited) { + other.is_valid = false; + } + + ticket &operator=(const ticket &) = delete; + + ticket &operator=(ticket &&other) { + id = other.id; + is_valid = other.is_valid; + has_been_waited = other.has_been_waited; + + other.is_valid = false; + } + + ~ticket() { + if (is_valid) { + VAST_ASSERT(has_been_waited); + } + } + }; + + class server_base + { + std::atomic_size_t progressive_id = 0; + + public: + virtual ~server_base() = default; + + using json = nlohmann::json; + + protected: + virtual json wait_message(size_t id) = 0; + virtual void send_message(const json &j) = 0; + + public: + void send_error(int64_t code, const std::string &message, const json &id) { + send_message({ + { "jsonrpc","2.0" }, + { "error", + { + { "id", id }, + { "code", code }, + { "message", message }, + } } + }); + } + + void + send_error(int64_t code, const std::string &message, const json &id, const json &data) { + send_message({ + { "jsonrpc","2.0" }, + { "error", + { + { "id", id }, + { "code", code }, + { "message", message }, + { "data", data }, + } } + }); + } + + void send_response(const json &id, const json &data) { + send_message({ + { "jsonrpc","2.0" }, + { "error", + { + { "id", id }, + { "result", data }, + } } + }); + } + + template< request_with_error_like request > + void send_result(const json &id, const result_type< request > &result) { + if (auto res = std::get_if< typename request::response_type >(&result)) { + send_response(id, result); + } else if (auto err = std::get_if< error< request > >(&result)) { + send_error(err->code, err->message, id, err->body); + } + } + + template< request_like request > + void send_result(const json &id, const result_type< request > &result) { + if (auto res = std::get_if< typename request::response_type >(&result)) { + send_response(id, result); + } else if (auto err = std::get_if< error< request > >(&result)) { + send_error(err->code, err->message, id); + } + } + + template< notification_like notification > + void send_notification(const notification ¬i) { + send_message({ + { "jsonrpc", "2.0" }, + { "method", notification::method }, + { "params", noti }, + }); + } + + template< request_like request > + [[nodiscard]] ticket< request > send_request_nonblock(const request &req) { + size_t id = progressive_id++; + send_message({ + { "jsonrpc", "2.0" }, + { "method", request::method }, + { "id", id }, + { "params", req }, + }); + + return ticket< request >(id); + } + + template< request_like request > + [[nodiscard]] result_type< request > wait_request(ticket< request > ticket) { + VAST_ASSERT(ticket.is_valid && !ticket.has_been_waited); + ticket.has_been_waited = true; + json response = wait_message(ticket.id); + + if (response.find("error") != response.end()) { + error< request > err = response["error"]; + return err; + } else { + typename request::response_type body = response["result"]; + return body; + } + } + + template< request_like... requests > + [[nodiscard]] std::tuple< result_type< requests >... > + wait_requests(ticket< requests >... tickets) { + return std::make_tuple(wait_request(std::move(tickets))...); + } + + template< request_like request > + [[nodiscard]] result_type< request > send_request(const request &req) { + return wait_request(send_request_nonblock(req)); + } + }; + + namespace detail { + template< message_like... message_types > + struct dispatch_handler; + + template<> + struct dispatch_handler<> + { + template< typename handler > + void operator()(handler &, server_base &server, const nlohmann::json &req) { + nlohmann::json id = req.find("id") != req.end() ? req["id"] : nullptr; + server.send_error(JSONRPC_METHOD_NOT_FOUND, "Unsupported method", id); + } + }; + + template< request_like message_type, message_like... messages > + struct dispatch_handler< message_type, messages... > + { + template< typename handler > + void operator()(handler &h, server_base &server, const nlohmann::json &j) { + if (j["method"] == message_type::method) { + if (j.find("id") == j.end()) { + server.send_error( + JSONRPC_INVALID_REQUEST, "ID was expected but not found", nullptr + ); + return; + } + + server.send_result(j["id"], h(server, j["params"])); + } else { + dispatch_handler< messages... > dispatcher; + return dispatcher(h, server, j); + } + } + }; + + template< notification_like message_type, message_like... messages > + struct dispatch_handler< message_type, messages... > + { + template< typename handler > + void operator()(handler &h, server_base &server, const nlohmann::json &j) { + if (j["method"] == message_type::method) { + if (j.find("id") != j.end()) { + server.send_error( + JSONRPC_INVALID_REQUEST, "ID found but not expected", nullptr + ); + return; + } + + h(server, j["params"]); + } else { + dispatch_handler< messages... > dispatcher; + return dispatcher(h, server, j); + } + } + }; + } // namespace detail + + template< typename message_handler, message_like... message_types > + class server final : public server_base + { + std::mutex write_mutex; + std::thread reader_thread; + std::vector< std::thread > request_threads; + + std::unique_ptr< io_adapter > adapter; + + message_handler handler; + + sync_map< size_t, json > responses; + sync_queue< json > requests; + + void read_lit(char lit) { + if (adapter->read() != lit) { + throw protocol_error("Invalid literal"); + } + } + + std::pair< std::string, std::string > read_header() { + std::stringstream header_name; + std::stringstream header_value; + do { + char c = adapter->read(); + if (c == '\r') { + break; + } + + for (; c != ':'; c = adapter->read()) { + header_name << c; + } + + // Skip whitespace between : and header value + do { + c = adapter->read(); + } while (c == ' ' || c == '\t'); + + for (; c != '\r'; c = adapter->read()) { + header_value << c; + } + } while (false); + read_lit('\n'); + return std::make_pair(header_name.str(), header_value.str()); + } + + size_t read_headers() { + std::unordered_map< std::string, std::string, ci_hash<>, ci_comparison<> > headers; + + while (true) { + auto [name, value] = read_header(); + if (name == "") { + break; + } + + headers[name] = value; + } + + auto content_length = headers.find("content-length"); + if (content_length == headers.end()) { + throw protocol_error{ "Missing Content-Length header" }; + } + + size_t size = 0; + for (char c : content_length->second) { + if (c < '0' || c > '9') { + throw protocol_error{ "Invalid Content-Length value" }; + } + if (size >= std::numeric_limits< size_t >::max() / 10) { + throw protocol_error{ "Content-Length too large" }; + } + size *= 10; + size_t digit = static_cast< size_t >(c - '0'); + if (size >= std::numeric_limits< size_t >::max() - digit) { + throw protocol_error{ "Content-Length too large" }; + } + size += digit; + } + + return size; + } + + json receive_message() { + auto body_size = read_headers(); + std::string body; + body.resize(body_size); + adapter->read_all(body); + return json::parse(body); + } + + void receive_msg_with_id(const json &msg) { + // Message is either a response or a request + json id = msg["id"]; + + if (msg.find("method") != msg.end() && msg.find("params") != msg.end()) { + requests.enqueue(msg); + } else if (msg.find("result") != msg.end() || msg.find("error") != msg.end()) { + responses.insert(id, msg); + } else { + send_error(JSONRPC_INVALID_REQUEST, "Invalid message", id); + } + } + + void reader_thread_routine() { + try { + while (true) { + json msg = receive_message(); + + if (msg.find("id") != msg.end()) { + receive_msg_with_id(msg); + } else if (msg.find("method") != msg.end() + && msg.find("params") != msg.end()) + { + // Message is a notification + requests.enqueue(msg); + } else { + send_error(JSONRPC_INVALID_REQUEST, "Invalid request", nullptr); + } + } + } catch (const execution_stopped &) { + } catch (const connection_closed &) { + responses.stop(); + requests.stop(); + } catch (const json::parse_error &err) { + send_error(JSONRPC_PARSE_ERROR, err.what(), nullptr); + } + } + + void request_thread_routine() { + try { + detail::dispatch_handler< message_types... > dispatcher; + while (true) { + auto req = requests.dequeue(); + dispatcher(this->handler, *this, req); + } + } catch (const execution_stopped &stop) { + } + } + + protected: + virtual void send_message(const json &j) override { + std::unique_lock< std::mutex > lock(write_mutex); + // The final \r\n is not necessary, but makes things easier to + // read when debugging from communication dumps + std::string data = j.dump() + "\r\n"; + + std::stringstream ss; + ss << "Content-Type: application/json;charset=utf-8\r\n"; + ss << "Content-Length: " << data.size() << "\r\n\r\n"; + ss << data; + + std::string res = ss.str(); + adapter->write_all(res); + } + + virtual json wait_message(size_t id) override { return responses.get(id); } + + public: + server( + std::unique_ptr< io_adapter > adapter, int num_request_threads = 1, + const message_handler &handler = {} + ) + : reader_thread(&server::reader_thread_routine, this) + , adapter(std::move(adapter)) + , handler(handler) { + for (int i = 0; i < num_request_threads; ++i) { + request_threads.emplace_back(&server::request_thread_routine, this); + } + } + + virtual ~server() override { + responses.stop(); + requests.stop(); + adapter->close(); + reader_thread.join(); + for (auto &thread : request_threads) { + thread.join(); + } + } + }; +} // namespace vast::server diff --git a/include/vast/server/sync_collections.hpp b/include/vast/server/sync_collections.hpp new file mode 100644 index 0000000000..058ff1ab78 --- /dev/null +++ b/include/vast/server/sync_collections.hpp @@ -0,0 +1,106 @@ +// Copyright (c) 2024-present, Trail of Bits, Inc. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace vast::server { + class execution_stopped : public std::runtime_error + { + public: + execution_stopped(const char *what) : std::runtime_error(what) {} + }; + + template< + typename Key, typename Value, typename BackingStore = std::unordered_map< Key, Value > > + class sync_map + { + std::mutex mutex; + std::condition_variable cv; + BackingStore data; + + std::atomic_bool stopped = false; + + public: + void stop() { + stopped = true; + std::lock_guard< std::mutex > lock(mutex); + cv.notify_all(); + } + + void insert(Key k, Value v) { + if (stopped) { + throw execution_stopped("User requested stop"); + } + std::lock_guard< std::mutex > lock(mutex); + data[k] = v; + cv.notify_all(); + } + + Value get(Key k) { + std::unique_lock< std::mutex > lock(mutex); + cv.wait(lock, [k, this]() { return data.find(k) != data.end() || stopped; }); + + if (stopped) { + throw execution_stopped("User requested stop"); + } + + auto it = data.find(k); + VAST_ASSERT(it != data.end()); + auto response = it->second; + data.erase(it); + return response; + } + }; + + template< typename Value, typename BackingStore = std::deque< Value > > + class sync_queue + { + std::mutex mutex; + std::condition_variable cv; + BackingStore data; + + std::atomic_bool stopped = false; + + public: + void stop() { + stopped = true; + std::lock_guard< std::mutex > lock(mutex); + cv.notify_all(); + } + + void enqueue(const Value &v) { + if (stopped) { + throw execution_stopped("User requested stop"); + } + std::lock_guard< std::mutex > lock(mutex); + data.push_back(v); + cv.notify_one(); + } + + template< typename... Args > + void enqueue(Args &&...args) { + std::lock_guard< std::mutex > lock(mutex); + data.emplace_back(std::forward< Args && >(args)...); + cv.notify_one(); + } + + Value dequeue() { + std::unique_lock< std::mutex > lock(mutex); + cv.wait(lock, [this]() { return data.begin() != data.end() || stopped; }); + + if (stopped) { + throw execution_stopped("User requested stop"); + } + + auto res = data.front(); + data.pop_front(); + return res; + } + }; +} // namespace vast::server diff --git a/include/vast/server/types.hpp b/include/vast/server/types.hpp new file mode 100644 index 0000000000..337e9fab11 --- /dev/null +++ b/include/vast/server/types.hpp @@ -0,0 +1,155 @@ +// Copyright (c) 2024-present, Trail of Bits, Inc. + +#pragma once + +#include +#include +#include +#include + +#include + +namespace vast::server { + template< typename T > + concept json_convertible = requires(T obj, nlohmann::json &json) { + { + nlohmann::to_json(json, obj) + }; + { + nlohmann::from_json(json, obj) + }; + }; + + template< typename T > + concept message_like = json_convertible< T > && requires { + { + T::is_notification + } -> std::convertible_to< bool >; + { + T::method + } -> std::convertible_to< std::string >; + }; + + template< typename message > + concept notification_like = message_like< message > && message::is_notification; + + template< typename message > + concept request_like = message_like< message > && !message::is_notification + && json_convertible< typename message::response_type >; + + template< typename message > + concept request_with_error_like = + request_like< message > && json_convertible< typename message::error_type >; + + template< typename T > + struct error; + + template< request_with_error_like request > + struct error< request > + { + int64_t code; + std::string message; + typename request::error_type body; + }; + + template< request_like request > + struct error< request > + { + int64_t code; + std::string message; + }; + + template< request_with_error_like request > + void to_json(nlohmann::json &json, const error< request > &err) { + json["code"] = err.code; + json["message"] = err.message; + to_json(json["data"], err.body); + } + + template< request_like request > + void to_json(nlohmann::json &json, const error< request > &err) { + json["code"] = err.code; + json["message"] = err.message; + } + + template< request_with_error_like request > + void from_json(const nlohmann::json &json, error< request > &err) { + from_json(json["code"], err.code); + from_json(json["message"], err.message); + from_json(json["data"], err.body); + } + + template< request_like request > + void from_json(const nlohmann::json &json, error< request > &err) { + from_json(json["code"], err.code); + from_json(json["message"], err.message); + } + + template< request_like request > + using result_type = std::variant< typename request::response_type, error< request > >; + + struct input_request + { + static constexpr const char *method = "input"; + static constexpr bool is_notification = false; + + nlohmann::json type; + std::string text; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text) + + struct response_type + { + nlohmann::json value; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(response_type, value) + }; + }; + + enum class message_kind { + info, + warn, + err, + }; + + NLOHMANN_JSON_SERIALIZE_ENUM( + message_kind, + { + { message_kind::info, "info" }, + { message_kind::warn, "warn" }, + { message_kind::err, "err" }, + } + ) + + struct message_notification + { + static constexpr const char *method = "message"; + static constexpr bool is_notification = true; + + message_kind kind; + std::string text; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(message_notification, kind, text) + }; + + enum class console_severity { trace, debug, info, warn, err }; + + NLOHMANN_JSON_SERIALIZE_ENUM( + console_severity, + { + { console_severity::trace, "trace" }, + { console_severity::debug, "debug" }, + { console_severity::info, "info" }, + { console_severity::warn, "warn" }, + { console_severity::err, "err" }, + } + ) + + struct console_notification + { + static constexpr const char *method = "console"; + static constexpr bool is_notification = true; + + console_severity severity; + std::string message; + nlohmann::json params; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(console_notification, severity, message, params) + }; +} // namespace vast::server diff --git a/include/vast/server/util.hpp b/include/vast/server/util.hpp new file mode 100644 index 0000000000..c9b8809233 --- /dev/null +++ b/include/vast/server/util.hpp @@ -0,0 +1,49 @@ +// Copyright (c) 2024-present, Trail of Bits, Inc. + +#pragma once + +#include +#include +#include + +namespace vast::server { + template< typename Hash = std::hash< std::string > > + class ci_hash + { + Hash hash; + + public: + ci_hash(const Hash &hash = {}) : hash(hash) {} + + std::uint64_t operator()(const std::string &s) const { + std::string lower(s.size(), '\0'); + std::transform(s.begin(), s.end(), lower.begin(), [](char c) { + return static_cast< char >(std::tolower(static_cast< unsigned char >(c))); + }); + return hash(lower); + } + }; + + template< typename Comparison = std::equal_to< std::string > > + class ci_comparison + { + Comparison comp; + + public: + ci_comparison(const Comparison &comp = {}) : comp(comp) {} + + bool operator()(const std::string &a, const std::string &b) const { + std::string a_lower(a.size(), '\0'); + std::string b_lower(b.size(), '\0'); + + auto to_lower = [](char c) -> char { + return static_cast< char >(std::tolower(static_cast< unsigned char >(c))); + }; + + std::transform(a.begin(), a.end(), a_lower.begin(), to_lower); + std::transform(b.begin(), b.end(), b_lower.begin(), to_lower); + + return comp(a_lower, b_lower); + } + }; +} // namespace vast::server diff --git a/lib/vast/CMakeLists.txt b/lib/vast/CMakeLists.txt index b8961dbe89..3340ad0866 100644 --- a/lib/vast/CMakeLists.txt +++ b/lib/vast/CMakeLists.txt @@ -14,5 +14,7 @@ if (VAST_BUILD_FRONTEND) add_subdirectory(Frontend) endif() +add_subdirectory(server) + add_subdirectory(Tower) add_subdirectory(Util) diff --git a/lib/vast/server/CMakeLists.txt b/lib/vast/server/CMakeLists.txt new file mode 100644 index 0000000000..894884c758 --- /dev/null +++ b/lib/vast/server/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright (c) 2025-present, Trail of Bits, Inc. + +add_vast_library(server + io.cpp +) diff --git a/lib/vast/server/io.cpp b/lib/vast/server/io.cpp new file mode 100644 index 0000000000..03df91711f --- /dev/null +++ b/lib/vast/server/io.cpp @@ -0,0 +1,114 @@ +#include "vast/server/io.hpp" + +#include +#include + +#include +#include +#include + +namespace vast::server { + union addr { + sockaddr base; + sockaddr_un unix; + }; + + struct descriptor + { + int fd; + + explicit descriptor(int fd) : fd(fd) {} + + descriptor(const descriptor &) = delete; + descriptor &operator=(const descriptor &) = delete; + + descriptor(descriptor &&other) : fd(other.fd) { other.fd = -1; } + + descriptor &operator=(descriptor &&other) { + if (fd >= 0) { + close(fd); + } + fd = other.fd; + other.fd = -1; + return *this; + } + + operator int() { return fd; } + + ~descriptor() { + if (fd >= 0) { + close(fd); + } + } + }; + + struct sock_adapter::impl + { + descriptor serverd; + descriptor clientd; + }; + + sock_adapter::sock_adapter(std::unique_ptr< impl > pimpl) : pimpl(std::move(pimpl)) {} + + sock_adapter::~sock_adapter() = default; + + void sock_adapter::close() { + pimpl->clientd = descriptor{ -1 }; + pimpl->serverd = descriptor{ -1 }; + } + + size_t sock_adapter::read_some(std::span< char > dst) { + auto res = ::read(pimpl->clientd, dst.data(), dst.size_bytes()); + if (res <= 0) { + throw connection_closed{}; + } + return static_cast< size_t >(res); + } + + size_t sock_adapter::write_some(std::span< const char > src) { + auto res = ::write(pimpl->clientd, src.data(), src.size_bytes()); + if (res == -1) { + throw connection_closed{}; + } + return static_cast< size_t >(res); + } + + std::unique_ptr< sock_adapter > sock_adapter::create_unix_socket(const std::string &path) { + if (path.size() > (sizeof(sockaddr_un::sun_path) - 1)) { + throw std::runtime_error("Unix socket pathname is too long"); + } + + descriptor serverd{ socket(AF_UNIX, SOCK_STREAM, 0) }; + if (serverd < 0) { + throw std::system_error(errno, std::generic_category()); + } + + addr sock_addr{}; + sock_addr.unix.sun_family = AF_UNIX; + std::copy(path.begin(), path.end(), sock_addr.unix.sun_path); + + // Here we pre-emptively try to delete an old socket in case it was used before, + // in order to avoid having to check for existence and running into time-of-check to + // time-of-use issues. + // This also means that it's possible that we're tring to delete a socket that does + // not yet exist, so we ignore ENOENT errors. + if (unlink(path.c_str()) < 0 && errno != ENOENT) { + throw std::system_error(errno, std::generic_category()); + } + int rc = + bind(serverd, &sock_addr.base, static_cast< socklen_t >(SUN_LEN(&sock_addr.unix))); + if (rc < 0) { + throw std::system_error(errno, std::generic_category()); + } + + rc = listen(serverd, 1); + if (rc < 0) { + throw std::system_error(errno, std::generic_category()); + } + + descriptor clientd{ accept(serverd, nullptr, nullptr) }; + + return std::unique_ptr< sock_adapter >(new sock_adapter{ + std::make_unique< impl >(std::move(serverd), std::move(clientd)) }); + } +} // namespace vast::server diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 5e9356a6b6..9cd9f27f07 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -3,4 +3,4 @@ add_subdirectory(vast-front) add_subdirectory(vast-opt) add_subdirectory(vast-query) add_subdirectory(vast-repl) -add_subdirectory(vast-lsp-server) +add_subdirectory(vast-lsp-server) \ No newline at end of file diff --git a/vcpkg.json b/vcpkg.json index 046e6560fe..46869f07a2 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -14,6 +14,7 @@ } }, "dependencies": [ - "gap" + "gap", + "nlohmann-json" ] } \ No newline at end of file From e3dfaf1196c90a6b1c2a76b05d1177a3c0e6fe3d Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 31 Jan 2025 12:23:53 +0100 Subject: [PATCH 2/3] pr: Pass a `function_op_interface` instead of a name to `parser_conversion_pattern_base`. --- lib/vast/Conversion/Parser/ToParser.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/vast/Conversion/Parser/ToParser.cpp b/lib/vast/Conversion/Parser/ToParser.cpp index 85d952a882..0356e63a7f 100644 --- a/lib/vast/Conversion/Parser/ToParser.cpp +++ b/lib/vast/Conversion/Parser/ToParser.cpp @@ -281,18 +281,19 @@ namespace vast::conv { : base(mctx), models(models) {} - static std::optional< function_model > get_model( - const function_models &models, string_ref name - ) { - if (auto kv = models.find(name); kv != models.end()) { + static std::optional< function_model > + get_model(const function_models &models, core::function_op_interface op) { + auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation()); + VAST_ASSERT(sym); + if (auto kv = models.find(sym.getSymbolName()); kv != models.end()) { return kv->second; } return std::nullopt; } - std::optional< function_model > get_model(string_ref name) const { - return get_model(models, name); + std::optional< function_model > get_model(core::function_op_interface op) const { + return get_model(models, op); } const function_models ⊧ @@ -494,7 +495,7 @@ namespace vast::conv { op_t op, adaptor_t adaptor, conversion_rewriter &rewriter ) const override { auto func = op->getParentOfType< hl::FuncOp >(); - auto model = get_model(func.getSymName()); + auto model = get_model(func); auto rty = model ? model->get_return_type(rewriter.getContext()) @@ -534,7 +535,7 @@ namespace vast::conv { logical_result matchAndRewrite( op_t op, adaptor_t adaptor, conversion_rewriter &rewriter ) const override { - auto tc = function_type_converter(*rewriter.getContext(), get_model(op.getSymName())); + auto tc = function_type_converter(*rewriter.getContext(), get_model(op)); if (auto func_op = mlir::dyn_cast< core::function_op_interface >(op.getOperation())) { return this->replace(func_op, rewriter, tc); } @@ -545,9 +546,8 @@ namespace vast::conv { static void legalize(parser_conversion_config &cfg) { cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >(); cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) { - return function_type_converter( - *op.getContext(), get_model(models, op.getSymName()) - ).isLegal(op.getFunctionType()); + return function_type_converter(*op.getContext(), get_model(models, op)) + .isLegal(op.getFunctionType()); }); } }; From 67a30887eea503bdb8db63dadd38300229076dbf Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 27 Jan 2025 14:31:39 +0100 Subject: [PATCH 3/3] pr: Add option to use interactive server in parser conversion pass. --- include/vast/Conversion/Parser/Passes.hpp | 2 + include/vast/Conversion/Parser/Passes.td | 3 + include/vast/server/types.hpp | 44 ++++- lib/vast/Conversion/Parser/ToParser.cpp | 214 ++++++++++++++++++++-- 4 files changed, 244 insertions(+), 19 deletions(-) diff --git a/include/vast/Conversion/Parser/Passes.hpp b/include/vast/Conversion/Parser/Passes.hpp index e0f3c2f5eb..4da64db19d 100644 --- a/include/vast/Conversion/Parser/Passes.hpp +++ b/include/vast/Conversion/Parser/Passes.hpp @@ -4,6 +4,8 @@ #include "vast/Util/Warnings.hpp" +#include "vast/server/server.hpp" + VAST_RELAX_WARNINGS #include #include diff --git a/include/vast/Conversion/Parser/Passes.td b/include/vast/Conversion/Parser/Passes.td index 585fb54b82..2cdd1abb57 100644 --- a/include/vast/Conversion/Parser/Passes.td +++ b/include/vast/Conversion/Parser/Passes.td @@ -10,6 +10,9 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> { let options = [ Option< "config", "config", "std::string", "", "Configuration file for parser transformation." + >, + Option< "socket", "socket", "std::string", "", + "Unix socket path to use for server" > ]; diff --git a/include/vast/server/types.hpp b/include/vast/server/types.hpp index 337e9fab11..2c6370fab1 100644 --- a/include/vast/server/types.hpp +++ b/include/vast/server/types.hpp @@ -4,11 +4,34 @@ #include #include +#include #include #include #include +namespace nlohmann { + template< typename T > + struct adl_serializer< std::optional< T > > + { + static void to_json(json &j, const std::optional< T > &opt) { + if (!opt.has_value()) { + j = nullptr; + } else { + j = *opt; + } + } + + static void from_json(const json &j, std::optional< T > &opt) { + if (j.is_null()) { + opt = std::nullopt; + } else { + opt = j.template get< T >(); + } + } + }; +} // namespace nlohmann + namespace vast::server { template< typename T > concept json_convertible = requires(T obj, nlohmann::json &json) { @@ -88,6 +111,22 @@ namespace vast::server { template< request_like request > using result_type = std::variant< typename request::response_type, error< request > >; + struct position + { + unsigned int line; + unsigned int character; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(position, line, character) + }; + + struct range + { + position start; + position end; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(range, start, end) + }; + struct input_request { static constexpr const char *method = "input"; @@ -95,7 +134,10 @@ namespace vast::server { nlohmann::json type; std::string text; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text) + std::optional< std::string > filePath; + std::optional< range > range; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text, filePath, range) struct response_type { diff --git a/lib/vast/Conversion/Parser/ToParser.cpp b/lib/vast/Conversion/Parser/ToParser.cpp index 0356e63a7f..6e6968c971 100644 --- a/lib/vast/Conversion/Parser/ToParser.cpp +++ b/lib/vast/Conversion/Parser/ToParser.cpp @@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS #include "vast/Conversion/Parser/Config.hpp" +#include "vast/server/server.hpp" +#include "vast/server/types.hpp" + #include namespace vast::conv { @@ -75,6 +78,154 @@ namespace vast::conv { using function_models = llvm::StringMap< function_model >; + struct location + { + std::string filePath; + server::range range; + }; + + location get_location(file_loc_t loc) { + return { + .filePath = loc.getFilename().str(), + .range = { + .start = { loc.getLine(), loc.getColumn(), }, + .end = { loc.getLine(), loc.getColumn(), }, + }, + }; + } + + location get_location(name_loc_t loc) { + return get_location(mlir::cast< file_loc_t >(loc.getChildLoc())); + } + + std::optional< location > get_location(loc_t loc) { + if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) { + return get_location(file_loc); + } else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) { + return get_location(name_loc); + } + + return std::nullopt; + } + + pr::data_type parse_type_name(const std::string &name) { + if (name == "data") { + return pr::data_type::data; + } else if (name == "nodata") { + return pr::data_type::nodata; + } else { + return pr::data_type::maybedata; + } + } + + function_category + ask_user_for_category(vast::server::server_base &server, core::function_op_interface op) { + auto loc = op.getLoc(); + auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation()); + VAST_ASSERT(sym); + auto name = sym.getSymbolName().str(); + + vast::server::input_request req{ + .type = {"nonparser", "sink", "source", "parser",}, + .text = "Please choose category for function `" + name + '`', + .filePath = std::nullopt, + .range = std::nullopt, + }; + + if (auto req_loc = get_location(loc)) { + req.filePath = req_loc->filePath; + req.range = req_loc->range; + } + + auto response = server.send_request(req); + if (auto result = std::get_if< vast::server::input_request::response_type >(&response)) + { + if (result->value == "nonparser") { + return function_category::nonparser; + } else if (result->value == "sink") { + return function_category::sink; + } else if (result->value == "source") { + return function_category::source; + } else if (result->value == "parser") { + return function_category::parser; + } + } + return function_category::nonparser; + } + + pr::data_type ask_user_for_return_type( + vast::server::server_base &server, core::function_op_interface op + ) { + auto loc = op.getLoc(); + auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation()); + VAST_ASSERT(sym); + auto name = sym.getSymbolName().str(); + + vast::server::input_request req{ + .type = { "maybedata", "nodata", "data" }, + .text = "Please choose return type for function `" + name + '`', + .filePath = std::nullopt, + .range = std::nullopt, + }; + + if (auto req_loc = get_location(loc)) { + req.filePath = req_loc->filePath; + req.range = req_loc->range; + } + + auto response = server.send_request(req); + if (auto result = std::get_if< vast::server::input_request::response_type >(&response)) + { + return parse_type_name(result->value); + } + return pr::data_type::maybedata; + } + + pr::data_type ask_user_for_argument_type( + vast::server::server_base &server, core::function_op_interface op, unsigned int idx + ) { + auto num_body_args = op.getFunctionBody().getNumArguments(); + auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation()); + VAST_ASSERT(sym); + auto name = sym.getSymbolName().str(); + + vast::server::input_request req{ + .type = { "maybedata", "nodata", "data" }, + .text = "Please choose a type for argument " + std::to_string(idx) + + " of function `" + name + '`', + .filePath = std::nullopt, + .range = std::nullopt, + }; + + if (idx < num_body_args) { + auto arg = op.getArgument(idx); + auto loc = arg.getLoc(); + if (auto req_loc = get_location(loc)) { + req.filePath = req_loc->filePath; + req.range = req_loc->range; + } + } + + auto response = server.send_request(req); + if (auto result = std::get_if< vast::server::input_request::response_type >(&response)) + { + return parse_type_name(result->value); + } + return pr::data_type::maybedata; + } + + function_model ask_user_for_function_model( + vast::server::server_base &server, core::function_op_interface op + ) { + function_model model; + model.return_type = ask_user_for_return_type(server, op); + for (unsigned int i = 0; i < op.getNumArguments(); ++i) { + model.arguments.push_back(ask_user_for_argument_type(server, op, i)); + } + model.category = ask_user_for_category(server, op); + return model; + } + } // namespace vast::conv LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type); @@ -130,25 +281,28 @@ namespace vast::conv { using base = base_conversion_config; parser_conversion_config( - rewrite_pattern_set patterns, conversion_target target, - const function_models &models + rewrite_pattern_set patterns, conversion_target target, function_models &models, + vast::server::server_base *server ) - : base(std::move(patterns), std::move(target)), models(models) - {} + : base(std::move(patterns), std::move(target)), models(models), server(server) {} template< typename pattern > void add_pattern() { auto ctx = patterns.getContext(); if constexpr (std::is_constructible_v< pattern, mcontext_t * >) { patterns.template add< pattern >(ctx); - } else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) { - patterns.template add< pattern >(ctx, models); + } else if constexpr (std::is_constructible_v< + pattern, mcontext_t *, function_models &, + vast::server::server_base * >) + { + patterns.template add< pattern >(ctx, models, server); } else { static_assert(false, "pattern does not have a valid constructor"); } } - const function_models ⊧ + function_models ⊧ + vast::server::server_base *server; }; struct function_type_converter @@ -277,26 +431,36 @@ namespace vast::conv { { using base = mlir::OpConversionPattern< op_t >; - parser_conversion_pattern_base(mcontext_t *mctx, const function_models &models) - : base(mctx), models(models) - {} + parser_conversion_pattern_base( + mcontext_t *mctx, function_models &models, vast::server::server_base *server + ) + : base(mctx), models(models), server(server) {} - static std::optional< function_model > - get_model(const function_models &models, core::function_op_interface op) { + static std::optional< function_model > get_model( + function_models &models, core::function_op_interface op, + vast::server::server_base *server + ) { auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation()); VAST_ASSERT(sym); if (auto kv = models.find(sym.getSymbolName()); kv != models.end()) { return kv->second; } + if (server) { + auto model = ask_user_for_function_model(*server, op); + models[sym.getSymbolName()] = model; + return model; + } + return std::nullopt; } std::optional< function_model > get_model(core::function_op_interface op) const { - return get_model(models, op); + return get_model(models, op, server); } - const function_models ⊧ + function_models ⊧ + vast::server::server_base *server; }; // @@ -543,10 +707,13 @@ namespace vast::conv { return mlir::failure(); } - static void legalize(parser_conversion_config &cfg) { + static void + legalize(parser_conversion_config &cfg, vast::server::server_base *server) { cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >(); - cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) { - return function_type_converter(*op.getContext(), get_model(models, op)) + cfg.target.addDynamicallyLegalOp< op_t >([&cfg, server](op_t op) { + return function_type_converter( + *op.getContext(), get_model(cfg.models, op, server) + ) .isLegal(op.getFunctionType()); }); } @@ -729,6 +896,9 @@ namespace vast::conv { { using base = ConversionPassMixin< HLToParserPass, HLToParserBase >; + struct server_handler + {}; + static conversion_target create_conversion_target(mcontext_t &mctx) { return conversion_target(mctx); } @@ -743,6 +913,12 @@ namespace vast::conv { if (!config.empty()) { load_and_parse(config); } + + if (!socket.empty()) { + server = std::make_shared< vast::server::server< server_handler > >( + vast::server::sock_adapter::create_unix_socket(socket) + ); + } } void load_and_parse(string_ref config) { @@ -769,10 +945,12 @@ namespace vast::conv { parser_conversion_config make_config() { auto &ctx = getContext(); - return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models }; + return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models, + server.get() }; } function_models models; + std::shared_ptr< vast::server::server< server_handler > > server; }; } // namespace vast::conv