diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td index 6d521c5..2f170dc 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td @@ -116,6 +116,36 @@ def P4HIR_EnumFieldAttr : P4HIR_Attr<"EnumField", "enum_field", [TypedAttrInterf ]; } +//===----------------------------------------------------------------------===// +// ErrorAttr +//===----------------------------------------------------------------------===// +// An attribute to indicate a particular error code. +// TODO: Decide if we'd want to unify with EnumFieldAttr? +def P4HIR_ErrorCodeAttr : P4HIR_Attr<"ErrorCode", "error", [TypedAttrInterface]> { + let summary = "Error code attribute"; + let description = [{ + This attribute represents an error code. + + Examples: + ```mlir + #p4hir.error> : !p4hir.error + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"">:$type, "::mlir::StringAttr":$field); + + // Force all clients to go through custom builder so we can check + // whether the requested error value is part of the provided error type. + let skipDefaultBuilders = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "mlir::StringAttr": $value)>, + AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "mlir::StringRef": $value), [{ + return $_get(type.getContext(), type, mlir::StringAttr::get(type.getContext(), value)); + }]> + ]; +} + //===----------------------------------------------------------------------===// // ValidAttr //===----------------------------------------------------------------------===// diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td index 7217f33..36f704c 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td @@ -111,8 +111,6 @@ def BooleanType : P4HIR_Type<"Bool", "bool"> { //===----------------------------------------------------------------------===// def DontcareType : P4HIR_Type<"Dontcare", "dontcare"> {} -// FIXME: Add string for error here & declarations -def ErrorType : P4HIR_Type<"Error", "error"> {} def UnknownType : P4HIR_Type<"Unknown", "unknown"> {} def VoidType : P4HIR_Type<"Void", "void"> { @@ -404,7 +402,7 @@ def HeaderType : StructLikeType<"Header", "header"> { } //===----------------------------------------------------------------------===// -// EnumType & SerEnumType +// EnumType, ErrorType & SerEnumType //===----------------------------------------------------------------------===// def EnumType : P4HIR_Type<"Enum", "enum", []> { let summary = "enum type"; @@ -428,6 +426,29 @@ def EnumType : P4HIR_Type<"Enum", "enum", []> { }]; } +def ErrorType : P4HIR_Type<"Error", "error", []> { + let summary = "error type"; + let description = [{ + Represents an enumeration of error values, essentially an enum + !p4hir.error + }]; + + let hasCustomAssemblyFormat = 1; + + let parameters = (ins "mlir::ArrayAttr":$fields); + + let extraClassDeclaration = [{ + /// Returns true if the requested field is part of this enum + bool contains(mlir::StringRef field) { return indexOf(field).has_value(); } + + /// Returns the index of the requested field, or a nullopt if the field is + /// not part of this enum. + std::optional indexOf(mlir::StringRef field); + + llvm::StringRef getAlias() const { return "error"; }; + }]; +} + def SerEnumType : P4HIR_Type<"SerEnum", "ser_enum", []> { let summary = "serializable enum type"; let description = [{ @@ -475,7 +496,7 @@ def AnyIntP4Type : AnyTypeOf<[BitsType, InfIntType]> {} def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {} def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, HeaderType, - EnumType, SerEnumType, + EnumType, SerEnumType, ErrorType, ValidBitType]> {} def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>; def StructLikeType : AnyTypeOf<[StructType, HeaderType]>; diff --git a/lib/Dialect/P4HIR/P4HIR_Attrs.cpp b/lib/Dialect/P4HIR/P4HIR_Attrs.cpp index 78c3f13..4b506e3 100644 --- a/lib/Dialect/P4HIR/P4HIR_Attrs.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Attrs.cpp @@ -122,6 +122,36 @@ EnumFieldAttr EnumFieldAttr::get(mlir::Type type, StringAttr value) { return Base::get(value.getContext(), type, value); } +Attribute ErrorCodeAttr::parse(AsmParser &p, Type) { + StringRef field; + P4HIR::ErrorType type; + if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() || + p.parseCustomTypeWithFallback(type) || p.parseGreater()) + return {}; + + return EnumFieldAttr::get(type, field); +} + +void ErrorCodeAttr::print(AsmPrinter &p) const { + p << "<" << getField().getValue() << ", "; + p.printType(getType()); + p << ">"; +} + +ErrorCodeAttr ErrorCodeAttr::get(mlir::Type type, StringAttr value) { + ErrorType errorType = llvm::dyn_cast(type); + if (!errorType) return nullptr; + + // Check whether the provided value is a member of the enum type. + if (!errorType.contains(value.getValue())) { + // emitError() << "error code '" << value.getValue() + // << "' is not a member of error type " << errorType; + return nullptr; + } + + return Base::get(value.getContext(), type, value); +} + void P4HIRDialect::registerAttributes() { addAttributes< #define GET_ATTRDEF_LIST diff --git a/lib/Dialect/P4HIR/P4HIR_Ops.cpp b/lib/Dialect/P4HIR/P4HIR_Ops.cpp index 1176f8e..7cf0ff3 100644 --- a/lib/Dialect/P4HIR/P4HIR_Ops.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Ops.cpp @@ -1,5 +1,7 @@ #include "p4mlir/Dialect/P4HIR/P4HIR_Ops.h" +#include + #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/LogicalResult.h" @@ -52,6 +54,13 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, return success(); } + if (mlir::isa(attrType)) { + if (!mlir::isa(opType)) + return op->emitOpError("result type (") << opType << ") is not an error type"; + + return success(); + } + if (mlir::isa(attrType)) { if (!mlir::isa(opType)) return op->emitOpError("result type (") << opType << ") is not a validity bit type"; @@ -91,6 +100,10 @@ void P4HIR::ConstOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { setNameFn(getResult(), boolCst.getValue() ? "true" : "false"); } else if (auto validityCst = mlir::dyn_cast(getValue())) { setNameFn(getResult(), stringifyEnum(validityCst.getValue())); + } else if (auto errorCst = mlir::dyn_cast(getValue())) { + llvm::SmallString<32> error("error_"); + error += errorCst.getField().getValue(); + setNameFn(getResult(), error); } else if (auto enumCst = mlir::dyn_cast(getValue())) { llvm::SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); @@ -919,6 +932,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { return AliasResult::OverridableAlias; } + if (auto errorType = mlir::dyn_cast(type)) { + os << errorType.getAlias(); + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; } @@ -948,6 +966,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { return AliasResult::FinalAlias; } + if (auto errorAttr = mlir::dyn_cast(attr)) { + os << "error_" << errorAttr.getField().getValue(); + return AliasResult::FinalAlias; + } + if (auto enumFieldAttr = mlir::dyn_cast(attr)) { if (auto enumType = mlir::dyn_cast(enumFieldAttr.getType())) os << enumType.getName() << "_" << enumFieldAttr.getField().getValue(); @@ -975,5 +998,5 @@ void P4HIR::P4HIRDialect::initialize() { #define GET_OP_CLASSES #include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.cpp.inc" -#include "p4mlir/Dialect/P4HIR/P4HIR_Ops.cpp.inc" +#include "p4mlir/Dialect/P4HIR/P4HIR_Ops.cpp.inc" // NOLINT #include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.cpp.inc" diff --git a/lib/Dialect/P4HIR/P4HIR_Types.cpp b/lib/Dialect/P4HIR/P4HIR_Types.cpp index 8ac7125..745d136 100644 --- a/lib/Dialect/P4HIR/P4HIR_Types.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Types.cpp @@ -303,6 +303,34 @@ std::optional EnumType::indexOf(mlir::StringRef field) { return {}; } +Type ErrorType::parse(AsmParser &p) { + llvm::SmallVector fields; + if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() { + StringRef caseName; + if (p.parseKeyword(&caseName)) return failure(); + fields.push_back(StringAttr::get(p.getContext(), name)); + return success(); + })) + return {}; + + return get(p.getContext(), ArrayAttr::get(p.getContext(), fields)); +} + +void ErrorType::print(AsmPrinter &p) const { + auto fields = getFields(); + p << '<'; + llvm::interleaveComma(fields, p, [&](Attribute enumerator) { + p << mlir::cast(enumerator).getValue(); + }); + p << ">"; +} + +std::optional ErrorType::indexOf(mlir::StringRef field) { + for (auto it : llvm::enumerate(getFields())) + if (mlir::cast(it.value()).getValue() == field) return it.index(); + return {}; +} + void SerEnumType::print(AsmPrinter &p) const { auto fields = getFields(); p << '<'; diff --git a/test/Dialect/P4HIR/types.mlir b/test/Dialect/P4HIR/types.mlir index 55c0b64..26ae41e 100644 --- a/test/Dialect/P4HIR/types.mlir +++ b/test/Dialect/P4HIR/types.mlir @@ -1,7 +1,7 @@ // RUN: p4mlir-opt %s | FileCheck %s !unknown = !p4hir.unknown -!error = !p4hir.error +!error = !p4hir.error !dontcare = !p4hir.dontcare !bit42 = !p4hir.bit<42> !ref = !p4hir.ref> diff --git a/test/Translate/Ops/error.p4 b/test/Translate/Ops/error.p4 new file mode 100644 index 0000000..705e551 --- /dev/null +++ b/test/Translate/Ops/error.p4 @@ -0,0 +1,24 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +// CHECK: !error = !p4hir.error +// CHECK: !S = !p4hir.struct<"S", e: !error> +// CHECK: #error_Baz = #p4hir.error : !error +// CHECK: #error_Foo = #p4hir.error : !error + +error { Foo, Bar }; + +struct S { + error e; +}; + +action test(inout S s) { +// CHECK-LABEL: test +// CHECK: %[[e_field_ref:.*]] = p4hir.struct_extract_ref %arg0["e"] : +// CHECK: %[[error_Foo:.*]] = p4hir.const #error_Foo +// CHECK: p4hir.assign %[[error_Foo]], %[[e_field_ref]] : + s.e = error.Foo; +} + +error { Baz } + +const S s = { error.Baz }; diff --git a/tools/p4mlir-translate/translate.cpp b/tools/p4mlir-translate/translate.cpp index 3fae31d..70535b3 100644 --- a/tools/p4mlir-translate/translate.cpp +++ b/tools/p4mlir-translate/translate.cpp @@ -99,7 +99,7 @@ class ConversionTracer { // representation. class P4TypeConverter : public P4::Inspector { public: - P4TypeConverter(P4HIRConverter &converter) : converter(converter) {} + explicit P4TypeConverter(P4HIRConverter &converter) : converter(converter) {} profile_t init_apply(const P4::IR::Node *node) override { BUG_CHECK(!type, "Type already converted"); @@ -134,6 +134,7 @@ class P4TypeConverter : public P4::Inspector { bool preorder(const P4::IR::Type_Void *v) override; bool preorder(const P4::IR::Type_Struct *s) override; bool preorder(const P4::IR::Type_Enum *e) override; + bool preorder(const P4::IR::Type_Error *e) override; bool preorder(const P4::IR::Type_SerEnum *se) override; bool preorder(const P4::IR::Type_Header *h) override; @@ -523,6 +524,19 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Enum *type) { return setType(type, mlirType); } +bool P4TypeConverter::preorder(const P4::IR::Type_Error *type) { + if ((this->type = converter.findType(type))) return false; + + ConversionTracer trace("TypeConverting ", type); + llvm::SmallVector cases; + for (const auto *field : type->members) { + cases.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view())); + } + auto mlirType = P4HIR::ErrorType::get(converter.context(), + mlir::ArrayAttr::get(converter.context(), cases)); + return setType(type, mlirType); +} + bool P4TypeConverter::preorder(const P4::IR::Type_SerEnum *type) { if ((this->type = converter.findType(type))) return false; @@ -647,9 +661,15 @@ mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression } if (const auto *m = expr->to()) { if (const auto *typeNameExpr = m->expr->to()) { - auto baseType = mlir::cast(getOrCreateType(typeNameExpr->typeName)); + auto baseType = getOrCreateType(typeNameExpr->typeName); + if (auto errorType = mlir::dyn_cast(baseType)) { + return setConstantExpr( + expr, P4HIR::ErrorCodeAttr::get(errorType, m->member.string_view())); + } + + auto enumType = mlir::cast(baseType); return setConstantExpr(expr, - P4HIR::EnumFieldAttr::get(baseType, m->member.string_view())); + P4HIR::EnumFieldAttr::get(enumType, m->member.string_view())); } auto base = mlir::cast(getOrCreateConstantExpr(m->expr)); @@ -666,6 +686,9 @@ mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression if (mlir::isa(fieldType)) return setConstantExpr(expr, mlir::cast(field)); + if (mlir::isa(fieldType)) + return setConstantExpr(expr, mlir::cast(field)); + return setConstantExpr(expr, mlir::cast(field)); } else BUG("invalid member reference %1%", m); @@ -1229,12 +1252,18 @@ bool P4HIRConverter::preorder(const P4::IR::Member *m) { // This is just enum constant if (const auto *typeNameExpr = m->expr->to()) { auto type = getOrCreateType(typeNameExpr->typeName); - BUG_CHECK((mlir::isa(type)), + BUG_CHECK((mlir::isa(type)), "unexpected type for expression %1%", typeNameExpr); - setValue(m, builder.create( - getLoc(builder, m), - P4HIR::EnumFieldAttr::get(type, m->member.name.string_view()))); + if (mlir::isa(type)) + setValue(m, builder.create( + getLoc(builder, m), + P4HIR::ErrorCodeAttr::get(type, m->member.name.string_view()))); + else + setValue(m, builder.create( + getLoc(builder, m), + P4HIR::EnumFieldAttr::get(type, m->member.name.string_view()))); + return false; }