Skip to content

Commit

Permalink
Add support for tuple type. Use builtin tuple type (for now).
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Korobeynikov <anton@korobeynikov.info>
  • Loading branch information
asl committed Mar 4, 2025
1 parent dd9e2ae commit c1bf688
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 15 deletions.
1 change: 0 additions & 1 deletion include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def P4HIR_AggAttr : P4HIR_Attr<"Agg", "aggregate", [TypedAttrInterface]> {
}]>
];
// let genVerifyDecl = 1;
//let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
`<` $fields `>`
}];
Expand Down
48 changes: 48 additions & 0 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -947,4 +947,52 @@ def StructExtractRefOp : P4HIR_Op<"struct_extract_ref",
}];
}

def TupleOp : P4HIR_Op<"tuple",
[Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Create a tuplet from constituent parts.";
// FIXME: Better constraint type
let arguments = (ins Variadic<AnyP4Type>:$input);
let results = (outs Builtin_Tuple:$result);
let hasCustomAssemblyFormat = 1;
// FIXME: use declarative format
// let assemblyFormat = [{
// `(` $input `)` attr-dict `:` type($result)
// }];

let hasVerifier = 1;
}

def TupleExtractOp : P4HIR_Op<"tuple_extract",
[Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Extract a field from a tuple.";
let description = [{
```
%result = p4hir.tuple_extract %input[0] : !tuple<!bit32, !bit64>

```
}];

// TODO: find a way to use fieldid inteerface on tuples...
let arguments = (ins Builtin_Tuple:$input, I32Attr:$fieldIndex);
// FIXME: Better constraint type
let results = (outs AnyP4Type:$result);

let builders = [
OpBuilder<(ins "mlir::Value":$input, "unsigned":$fieldIndex)>,
OpBuilder<(ins "mlir::Value":$input, "P4HIR::IntAttr":$fieldIndex), [{
build($_builder, $_state, input, fieldIndex.getUInt());
}]>
];

// FIXME: use declarative format
// let assemblyFormat = [{
// $input `[` $fieldIndex `]` attr-dict `:` type($result)
// }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

#endif // P4MLIR_DIALECT_P4HIR_P4HIR_OPS_TD
5 changes: 3 additions & 2 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Types.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define P4MLIR_DIALECT_P4HIR_P4HIR_TYPES_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"

Expand Down Expand Up @@ -488,14 +489,14 @@ def SerEnumType : P4HIR_Type<"SerEnum", "ser_enum", []> {
//===----------------------------------------------------------------------===//

def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType,
StructType, HeaderType,
StructType, HeaderType, Builtin_Tuple,
EnumType, SerEnumType,
ValidBitType,
DontcareType, ErrorType, UnknownType]> {}
def AnyIntP4Type : AnyTypeOf<[BitsType, InfIntType]> {}
def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {}
def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType,
StructType, HeaderType,
StructType, HeaderType, Builtin_Tuple,
EnumType, SerEnumType, ErrorType,
ValidBitType]> {}
def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>;
Expand Down
106 changes: 105 additions & 1 deletion lib/Dialect/P4HIR/P4HIR_Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/FunctionImplementation.h"
Expand Down Expand Up @@ -41,7 +42,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
if (mlir::isa<P4HIR::IntAttr, P4HIR::BoolAttr>(attrType)) return success();

if (mlir::isa<P4HIR::AggAttr>(attrType)) {
if (!mlir::isa<P4HIR::StructType, P4HIR::HeaderType>(opType))
if (!mlir::isa<P4HIR::StructType, P4HIR::HeaderType, mlir::TupleType>(opType))
return op->emitOpError("result type (") << opType << ") is not an aggregate type";

return success();
Expand Down Expand Up @@ -887,6 +888,109 @@ void P4HIR::StructExtractRefOp::build(OpBuilder &builder, OperationState &odsSta
build(builder, odsState, ReferenceType::get(fieldType), input, *fieldIndex);
}

//===----------------------------------------------------------------------===//
// StructOp
//===----------------------------------------------------------------------===//

ParseResult P4HIR::TupleOp::parse(OpAsmParser &parser, OperationState &result) {
llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
Type declType;

if (parser.parseLParen() || parser.parseOperandList(operands) || parser.parseRParen() ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(declType))
return failure();

auto tupleType = mlir::dyn_cast<mlir::TupleType>(declType);
if (!tupleType) return parser.emitError(parser.getNameLoc(), "expected !tuple type");

result.addTypes(tupleType);
if (parser.resolveOperands(operands, tupleType.getTypes(), inputOperandsLoc, result.operands))
return failure();
return success();
}

void P4HIR::TupleOp::print(OpAsmPrinter &printer) {
printer << " (";
printer.printOperands(getInput());
printer << ")";
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : " << getType();
}

LogicalResult P4HIR::TupleOp::verify() {
auto elementTypes = getType().getTypes();

if (elementTypes.size() != getInput().size()) return emitOpError("tuple field count mismatch");

for (const auto &[field, value] : llvm::zip(elementTypes, getInput()))
if (field != value.getType()) return emitOpError("tuple field types do not match");

return success();
}

void P4HIR::TupleOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "tuple");
}

// TODO: This duplicates lots of things above for structs. Find a way to generalize
LogicalResult P4HIR::TupleExtractOp::verify() {
auto index = getFieldIndex();
auto fields = getInput().getType();
if (index >= fields.size())
return emitOpError() << "field index " << index
<< " exceeds element count of aggregate type";

if (getType() != fields.getType(index))
return emitOpError() << "type " << fields.getType(index)
<< " of accessed field in aggregate at index " << index
<< " does not match expected type " << getType();

return success();
}

ParseResult P4HIR::TupleExtractOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand operand;
unsigned fieldIndex = -1U;
mlir::TupleType declType;

if (parser.parseOperand(operand) || parser.parseLSquare() || parser.parseInteger(fieldIndex) ||
parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseType<mlir::TupleType>(declType))
return failure();

auto indexAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), fieldIndex);
result.addAttribute("fieldIndex", indexAttr);
Type resultType = declType.getType(fieldIndex);
result.addTypes(resultType);

if (parser.resolveOperand(operand, declType, result.operands)) return failure();
return success();
}

void P4HIR::TupleExtractOp::print(OpAsmPrinter &printer) {
printer << " ";
printer.printOperand(getInput());
printer << "[" << getFieldIndex() << "]";
printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
printer << " : ";
printer << getInput().getType();
}

void P4HIR::TupleExtractOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
llvm::SmallString<16> name;
llvm::raw_svector_ostream specialName(name);
specialName << 't' << getFieldIndex();

setNameFn(getResult(), name);
}

void P4HIR::TupleExtractOp::build(OpBuilder &builder, OperationState &odsState, Value input,
unsigned fieldIndex) {
auto tupleType = mlir::cast<mlir::TupleType>(input.getType());
build(builder, odsState, tupleType.getType(fieldIndex), input, fieldIndex);
}

namespace {
struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
Expand Down
39 changes: 39 additions & 0 deletions test/Dialect/P4HIR/tuple.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: p4mlir-opt %s | FileCheck %s

!b16i = !p4hir.bit<16>
!b32i = !p4hir.bit<32>
#int10_b32i = #p4hir.int<10> : !b32i
#int0_b32i = #p4hir.int<0> : !b32i
#int1_b32i = #p4hir.int<1> : !b32i
#false = #p4hir.bool<false> : !p4hir.bool
#int12_b16i = #p4hir.int<12> : !b16i

// CHECK: module
module {
%t = p4hir.const ["t"] #p4hir.aggregate<[#int0_b32i, #int1_b32i]> : tuple<!b32i, !b32i>

p4hir.func action @test(%arg0: !p4hir.ref<!b16i> {p4hir.dir = #p4hir<dir out>}) {
%c10_b32i = p4hir.const #int10_b32i
%c12_b16i = p4hir.const #int12_b16i
%tuple = p4hir.tuple (%c10_b32i, %c12_b16i) : tuple<!b32i, !b16i>
%x_0 = p4hir.variable ["x", init] : <tuple<!b32i, !b16i>>
p4hir.assign %tuple, %x_0 : <tuple<!b32i, !b16i>>
%val_4 = p4hir.read %x_0 : <tuple<!b32i, !b16i>>
%t1 = p4hir.tuple_extract %val_4[1] : tuple<!b32i, !b16i>
p4hir.assign %t1, %arg0 : <!b16i>

p4hir.return
}

p4hir.func action @test2() {
%c10_b32i = p4hir.const #int10_b32i
%false = p4hir.const #false
%tuple = p4hir.tuple (%c10_b32i, %false) : tuple<!b32i, !p4hir.bool>
%x_0 = p4hir.variable ["x", init] : <tuple<!b32i, !p4hir.bool>>
p4hir.assign %tuple, %x_0 : <tuple<!b32i, !p4hir.bool>>
%y = p4hir.variable ["y"] : <tuple<!b32i, !p4hir.bool>>
%val = p4hir.read %x_0 : <tuple<!b32i, !p4hir.bool>>
p4hir.assign %val, %y : <tuple<!b32i, !p4hir.bool>>
p4hir.return
}
}
2 changes: 2 additions & 0 deletions test/Dialect/P4HIR/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#valid = #p4hir<validity.bit valid>
#invalid = #p4hir<validity.bit invalid>

!tuple = tuple<!bit42, !void, !SuitsSerializable>

// No need to check stuff. If it parses, it's fine.
// CHECK: module
module {
Expand Down
37 changes: 37 additions & 0 deletions test/Translate/Ops/tuple.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s

struct S {
bit<32> f;
bool s;
}

const S x = { 42, false };

const tuple<bit<32>, bit<32>> t = { 0, 1 };
const bit<32> f = t[0];

// CHECK: p4hir.const ["t"] #p4hir.aggregate<[#int0_b32i, #int1_b32i]> : tuple<!b32i, !b32i>
// CHECK: p4hir.const ["f"] #int0_b32i

// CHECK-LABEL: p4hir.func action @test
action test(out bit<16> r) {
// CHECK: %[[c10_b32i:.*]] = p4hir.const #int10_b32i
// CHECK: %[[c12_b16i:.*]] = p4hir.const #int12_b16i
// CHECK: %[[tuple:.*]] = p4hir.tuple (%[[c10_b32i]], %[[c12_b16i]]) : tuple<!b32i, !b16i>
tuple<bit<32>, bit<16>> x = { 10, 12 };
// CHECK: %[[x_0:.*]] = p4hir.variable ["x", init] : <tuple<!b32i, !b16i>>
// CHECK: p4hir.if
// CHECK: %[[val_4:.*]] = p4hir.read %[[x_0]] : <tuple<!b32i, !b16i>>
// CHECK: p4hir.tuple_extract %[[val_4]][1] : tuple<!b32i, !b16i>
if (x == { 10, 12 })
r = x[1];
else
r = (bit<16>)x[0];
}

typedef tuple<bit<32>, bool> pair;
action test2() {
pair x = { 10, false };
tuple<bit<32>, bool> y;
y = x;
}
Loading

0 comments on commit c1bf688

Please sign in to comment.