Skip to content

Commit

Permalink
[XLA:GPU] Remove unused should_process callback in pipeliner
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729604132
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 27, 2025
1 parent a32279e commit 6019eb6
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 24 deletions.
39 changes: 28 additions & 11 deletions xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,8 @@ absl::StatusOr<HloInstruction*> CloneBackwardChain(
}
clone_map[chain_op] = cloned;
if (postprocess_pipelined_ops.has_value()) {
TF_RETURN_IF_ERROR((*postprocess_pipelined_ops)(cloned));
TF_RETURN_IF_ERROR(
(*postprocess_pipelined_ops)(cloned, /*new_while_instr=*/nullptr));
}
last_cloned = cloned;
if (loop_variant_parameter_info != nullptr &&
Expand Down Expand Up @@ -1941,7 +1942,8 @@ absl::Status TransformLoopForward(
}

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(processed));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(processed, /*new_while_instr=*/nullptr));
}

InstructionMap cloned_map = pipelined_values_map;
Expand All @@ -1957,7 +1959,8 @@ absl::Status TransformLoopForward(
}
cloned_map[formatting_op] = processed;
if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(processed));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(processed, /*new_while_instr=*/nullptr));
}
}
return processed;
Expand Down Expand Up @@ -2669,9 +2672,10 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis,
static absl::Status TransformLoopBackward(
const WhileLoopAnalysis& loop_analysis, bool insert_non_alias_custom_call,
int64_t level_to_operate_on, bool process_different_sized_ops,
HloPredicate should_process, HloPredicate acceptable_formatting,
HloPredicate acceptable_formatting,
CollectivePipeliner::HloPostprocessor postprocess_peeled,
CollectivePipeliner::HloPostprocessor postprocess_rotated,
CollectivePipeliner::HloPostprocessor postprocess_peeled_trailing_op,
int64_t& next_channel_id,
CollectivePipeliner::HloPostprocessor post_processing_fn) {
// Defining some maps/sets to keep track of instructions duplicated.
Expand Down Expand Up @@ -2777,10 +2781,12 @@ static absl::Status TransformLoopBackward(
/*loop_variant_parameter_info=*/nullptr, post_processing_fn));

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx]));
TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx],
/*new_while_instr=*/nullptr));
}
if (postprocess_peeled.has_value()) {
TF_RETURN_IF_ERROR(postprocess_peeled.value()(new_init_operands[idx]));
TF_RETURN_IF_ERROR(postprocess_peeled.value()(
new_init_operands[idx], /*new_while_instr=*/nullptr));
}
}
ConstantValue next_loop_iteration =
Expand Down Expand Up @@ -2835,10 +2841,12 @@ static absl::Status TransformLoopBackward(
post_processing_fn));

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(cloned_instr));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(cloned_instr, /*new_while_instr=*/nullptr));
}
if (postprocess_rotated.has_value()) {
TF_RETURN_IF_ERROR(postprocess_rotated.value()(cloned_instr));
TF_RETURN_IF_ERROR(postprocess_rotated.value()(
cloned_instr, /*new_while_instr=*/nullptr));
}
} else {
auto new_operands =
Expand Down Expand Up @@ -2972,6 +2980,13 @@ static absl::Status TransformLoopBackward(
MapNewOperands(instr->operands(), while_body_replacement_map);
HloInstruction* cloned_instr = while_loop->parent()->AddInstruction(
instr->CloneWithNewOperands(instr->shape(), new_operands));

if (postprocess_peeled_trailing_op.has_value()) {
CHECK_NE(new_while_loop, nullptr);
TF_RETURN_IF_ERROR(
postprocess_peeled_trailing_op.value()(cloned_instr, new_while_loop));
}

TF_RETURN_IF_ERROR(UpdateControlDependencies(instr, cloned_instr,
while_body_replacement_map));
UpdateInstructionChannelId(cloned_instr, next_channel_id);
Expand All @@ -2998,6 +3013,7 @@ static absl::Status TransformLoopBackward(
TF_RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations());
return absl::OkStatus();
}

bool IsForwardSinkIterationFeasible(HloInstruction* while_inst,
int64_t collective_size_threshold) {
for (HloInstruction* inst :
Expand Down Expand Up @@ -3103,9 +3119,10 @@ absl::StatusOr<bool> CollectivePipeliner::RunPipeliner(
CHECK_EQ(config_.pipelining_direction, PipeliningDirection::kBackward);
TF_RETURN_IF_ERROR(TransformLoopBackward(
*loop_analysis, !config_.last_run, config_.level_to_operate_on,
config_.process_different_sized_ops, config_.should_process,
config_.acceptable_formatting, config_.postprocess_backward_peeled_op,
config_.postprocess_backward_rotated_op, next_channel_id,
config_.process_different_sized_ops, config_.acceptable_formatting,
config_.postprocess_backward_peeled_op,
config_.postprocess_backward_rotated_op,
config_.postprocess_backward_peeled_trailing_op, next_channel_id,
config_.postprocess_pipelined_ops));
}
++transformed_loops;
Expand Down
13 changes: 9 additions & 4 deletions xla/service/collective_pipeliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"

Expand Down Expand Up @@ -65,10 +66,12 @@ class CollectivePipeliner : public HloModulePass {
kForwardSink,
};

// Postprocessing cloned collective instructions, such as for modifying loop
// iteration related frontend attributes to reflect loop pipelining.
using HloPostprocessor =
std::optional<std::function<absl::Status(HloInstruction* instr)>>;
// Postprocessing cloned collective instructions, such as peeled instructions
// before and after the loop, and rotated instructions. The new while op is
// only passed for the peeled trailing ops when the new while op was already
// created.
using HloPostprocessor = std::optional<std::function<absl::Status(
HloInstruction* instr, HloInstruction* new_while_instr)>>;

struct Config {
int64_t level_to_operate_on = 0;
Expand Down Expand Up @@ -100,8 +103,10 @@ class CollectivePipeliner : public HloModulePass {
// pipelined. The control dependencies will be dropped when the operation is
// pipelined. This is currently only used to support kBackward pipelining.
bool should_allow_control_dependencies = false;
// TODO(b/399476667): Consolidate these postprocessing functions.
HloPostprocessor postprocess_backward_peeled_op = std::nullopt;
HloPostprocessor postprocess_backward_rotated_op = std::nullopt;
HloPostprocessor postprocess_backward_peeled_trailing_op = std::nullopt;
// Determines whether a loop invariant instruction can be considered
// in the pipelining chain.
bool should_add_loop_invariant_op_in_chain = false;
Expand Down
14 changes: 11 additions & 3 deletions xla/service/collective_pipeliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ absl::StatusOr<bool> RunOptimizer(
std::nullopt,
CollectivePipeliner::HloPostprocessor postprocess_backward_rotated =
std::nullopt,
CollectivePipeliner::HloPostprocessor postprocess_backward_peeled_trailing =
std::nullopt,
bool should_add_loop_invariant_op_in_chain = false,
int64_t collective_size_threshold_to_stop_sinking = INT64_MAX) {
CollectivePipeliner::Config config = {
Expand All @@ -101,7 +103,8 @@ absl::StatusOr<bool> RunOptimizer(
/*reuse_pipelined_op_buffer=*/reuse_pipelined_op_buffer,
should_allow_loop_variant_parameter_in_chain,
/*should_allow_control_dependencies=*/false, postprocess_backward_peeled,
postprocess_backward_rotated, should_add_loop_invariant_op_in_chain,
postprocess_backward_rotated, postprocess_backward_peeled_trailing,
should_add_loop_invariant_op_in_chain,
/*postprocess_pipelined_ops=*/std::nullopt,
collective_size_threshold_to_stop_sinking};
HloPassPipeline pass("optimizer");
Expand Down Expand Up @@ -2790,13 +2793,15 @@ TEST_F(CollectivePipelinerTest,
};
const char* kAttr = "_xla_other_attr";
// Mutate an existing attribute.
auto postprocess_peeled = [&](HloInstruction* instr) {
auto postprocess_peeled = [&](HloInstruction* instr,
HloInstruction* new_while_instr) {
xla::FrontendAttributes attributes = instr->frontend_attributes();
(*attributes.mutable_map())[kAttr] = "1";
instr->set_frontend_attributes(attributes);
return absl::OkStatus();
};
auto postprocess_rotated = [&](HloInstruction* instr) {
auto postprocess_rotated = [&](HloInstruction* instr,
HloInstruction* new_while_instr) {
xla::FrontendAttributes attributes = instr->frontend_attributes();
(*attributes.mutable_map())[kAttr] = "2";
instr->set_frontend_attributes(attributes);
Expand Down Expand Up @@ -3172,6 +3177,7 @@ ENTRY entry {
/*should_allow_loop_variant_parameter_in_chain=*/HloPredicateTrue,
/*postprocess_backward_peeled=*/std::nullopt,
/*postprocess_backward_rotated=*/std::nullopt,
/*postprocess_backward_peeled_trailing=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true)
.value());
XLA_VLOG_LINES(1, module->ToString());
Expand Down Expand Up @@ -3202,6 +3208,7 @@ ENTRY entry {
/*should_allow_loop_variant_parameter_in_chain=*/HloPredicateTrue,
/*postprocess_backward_peeled=*/std::nullopt,
/*postprocess_backward_rotated=*/std::nullopt,
/*postprocess_backward_peeled_trailing=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false)
.value());
}
Expand Down Expand Up @@ -3593,6 +3600,7 @@ ENTRY entry {
/*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse,
/*postprocess_backward_peeled=*/std::nullopt,
/*postprocess_backward_rotated=*/std::nullopt,
/*postprocess_backward_peeled_trailing=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false,
/*collective_size_threshold_to_stop_sinking=*/1024)
.value());
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand All @@ -905,6 +906,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand All @@ -928,6 +930,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand Down
26 changes: 22 additions & 4 deletions xla/service/gpu/gpu_p2p_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ absl::Status PostprocessP2PImpl(

// Modifies the loop iteration frontend attribute for the peeled off Send and
// Recv for the first iteration of a loop.
absl::Status PostprocessPeeledP2P(HloInstruction* instr) {
absl::Status PostprocessPeeledP2P(HloInstruction* instr,
HloInstruction* new_while_instr) {
// We only use this to post-process the peeled send/recv before the new loop
// was created.
CHECK(new_while_instr == nullptr);

auto transform_bounds = [&](std::vector<ReplicaGroup>& replica_groups) {
std::vector<std::pair<int64_t, int64_t>> bounds;
bounds.reserve(replica_groups.size());
Expand Down Expand Up @@ -210,7 +215,12 @@ absl::Status PostprocessPeeledP2P(HloInstruction* instr) {

// Modifies the loop iteration frontend attribute for the rotated Send and Recv
// for the remaining iterations in a loop.
absl::Status PostprocessRotatedP2P(HloInstruction* instr) {
absl::Status PostprocessRotatedP2P(HloInstruction* instr,
HloInstruction* new_while_instr) {
// We only use this to post-process the peeled send/recv before the new loop
// was created.
CHECK(new_while_instr == nullptr);

auto transform_bounds = [&](std::vector<ReplicaGroup>& replica_groups) {
std::vector<std::pair<int64_t, int64_t>> bounds;
bounds.reserve(replica_groups.size());
Expand Down Expand Up @@ -471,11 +481,19 @@ absl::StatusOr<bool> GpuP2PPipeliner::Run(

if (enable_partial_send_recv_pipelining_) {
should_process = FullyPipelineRecv;
postprocess_backward_peeled_op = [&](HloInstruction* it) {
postprocess_backward_peeled_op = [&](HloInstruction* it,
HloInstruction* new_while_instr) {
// When post-processing non-trailing peeled send/recv, the new while loop
// was not yet created.
CHECK_EQ(new_while_instr, nullptr);
peeled_send_recvs.push_back(it);
return absl::OkStatus();
};
postprocess_backward_rotated_op = [&](HloInstruction* it) {
postprocess_backward_rotated_op = [&](HloInstruction* it,
HloInstruction* new_while_instr) {
// When post-processing non-trailing peeled send/recv, the new while loop
// was not yet created.
CHECK_EQ(new_while_instr, nullptr);
rotated_send_recvs.push_back(it);
return absl::OkStatus();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ int64_t ComputeSuggestedCombinerThreshold(
return MaxAvailableMemory(module, device_info) - peak_memory_bytes;
}

absl::Status AppendPipelinedInstruction(HloInstruction* instr) {
absl::Status AppendPipelinedInstruction(HloInstruction* instr,
HloInstruction* new_while_instr) {
if (!IsCollective(instr)) {
return absl::OkStatus();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ int64_t ComputeSuggestedCombinerThreshold(
// Adds information that `instr` has been pipelined to the
// `CollectiveBackendInfo`. It is up to the caller to decide when to invoke
// this.
absl::Status AppendPipelinedInstruction(HloInstruction* instr);
absl::Status AppendPipelinedInstruction(HloInstruction* instr,
HloInstruction* new_while_instr);

// Returns true if module contains any pipelined instruction. False otherwise.
bool ContainsPipelinedInstruction(const HloModule& module);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ TEST_F(CollectiveCombinerUtilsTest,
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true,
};
config.postprocess_pipelined_ops = AppendPipelinedInstruction;
Expand Down Expand Up @@ -434,6 +435,7 @@ TEST_F(CollectiveCombinerUtilsTest,
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true,
};
config.postprocess_pipelined_ops = AppendPipelinedInstruction;
Expand Down

0 comments on commit 6019eb6

Please sign in to comment.