Skip to content

Commit

Permalink
Pass to wrap stablehlo ops in composite
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Feb 27, 2025
1 parent 03597b1 commit 21d6e25
Show file tree
Hide file tree
Showing 7 changed files with 404 additions and 0 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,7 @@ cc_library(
"stablehlo/transforms/StablehloLegalizeToVhlo.cpp",
"stablehlo/transforms/StablehloRefineArguments.cpp",
"stablehlo/transforms/StablehloRefineShapes.cpp",
"stablehlo/transforms/StablehloWrapInComposite.cpp",
"stablehlo/transforms/VhloLegalizeToStablehlo.cpp",
"stablehlo/transforms/VhloToVersion.cpp",
],
Expand Down
53 changes: 53 additions & 0 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,59 @@ Modules valid for shape refinement must have the following properties:
* All calls to a single function resolve to the same argument shapes, and no
recursive / co-recursive function calls are made.

### `-stablehlo-wrap-in-composite`

_Wraps a non-composite StableHLO op in a composite op._

Wraps StableHLO ops, as specified by the pass option flag, in a
composite op. The composite op will inherit all attributes of the original
op.

For example, using the pass option `--stablehlo-wrap-in-composite=op-names='stablehlo.add'`,

```mlir
func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
```

will become:

```mlir
func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 {
decomposition = @stablehlo.add.impl,
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
```

The pass is also exposed as an API `createStablehloWrapInCompositePass` to
allow for more flexible selection of ops to wrap.
For example, the following will wrap all non-composite ops that are not
`stablehlo.add` or `stablehlo.convolution`:

```c++
auto pass = createStablehloWrapInCompositePass(
(Operation *op) {
return (op->getName().getStringRef() == "stablehlo.add" ||
op->getName().getStringRef() == "stablehlo.convolution") &&
!isa<stablehlo::CompositeOp>(op);
});
```
#### Options
```
-op-names : The names of the ops to wrap.
```
### `-vhlo-legalize-to-stablehlo`
_Legalize VHLO to StableHLO._
Expand Down
68 changes: 68 additions & 0 deletions stablehlo/tests/transforms/stablehlo_wrap_in_composite.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: stablehlo-opt --stablehlo-wrap-in-composite=op-names='stablehlo.add,stablehlo.convolution' --split-input-file --verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: func.func @wrap_in_composite
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x8x8x8xi8>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x4x8x32xi8>,
// CHECK-SAME: %[[ARG_2:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
// CHECK: %[[CONV:.*]] = stablehlo.composite "stablehlo.convolution" %[[ARG_0]], %[[ARG_1]] {
// CHECK-SAME: composite_attributes = {batch_group_count = 1 : i64,
// CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
// CHECK-SAME: feature_group_count = 1 : i64,
// CHECK-SAME{LITERAL}: padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
// CHECK-SAME{LITERAL}: rhs_dilation = array<i64: 2, 2>,
// CHECK-SAME{LITERAL}: window_strides = array<i64: 1, 1>},
// CHECK-SAME: decomposition = @stablehlo.convolution.impl} : (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
// CHECK: %[[ADD:.*]] = stablehlo.composite "stablehlo.add" %[[CONV]], %[[ARG_2]] {decomposition = @stablehlo.add.impl} : (tensor<64x3x3x32xi32>, tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32>
// CHECK-NEXT: %[[ADD1:.*]] = stablehlo.composite "stablehlo.add" %[[ADD]], %[[ADD]] {decomposition = @stablehlo.add.impl1} : (tensor<64x3x3x32xi32>, tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32>
// CHECK-NEXT: return %[[ADD1]]

// CHECK-LABEL: func.func private @stablehlo.convolution.impl
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x8x8x8xi8>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32> {
// CHECK: %[[VAL:.*]] = stablehlo.convolution(%[[ARG_0]], %[[ARG_1]])
// CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
// CHECK-SAME{LITERAL}: stride = [1, 1],
// CHECK-SAME{LITERAL}: pad = [[0, 1], [0, 1]],
// CHECK-SAME{LITERAL}: rhs_dilate = [2, 2]}
// CHECK-SAME: batch_group_count = 1 : i64
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK-SAME: (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
// CHECK-NEXT: return %[[VAL]]

// CHECK-LABEL: func.func private @stablehlo.add.impl
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x3x3x32xi32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
// CHECK: %[[VAL:.*]] = stablehlo.add %[[ARG_0]], %[[ARG_1]] : tensor<64x3x3x32xi32>
// CHECK-NEXT: return %[[VAL]]

// CHECK-LABEL: func.func private @stablehlo.add.impl1
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x3x3x32xi32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
// CHECK: %[[VAL:.*]] = stablehlo.add %[[ARG_1]], %[[ARG_1]] : tensor<64x3x3x32xi32>
// CHECK-NEXT: return %[[VAL]]

func.func @wrap_in_composite(
%arg0: tensor<64x8x8x8xi8>,
%arg1: tensor<4x4x8x32xi8>,
%arg2: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [2, 2]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64} :
(tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
%1 = stablehlo.add %0, %arg2 : tensor<64x3x3x32xi32>
%2 = stablehlo.add %1, %1 : tensor<64x3x3x32xi32>
func.return %2 : tensor<64x3x3x32xi32>
}

// -----

// CHECK-LABEL: func.func @cannot_be_wrapped_ops_does_not_match
// CHECK-SAME: %[[ARG_0:.*]]: tensor<2xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[VAL:.*]] = stablehlo.multiply %[[ARG_0]], %[[ARG_1]] : tensor<2xf32>
// CHECK-NEXT: return %[[VAL]] : tensor<2xf32>
func.func @cannot_be_wrapped_ops_does_not_match(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.multiply %arg0, %arg1 : tensor<2xf32>
func.return %0 : tensor<2xf32>
}
1 change: 1 addition & 0 deletions stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ add_mlir_dialect_library(StablehloPasses
StablehloLegalizeToVhlo.cpp
StablehloRefineArguments.cpp
StablehloRefineShapes.cpp
StablehloWrapInComposite.cpp
VhloLegalizeToStablehlo.cpp
VhloToVersion.cpp
PassUtils.cpp
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef STABLEHLO_TRANSFORMS_PASSES_H
#define STABLEHLO_TRANSFORMS_PASSES_H

#include <functional>
#include <memory>

#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -102,6 +103,14 @@ void populateStablehloCompatibilityExpanderPatterns(
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
TypeRange refinedTypes);

/// Creates a pass that wraps StableHLO ops in CompositeOp.
///
/// The pass will wrap the StableHLO ops that match the given opPredicate
/// function in CompositeOp. The opPredicate function should return true if the
/// op should be wrapped in CompositeOp.
std::unique_ptr<OperationPass<ModuleOp>> createStablehloWrapInCompositePass(
std::function<bool(Operation *)> opPredicate);

//// Pass pipelines ////

// StableHLO consumers can add this pipeline to convert portable artifacts to
Expand Down
56 changes: 56 additions & 0 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,59 @@ def VhloToVersionPass : Pass<"vhlo-to-version"> {
];
let dependentDialects = ["mlir::vhlo::VhloDialect"];
}

def StablehloWrapInCompositePass : Pass<"stablehlo-wrap-in-composite", "ModuleOp"> {
let summary = "Wraps a non-composite StableHLO op in a composite op.";
let description = [{
Wraps StableHLO ops, as specified by the pass option flag, in a
composite op. The composite op will inherit all attributes of the original
op.

For example, using the pass option `--stablehlo-wrap-in-composite=op-names='stablehlo.add'`,

```mlir
func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
```

will become:

```mlir
func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 {
decomposition = @stablehlo.add.impl,
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}

func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
```

The pass is also exposed as an API `createStablehloWrapInCompositePass` to
allow for more flexible selection of ops to wrap.
For example, the following will wrap all non-composite ops that are not
`stablehlo.add` or `stablehlo.convolution`:

```c++
auto pass = createStablehloWrapInCompositePass(
(Operation *op) {
return (op->getName().getStringRef() == "stablehlo.add" ||
op->getName().getStringRef() == "stablehlo.convolution") &&
!isa<stablehlo::CompositeOp>(op);
});
```
}];
let options = [
ListOption<"opNamesOption", "op-names", "std::string",
"The names of the ops to wrap.">,
];
let dependentDialects = [
"mlir::func::FuncDialect",
"mlir::stablehlo::StablehloDialect",
];
}
Loading

0 comments on commit 21d6e25

Please sign in to comment.