105
105
from .trace_rules import is_numpy
106
106
from .utils import (
107
107
CleanupManager ,
108
+ codecache_metrics ,
108
109
CompilationMetrics ,
109
110
counters ,
110
111
dynamo_timed ,
@@ -973,6 +974,7 @@ def format_guard_failures() -> str:
973
974
fail_user_frame_lineno : Optional [int ] = None
974
975
torch ._dynamo .utils .ReinplaceCounters .clear ()
975
976
guarded_code = None
977
+ codecache_metrics .clear ()
976
978
try :
977
979
guarded_code = compile_inner (code , one_graph , hooks , transform )
978
980
return guarded_code
@@ -1058,6 +1060,7 @@ def format_guard_failures() -> str:
1058
1060
remote_fx_graph_cache_put_time = frame_phase_timing [frame_key ].get (
1059
1061
"remote_fx_graph_cache_put" , None
1060
1062
)
1063
+ num_triton_bundles = codecache_metrics .get ("num_triton_bundles" , None )
1061
1064
torch ._dynamo .utils .ReinplaceCounters .log ()
1062
1065
1063
1066
else :
@@ -1078,6 +1081,7 @@ def format_guard_failures() -> str:
1078
1081
remote_cache_time_saved = None
1079
1082
remote_fx_graph_cache_get_time = None
1080
1083
remote_fx_graph_cache_put_time = None
1084
+ num_triton_bundles = None
1081
1085
1082
1086
structured_logging_overhead_s = (
1083
1087
torch ._logging .get_structured_logging_overhead ()
@@ -1146,6 +1150,7 @@ def clean_for_json(d: Dict[str, Any]) -> Dict[str, Any]:
1146
1150
config .specialize_float ,
1147
1151
json .dumps (config_dict ),
1148
1152
True , # is_forward
1153
+ num_triton_bundles ,
1149
1154
to_int_ms (remote_fx_graph_cache_get_time ),
1150
1155
to_int_ms (remote_fx_graph_cache_put_time ),
1151
1156
start_time_us = start_time_ns // 1000 ,
0 commit comments