diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 6d763b6ca3..6a012233d5 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -209,8 +209,13 @@ def select_metric(m): row.append(x_val) # Append x_val_only metrics for x_only_metric in x_only_metrics: - x_only_metric_dict = asdict(y_val[y_val_keys[0]]) - if "extra_metrics" in x_only_metric_dict and x_only_metric in x_only_metric_dict["extra_metrics"]: + x_only_metric_dict = asdict( + y_val[y_val_keys[0][0]] + ) # retrieve canonical name for metric function, where y_val_keys[0] = (canonical name, customized label name) + if ( + "extra_metrics" in x_only_metric_dict + and x_only_metric in x_only_metric_dict["extra_metrics"] + ): row.append(x_only_metric_dict["extra_metrics"][x_only_metric]) else: row.append(x_only_metric_dict[x_only_metric])