Skip to content

Commit

Permalink
[XLA:GPU] Fix post-optimization pipeline parallelism tests
Browse files Browse the repository at this point in the history
The test does not run HLO passes.
In particular, copy insertion does not run on this input.
This means the input must guarantee non-conflicting live ranges of all buffers.
The new copies and control dependencies enforce this guarantee.
With HLO passes enabled, this would be enforced by the copy insertion pass.

PiperOrigin-RevId: 730568729
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 27, 2025
1 parent df40c9e commit 03c4f79
Showing 1 changed file with 37 additions and 5 deletions.
42 changes: 37 additions & 5 deletions xla/tests/collective_pipeline_parallelism_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class CollectivePipelineParallelismTest
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_experimental_pipeline_parallelism_opt_level(
xla_gpu_experimental_pipeline_parallelism_opt_level_);
debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true);
debug_options.set_xla_gpu_collective_permute_decomposer_threshold(0);
config.set_debug_options(debug_options);

return config;
Expand Down Expand Up @@ -113,6 +115,12 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
<< test_runner().device_count() << " available)";
}

// TODO(b/398888176): Remove this skip once cycle decomposer is removed.
if (xla_gpu_experimental_pipeline_parallelism_opt_level_ ==
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE_CYCLE_DECOMPOSER) {
GTEST_SKIP();
}

// Parse HLO module.
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);
Expand Down Expand Up @@ -309,6 +317,12 @@ XLA_TEST_P(CollectivePipelineParallelismTest, NaiveBFSMicrobatch4Replica4) {
<< test_runner().device_count() << " available)";
}

// TODO(b/398888176): Remove this skip once cycle decomposer is removed.
if (xla_gpu_experimental_pipeline_parallelism_opt_level_ ==
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE_CYCLE_DECOMPOSER) {
GTEST_SKIP();
}

// Parse HLO module.
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);
Expand All @@ -333,6 +347,7 @@ XLA_TEST_P(CollectivePipelineParallelismTest, NaiveBFSMicrobatch4Replica4) {
const int64_t kMicrobatches = 4;
Literal real_input =
LiteralUtil::CreateFingerprintMatixR2<float>(kMicrobatches, kInputSize);

Literal fake_input =
LiteralUtil::CreateFull<float>({kMicrobatches, kInputSize}, 0.0);

Expand Down Expand Up @@ -432,6 +447,12 @@ XLA_TEST_P(CollectivePipelineParallelismTest, NaiveBFSMicrobatch5Replica4) {
<< test_runner().device_count() << " available)";
}

// TODO(b/398888176): Remove this skip once cycle decomposer is removed.
if (xla_gpu_experimental_pipeline_parallelism_opt_level_ ==
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE_CYCLE_DECOMPOSER) {
GTEST_SKIP();
}

// Parse HLO module.
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);
Expand Down Expand Up @@ -555,6 +576,12 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
<< test_runner().device_count() << " available)";
}

// TODO(b/398888176): Remove this skip once cycle decomposer is removed.
if (xla_gpu_experimental_pipeline_parallelism_opt_level_ ==
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE_CYCLE_DECOMPOSER) {
GTEST_SKIP();
}

// Parse HLO module.
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);
Expand Down Expand Up @@ -903,7 +930,8 @@ XLA_TEST_P(CollectivePipelineParallelismTest, SendRecvLoop) {
// Send data from GPU i to i+1. Break cycle to avoid deadlock.
after_all = token[] after-all()
send_ctx = (f32[2,2], u32[], token[]) send(data, after_all),
data_cpy = f32[2,2] copy(data)
send_ctx = (f32[2,2], u32[], token[]) send(data_cpy, after_all),
frontend_attributes={
_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=1
recv_ctx = (f32[2,2], u32[], token[]) recv(after_all),
Expand Down Expand Up @@ -967,6 +995,7 @@ XLA_TEST_P(CollectivePipelineParallelismTest, SendRecvLoop) {
ExecuteReplicated(std::move(module), inputs,
/*num_replicas=*/kNumPartitions,
/*run_hlo_passes=*/false, &device_assignment));

LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {0, 0}}, results[0]);
LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {0, 0}}, results[1]);
LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {0, 0}}, results[2]);
Expand Down Expand Up @@ -1083,12 +1112,14 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
recv_done = (f32[2,2], token[]) recv-done(recv_ctx), channel_id=2
data = get-tuple-element(recv_done), index=0
after_all = token[] after-all()
send_ctx_ = (f32[2,2], u32[], token[]) send(data, after_all),
data_cpy = f32[2,2] copy(data)
send_ctx_ = (f32[2,2], u32[], token[]) send(data_cpy, after_all),
frontend_attributes={
_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=1
recv_ctx_ = (f32[2,2], u32[], token[]) recv(after_all),
frontend_attributes={
_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=2
_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=2,
control-predecessors={data_cpy}
c1 = u32[] constant(1)
i_ = u32[] add(i, c1)
ROOT result = (u32[], (f32[2,2], u32[], token[]),
Expand Down Expand Up @@ -1156,6 +1187,7 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
ExecuteReplicated(std::move(module), inputs,
/*num_replicas=*/kNumPartitions,
/*run_hlo_passes=*/false, &device_assignment));

LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {0, 0}}, results[0]);
LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {0, 0}}, results[1]);
LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {0, 0}}, results[2]);
Expand Down Expand Up @@ -1186,10 +1218,10 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
data = get-tuple-element(recv_done), index=0
after_all = token[] after-all()
send_ctx_ = (f32[2,2], u32[], token[]) send(data, after_all),
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
channel_id=1
recv_ctx_ = (f32[2,2], u32[], token[]) recv(after_all),
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
channel_id=2
c1 = u32[] constant(1)
i_ = u32[] add(i, c1)
Expand Down

0 comments on commit 03c4f79

Please sign in to comment.