|
| 1 | +/* |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#pragma once |
| 18 | + |
| 19 | +#include <pybind11/pybind11.h> |
| 20 | +#include <pybind11/stl.h> |
| 21 | +#include <pybind11/stl_bind.h> |
| 22 | +#include <velox/type/Type.h> |
| 23 | +#include "folly/json.h" |
| 24 | + |
| 25 | +namespace facebook::velox::py { |
| 26 | + |
| 27 | +std::string serializeType(const std::shared_ptr<const velox::Type>& type); |
| 28 | + |
| 29 | +/// Adds Velox Python Bindings to the module m. |
| 30 | +/// |
| 31 | +/// This function adds the following bindings: |
| 32 | +/// * velox::TypeKind enum |
| 33 | +/// * velox::Type and its derived types |
| 34 | +/// * Basic functions on Type and its derived types. |
| 35 | +/// |
| 36 | +/// @param m Module to add bindings too. |
| 37 | +/// @param asLocalModule If true then these bindings are only visible inside |
| 38 | +/// the module. Refer to |
| 39 | +/// https://pybind11.readthedocs.io/en/stable/advanced/classes.html#module-local-class-bindings |
| 40 | +/// for further details. |
| 41 | +inline void addVeloxBindings(pybind11::module& m, bool asLocalModule = true) { |
| 42 | + // Inlining these bindings since adding them to the cpp file results in a |
| 43 | + // ASAN error. |
| 44 | + using namespace velox; |
| 45 | + namespace py = pybind11; |
| 46 | + |
| 47 | + // Add TypeKind enum. |
| 48 | + py::enum_<velox::TypeKind>(m, "TypeKind", py::module_local(asLocalModule)) |
| 49 | + .value("BOOLEAN", velox::TypeKind::BOOLEAN) |
| 50 | + .value("TINYINT", velox::TypeKind::TINYINT) |
| 51 | + .value("SMALLINT", velox::TypeKind::SMALLINT) |
| 52 | + .value("INTEGER", velox::TypeKind::INTEGER) |
| 53 | + .value("BIGINT", velox::TypeKind::BIGINT) |
| 54 | + .value("REAL", velox::TypeKind::REAL) |
| 55 | + .value("DOUBLE", velox::TypeKind::DOUBLE) |
| 56 | + .value("VARCHAR", velox::TypeKind::VARCHAR) |
| 57 | + .value("VARBINARY", velox::TypeKind::VARBINARY) |
| 58 | + .value("TIMESTAMP", velox::TypeKind::TIMESTAMP) |
| 59 | + .value("OPAQUE", velox::TypeKind::OPAQUE) |
| 60 | + .value("ARRAY", velox::TypeKind::ARRAY) |
| 61 | + .value("MAP", velox::TypeKind::MAP) |
| 62 | + .value("ROW", velox::TypeKind::ROW) |
| 63 | + .export_values(); |
| 64 | + |
| 65 | + // Create VeloxType bound to velox::Type. |
| 66 | + py::class_<Type, std::shared_ptr<Type>> type( |
| 67 | + m, "VeloxType", py::module_local(asLocalModule)); |
| 68 | + |
| 69 | + // Adding all the derived types of Type here. |
| 70 | + py::class_<BooleanType, Type, std::shared_ptr<BooleanType>> booleanType( |
| 71 | + m, "BooleanType", py::module_local(asLocalModule)); |
| 72 | + py::class_<IntegerType, Type, std::shared_ptr<IntegerType>> integerType( |
| 73 | + m, "IntegerType", py::module_local(asLocalModule)); |
| 74 | + py::class_<BigintType, Type, std::shared_ptr<BigintType>> bigintType( |
| 75 | + m, "BigintType", py::module_local(asLocalModule)); |
| 76 | + py::class_<SmallintType, Type, std::shared_ptr<SmallintType>> smallintType( |
| 77 | + m, "SmallintType", py::module_local(asLocalModule)); |
| 78 | + py::class_<TinyintType, Type, std::shared_ptr<TinyintType>> tinyintType( |
| 79 | + m, "TinyintType", py::module_local(asLocalModule)); |
| 80 | + py::class_<RealType, Type, std::shared_ptr<RealType>> realType( |
| 81 | + m, "RealType", py::module_local(asLocalModule)); |
| 82 | + py::class_<DoubleType, Type, std::shared_ptr<DoubleType>> doubleType( |
| 83 | + m, "DoubleType", py::module_local(asLocalModule)); |
| 84 | + py::class_<TimestampType, Type, std::shared_ptr<TimestampType>> timestampType( |
| 85 | + m, "TimestampType", py::module_local(asLocalModule)); |
| 86 | + py::class_<VarcharType, Type, std::shared_ptr<VarcharType>> varcharType( |
| 87 | + m, "VarcharType", py::module_local(asLocalModule)); |
| 88 | + py::class_<VarbinaryType, Type, std::shared_ptr<VarbinaryType>> varbinaryType( |
| 89 | + m, "VarbinaryType", py::module_local(asLocalModule)); |
| 90 | + py::class_<ArrayType, Type, std::shared_ptr<ArrayType>> arrayType( |
| 91 | + m, "ArrayType", py::module_local(asLocalModule)); |
| 92 | + py::class_<MapType, Type, std::shared_ptr<MapType>> mapType( |
| 93 | + m, "MapType", py::module_local(asLocalModule)); |
| 94 | + py::class_<RowType, Type, std::shared_ptr<RowType>> rowType( |
| 95 | + m, "RowType", py::module_local(asLocalModule)); |
| 96 | + py::class_<FixedSizeArrayType, Type, std::shared_ptr<FixedSizeArrayType>> |
| 97 | + fixedArrayType(m, "FixedSizeArrayType", py::module_local(asLocalModule)); |
| 98 | + |
| 99 | + // Basic operations on Type. |
| 100 | + type.def("__str__", &Type::toString); |
| 101 | + // Gcc doesnt support the below kind of templatization. |
| 102 | +#if defined(__clang__) |
| 103 | + // Adds equality and inequality comparison operators. |
| 104 | + type.def(py::self == py::self); |
| 105 | + type.def(py::self != py::self); |
| 106 | +#endif |
| 107 | + type.def( |
| 108 | + "cpp_size_in_bytes", |
| 109 | + &Type::cppSizeInBytes, |
| 110 | + "Return the C++ size in bytes"); |
| 111 | + type.def( |
| 112 | + "is_fixed_width", |
| 113 | + &Type::isFixedWidth, |
| 114 | + "Check if the type is fixed width"); |
| 115 | + type.def( |
| 116 | + "is_primitive_type", |
| 117 | + &Type::isPrimitiveType, |
| 118 | + "Check if the type is a primitive type"); |
| 119 | + type.def("kind", &Type::kind, "Returns the kind of the type"); |
| 120 | + type.def("serialize", &serializeType, "Serializes the type as JSON"); |
| 121 | + |
| 122 | + booleanType.def(py::init()); |
| 123 | + tinyintType.def(py::init()); |
| 124 | + smallintType.def(py::init()); |
| 125 | + integerType.def(py::init()); |
| 126 | + bigintType.def(py::init()); |
| 127 | + realType.def(py::init()); |
| 128 | + doubleType.def(py::init()); |
| 129 | + varcharType.def(py::init()); |
| 130 | + varbinaryType.def(py::init()); |
| 131 | + timestampType.def(py::init()); |
| 132 | + arrayType.def(py::init<std::shared_ptr<Type>>()); |
| 133 | + arrayType.def( |
| 134 | + "element_type", &ArrayType::elementType, "Return the element type"); |
| 135 | + fixedArrayType.def(py::init<int, velox::TypePtr>()) |
| 136 | + .def("element_type", &velox::FixedSizeArrayType::elementType) |
| 137 | + .def("fixed_width", &velox::FixedSizeArrayType::fixedElementsWidth); |
| 138 | + mapType.def(py::init<std::shared_ptr<Type>, std::shared_ptr<Type>>()); |
| 139 | + mapType.def("key_type", &MapType::keyType, "Return the key type"); |
| 140 | + mapType.def("value_type", &MapType::valueType, "Return the value type"); |
| 141 | + |
| 142 | + rowType.def(py::init< |
| 143 | + std::vector<std::string>, |
| 144 | + std::vector<std::shared_ptr<const Type>>>()); |
| 145 | + rowType.def("size", &RowType::size, "Return the number of columns"); |
| 146 | + rowType.def( |
| 147 | + "child_at", |
| 148 | + &RowType::childAt, |
| 149 | + "Return the type of the column at a given index", |
| 150 | + py::arg("idx")); |
| 151 | + rowType.def( |
| 152 | + "find_child", |
| 153 | + [](const std::shared_ptr<RowType>& type, const std::string& name) { |
| 154 | + return type->findChild(name); |
| 155 | + }, |
| 156 | + "Return the type of the column with the given name", |
| 157 | + py::arg("name")); |
| 158 | + rowType.def( |
| 159 | + "get_child_idx", |
| 160 | + &RowType::getChildIdx, |
| 161 | + "Return the index of the column with the given name", |
| 162 | + py::arg("name")); |
| 163 | + rowType.def( |
| 164 | + "name_of", |
| 165 | + &RowType::nameOf, |
| 166 | + "Return the name of the column at the given index", |
| 167 | + py::arg("idx")); |
| 168 | + rowType.def("names", &RowType::names, "Return the names of the columns"); |
| 169 | +} |
| 170 | + |
| 171 | +} // namespace facebook::velox::py |
0 commit comments