From ae0103cecb844e604426e285298a997005ebe91b Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Tue, 11 Feb 2025 01:00:26 -0800 Subject: [PATCH 1/3] Add support for ordinary enums Signed-off-by: Anton Korobeynikov --- include/p4mlir/Dialect/P4HIR/CMakeLists.txt | 12 ++++-- include/p4mlir/Dialect/P4HIR/P4HIR.td | 1 + include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td | 28 ++++++++++++ include/p4mlir/Dialect/P4HIR/P4HIR_Types.td | 29 ++++++++++++- lib/Dialect/P4HIR/P4HIR_Attrs.cpp | 30 +++++++++++++ lib/Dialect/P4HIR/P4HIR_Ops.cpp | 24 +++++++++++ lib/Dialect/P4HIR/P4HIR_Types.cpp | 41 ++++++++++++++++++ test/Dialect/P4HIR/enum.mlir | 11 +++++ test/Dialect/P4HIR/types.mlir | 2 + test/Translate/Ops/enum.p4 | 25 +++++++++++ tools/p4mlir-translate/translate.cpp | 48 +++++++++++++++++++-- 11 files changed, 242 insertions(+), 9 deletions(-) create mode 100644 test/Dialect/P4HIR/enum.mlir create mode 100644 test/Translate/Ops/enum.p4 diff --git a/include/p4mlir/Dialect/P4HIR/CMakeLists.txt b/include/p4mlir/Dialect/P4HIR/CMakeLists.txt index d7ca937..1882c93 100644 --- a/include/p4mlir/Dialect/P4HIR/CMakeLists.txt +++ b/include/p4mlir/Dialect/P4HIR/CMakeLists.txt @@ -5,15 +5,19 @@ mlir_tablegen(P4HIR_Ops.h.inc -gen-op-decls) mlir_tablegen(P4HIR_Ops.cpp.inc -gen-op-defs) mlir_tablegen(P4HIR_Types.h.inc -gen-typedef-decls -typedefs-dialect=p4hir) mlir_tablegen(P4HIR_Types.cpp.inc -gen-typedef-defs -typedefs-dialect=p4hir) +add_public_tablegen_target(P4MLIR_P4HIR_IncGen) +add_dependencies(mlir-headers P4MLIR_P4HIR_IncGen) # Generate extra headers for custom enum and attrs. -mlir_tablegen(P4HIR_OpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(P4HIR_OpsEnums.cpp.inc -gen-enum-defs) mlir_tablegen(P4HIR_Attrs.h.inc -gen-attrdef-decls -attrdefs-dialect=p4hir) mlir_tablegen(P4HIR_Attrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=p4hir) +add_public_tablegen_target(P4MLIR_P4HIR_AttrIncGen) +add_dependencies(mlir-headers P4MLIR_P4HIR_AttrIncGen) -add_public_tablegen_target(P4MLIR_P4HIR_IncGen) -add_dependencies(mlir-headers P4MLIR_P4HIR_IncGen) +mlir_tablegen(P4HIR_OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(P4HIR_OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(P4MLIR_P4HIR_EnumIncGen) +add_dependencies(mlir-headers P4MLIR_P4HIR_EnumIncGen) set(LLVM_TARGET_DEFINITIONS P4HIR_TypeInterfaces.td) mlir_tablegen(P4HIR_TypeInterfaces.h.inc -gen-type-interface-decls) diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR.td b/include/p4mlir/Dialect/P4HIR/P4HIR.td index b4e45c1..1018154 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR.td @@ -2,6 +2,7 @@ #define P4MLIR_DIALECT_P4HIR_P4HIR_TD include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.td" +include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.td" include "p4mlir/Dialect/P4HIR/P4HIR_Ops.td" include "p4mlir/Dialect/P4HIR/P4HIR_Types.td" include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.td" diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td index b5c37f8..d9f2a32 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td @@ -87,6 +87,34 @@ def P4HIR_AggAttr : P4HIR_Attr<"Agg", "aggregate", [TypedAttrInterface]> { } +//===----------------------------------------------------------------------===// +// EnumFieldAttr +//===----------------------------------------------------------------------===// +// An attribute to indicate an enumeration value. +def P4HIR_EnumFieldAttr : P4HIR_Attr<"EnumField", "enum.field", [TypedAttrInterface]> { + let summary = "Enumeration field attribute"; + let description = [{ + This attribute represents a field of an enumeration. + + Examples: + ```mlir + #p4hir.enum.field> : !p4hir.enum<"name", A, B, C> + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"">:$type, "::mlir::StringAttr":$field); + + // Force all clients to go through custom builder so we can check + // whether the requested enum value is part of the provided enum type. + let skipDefaultBuilders = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "mlir::StringAttr": $value)>, + AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "mlir::StringRef": $value), [{ + return $_get(type.getContext(), type, mlir::StringAttr::get(type.getContext(), value)); + }]> + ]; +} //===----------------------------------------------------------------------===// // ParamDirAttr diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td index 4002e45..93b04b6 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td @@ -251,15 +251,42 @@ def StructType : P4HIR_Type<"Struct", "struct", [ }]; } +//===----------------------------------------------------------------------===// +// EnumType +//===----------------------------------------------------------------------===// +def EnumType : P4HIR_Type<"Enum", "enum", []> { + let summary = "enum type"; + let description = [{ + Represents an enumeration of values + !p4hir.enum<"name", Case1, Case2> + }]; + + let hasCustomAssemblyFormat = 1; + + let parameters = ( + ins StringRefParameter<"enum name">:$name, "mlir::ArrayAttr":$fields); + + let extraClassDeclaration = [{ + /// Returns true if the requested field is part of this enum + bool contains(mlir::StringRef field); + + /// Returns the index of the requested field, or a nullopt if the field is + // not part of this enum. + std::optional indexOf(mlir::StringRef field); + }]; +} + + //===----------------------------------------------------------------------===// // P4HIR type constraints. //===----------------------------------------------------------------------===// def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, + EnumType, DontcareType, ErrorType, UnknownType]> {} def AnyIntP4Type : AnyTypeOf<[BitsType, InfIntType]> {} def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {} -def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType]> {} +def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, EnumType]> {} /// A ref type with the specified constraints on the nested type. class SpecificRefType : ConfinedType emitError, Type return success(); } +Attribute EnumFieldAttr::parse(AsmParser &p, Type) { + StringRef field; + P4HIR::EnumType type; + if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() || + p.parseCustomTypeWithFallback(type) || p.parseGreater()) + return {}; + + return EnumFieldAttr::get(type, field); +} + +void EnumFieldAttr::print(AsmPrinter &p) const { + p << "<" << getField().getValue() << ", "; + p.printType(getType()); + p << ">"; +} + +EnumFieldAttr EnumFieldAttr::get(mlir::Type type, StringAttr value) { + EnumType enumType = llvm::dyn_cast(type); + if (!enumType) return nullptr; + + // Check whether the provided value is a member of the enum type. + if (!enumType.contains(value.getValue())) { + // emitError() << "enum value '" << value.getValue() + // << "' is not a member of enum type " << enumType; + return nullptr; + } + + return Base::get(value.getContext(), type, value); +} + void P4HIRDialect::registerAttributes() { addAttributes< #define GET_ATTRDEF_LIST diff --git a/lib/Dialect/P4HIR/P4HIR_Ops.cpp b/lib/Dialect/P4HIR/P4HIR_Ops.cpp index cfeae8d..d8d8304 100644 --- a/lib/Dialect/P4HIR/P4HIR_Ops.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Ops.cpp @@ -43,6 +43,13 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, return success(); } + if (mlir::isa(attrType)) { + if (!mlir::isa(opType)) + return op->emitOpError("result type (") << opType << ") is not an enum type"; + + return success(); + } + assert(isa(attrType) && "expected typed attribute"); return op->emitOpError("constant with type ") << cast(attrType).getType() << " not supported"; @@ -73,6 +80,12 @@ void P4HIR::ConstOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { setNameFn(getResult(), specialName.str()); } else if (auto boolCst = mlir::dyn_cast(getValue())) { setNameFn(getResult(), boolCst.getValue() ? "true" : "false"); + } else if (auto enumCst = mlir::dyn_cast(getValue())) { + llvm::SmallString<32> specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << mlir::cast(enumCst.getType()).getName() << '_' + << enumCst.getField().getValue(); + setNameFn(getResult(), specialName.str()); } else { setNameFn(getResult(), "cst"); } @@ -859,6 +872,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { return AliasResult::OverridableAlias; } + if (auto enumType = mlir::dyn_cast(type)) { + os << enumType.getName(); + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; } @@ -883,6 +901,12 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { return AliasResult::FinalAlias; } + if (auto enumFieldAttr = mlir::dyn_cast(attr)) { + os << mlir::cast(enumFieldAttr.getType()).getName() << "_" + << enumFieldAttr.getField().getValue(); + return AliasResult::FinalAlias; + } + return AliasResult::NoAlias; } }; diff --git a/lib/Dialect/P4HIR/P4HIR_Types.cpp b/lib/Dialect/P4HIR/P4HIR_Types.cpp index c79a0b0..31d42f3 100644 --- a/lib/Dialect/P4HIR/P4HIR_Types.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Types.cpp @@ -315,6 +315,47 @@ void StructType::getInnerTypes(SmallVectorImpl &types) { for (const auto &field : getElements()) types.push_back(field.type); } +Type EnumType::parse(AsmParser &p) { + std::string name; + llvm::SmallVector fields; + bool parsedName = false; + if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() { + // First, try to parse name + if (!parsedName) { + if (p.parseKeywordOrString(&name)) return failure(); + parsedName = true; + return success(); + } + + StringRef caseName; + if (p.parseKeyword(&caseName)) return failure(); + fields.push_back(StringAttr::get(p.getContext(), name)); + return success(); + })) + return {}; + + return get(p.getContext(), name, ArrayAttr::get(p.getContext(), fields)); +} + +void EnumType::print(AsmPrinter &p) const { + auto fields = getFields(); + p << '<'; + p.printString(getName()); + if (!fields.empty()) p << ", "; + llvm::interleaveComma(fields, p, [&](Attribute enumerator) { + p << mlir::cast(enumerator).getValue(); + }); + p << ">"; +} + +bool EnumType::contains(mlir::StringRef field) { return indexOf(field).has_value(); } + +std::optional EnumType::indexOf(mlir::StringRef field) { + for (auto it : llvm::enumerate(getFields())) + if (mlir::cast(it.value()).getValue() == field) return it.index(); + return {}; +} + void P4HIRDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST diff --git a/test/Dialect/P4HIR/enum.mlir b/test/Dialect/P4HIR/enum.mlir new file mode 100644 index 0000000..24b1650 --- /dev/null +++ b/test/Dialect/P4HIR/enum.mlir @@ -0,0 +1,11 @@ +// RUN: p4mlir-opt %s | FileCheck %s + +!Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades> + +#Suits_Clubs = #p4hir.enum.field : !Suits +#Suits_Diamonds = #p4hir.enum.field : !Suits + +// CHECK: module +module { + %Suits_Diamonds = p4hir.const #Suits_Diamonds +} diff --git a/test/Dialect/P4HIR/types.mlir b/test/Dialect/P4HIR/types.mlir index ac45688..7c449ba 100644 --- a/test/Dialect/P4HIR/types.mlir +++ b/test/Dialect/P4HIR/types.mlir @@ -13,6 +13,8 @@ !struct = !p4hir.struct<"struct_name", boolfield : !p4hir.bool, bitfield : !bit42> !nested_struct = !p4hir.struct<"another_name", neststructfield : !struct, bitfield : !bit42> +!Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades> + // No need to check stuff. If it parses, it's fine. // CHECK: module module { diff --git a/test/Translate/Ops/enum.p4 b/test/Translate/Ops/enum.p4 new file mode 100644 index 0000000..f9a5815 --- /dev/null +++ b/test/Translate/Ops/enum.p4 @@ -0,0 +1,25 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +// CHECK: !Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades> +// CHECK: #Suits_Diamonds = #p4hir.enum.field : !Suits +// CHECK: #Suits_Hearths = #p4hir.enum.field : !Suits +// CHECK: #Suits_Spades = #p4hir.enum.field : !Suits + +enum Suits { Clubs, Diamonds, Hearths, Spades } + +// CHECK-LABEL: module +// CHECK: p4hir.const ["cEnum"] #Suits_Hearths +const Suits cEnum = Suits.Hearths; + +// CHECK-LABEL: p4hir.func action @test +action test(inout bit<42> a, Suits b) { + // CHECK: %Suits_Diamonds = p4hir.const #Suits_Diamonds + // CHECK: %d = p4hir.variable ["d", init] : + // CHECK: p4hir.assign %Suits_Diamonds, %d : + Suits d = Suits.Diamonds; + if (b == Suits.Spades) { + a = a + 1; + } else if (b == d) { + a = a - 1; + } +} diff --git a/tools/p4mlir-translate/translate.cpp b/tools/p4mlir-translate/translate.cpp index ca40764..f0a1923 100644 --- a/tools/p4mlir-translate/translate.cpp +++ b/tools/p4mlir-translate/translate.cpp @@ -3,6 +3,8 @@ #include #include +#include "ir/ir-generated.h" + #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcovered-switch-default" #include "frontends/common/resolveReferences/resolveReferences.h" @@ -132,7 +134,8 @@ class P4TypeConverter : public P4::Inspector { bool preorder(const P4::IR::Type_Action *act) override; bool preorder(const P4::IR::Type_Method *m) override; bool preorder(const P4::IR::Type_Void *v) override; - bool preorder(const P4::IR::Type_Struct *v) override; + bool preorder(const P4::IR::Type_Struct *s) override; + bool preorder(const P4::IR::Type_Enum *e) override; mlir::Type getType() const { return type; } bool setType(const P4::IR::Type *type, mlir::Type mlirType); @@ -353,10 +356,11 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { HANDLE_IN_POSTORDER(Cast) HANDLE_IN_POSTORDER(Declaration_Variable) HANDLE_IN_POSTORDER(ReturnStatement) - HANDLE_IN_POSTORDER(Member) #undef HANDLE_IN_POSTORDER + void postorder(const P4::IR::Member *m) override; + bool preorder(const P4::IR::Declaration_Constant *decl) override; bool preorder(const P4::IR::AssignmentStatement *assign) override; bool preorder(const P4::IR::Mux *mux) override; @@ -370,6 +374,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { bool preorder(const P4::IR::MethodCallExpression *mce) override; bool preorder(const P4::IR::StructExpression *str) override; + bool preorder(const P4::IR::Member *m) override; mlir::Value emitUnOp(const P4::IR::Operation_Unary *unop, P4HIR::UnaryOpKind kind); mlir::Value emitBinOp(const P4::IR::Operation_Binary *binop, P4HIR::BinOpKind kind); @@ -479,6 +484,19 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Struct *type) { return setType(type, mlirType); } +bool P4TypeConverter::preorder(const P4::IR::Type_Enum *type) { + if ((this->type = converter.findType(type))) return false; + + ConversionTracer trace("TypeConverting ", type); + llvm::SmallVector fields; + for (const auto *field : type->members) { + fields.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view())); + } + auto mlirType = P4HIR::EnumType::get(converter.context(), type->name.string_view(), + mlir::ArrayAttr::get(converter.context(), fields)); + return setType(type, mlirType); +} + bool P4TypeConverter::setType(const P4::IR::Type *type, mlir::Type mlirType) { this->type = mlirType; converter.setType(type, mlirType); @@ -585,6 +603,12 @@ mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression return setConstantExpr(expr, P4HIR::AggAttr::get(type, builder.getArrayAttr(fields))); } if (const auto *m = expr->to()) { + if (const auto *typeNameExpr = m->expr->to()) { + auto baseType = mlir::cast(getOrCreateType(typeNameExpr->typeName)); + return setConstantExpr(expr, + P4HIR::EnumFieldAttr::get(baseType, m->member.string_view())); + } + auto base = mlir::cast(getOrCreateConstantExpr(m->expr)); auto structType = mlir::cast(base.getType()); @@ -999,8 +1023,9 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { auto ref = resolveReference(arg->expression); auto copyIn = b.create( loc, ref.getType(), - b.getStringAttr(llvm::Twine(params[idx]->name.string_view()) + - (dir == P4::IR::Direction::InOut ? "_inout_arg" : "_out_arg"))); + b.getStringAttr( + llvm::Twine(params[idx]->name.string_view()) + + (dir == P4::IR::Direction::InOut ? "_inout_arg" : "_out_arg"))); if (dir == P4::IR::Direction::InOut) { copyIn.setInit(true); @@ -1070,6 +1095,21 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { return false; } + +bool P4HIRConverter::preorder(const P4::IR::Member *m) { + // This is just enum constant + if (const auto *typeNameExpr = m->expr->to()) { + auto enumType = mlir::cast(getOrCreateType(typeNameExpr->typeName)); + setValue(m, builder.create( + getLoc(builder, m), + P4HIR::EnumFieldAttr::get(enumType, m->member.name.string_view()))); + return false; + } + + // Handle other members in postorder traversal + return true; +} + void P4HIRConverter::postorder(const P4::IR::Member *m) { // Resolve member rvalue expression to something we can reason about // TODO: Likely we can do similar things for the majority of struct-like From a13e2580c7ae802e5f893b036da1735ee833586c Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Tue, 11 Feb 2025 15:45:38 -0800 Subject: [PATCH 2/3] Add serializable enum Signed-off-by: Anton Korobeynikov --- include/p4mlir/Dialect/P4HIR/P4HIR_Types.td | 50 ++++++++++++++++++--- lib/Dialect/P4HIR/P4HIR_Ops.cpp | 24 +++++++--- lib/Dialect/P4HIR/P4HIR_Types.cpp | 48 +++++++++++++++++++- test/Dialect/P4HIR/types.mlir | 7 +++ test/Translate/Ops/serenum.p4 | 35 +++++++++++++++ tools/p4mlir-translate/translate.cpp | 33 +++++++++++--- 6 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 test/Translate/Ops/serenum.p4 diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td index 93b04b6..ae3f892 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td @@ -252,7 +252,7 @@ def StructType : P4HIR_Type<"Struct", "struct", [ } //===----------------------------------------------------------------------===// -// EnumType +// EnumType & SerEnumType //===----------------------------------------------------------------------===// def EnumType : P4HIR_Type<"Enum", "enum", []> { let summary = "enum type"; @@ -263,30 +263,66 @@ def EnumType : P4HIR_Type<"Enum", "enum", []> { let hasCustomAssemblyFormat = 1; - let parameters = ( - ins StringRefParameter<"enum name">:$name, "mlir::ArrayAttr":$fields); + let parameters = (ins StringRefParameter<"enum name">:$name, + "mlir::ArrayAttr":$fields); let extraClassDeclaration = [{ /// Returns true if the requested field is part of this enum - bool contains(mlir::StringRef field); + bool contains(mlir::StringRef field) { return indexOf(field).has_value(); } /// Returns the index of the requested field, or a nullopt if the field is - // not part of this enum. + /// not part of this enum. std::optional indexOf(mlir::StringRef field); }]; } +def SerEnumType : P4HIR_Type<"SerEnum", "ser.enum", []> { + let summary = "serializable enum type"; + let description = [{ + Represents an enumeration of values backed by some integer value + !p4hir.ser.enum<"name", !p4hir.bit<32>, Case1 : 42, Case2 : 0> + }]; + + let hasCustomAssemblyFormat = 1; + + let parameters = (ins StringRefParameter<"enum name">:$name, + "P4HIR::BitsType":$type, "mlir::DictionaryAttr":$fields); + + let builders = [ + TypeBuilderWithInferredContext<(ins "llvm::StringRef":$name, + "P4HIR::BitsType":$type, "mlir::DictionaryAttr":$fields), [{ + return $_get(type.getContext(), name, type, fields); + }]>, + TypeBuilderWithInferredContext<(ins "llvm::StringRef":$name, + "P4HIR::BitsType":$type, "llvm::ArrayRef":$fields), [{ + return $_get(type.getContext(), name, type, + DictionaryAttr::get(type.getContext(), fields)); + }]> + + ]; + + let extraClassDeclaration = [{ + /// Returns true if the requested field is part of this enum + bool contains(mlir::StringRef field) { return getFields().contains(field); } + + /// Returns the underlying value of the requested field. Must be BitsAttr. + mlir::Attribute valueOf(mlir::StringRef field) { return getFields().get(field); } + }]; +} + //===----------------------------------------------------------------------===// // P4HIR type constraints. //===----------------------------------------------------------------------===// def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, - EnumType, + EnumType, SerEnumType, DontcareType, ErrorType, UnknownType]> {} def AnyIntP4Type : AnyTypeOf<[BitsType, InfIntType]> {} def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {} -def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, EnumType]> {} +def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, + EnumType, SerEnumType]> {} +def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>; /// A ref type with the specified constraints on the nested type. class SpecificRefType : ConfinedType(attrType)) { - if (!mlir::isa(opType)) + if (!mlir::isa(opType)) return op->emitOpError("result type (") << opType << ") is not an enum type"; return success(); @@ -83,8 +83,13 @@ void P4HIR::ConstOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { } else if (auto enumCst = mlir::dyn_cast(getValue())) { llvm::SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); - specialName << mlir::cast(enumCst.getType()).getName() << '_' - << enumCst.getField().getValue(); + if (auto enumType = mlir::dyn_cast(enumCst.getType())) + specialName << enumType.getName() << '_' << enumCst.getField().getValue(); + else { + specialName << mlir::cast(enumCst.getType()).getName() << '_' + << enumCst.getField().getValue(); + } + setNameFn(getResult(), specialName.str()); } else { setNameFn(getResult(), "cst"); @@ -877,6 +882,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { return AliasResult::OverridableAlias; } + if (auto serEnumType = mlir::dyn_cast(type)) { + os << serEnumType.getName(); + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; } @@ -902,8 +912,12 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { } if (auto enumFieldAttr = mlir::dyn_cast(attr)) { - os << mlir::cast(enumFieldAttr.getType()).getName() << "_" - << enumFieldAttr.getField().getValue(); + if (auto enumType = mlir::dyn_cast(enumFieldAttr.getType())) + os << enumType.getName() << "_" << enumFieldAttr.getField().getValue(); + else + os << mlir::cast(enumFieldAttr.getType()).getName() << "_" + << enumFieldAttr.getField().getValue(); + return AliasResult::FinalAlias; } diff --git a/lib/Dialect/P4HIR/P4HIR_Types.cpp b/lib/Dialect/P4HIR/P4HIR_Types.cpp index 31d42f3..596225f 100644 --- a/lib/Dialect/P4HIR/P4HIR_Types.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Types.cpp @@ -2,8 +2,10 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.h" #include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h" #include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.h" @@ -348,14 +350,56 @@ void EnumType::print(AsmPrinter &p) const { p << ">"; } -bool EnumType::contains(mlir::StringRef field) { return indexOf(field).has_value(); } - std::optional EnumType::indexOf(mlir::StringRef field) { for (auto it : llvm::enumerate(getFields())) if (mlir::cast(it.value()).getValue() == field) return it.index(); return {}; } +void SerEnumType::print(AsmPrinter &p) const { + auto fields = getFields(); + p << '<'; + p.printString(getName()); + p << ", "; + p.printType(getType()); + if (!fields.empty()) p << ", "; + llvm::interleaveComma(fields, p, [&](NamedAttribute enumerator) { + p.printKeywordOrString(enumerator.getName()); + p << " : "; + p.printAttribute(enumerator.getValue()); + }); + p << ">"; +} + +Type SerEnumType::parse(AsmParser &p) { + std::string name; + llvm::SmallVector fields; + P4HIR::BitsType type; + + // Parse "(type) || p.parseComma()) + return {}; + + if (p.parseCommaSeparatedList([&]() { + StringRef caseName; + P4HIR::IntAttr attr; + // Parse fields "name : #value" + if (p.parseKeyword(&caseName) || p.parseColon() || + p.parseCustomAttributeWithFallback(attr)) + return failure(); + + fields.emplace_back(StringAttr::get(p.getContext(), caseName), attr); + return success(); + })) + return {}; + + // Parse closing > + if (p.parseGreater()) return {}; + + return get(name, type, fields); +} + void P4HIRDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST diff --git a/test/Dialect/P4HIR/types.mlir b/test/Dialect/P4HIR/types.mlir index 7c449ba..110bf97 100644 --- a/test/Dialect/P4HIR/types.mlir +++ b/test/Dialect/P4HIR/types.mlir @@ -15,6 +15,13 @@ !Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades> +#b1 = #p4hir.int<1> : !bit42 +#b2 = #p4hir.int<2> : !bit42 +#b3 = #p4hir.int<3> : !bit42 +#b4 = #p4hir.int<4> : !bit42 + +!SuitsSerializable = !p4hir.ser.enum<"Suits", !bit42, Clubs : #b1, Diamonds : #b2, Hearths : #b3, Spades : #b4> + // No need to check stuff. If it parses, it's fine. // CHECK: module module { diff --git a/test/Translate/Ops/serenum.p4 b/test/Translate/Ops/serenum.p4 new file mode 100644 index 0000000..0d6640e --- /dev/null +++ b/test/Translate/Ops/serenum.p4 @@ -0,0 +1,35 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +enum bit<16> EthTypes { + IPv4 = 0x0800, + ARP = 0x0806, + RARP = 0x8035, + EtherTalk = 0x809B, + VLAN = 0x8100, + IPX = 0x8137, + IPv6 = 0x86DD +} + +struct Ethernet { + bit<48> src; + bit<48> dest; + EthTypes type; +} + +struct Headers { + Ethernet eth; +} + +// CHECK: !EthTypes = !p4hir.ser.enum<"EthTypes", !b16i, ARP : #int2054_b16i, EtherTalk : #int-32613_b16i, IPX : #int-32457_b16i, IPv4 : #int2048_b16i, IPv6 : #int-31011_b16i, RARP : #int-32715_b16i, VLAN : #int-32512_b16i> +// CHECK: !Ethernet = !p4hir.struct<"Ethernet", src: !b48i, dest: !b48i, type: !EthTypes> +// CHECK: #EthTypes_IPv4_ = #p4hir.enum.field : !EthTypes +// CHECK-LABEL: module + +// CHECK-LABEL: p4hir.func action @test(%arg0: !p4hir.ref +// CHECK: p4hir.const #EthTypes_IPv4_ +action test(inout Headers h) { + if (h.eth.type == EthTypes.IPv4) + h.eth.src = h.eth.dest; + else + h.eth.type = (EthTypes)(bit<16>)0; +} diff --git a/tools/p4mlir-translate/translate.cpp b/tools/p4mlir-translate/translate.cpp index f0a1923..5b64bca 100644 --- a/tools/p4mlir-translate/translate.cpp +++ b/tools/p4mlir-translate/translate.cpp @@ -3,8 +3,6 @@ #include #include -#include "ir/ir-generated.h" - #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcovered-switch-default" #include "frontends/common/resolveReferences/resolveReferences.h" @@ -136,6 +134,7 @@ class P4TypeConverter : public P4::Inspector { bool preorder(const P4::IR::Type_Void *v) override; bool preorder(const P4::IR::Type_Struct *s) override; bool preorder(const P4::IR::Type_Enum *e) override; + bool preorder(const P4::IR::Type_SerEnum *se) override; mlir::Type getType() const { return type; } bool setType(const P4::IR::Type *type, mlir::Type mlirType); @@ -488,12 +487,29 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Enum *type) { if ((this->type = converter.findType(type))) return false; ConversionTracer trace("TypeConverting ", type); - llvm::SmallVector fields; + llvm::SmallVector cases; for (const auto *field : type->members) { - fields.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view())); + cases.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view())); } auto mlirType = P4HIR::EnumType::get(converter.context(), type->name.string_view(), - mlir::ArrayAttr::get(converter.context(), fields)); + mlir::ArrayAttr::get(converter.context(), cases)); + return setType(type, mlirType); +} + +bool P4TypeConverter::preorder(const P4::IR::Type_SerEnum *type) { + if ((this->type = converter.findType(type))) return false; + + ConversionTracer trace("TypeConverting ", type); + llvm::SmallVector cases; + + auto enumType = mlir::cast(convert(type->type)); + for (const auto *field : type->members) { + auto value = mlir::cast(converter.getOrCreateConstantExpr(field->value)); + cases.emplace_back(mlir::StringAttr::get(converter.context(), field->name.string_view()), + value); + } + + auto mlirType = P4HIR::SerEnumType::get(type->name.string_view(), enumType, cases); return setType(type, mlirType); } @@ -1099,10 +1115,13 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { bool P4HIRConverter::preorder(const P4::IR::Member *m) { // This is just enum constant if (const auto *typeNameExpr = m->expr->to()) { - auto enumType = mlir::cast(getOrCreateType(typeNameExpr->typeName)); + auto type = getOrCreateType(typeNameExpr->typeName); + BUG_CHECK((mlir::isa(type)), + "unexpected type for expression %1%", typeNameExpr); + setValue(m, builder.create( getLoc(builder, m), - P4HIR::EnumFieldAttr::get(enumType, m->member.name.string_view()))); + P4HIR::EnumFieldAttr::get(type, m->member.name.string_view()))); return false; } From 4299c9069f43b1838d64f0dff3a583e52842e7af Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Sun, 23 Feb 2025 16:18:52 -0800 Subject: [PATCH 3/3] Rename ser.enum => ser_enum and enum.field => enum_field Signed-off-by: Anton Korobeynikov --- include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td | 4 ++-- include/p4mlir/Dialect/P4HIR/P4HIR_Types.td | 4 ++-- test/Dialect/P4HIR/enum.mlir | 4 ++-- test/Dialect/P4HIR/types.mlir | 2 +- test/Translate/Ops/enum.p4 | 6 +++--- test/Translate/Ops/serenum.p4 | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td index d9f2a32..999971e 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td @@ -91,14 +91,14 @@ def P4HIR_AggAttr : P4HIR_Attr<"Agg", "aggregate", [TypedAttrInterface]> { // EnumFieldAttr //===----------------------------------------------------------------------===// // An attribute to indicate an enumeration value. -def P4HIR_EnumFieldAttr : P4HIR_Attr<"EnumField", "enum.field", [TypedAttrInterface]> { +def P4HIR_EnumFieldAttr : P4HIR_Attr<"EnumField", "enum_field", [TypedAttrInterface]> { let summary = "Enumeration field attribute"; let description = [{ This attribute represents a field of an enumeration. Examples: ```mlir - #p4hir.enum.field> : !p4hir.enum<"name", A, B, C> + #p4hir.enum_field> : !p4hir.enum<"name", A, B, C> ``` }]; let parameters = (ins AttributeSelfTypeParameter<"">:$type, "::mlir::StringAttr":$field); diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td index ae3f892..4d87d68 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td @@ -276,11 +276,11 @@ def EnumType : P4HIR_Type<"Enum", "enum", []> { }]; } -def SerEnumType : P4HIR_Type<"SerEnum", "ser.enum", []> { +def SerEnumType : P4HIR_Type<"SerEnum", "ser_enum", []> { let summary = "serializable enum type"; let description = [{ Represents an enumeration of values backed by some integer value - !p4hir.ser.enum<"name", !p4hir.bit<32>, Case1 : 42, Case2 : 0> + !p4hir.ser_enum<"name", !p4hir.bit<32>, Case1 : 42, Case2 : 0> }]; let hasCustomAssemblyFormat = 1; diff --git a/test/Dialect/P4HIR/enum.mlir b/test/Dialect/P4HIR/enum.mlir index 24b1650..1e5615a 100644 --- a/test/Dialect/P4HIR/enum.mlir +++ b/test/Dialect/P4HIR/enum.mlir @@ -2,8 +2,8 @@ !Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades> -#Suits_Clubs = #p4hir.enum.field : !Suits -#Suits_Diamonds = #p4hir.enum.field : !Suits +#Suits_Clubs = #p4hir.enum_field : !Suits +#Suits_Diamonds = #p4hir.enum_field : !Suits // CHECK: module module { diff --git a/test/Dialect/P4HIR/types.mlir b/test/Dialect/P4HIR/types.mlir index 110bf97..870a093 100644 --- a/test/Dialect/P4HIR/types.mlir +++ b/test/Dialect/P4HIR/types.mlir @@ -20,7 +20,7 @@ #b3 = #p4hir.int<3> : !bit42 #b4 = #p4hir.int<4> : !bit42 -!SuitsSerializable = !p4hir.ser.enum<"Suits", !bit42, Clubs : #b1, Diamonds : #b2, Hearths : #b3, Spades : #b4> +!SuitsSerializable = !p4hir.ser_enum<"Suits", !bit42, Clubs : #b1, Diamonds : #b2, Hearths : #b3, Spades : #b4> // No need to check stuff. If it parses, it's fine. // CHECK: module diff --git a/test/Translate/Ops/enum.p4 b/test/Translate/Ops/enum.p4 index f9a5815..264db2a 100644 --- a/test/Translate/Ops/enum.p4 +++ b/test/Translate/Ops/enum.p4 @@ -1,9 +1,9 @@ // RUN: p4mlir-translate --typeinference-only %s | FileCheck %s // CHECK: !Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades> -// CHECK: #Suits_Diamonds = #p4hir.enum.field : !Suits -// CHECK: #Suits_Hearths = #p4hir.enum.field : !Suits -// CHECK: #Suits_Spades = #p4hir.enum.field : !Suits +// CHECK: #Suits_Diamonds = #p4hir.enum_field : !Suits +// CHECK: #Suits_Hearths = #p4hir.enum_field : !Suits +// CHECK: #Suits_Spades = #p4hir.enum_field : !Suits enum Suits { Clubs, Diamonds, Hearths, Spades } diff --git a/test/Translate/Ops/serenum.p4 b/test/Translate/Ops/serenum.p4 index 0d6640e..f7755de 100644 --- a/test/Translate/Ops/serenum.p4 +++ b/test/Translate/Ops/serenum.p4 @@ -20,9 +20,9 @@ struct Headers { Ethernet eth; } -// CHECK: !EthTypes = !p4hir.ser.enum<"EthTypes", !b16i, ARP : #int2054_b16i, EtherTalk : #int-32613_b16i, IPX : #int-32457_b16i, IPv4 : #int2048_b16i, IPv6 : #int-31011_b16i, RARP : #int-32715_b16i, VLAN : #int-32512_b16i> +// CHECK: !EthTypes = !p4hir.ser_enum<"EthTypes", !b16i, ARP : #int2054_b16i, EtherTalk : #int-32613_b16i, IPX : #int-32457_b16i, IPv4 : #int2048_b16i, IPv6 : #int-31011_b16i, RARP : #int-32715_b16i, VLAN : #int-32512_b16i> // CHECK: !Ethernet = !p4hir.struct<"Ethernet", src: !b48i, dest: !b48i, type: !EthTypes> -// CHECK: #EthTypes_IPv4_ = #p4hir.enum.field : !EthTypes +// CHECK: #EthTypes_IPv4_ = #p4hir.enum_field : !EthTypes // CHECK-LABEL: module // CHECK-LABEL: p4hir.func action @test(%arg0: !p4hir.ref