diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td index 2f170dc..fe787bf 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td @@ -80,7 +80,6 @@ def P4HIR_AggAttr : P4HIR_Attr<"Agg", "aggregate", [TypedAttrInterface]> { }]> ]; // let genVerifyDecl = 1; - //let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ `<` $fields `>` }]; diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td index b510d37..4e8a864 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td @@ -947,4 +947,52 @@ def StructExtractRefOp : P4HIR_Op<"struct_extract_ref", }]; } +def TupleOp : P4HIR_Op<"tuple", + [Pure, + DeclareOpInterfaceMethods]> { + let summary = "Create a tuplet from constituent parts."; + // FIXME: Better constraint type + let arguments = (ins Variadic:$input); + let results = (outs Builtin_Tuple:$result); + let hasCustomAssemblyFormat = 1; + // FIXME: use declarative format + // let assemblyFormat = [{ + // `(` $input `)` attr-dict `:` type($result) + // }]; + + let hasVerifier = 1; +} + +def TupleExtractOp : P4HIR_Op<"tuple_extract", + [Pure, + DeclareOpInterfaceMethods + ]> { + let summary = "Extract a field from a tuple."; + let description = [{ + ``` + %result = p4hir.tuple_extract %input[0] : !tuple + + ``` + }]; + + // TODO: find a way to use fieldid inteerface on tuples... + let arguments = (ins Builtin_Tuple:$input, I32Attr:$fieldIndex); + // FIXME: Better constraint type + let results = (outs AnyP4Type:$result); + + let builders = [ + OpBuilder<(ins "mlir::Value":$input, "unsigned":$fieldIndex)>, + OpBuilder<(ins "mlir::Value":$input, "P4HIR::IntAttr":$fieldIndex), [{ + build($_builder, $_state, input, fieldIndex.getUInt()); + }]> + ]; + + // FIXME: use declarative format + // let assemblyFormat = [{ + // $input `[` $fieldIndex `]` attr-dict `:` type($result) + // }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + #endif // P4MLIR_DIALECT_P4HIR_P4HIR_OPS_TD diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td index 36f704c..a160ba9 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td @@ -2,6 +2,7 @@ #define P4MLIR_DIALECT_P4HIR_P4HIR_TYPES_TD include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypes.td" include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/MemorySlotInterfaces.td" @@ -488,14 +489,14 @@ def SerEnumType : P4HIR_Type<"SerEnum", "ser_enum", []> { //===----------------------------------------------------------------------===// def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, - StructType, HeaderType, + StructType, HeaderType, Builtin_Tuple, EnumType, SerEnumType, ValidBitType, DontcareType, ErrorType, UnknownType]> {} def AnyIntP4Type : AnyTypeOf<[BitsType, InfIntType]> {} def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {} def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, - StructType, HeaderType, + StructType, HeaderType, Builtin_Tuple, EnumType, SerEnumType, ErrorType, ValidBitType]> {} def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>; diff --git a/lib/Dialect/P4HIR/P4HIR_Ops.cpp b/lib/Dialect/P4HIR/P4HIR_Ops.cpp index 7cf0ff3..522d39a 100644 --- a/lib/Dialect/P4HIR/P4HIR_Ops.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Ops.cpp @@ -6,6 +6,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/LogicalResult.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/FunctionImplementation.h" @@ -41,7 +42,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, if (mlir::isa(attrType)) return success(); if (mlir::isa(attrType)) { - if (!mlir::isa(opType)) + if (!mlir::isa(opType)) return op->emitOpError("result type (") << opType << ") is not an aggregate type"; return success(); @@ -887,6 +888,109 @@ void P4HIR::StructExtractRefOp::build(OpBuilder &builder, OperationState &odsSta build(builder, odsState, ReferenceType::get(fieldType), input, *fieldIndex); } +//===----------------------------------------------------------------------===// +// StructOp +//===----------------------------------------------------------------------===// + +ParseResult P4HIR::TupleOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); + llvm::SmallVector operands; + Type declType; + + if (parser.parseLParen() || parser.parseOperandList(operands) || parser.parseRParen() || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(declType)) + return failure(); + + auto tupleType = mlir::dyn_cast(declType); + if (!tupleType) return parser.emitError(parser.getNameLoc(), "expected !tuple type"); + + result.addTypes(tupleType); + if (parser.resolveOperands(operands, tupleType.getTypes(), inputOperandsLoc, result.operands)) + return failure(); + return success(); +} + +void P4HIR::TupleOp::print(OpAsmPrinter &printer) { + printer << " ("; + printer.printOperands(getInput()); + printer << ")"; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getType(); +} + +LogicalResult P4HIR::TupleOp::verify() { + auto elementTypes = getType().getTypes(); + + if (elementTypes.size() != getInput().size()) return emitOpError("tuple field count mismatch"); + + for (const auto &[field, value] : llvm::zip(elementTypes, getInput())) + if (field != value.getType()) return emitOpError("tuple field types do not match"); + + return success(); +} + +void P4HIR::TupleOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), "tuple"); +} + +// TODO: This duplicates lots of things above for structs. Find a way to generalize +LogicalResult P4HIR::TupleExtractOp::verify() { + auto index = getFieldIndex(); + auto fields = getInput().getType(); + if (index >= fields.size()) + return emitOpError() << "field index " << index + << " exceeds element count of aggregate type"; + + if (getType() != fields.getType(index)) + return emitOpError() << "type " << fields.getType(index) + << " of accessed field in aggregate at index " << index + << " does not match expected type " << getType(); + + return success(); +} + +ParseResult P4HIR::TupleExtractOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + unsigned fieldIndex = -1U; + mlir::TupleType declType; + + if (parser.parseOperand(operand) || parser.parseLSquare() || parser.parseInteger(fieldIndex) || + parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseType(declType)) + return failure(); + + auto indexAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), fieldIndex); + result.addAttribute("fieldIndex", indexAttr); + Type resultType = declType.getType(fieldIndex); + result.addTypes(resultType); + + if (parser.resolveOperand(operand, declType, result.operands)) return failure(); + return success(); +} + +void P4HIR::TupleExtractOp::print(OpAsmPrinter &printer) { + printer << " "; + printer.printOperand(getInput()); + printer << "[" << getFieldIndex() << "]"; + printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"}); + printer << " : "; + printer << getInput().getType(); +} + +void P4HIR::TupleExtractOp::getAsmResultNames(function_ref setNameFn) { + llvm::SmallString<16> name; + llvm::raw_svector_ostream specialName(name); + specialName << 't' << getFieldIndex(); + + setNameFn(getResult(), name); +} + +void P4HIR::TupleExtractOp::build(OpBuilder &builder, OperationState &odsState, Value input, + unsigned fieldIndex) { + auto tupleType = mlir::cast(input.getType()); + build(builder, odsState, tupleType.getType(fieldIndex), input, fieldIndex); +} + namespace { struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; diff --git a/test/Dialect/P4HIR/tuple.mlir b/test/Dialect/P4HIR/tuple.mlir new file mode 100644 index 0000000..1b50c78 --- /dev/null +++ b/test/Dialect/P4HIR/tuple.mlir @@ -0,0 +1,39 @@ +// RUN: p4mlir-opt %s | FileCheck %s + +!b16i = !p4hir.bit<16> +!b32i = !p4hir.bit<32> +#int10_b32i = #p4hir.int<10> : !b32i +#int0_b32i = #p4hir.int<0> : !b32i +#int1_b32i = #p4hir.int<1> : !b32i +#false = #p4hir.bool : !p4hir.bool +#int12_b16i = #p4hir.int<12> : !b16i + +// CHECK: module +module { + %t = p4hir.const ["t"] #p4hir.aggregate<[#int0_b32i, #int1_b32i]> : tuple + + p4hir.func action @test(%arg0: !p4hir.ref {p4hir.dir = #p4hir}) { + %c10_b32i = p4hir.const #int10_b32i + %c12_b16i = p4hir.const #int12_b16i + %tuple = p4hir.tuple (%c10_b32i, %c12_b16i) : tuple + %x_0 = p4hir.variable ["x", init] : > + p4hir.assign %tuple, %x_0 : > + %val_4 = p4hir.read %x_0 : > + %t1 = p4hir.tuple_extract %val_4[1] : tuple + p4hir.assign %t1, %arg0 : + + p4hir.return + } + + p4hir.func action @test2() { + %c10_b32i = p4hir.const #int10_b32i + %false = p4hir.const #false + %tuple = p4hir.tuple (%c10_b32i, %false) : tuple + %x_0 = p4hir.variable ["x", init] : > + p4hir.assign %tuple, %x_0 : > + %y = p4hir.variable ["y"] : > + %val = p4hir.read %x_0 : > + p4hir.assign %val, %y : > + p4hir.return + } +} diff --git a/test/Dialect/P4HIR/types.mlir b/test/Dialect/P4HIR/types.mlir index 26ae41e..b755b04 100644 --- a/test/Dialect/P4HIR/types.mlir +++ b/test/Dialect/P4HIR/types.mlir @@ -27,6 +27,8 @@ #valid = #p4hir #invalid = #p4hir +!tuple = tuple + // No need to check stuff. If it parses, it's fine. // CHECK: module module { diff --git a/test/Translate/Ops/tuple.p4 b/test/Translate/Ops/tuple.p4 new file mode 100644 index 0000000..53454c0 --- /dev/null +++ b/test/Translate/Ops/tuple.p4 @@ -0,0 +1,37 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +struct S { + bit<32> f; + bool s; +} + +const S x = { 42, false }; + +const tuple, bit<32>> t = { 0, 1 }; +const bit<32> f = t[0]; + +// CHECK: p4hir.const ["t"] #p4hir.aggregate<[#int0_b32i, #int1_b32i]> : tuple +// CHECK: p4hir.const ["f"] #int0_b32i + +// CHECK-LABEL: p4hir.func action @test +action test(out bit<16> r) { + // CHECK: %[[c10_b32i:.*]] = p4hir.const #int10_b32i + // CHECK: %[[c12_b16i:.*]] = p4hir.const #int12_b16i + // CHECK: %[[tuple:.*]] = p4hir.tuple (%[[c10_b32i]], %[[c12_b16i]]) : tuple + tuple, bit<16>> x = { 10, 12 }; + // CHECK: %[[x_0:.*]] = p4hir.variable ["x", init] : > + // CHECK: p4hir.if + // CHECK: %[[val_4:.*]] = p4hir.read %[[x_0]] : > + // CHECK: p4hir.tuple_extract %[[val_4]][1] : tuple + if (x == { 10, 12 }) + r = x[1]; + else + r = (bit<16>)x[0]; +} + +typedef tuple, bool> pair; +action test2() { + pair x = { 10, false }; + tuple, bool> y; + y = x; +} diff --git a/tools/p4mlir-translate/translate.cpp b/tools/p4mlir-translate/translate.cpp index 70535b3..9d11fca 100644 --- a/tools/p4mlir-translate/translate.cpp +++ b/tools/p4mlir-translate/translate.cpp @@ -33,6 +33,7 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Types.h" @@ -137,6 +138,7 @@ class P4TypeConverter : public P4::Inspector { bool preorder(const P4::IR::Type_Error *e) override; bool preorder(const P4::IR::Type_SerEnum *se) override; bool preorder(const P4::IR::Type_Header *h) override; + bool preorder(const P4::IR::Type_BaseList *l) override; // covers both Type_Tuple and Type_List mlir::Type getType() const { return type; } bool setType(const P4::IR::Type *type, mlir::Type mlirType); @@ -172,6 +174,17 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { return builder.create(loc, P4HIR::BoolAttr::get(context(), boolType, value)); } + mlir::TypedAttr getTypedConstant(mlir::Type type, mlir::Attribute constant) { + if (mlir::isa(type)) return mlir::cast(constant); + + if (mlir::isa(type)) + return mlir::cast(constant); + + if (mlir::isa(type)) return mlir::cast(constant); + + return mlir::cast(constant); + } + mlir::Value getValidHeaderConstant(mlir::Location loc, P4HIR::ValidityBit valid = P4HIR::ValidityBit::Valid) { return builder.create(loc, P4HIR::ValidityBitAttr::get(context(), valid)); @@ -365,6 +378,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { HANDLE_IN_POSTORDER(Cast) HANDLE_IN_POSTORDER(Declaration_Variable) HANDLE_IN_POSTORDER(ReturnStatement) + HANDLE_IN_POSTORDER(ArrayIndex) #undef HANDLE_IN_POSTORDER @@ -383,6 +397,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { bool preorder(const P4::IR::MethodCallExpression *mce) override; bool preorder(const P4::IR::StructExpression *str) override; + bool preorder(const P4::IR::ListExpression *lst) override; bool preorder(const P4::IR::Member *m) override; bool preorder(const P4::IR::Equ *) override; bool preorder(const P4::IR::Neq *) override; @@ -554,6 +569,19 @@ bool P4TypeConverter::preorder(const P4::IR::Type_SerEnum *type) { return setType(type, mlirType); } +bool P4TypeConverter::preorder(const P4::IR::Type_BaseList *type) { + if ((this->type = converter.findType(type))) return false; + + ConversionTracer trace("TypeConverting ", type); + llvm::SmallVector fields; + for (const auto *field : type->components) { + fields.push_back(convert(field)); + } + + auto mlirType = mlir::TupleType::get(converter.context(), fields); + return setType(type, mlirType); +} + bool P4TypeConverter::setType(const P4::IR::Type *type, mlir::Type mlirType) { this->type = mlirType; converter.setType(type, mlirType); @@ -652,6 +680,12 @@ mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression } } } + if (const auto *lst = expr->to()) { + auto type = getOrCreateType(lst->type); + llvm::SmallVector fields; + for (const auto *field : lst->components) fields.push_back(getOrCreateConstantExpr(field)); + return setConstantExpr(expr, P4HIR::AggAttr::get(type, builder.getArrayAttr(fields))); + } if (const auto *str = expr->to()) { auto type = getOrCreateType(str->type); llvm::SmallVector fields; @@ -659,6 +693,14 @@ mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression fields.push_back(getOrCreateConstantExpr(field->expression)); return setConstantExpr(expr, P4HIR::AggAttr::get(type, builder.getArrayAttr(fields))); } + if (const auto *arr = expr->to()) { + auto base = mlir::cast(getOrCreateConstantExpr(arr->left)); + auto idx = mlir::cast(getOrCreateConstantExpr(arr->right)); + + auto field = base.getFields()[idx.getUInt()]; + auto fieldType = getOrCreateType(arr->type); + return setConstantExpr(expr, getTypedConstant(fieldType, field)); + } if (const auto *m = expr->to()) { if (const auto *typeNameExpr = m->expr->to()) { auto baseType = getOrCreateType(typeNameExpr->typeName); @@ -679,17 +721,7 @@ mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression auto field = base.getFields()[*maybeIdx]; auto fieldType = structType.getFieldType(m->member.string_view()); - // TODO: We'd likely would want to convert this to some kind of interface, - if (mlir::isa(fieldType)) - return setConstantExpr(expr, mlir::cast(field)); - - if (mlir::isa(fieldType)) - return setConstantExpr(expr, mlir::cast(field)); - - if (mlir::isa(fieldType)) - return setConstantExpr(expr, mlir::cast(field)); - - return setConstantExpr(expr, mlir::cast(field)); + return setConstantExpr(expr, getTypedConstant(fieldType, field)); } else BUG("invalid member reference %1%", m); } @@ -1305,6 +1337,33 @@ bool P4HIRConverter::preorder(const P4::IR::StructExpression *str) { return false; } +bool P4HIRConverter::preorder(const P4::IR::ListExpression *lst) { + auto type = getOrCreateType(lst->type); + + auto loc = getLoc(builder, lst); + llvm::SmallVector fields; + for (const auto *field : lst->components) { + visit(field); + fields.push_back(getValue(field)); + } + + setValue(lst, builder.create(loc, type, fields).getResult()); + + return false; +} + +void P4HIRConverter::postorder(const P4::IR::ArrayIndex *arr) { + auto lhs = getValue(arr->left); + auto loc = getLoc(builder, arr); + if (mlir::isa(lhs.getType())) { + auto idx = mlir::cast(getOrCreateConstantExpr(arr->right)); + setValue(arr, builder.create(loc, lhs, idx).getResult()); + return; + } + + BUG("cannot handle this array yet: %1%", arr); +} + } // namespace namespace P4::P4MLIR {