diff --git a/xla/service/gpu/transforms/priority_fusion.cc b/xla/service/gpu/transforms/priority_fusion.cc index c9a2f12d244cf..20c23d57c77d3 100644 --- a/xla/service/gpu/transforms/priority_fusion.cc +++ b/xla/service/gpu/transforms/priority_fusion.cc @@ -507,6 +507,8 @@ class PriorityFusionQueue { is_incremental_update ? operands_to_new_consumers_.find(producer)->second : absl::MakeConstSpan(producer->users()); + // Note that `gpu_performance_model_cache_` may contain a runtime estimate + // from the Triton cost model. GpuPerformanceModel::RunTimes run_times = GpuPerformanceModel::EstimateRunTimes( producer, *device_info_, &cost_analysis_, @@ -635,6 +637,8 @@ class PriorityFusionQueue { TiledRunTimeData tiled_run_time_data = std::get(std::move(tiled_run_time_data_or_error)); + // This is our way to pass the runtime estimate to the CalculatePriorities() + // function. gpu_performance_model_cache_.Set( *producer, *consumer, tiled_run_time_data.runtime_data.exec_time); @@ -983,7 +987,7 @@ absl::StatusOr PriorityFusion::Run( FusionDeduplicationCache fusion_deduplication_cache = FusionDeduplicationCache::Create(*module, IsFusible); - int changed = false; + bool changed = false; for (auto* computation : fusible_computations) { CHECK(!computation->IsFusionComputation());