From 6ddab483ca848ac1749ecc0135c0b426fad16de9 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 26 Feb 2025 05:01:30 -0800 Subject: [PATCH] Fix variable type and add comments for PriorityFusion (NFC). PiperOrigin-RevId: 731270233 --- xla/service/gpu/transforms/priority_fusion.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/transforms/priority_fusion.cc b/xla/service/gpu/transforms/priority_fusion.cc index c9a2f12d244cf..20c23d57c77d3 100644 --- a/xla/service/gpu/transforms/priority_fusion.cc +++ b/xla/service/gpu/transforms/priority_fusion.cc @@ -507,6 +507,8 @@ class PriorityFusionQueue { is_incremental_update ? operands_to_new_consumers_.find(producer)->second : absl::MakeConstSpan(producer->users()); + // Note that `gpu_performance_model_cache_` may contain a runtime estimate + // from the Triton cost model. GpuPerformanceModel::RunTimes run_times = GpuPerformanceModel::EstimateRunTimes( producer, *device_info_, &cost_analysis_, @@ -635,6 +637,8 @@ class PriorityFusionQueue { TiledRunTimeData tiled_run_time_data = std::get(std::move(tiled_run_time_data_or_error)); + // This is our way to pass the runtime estimate to the CalculatePriorities() + // function. gpu_performance_model_cache_.Set( *producer, *consumer, tiled_run_time_data.runtime_data.exec_time); @@ -983,7 +987,7 @@ absl::StatusOr PriorityFusion::Run( FusionDeduplicationCache fusion_deduplication_cache = FusionDeduplicationCache::Create(*module, IsFusible); - int changed = false; + bool changed = false; for (auto* computation : fusible_computations) { CHECK(!computation->IsFusionComputation());