-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from MaastrichtU-BISS/metrics_alternative
Introduce an alternative wrappers for regression metrics
- Loading branch information
Showing
20 changed files
with
734 additions
and
464 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .classification_metrics import ( | ||
PERFORMANCE_METRICS as performance, | ||
FAIRNESS_METRICS as fairness, | ||
EXPLAINABILITY_METRICS as explainability | ||
) | ||
|
||
class ClassificationMetrics: | ||
def __init__(self): | ||
self.performance = performance | ||
self.fairness = fairness | ||
self.explainability = explainability | ||
|
||
# Create an instance for easy access | ||
metrics = ClassificationMetrics() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ..config_loader import load_metrics | ||
|
||
PERFORMANCE_METRICS, FAIRNESS_METRICS, EXPLAINABILITY_METRICS = load_metrics("classification/classification_metrics.yaml") |
100 changes: 100 additions & 0 deletions
100
src/faivor/metrics/classification/classification_metrics.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
performance: | ||
- function_name: accuracy_score | ||
regular_name: Accuracy Score | ||
description: Accuracy classification score. | ||
func: sklearn.metrics.accuracy_score | ||
is_torch: false | ||
- function_name: balanced_accuracy_score | ||
regular_name: Balanced Accuracy Score | ||
description: Balanced accuracy classification score. | ||
func: sklearn.metrics.balanced_accuracy_score | ||
is_torch: false | ||
- function_name: average_precision_score | ||
regular_name: Average Precision Score | ||
description: Compute average precision (AP) from prediction scores. | ||
func: sklearn.metrics.average_precision_score | ||
is_torch: false | ||
- function_name: f1_score | ||
regular_name: F1 Score | ||
description: F1 score, harmonic mean of precision and recall. | ||
func: sklearn.metrics.f1_score | ||
is_torch: false | ||
- function_name: precision_score | ||
regular_name: Precision Score | ||
description: Precision classification score. | ||
func: sklearn.metrics.precision_score | ||
is_torch: false | ||
- function_name: recall_score | ||
regular_name: Recall Score | ||
description: Recall classification score. | ||
func: sklearn.metrics.recall_score | ||
is_torch: false | ||
- function_name: roc_auc_score | ||
regular_name: ROC AUC Score | ||
description: Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. | ||
func: sklearn.metrics.roc_auc_score | ||
is_torch: false | ||
- function_name: jaccard_score | ||
regular_name: Jaccard Score | ||
description: Jaccard similarity coefficient. | ||
func: sklearn.metrics.jaccard_score | ||
is_torch: false | ||
- function_name: log_loss | ||
regular_name: Log Loss | ||
description: Log loss, aka logistic regression loss or cross-entropy loss. | ||
func: sklearn.metrics.log_loss | ||
is_torch: false | ||
- function_name: matthews_corrcoef | ||
regular_name: Matthews Correlation Coefficient | ||
description: Compute the Matthews correlation coefficient (MCC). | ||
func: sklearn.metrics.matthews_corrcoef | ||
is_torch: false | ||
- function_name: brier_score_loss | ||
regular_name: Brier Score Loss | ||
description: Compute the Brier score loss. | ||
func: sklearn.metrics.brier_score_loss | ||
is_torch: false | ||
- function_name: top_k_accuracy_score | ||
regular_name: Top K Accuracy Score | ||
description: Top-k accuracy classification score. | ||
func: sklearn.metrics.top_k_accuracy_score | ||
is_torch: false | ||
- function_name: roc_curve | ||
regular_name: ROC Curve | ||
description: Compute Receiver operating characteristic (ROC) curve. | ||
func: sklearn.metrics.roc_curve | ||
is_torch: false | ||
- function_name: precision_recall_curve | ||
regular_name: Precision Recall Curve | ||
description: Compute precision-recall pairs for different probability thresholds. | ||
func: sklearn.metrics.precision_recall_curve | ||
is_torch: false | ||
- function_name: hamming_loss | ||
regular_name: Hamming Loss | ||
description: Compute the average Hamming loss or Hamming distance between two sets of samples. | ||
func: sklearn.metrics.hamming_loss | ||
is_torch: false | ||
- function_name: zero_one_loss | ||
regular_name: Zero One Loss | ||
description: Zero-one classification loss. | ||
func: sklearn.metrics.zero_one_loss | ||
is_torch: false | ||
- function_name: confusion_matrix | ||
regular_name: Confusion Matrix | ||
description: Compute confusion matrix to evaluate the accuracy of a classification. | ||
func: sklearn.metrics.confusion_matrix | ||
is_torch: false | ||
|
||
fairness: | ||
- function_name: disparate_impact | ||
regular_name: Disparate Impact | ||
description: Calculates the disparate impact for classification by comparing the rate of favorable outcomes for different groups. | ||
func: faivor.metrics.classification.fairness.disparate_impact | ||
is_torch: false | ||
|
||
explainability: | ||
- function_name: prediction_entropy | ||
regular_name: Prediction Entropy | ||
description: Calculates the entropy of predictions to measure model uncertainty. | ||
func: faivor.metrics.classification.explainability.prediction_entropy | ||
is_torch: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,41 @@ | ||
from typing import List | ||
from sklearn import metrics as skm | ||
|
||
__all__ = ["ClassificationExplainabilityMetrics"] | ||
|
||
|
||
class ClassificationExplainabilityMetricsMeta(type): | ||
"""Metaclass for dynamically creating classification explainability metric classes.""" | ||
|
||
_WHITELISTED_METRICS: List[str] = [] # sklearn doesn't provide direct explainability metrics | ||
|
||
def __new__(mcs, name, bases, dct): | ||
"""Creates a new class, inheriting from skm metrics.""" | ||
for metric_name in mcs._WHITELISTED_METRICS: | ||
metric_function = getattr(skm, metric_name, None) | ||
if metric_function: | ||
def method_wrapper(self, y_true, y_pred, **kwargs): | ||
return metric_function(y_true, y_pred, **kwargs) | ||
dct[metric_name] = method_wrapper | ||
return super().__new__(mcs, name, bases, dct) | ||
|
||
|
||
class BaseClassificationExplainabilityMetrics: | ||
"""Base class for classification explainability metrics.""" | ||
pass | ||
|
||
|
||
class ClassificationExplainabilityMetrics(BaseClassificationExplainabilityMetrics, metaclass=ClassificationExplainabilityMetricsMeta): | ||
"""Class for classification explainability metrics.""" | ||
|
||
def custom_prediction_entropy(self, probas): | ||
"""Calculate the average entropy of prediction probabilities.""" | ||
import numpy as np | ||
probas = np.asarray(probas) | ||
log_probs = np.log2(probas) | ||
return -np.mean(np.sum(probas * log_probs, axis=1)) | ||
import numpy as np | ||
from scipy.stats import entropy | ||
|
||
def prediction_entropy(y_prob) -> float: | ||
""" | ||
Calculates the entropy of predictions for classification. | ||
Entropy is a measure of uncertainty. Higher entropy in predictions indicates | ||
higher model uncertainty. This function computes the average entropy across all predictions. | ||
Parameters | ||
---------- | ||
y_prob : array-like of shape (n_samples, n_classes) or (n_samples,) | ||
The predicted probabilities for each class. Can be either: | ||
- A 2D array of shape (n_samples, n_classes) where each row represents | ||
the probability distribution over classes for a single sample. | ||
- A 1D array of shape (n_samples,) for binary classification, representing | ||
the probability of the positive class (class 1). | ||
Returns | ||
------- | ||
float | ||
The average prediction entropy. Returns np.nan if input is empty or invalid. | ||
""" | ||
y_prob = np.asarray(y_prob) | ||
if y_prob.size == 0: | ||
return np.nan | ||
|
||
if y_prob.ndim == 1: # assume binary classification and probabilities are for positive class | ||
y_prob = np.vstack([1 - y_prob, y_prob]).T # create 2D prob array: [[p(class0), p(class1)], ...] | ||
|
||
if np.any(y_prob < 0) or np.any(y_prob > 1): | ||
return np.nan # probabilities should be between 0 and 1 | ||
|
||
# Normalize probabilities to ensure they sum to 1 (handle potential rounding errors) | ||
y_prob_normalized = y_prob / np.sum(y_prob, axis=1, keepdims=True) | ||
|
||
# Calculate entropy for each prediction | ||
entropies = entropy(y_prob_normalized, axis=1) | ||
|
||
return np.mean(entropies) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,55 @@ | ||
from typing import List | ||
from sklearn import metrics as skm | ||
from torchmetrics import Accuracy, F1Score, Precision, Recall | ||
import torch | ||
__all__ = ["ClassificationFairnessMetrics"] | ||
|
||
|
||
class ClassificationFairnessMetricsMeta(type): | ||
"""Metaclass for dynamically creating classification fairness metric classes.""" | ||
|
||
_WHITELISTED_METRICS: List[str] = [ | ||
"accuracy_score", # useful for group fairness comparisons | ||
] | ||
|
||
def __new__(mcs, name, bases, dct): | ||
"""Creates a new class, inheriting from skm metrics.""" | ||
for metric_name in mcs._WHITELISTED_METRICS: | ||
metric_function = getattr(skm, metric_name, None) | ||
if metric_function: | ||
def method_wrapper(self, y_true, y_pred, **kwargs): | ||
return metric_function(y_true, y_pred, **kwargs) | ||
dct[metric_name] = method_wrapper | ||
|
||
for metric_name in ["accuracy", "f1_score", "precision", "recall"]: | ||
if metric_name == "accuracy": | ||
metric_class = Accuracy | ||
elif metric_name == "f1_score": | ||
metric_class = F1Score | ||
elif metric_name == "precision": | ||
metric_class = Precision | ||
elif metric_name == "recall": | ||
metric_class = Recall | ||
|
||
def torchmetrics_method_wrapper(self, y_true, y_pred, **kwargs): | ||
metric = metric_class(task = "binary", **kwargs) | ||
return metric( | ||
torch.tensor(y_pred, dtype = torch.float32), | ||
torch.tensor(y_true, dtype= torch.int), | ||
).detach().cpu().item() | ||
dct[metric_name] = torchmetrics_method_wrapper | ||
return super().__new__(mcs, name, bases, dct) | ||
|
||
|
||
class BaseClassificationFairnessMetrics: | ||
"""Base class for classification fairness metrics.""" | ||
pass | ||
|
||
|
||
class ClassificationFairnessMetrics(BaseClassificationFairnessMetrics, metaclass=ClassificationFairnessMetricsMeta): | ||
"""Class for classification fairness metrics.""" | ||
|
||
def custom_disparate_impact(self, y_true, y_pred, sensitive_attribute): | ||
"""Calculates Disparate Impact for classification.""" | ||
import numpy as np | ||
y_true, y_pred, sensitive_attribute = np.asarray(y_true), np.asarray(y_pred), np.asarray(sensitive_attribute) | ||
|
||
unique_sensitive_values = np.unique(sensitive_attribute) | ||
if len(unique_sensitive_values) < 2: | ||
return np.nan | ||
|
||
group_positive_rates = [] | ||
for value in unique_sensitive_values: | ||
group_mask = sensitive_attribute == value | ||
if group_mask.sum() == 0: | ||
group_positive_rates.append(np.nan) | ||
else: | ||
group_positive_rates.append(np.mean(y_pred[group_mask] == np.max(y_pred))) # Assuming 1 is the positive class | ||
|
||
group_positive_rates = np.asarray(group_positive_rates) | ||
if np.isnan(group_positive_rates).any(): | ||
return np.nan | ||
|
||
return np.min(group_positive_rates) / np.max(group_positive_rates) | ||
|
||
import numpy as np | ||
|
||
def disparate_impact(y_true, y_pred, sensitive_attribute, favorable_outcome=1) -> float: | ||
""" | ||
Calculates Disparate Impact for classification. | ||
Disparate Impact (DI) is the ratio of the rate of favorable outcomes for the | ||
disadvantaged group compared to the advantaged group. A common threshold for | ||
concern is DI < 0.8, indicating potential adverse impact. | ||
Parameters | ||
---------- | ||
y_true : array-like of shape (n_samples,) | ||
The true target values (binary: 0 or 1). | ||
y_pred : array-like of shape (n_samples,) | ||
The predicted target values (binary: 0 or 1). | ||
sensitive_attribute : array-like of shape (n_samples,) | ||
The sensitive attribute values (categorical). | ||
favorable_outcome : int or float, default=1 | ||
The value representing the favorable outcome in y_true and y_pred. | ||
Returns | ||
------- | ||
float | ||
The disparate impact ratio. Returns np.nan if there's only one group or | ||
if the advantaged group has no favorable outcomes. | ||
""" | ||
y_true, y_pred, sensitive_attribute = ( | ||
np.asarray(y_true), | ||
np.asarray(y_pred), | ||
np.asarray(sensitive_attribute), | ||
) | ||
|
||
unique_sensitive_values = np.unique(sensitive_attribute) | ||
if len(unique_sensitive_values) < 2: | ||
return np.nan # Not applicable for less than 2 groups | ||
|
||
favorable_rates = {} | ||
for value in unique_sensitive_values: | ||
group_mask = sensitive_attribute == value | ||
group_size = group_mask.sum() | ||
if group_size == 0: | ||
favorable_rates[value] = 0 # Handle empty groups to avoid division by zero later, assume 0 favorable rate | ||
else: | ||
favorable_outcomes_count = np.sum(y_pred[group_mask] == favorable_outcome) | ||
favorable_rates[value] = favorable_outcomes_count / group_size | ||
|
||
rates = np.array(list(favorable_rates.values())) | ||
min_rate = np.min(rates) | ||
max_rate = np.max(rates) | ||
|
||
if max_rate == 0: # avoid division by zero if advantaged group has no favorable outcomes | ||
return np.nan | ||
|
||
return min_rate / max_rate |
Oops, something went wrong.