Skip to content

Commit

Permalink
Add proper support for Type_Error (#71)
Browse files Browse the repository at this point in the history
This makes handling pretty similar to enums. It could be refined later.

Signed-off-by: Anton Korobeynikov <anton@korobeynikov.info>
  • Loading branch information
asl authored Feb 24, 2025
1 parent e5a1258 commit bfdaf7b
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 13 deletions.
30 changes: 30 additions & 0 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,36 @@ def P4HIR_EnumFieldAttr : P4HIR_Attr<"EnumField", "enum_field", [TypedAttrInterf
];
}

//===----------------------------------------------------------------------===//
// ErrorAttr
//===----------------------------------------------------------------------===//
// An attribute to indicate a particular error code.
// TODO: Decide if we'd want to unify with EnumFieldAttr?
def P4HIR_ErrorCodeAttr : P4HIR_Attr<"ErrorCode", "error", [TypedAttrInterface]> {
let summary = "Error code attribute";
let description = [{
This attribute represents an error code.

Examples:
```mlir
#p4hir.error<ErrorA, !p4hir.error<ErrorA, ErrorB, ErrorC>> : !p4hir.error<ErrorA, ErrorB, ErrorC>
```
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "::mlir::StringAttr":$field);

// Force all clients to go through custom builder so we can check
// whether the requested error value is part of the provided error type.
let skipDefaultBuilders = 1;
let hasCustomAssemblyFormat = 1;

let builders = [
AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "mlir::StringAttr": $value)>,
AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "mlir::StringRef": $value), [{
return $_get(type.getContext(), type, mlir::StringAttr::get(type.getContext(), value));
}]>
];
}

//===----------------------------------------------------------------------===//
// ValidAttr
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 25 additions & 4 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Types.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ def BooleanType : P4HIR_Type<"Bool", "bool"> {
//===----------------------------------------------------------------------===//

def DontcareType : P4HIR_Type<"Dontcare", "dontcare"> {}
// FIXME: Add string for error here & declarations
def ErrorType : P4HIR_Type<"Error", "error"> {}
def UnknownType : P4HIR_Type<"Unknown", "unknown"> {}

def VoidType : P4HIR_Type<"Void", "void"> {
Expand Down Expand Up @@ -404,7 +402,7 @@ def HeaderType : StructLikeType<"Header", "header"> {
}

//===----------------------------------------------------------------------===//
// EnumType & SerEnumType
// EnumType, ErrorType & SerEnumType
//===----------------------------------------------------------------------===//
def EnumType : P4HIR_Type<"Enum", "enum", []> {
let summary = "enum type";
Expand All @@ -428,6 +426,29 @@ def EnumType : P4HIR_Type<"Enum", "enum", []> {
}];
}

def ErrorType : P4HIR_Type<"Error", "error", []> {
let summary = "error type";
let description = [{
Represents an enumeration of error values, essentially an enum
!p4hir.error<Case1, Case2>
}];

let hasCustomAssemblyFormat = 1;

let parameters = (ins "mlir::ArrayAttr":$fields);

let extraClassDeclaration = [{
/// Returns true if the requested field is part of this enum
bool contains(mlir::StringRef field) { return indexOf(field).has_value(); }

/// Returns the index of the requested field, or a nullopt if the field is
/// not part of this enum.
std::optional<size_t> indexOf(mlir::StringRef field);

llvm::StringRef getAlias() const { return "error"; };
}];
}

def SerEnumType : P4HIR_Type<"SerEnum", "ser_enum", []> {
let summary = "serializable enum type";
let description = [{
Expand Down Expand Up @@ -475,7 +496,7 @@ def AnyIntP4Type : AnyTypeOf<[BitsType, InfIntType]> {}
def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {}
def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType,
StructType, HeaderType,
EnumType, SerEnumType,
EnumType, SerEnumType, ErrorType,
ValidBitType]> {}
def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>;
def StructLikeType : AnyTypeOf<[StructType, HeaderType]>;
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/P4HIR/P4HIR_Attrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,36 @@ EnumFieldAttr EnumFieldAttr::get(mlir::Type type, StringAttr value) {
return Base::get(value.getContext(), type, value);
}

Attribute ErrorCodeAttr::parse(AsmParser &p, Type) {
StringRef field;
P4HIR::ErrorType type;
if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() ||
p.parseCustomTypeWithFallback<P4HIR::ErrorType>(type) || p.parseGreater())
return {};

return EnumFieldAttr::get(type, field);
}

void ErrorCodeAttr::print(AsmPrinter &p) const {
p << "<" << getField().getValue() << ", ";
p.printType(getType());
p << ">";
}

ErrorCodeAttr ErrorCodeAttr::get(mlir::Type type, StringAttr value) {
ErrorType errorType = llvm::dyn_cast<ErrorType>(type);
if (!errorType) return nullptr;

// Check whether the provided value is a member of the enum type.
if (!errorType.contains(value.getValue())) {
// emitError() << "error code '" << value.getValue()
// << "' is not a member of error type " << errorType;
return nullptr;
}

return Base::get(value.getContext(), type, value);
}

void P4HIRDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
Expand Down
25 changes: 24 additions & 1 deletion lib/Dialect/P4HIR/P4HIR_Ops.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "p4mlir/Dialect/P4HIR/P4HIR_Ops.h"

#include <string>

#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/LogicalResult.h"
Expand Down Expand Up @@ -52,6 +54,13 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}

if (mlir::isa<P4HIR::ErrorCodeAttr>(attrType)) {
if (!mlir::isa<P4HIR::ErrorType>(opType))
return op->emitOpError("result type (") << opType << ") is not an error type";

return success();
}

if (mlir::isa<P4HIR::ValidityBitAttr>(attrType)) {
if (!mlir::isa<P4HIR::ValidBitType>(opType))
return op->emitOpError("result type (") << opType << ") is not a validity bit type";
Expand Down Expand Up @@ -91,6 +100,10 @@ void P4HIR::ConstOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
setNameFn(getResult(), boolCst.getValue() ? "true" : "false");
} else if (auto validityCst = mlir::dyn_cast<P4HIR::ValidityBitAttr>(getValue())) {
setNameFn(getResult(), stringifyEnum(validityCst.getValue()));
} else if (auto errorCst = mlir::dyn_cast<P4HIR::ErrorCodeAttr>(getValue())) {
llvm::SmallString<32> error("error_");
error += errorCst.getField().getValue();
setNameFn(getResult(), error);
} else if (auto enumCst = mlir::dyn_cast<P4HIR::EnumFieldAttr>(getValue())) {
llvm::SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
Expand Down Expand Up @@ -919,6 +932,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface {
return AliasResult::OverridableAlias;
}

if (auto errorType = mlir::dyn_cast<P4HIR::ErrorType>(type)) {
os << errorType.getAlias();
return AliasResult::OverridableAlias;
}

return AliasResult::NoAlias;
}

Expand Down Expand Up @@ -948,6 +966,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}

if (auto errorAttr = mlir::dyn_cast<P4HIR::ErrorCodeAttr>(attr)) {
os << "error_" << errorAttr.getField().getValue();
return AliasResult::FinalAlias;
}

if (auto enumFieldAttr = mlir::dyn_cast<P4HIR::EnumFieldAttr>(attr)) {
if (auto enumType = mlir::dyn_cast<P4HIR::EnumType>(enumFieldAttr.getType()))
os << enumType.getName() << "_" << enumFieldAttr.getField().getValue();
Expand Down Expand Up @@ -975,5 +998,5 @@ void P4HIR::P4HIRDialect::initialize() {

#define GET_OP_CLASSES
#include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.cpp.inc"
#include "p4mlir/Dialect/P4HIR/P4HIR_Ops.cpp.inc"
#include "p4mlir/Dialect/P4HIR/P4HIR_Ops.cpp.inc" // NOLINT
#include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.cpp.inc"
28 changes: 28 additions & 0 deletions lib/Dialect/P4HIR/P4HIR_Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,34 @@ std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
return {};
}

Type ErrorType::parse(AsmParser &p) {
llvm::SmallVector<Attribute> fields;
if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
StringRef caseName;
if (p.parseKeyword(&caseName)) return failure();
fields.push_back(StringAttr::get(p.getContext(), name));
return success();
}))
return {};

return get(p.getContext(), ArrayAttr::get(p.getContext(), fields));
}

void ErrorType::print(AsmPrinter &p) const {
auto fields = getFields();
p << '<';
llvm::interleaveComma(fields, p, [&](Attribute enumerator) {
p << mlir::cast<StringAttr>(enumerator).getValue();
});
p << ">";
}

std::optional<size_t> ErrorType::indexOf(mlir::StringRef field) {
for (auto it : llvm::enumerate(getFields()))
if (mlir::cast<StringAttr>(it.value()).getValue() == field) return it.index();
return {};
}

void SerEnumType::print(AsmPrinter &p) const {
auto fields = getFields();
p << '<';
Expand Down
2 changes: 1 addition & 1 deletion test/Dialect/P4HIR/types.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: p4mlir-opt %s | FileCheck %s

!unknown = !p4hir.unknown
!error = !p4hir.error
!error = !p4hir.error<ErrorA, ErrorB>
!dontcare = !p4hir.dontcare
!bit42 = !p4hir.bit<42>
!ref = !p4hir.ref<!p4hir.bit<42>>
Expand Down
24 changes: 24 additions & 0 deletions test/Translate/Ops/error.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s

// CHECK: !error = !p4hir.error<Foo, Bar, Baz>
// CHECK: !S = !p4hir.struct<"S", e: !error>
// CHECK: #error_Baz = #p4hir.error<Baz, !error> : !error
// CHECK: #error_Foo = #p4hir.error<Foo, !error> : !error

error { Foo, Bar };

struct S {
error e;
};

action test(inout S s) {
// CHECK-LABEL: test
// CHECK: %[[e_field_ref:.*]] = p4hir.struct_extract_ref %arg0["e"] : <!S>
// CHECK: %[[error_Foo:.*]] = p4hir.const #error_Foo
// CHECK: p4hir.assign %[[error_Foo]], %[[e_field_ref]] : <!error>
s.e = error.Foo;
}

error { Baz }

const S s = { error.Baz };
43 changes: 36 additions & 7 deletions tools/p4mlir-translate/translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ConversionTracer {
// representation.
class P4TypeConverter : public P4::Inspector {
public:
P4TypeConverter(P4HIRConverter &converter) : converter(converter) {}
explicit P4TypeConverter(P4HIRConverter &converter) : converter(converter) {}

profile_t init_apply(const P4::IR::Node *node) override {
BUG_CHECK(!type, "Type already converted");
Expand Down Expand Up @@ -134,6 +134,7 @@ class P4TypeConverter : public P4::Inspector {
bool preorder(const P4::IR::Type_Void *v) override;
bool preorder(const P4::IR::Type_Struct *s) override;
bool preorder(const P4::IR::Type_Enum *e) override;
bool preorder(const P4::IR::Type_Error *e) override;
bool preorder(const P4::IR::Type_SerEnum *se) override;
bool preorder(const P4::IR::Type_Header *h) override;

Expand Down Expand Up @@ -523,6 +524,19 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Enum *type) {
return setType(type, mlirType);
}

bool P4TypeConverter::preorder(const P4::IR::Type_Error *type) {
if ((this->type = converter.findType(type))) return false;

ConversionTracer trace("TypeConverting ", type);
llvm::SmallVector<mlir::Attribute, 4> cases;
for (const auto *field : type->members) {
cases.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view()));
}
auto mlirType = P4HIR::ErrorType::get(converter.context(),
mlir::ArrayAttr::get(converter.context(), cases));
return setType(type, mlirType);
}

bool P4TypeConverter::preorder(const P4::IR::Type_SerEnum *type) {
if ((this->type = converter.findType(type))) return false;

Expand Down Expand Up @@ -647,9 +661,15 @@ mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression
}
if (const auto *m = expr->to<P4::IR::Member>()) {
if (const auto *typeNameExpr = m->expr->to<P4::IR::TypeNameExpression>()) {
auto baseType = mlir::cast<P4HIR::EnumType>(getOrCreateType(typeNameExpr->typeName));
auto baseType = getOrCreateType(typeNameExpr->typeName);
if (auto errorType = mlir::dyn_cast<P4HIR::ErrorType>(baseType)) {
return setConstantExpr(
expr, P4HIR::ErrorCodeAttr::get(errorType, m->member.string_view()));
}

auto enumType = mlir::cast<P4HIR::EnumType>(baseType);
return setConstantExpr(expr,
P4HIR::EnumFieldAttr::get(baseType, m->member.string_view()));
P4HIR::EnumFieldAttr::get(enumType, m->member.string_view()));
}

auto base = mlir::cast<P4HIR::AggAttr>(getOrCreateConstantExpr(m->expr));
Expand All @@ -666,6 +686,9 @@ mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression
if (mlir::isa<P4HIR::BitsType, P4HIR::InfIntType>(fieldType))
return setConstantExpr(expr, mlir::cast<P4HIR::IntAttr>(field));

if (mlir::isa<P4HIR::ErrorType>(fieldType))
return setConstantExpr(expr, mlir::cast<P4HIR::ErrorCodeAttr>(field));

return setConstantExpr(expr, mlir::cast<P4HIR::AggAttr>(field));
} else
BUG("invalid member reference %1%", m);
Expand Down Expand Up @@ -1229,12 +1252,18 @@ bool P4HIRConverter::preorder(const P4::IR::Member *m) {
// This is just enum constant
if (const auto *typeNameExpr = m->expr->to<P4::IR::TypeNameExpression>()) {
auto type = getOrCreateType(typeNameExpr->typeName);
BUG_CHECK((mlir::isa<P4HIR::EnumType, P4HIR::SerEnumType>(type)),
BUG_CHECK((mlir::isa<P4HIR::EnumType, P4HIR::SerEnumType, P4HIR::ErrorType>(type)),
"unexpected type for expression %1%", typeNameExpr);

setValue(m, builder.create<P4HIR::ConstOp>(
getLoc(builder, m),
P4HIR::EnumFieldAttr::get(type, m->member.name.string_view())));
if (mlir::isa<P4HIR::ErrorType>(type))
setValue(m, builder.create<P4HIR::ConstOp>(
getLoc(builder, m),
P4HIR::ErrorCodeAttr::get(type, m->member.name.string_view())));
else
setValue(m, builder.create<P4HIR::ConstOp>(
getLoc(builder, m),
P4HIR::EnumFieldAttr::get(type, m->member.name.string_view())));

return false;
}

Expand Down

0 comments on commit bfdaf7b

Please sign in to comment.