diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td index 685de0e..e9acbc6 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td @@ -252,7 +252,7 @@ def StructType : P4HIR_Type<"Struct", "struct", [ } //===----------------------------------------------------------------------===// -// EnumType +// EnumType & SerEnumType //===----------------------------------------------------------------------===// def EnumType : P4HIR_Type<"Enum", "enum", []> { let summary = "enum type"; @@ -263,29 +263,65 @@ def EnumType : P4HIR_Type<"Enum", "enum", []> { let hasCustomAssemblyFormat = 1; - let parameters = ( - ins StringRefParameter<"enum name">:$name, "mlir::ArrayAttr":$fields); + let parameters = (ins StringRefParameter<"enum name">:$name, + "mlir::ArrayAttr":$fields); let extraClassDeclaration = [{ /// Returns true if the requested field is part of this enum - bool contains(mlir::StringRef field); + bool contains(mlir::StringRef field) { return indexOf(field).has_value(); } /// Returns the index of the requested field, or a nullopt if the field is - // not part of this enum. + /// not part of this enum. std::optional indexOf(mlir::StringRef field); }]; } +def SerEnumType : P4HIR_Type<"SerEnum", "ser.enum", []> { + let summary = "serializable enum type"; + let description = [{ + Represents an enumeration of values backed by some integer value + !p4hir.ser.enum<"name", !p4hir.bit<32>, Case1 : 42, Case2 : 0> + }]; + + let hasCustomAssemblyFormat = 1; + + let parameters = (ins StringRefParameter<"enum name">:$name, + "P4HIR::BitsType":$type, "mlir::DictionaryAttr":$fields); + + let builders = [ + TypeBuilderWithInferredContext<(ins "llvm::StringRef":$name, + "P4HIR::BitsType":$type, "mlir::DictionaryAttr":$fields), [{ + return $_get(type.getContext(), name, type, fields); + }]>, + TypeBuilderWithInferredContext<(ins "llvm::StringRef":$name, + "P4HIR::BitsType":$type, "llvm::ArrayRef":$fields), [{ + return $_get(type.getContext(), name, type, + DictionaryAttr::get(type.getContext(), fields)); + }]> + + ]; + + let extraClassDeclaration = [{ + /// Returns true if the requested field is part of this enum + bool contains(mlir::StringRef field) { return getFields().contains(field); } + + /// Returns the underlying value of the requested field. Must be BitsAttr. + mlir::Attribute valueOf(mlir::StringRef field) { return getFields().get(field); } + }]; +} + //===----------------------------------------------------------------------===// // P4HIR type constraints. //===----------------------------------------------------------------------===// def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, - EnumType, + EnumType, SerEnumType, DontcareType, ErrorType, UnknownType]> {} def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {} -def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, EnumType]> {} +def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, + EnumType, SerEnumType]> {} +def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>; /// A ref type with the specified constraints on the nested type. class SpecificRefType : ConfinedType(attrType)) { - if (!mlir::isa(opType)) + if (!mlir::isa(opType)) return op->emitOpError("result type (") << opType << ") is not an enum type"; return success(); @@ -83,8 +83,13 @@ void P4HIR::ConstOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { } else if (auto enumCst = mlir::dyn_cast(getValue())) { llvm::SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); - specialName << mlir::cast(enumCst.getType()).getName() << '_' - << enumCst.getField().getValue(); + if (auto enumType = mlir::dyn_cast(enumCst.getType())) + specialName << enumType.getName() << '_' << enumCst.getField().getValue(); + else { + specialName << mlir::cast(enumCst.getType()).getName() << '_' + << enumCst.getField().getValue(); + } + setNameFn(getResult(), specialName.str()); } else { setNameFn(getResult(), "cst"); @@ -856,6 +861,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { return AliasResult::OverridableAlias; } + if (auto serEnumType = mlir::dyn_cast(type)) { + os << serEnumType.getName(); + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; } @@ -881,8 +891,12 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { } if (auto enumFieldAttr = mlir::dyn_cast(attr)) { - os << mlir::cast(enumFieldAttr.getType()).getName() << "_" - << enumFieldAttr.getField().getValue(); + if (auto enumType = mlir::dyn_cast(enumFieldAttr.getType())) + os << enumType.getName() << "_" << enumFieldAttr.getField().getValue(); + else + os << mlir::cast(enumFieldAttr.getType()).getName() << "_" + << enumFieldAttr.getField().getValue(); + return AliasResult::FinalAlias; } diff --git a/lib/Dialect/P4HIR/P4HIR_Types.cpp b/lib/Dialect/P4HIR/P4HIR_Types.cpp index 1abb14d..90d76be 100644 --- a/lib/Dialect/P4HIR/P4HIR_Types.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Types.cpp @@ -2,8 +2,10 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.h" #include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h" #include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.h" @@ -347,14 +349,56 @@ void EnumType::print(AsmPrinter &p) const { p << ">"; } -bool EnumType::contains(mlir::StringRef field) { return indexOf(field).has_value(); } - std::optional EnumType::indexOf(mlir::StringRef field) { for (auto it : llvm::enumerate(getFields())) if (mlir::cast(it.value()).getValue() == field) return it.index(); return {}; } +void SerEnumType::print(AsmPrinter &p) const { + auto fields = getFields(); + p << '<'; + p.printString(getName()); + p << ", "; + p.printType(getType()); + if (!fields.empty()) p << ", "; + llvm::interleaveComma(fields, p, [&](NamedAttribute enumerator) { + p.printKeywordOrString(enumerator.getName()); + p << " : "; + p.printAttribute(enumerator.getValue()); + }); + p << ">"; +} + +Type SerEnumType::parse(AsmParser &p) { + std::string name; + llvm::SmallVector fields; + P4HIR::BitsType type; + + // Parse "(type) || p.parseComma()) + return {}; + + if (p.parseCommaSeparatedList([&]() { + StringRef caseName; + P4HIR::IntAttr attr; + // Parse fields "name : #value" + if (p.parseKeyword(&caseName) || p.parseColon() || + p.parseCustomAttributeWithFallback(attr)) + return failure(); + + fields.emplace_back(StringAttr::get(p.getContext(), caseName), attr); + return success(); + })) + return {}; + + // Parse closing > + if (p.parseGreater()) return {}; + + return get(name, type, fields); +} + void P4HIRDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST diff --git a/test/Dialect/P4HIR/types.mlir b/test/Dialect/P4HIR/types.mlir index 7c449ba..110bf97 100644 --- a/test/Dialect/P4HIR/types.mlir +++ b/test/Dialect/P4HIR/types.mlir @@ -15,6 +15,13 @@ !Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades> +#b1 = #p4hir.int<1> : !bit42 +#b2 = #p4hir.int<2> : !bit42 +#b3 = #p4hir.int<3> : !bit42 +#b4 = #p4hir.int<4> : !bit42 + +!SuitsSerializable = !p4hir.ser.enum<"Suits", !bit42, Clubs : #b1, Diamonds : #b2, Hearths : #b3, Spades : #b4> + // No need to check stuff. If it parses, it's fine. // CHECK: module module { diff --git a/test/Translate/Ops/serenum.p4 b/test/Translate/Ops/serenum.p4 new file mode 100644 index 0000000..0d6640e --- /dev/null +++ b/test/Translate/Ops/serenum.p4 @@ -0,0 +1,35 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +enum bit<16> EthTypes { + IPv4 = 0x0800, + ARP = 0x0806, + RARP = 0x8035, + EtherTalk = 0x809B, + VLAN = 0x8100, + IPX = 0x8137, + IPv6 = 0x86DD +} + +struct Ethernet { + bit<48> src; + bit<48> dest; + EthTypes type; +} + +struct Headers { + Ethernet eth; +} + +// CHECK: !EthTypes = !p4hir.ser.enum<"EthTypes", !b16i, ARP : #int2054_b16i, EtherTalk : #int-32613_b16i, IPX : #int-32457_b16i, IPv4 : #int2048_b16i, IPv6 : #int-31011_b16i, RARP : #int-32715_b16i, VLAN : #int-32512_b16i> +// CHECK: !Ethernet = !p4hir.struct<"Ethernet", src: !b48i, dest: !b48i, type: !EthTypes> +// CHECK: #EthTypes_IPv4_ = #p4hir.enum.field : !EthTypes +// CHECK-LABEL: module + +// CHECK-LABEL: p4hir.func action @test(%arg0: !p4hir.ref +// CHECK: p4hir.const #EthTypes_IPv4_ +action test(inout Headers h) { + if (h.eth.type == EthTypes.IPv4) + h.eth.src = h.eth.dest; + else + h.eth.type = (EthTypes)(bit<16>)0; +} diff --git a/tools/p4mlir-translate/translate.cpp b/tools/p4mlir-translate/translate.cpp index 09b47ec..8dbf5c3 100644 --- a/tools/p4mlir-translate/translate.cpp +++ b/tools/p4mlir-translate/translate.cpp @@ -3,8 +3,6 @@ #include #include -#include "ir/ir-generated.h" - #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcovered-switch-default" #include "frontends/common/resolveReferences/resolveReferences.h" @@ -136,6 +134,7 @@ class P4TypeConverter : public P4::Inspector { bool preorder(const P4::IR::Type_Void *v) override; bool preorder(const P4::IR::Type_Struct *s) override; bool preorder(const P4::IR::Type_Enum *e) override; + bool preorder(const P4::IR::Type_SerEnum *se) override; mlir::Type getType() const { return type; } bool setType(const P4::IR::Type *type, mlir::Type mlirType); @@ -484,12 +483,29 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Enum *type) { if ((this->type = converter.findType(type))) return false; ConversionTracer trace("TypeConverting ", type); - llvm::SmallVector fields; + llvm::SmallVector cases; for (const auto *field : type->members) { - fields.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view())); + cases.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view())); } auto mlirType = P4HIR::EnumType::get(converter.context(), type->name.string_view(), - mlir::ArrayAttr::get(converter.context(), fields)); + mlir::ArrayAttr::get(converter.context(), cases)); + return setType(type, mlirType); +} + +bool P4TypeConverter::preorder(const P4::IR::Type_SerEnum *type) { + if ((this->type = converter.findType(type))) return false; + + ConversionTracer trace("TypeConverting ", type); + llvm::SmallVector cases; + + auto enumType = mlir::cast(convert(type->type)); + for (const auto *field : type->members) { + auto value = mlir::cast(converter.getOrCreateConstantExpr(field->value)); + cases.emplace_back(mlir::StringAttr::get(converter.context(), field->name.string_view()), + value); + } + + auto mlirType = P4HIR::SerEnumType::get(type->name.string_view(), enumType, cases); return setType(type, mlirType); } @@ -1084,10 +1100,13 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { bool P4HIRConverter::preorder(const P4::IR::Member *m) { // This is just enum constant if (const auto *typeNameExpr = m->expr->to()) { - auto enumType = mlir::cast(getOrCreateType(typeNameExpr->typeName)); + auto type = getOrCreateType(typeNameExpr->typeName); + BUG_CHECK((mlir::isa(type)), + "unexpected type for expression %1%", typeNameExpr); + setValue(m, builder.create( getLoc(builder, m), - P4HIR::EnumFieldAttr::get(enumType, m->member.name.string_view()))); + P4HIR::EnumFieldAttr::get(type, m->member.name.string_view()))); return false; }