Skip to content

Commit

Permalink
fix(model quality): Simplify performance by metrics (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-encord authored Dec 16, 2022
1 parent 27379fd commit 8759144
Showing 1 changed file with 13 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import re
from collections import namedtuple
from enum import Enum
from typing import Optional

import altair as alt
import altair.vegalite.v4.api as alt_api
import pandas as pd
import streamlit as st

Expand All @@ -14,10 +14,6 @@
FLOAT_FMT = ",.4f"
PCT_FMT = ",.2f"
COUNT_FMT = ",d"
CompositeChart = namedtuple("CompositeChart", "chart xmin xmax")

sub_plot_args = dict(width=370, height=300)
zoom_view_args = dict(height=50, width=800, title="Select to Zoom")


class ChartSubject(Enum):
Expand All @@ -35,7 +31,7 @@ def build_bar_chart(
show_decomposition: bool,
title: str,
subject: ChartSubject,
) -> alt.Chart:
) -> alt_api.Chart:
str_type = "predictions" if subject == ChartSubject.TPR else "labels"
largest_bin_count = sorted_predictions["bin"].value_counts().max()
chart = (
Expand All @@ -50,7 +46,7 @@ def build_bar_chart(
if show_decomposition:
# Aggregate over each class
return chart.encode(
alt.X("bin:Q", scale=self.x_scale),
alt.X("bin:Q"),
alt.Y("sum(pctf):Q", stack="zero"),
alt.Color("class_name:N", scale=self.class_scale, legend=alt.Legend(symbolOpacity=1)),
tooltip=[
Expand All @@ -63,7 +59,7 @@ def build_bar_chart(
else:
# Only use aggregate over all classes
return chart.encode(
alt.X("bin:Q", scale=self.x_scale),
alt.X("bin:Q"),
alt.Y("sum(pctf):Q", stack="zero"),
tooltip=[
alt.Tooltip("bin", title=metric_name, format=FLOAT_FMT),
Expand All @@ -74,12 +70,12 @@ def build_bar_chart(

def build_line_chart(
self, bar_chart: alt.Chart, metric_name: str, show_decomposition: bool, title: str
) -> alt.Chart:
) -> alt_api.Chart:
legend = alt.Legend(title="class name".title())
title_shorthand = "".join(w[0].upper() for w in title.split())

line_chart = bar_chart.mark_line(point=True, opacity=0.5 if show_decomposition else 1.0).encode(
alt.X("bin:Q", scale=self.x_scale),
alt.X("bin:Q"),
alt.Y("mean(indicator):Q"),
alt.Color("average:N", legend=legend, scale=self.class_scale),
tooltip=[
Expand All @@ -99,12 +95,10 @@ def build_line_chart(
alt.Tooltip("class_name:N", title="Class name"),
],
strokeDash=alt.value([10, 0]),
opacity=alt.condition(self.class_selection, alt.value(1), alt.value(0.1)),
)
line_chart = line_chart.add_selection(self.class_selection)
return line_chart

def build_average_rule(self, indicator_mean: float, title: str):
def build_average_rule(self, indicator_mean: float, title: str) -> alt_api.Chart:
title_shorthand = "".join(w[0].upper() for w in title.split())
return (
alt.Chart(pd.DataFrame({"y": [indicator_mean], "average": ["Average"]}))
Expand All @@ -119,7 +113,7 @@ def build_average_rule(self, indicator_mean: float, title: str):

def make_composite_chart(
self, df: pd.DataFrame, title: str, metric_name: str, subject: ChartSubject
) -> Optional[CompositeChart]:
) -> Optional[alt_api.LayerChart]:
# Avoid over-shooting number of bins.
if metric_name not in df.columns:
return None
Expand Down Expand Up @@ -149,13 +143,9 @@ def make_composite_chart(
line_chart = self.build_line_chart(bar_chart, metric_name, show_decomposition, title=title)
mean_rule = self.build_average_rule(df["indicator"].mean(), title=title)

chart_composition = bar_chart + line_chart + mean_rule
chart_composition = chart_composition.encode(
alt.X(title=metric_name.title()), alt.Y(title=title.title())
).properties(
**sub_plot_args
) # TODO Is there a better way to set size?
return CompositeChart(chart_composition, df["bin"].min(), df["bin"].max())
chart_composition: alt_api.LayerChart = bar_chart + line_chart + mean_rule
chart_composition = chart_composition.encode(alt.X(title=metric_name.title()), alt.Y(title=title.title()))
return chart_composition

def sidebar_options(self):
c1, c2, c3 = st.columns([4, 4, 3])
Expand Down Expand Up @@ -241,17 +231,13 @@ def build(
self.class_scale = alt.Scale(
domain=classes_for_coloring,
) # Used to sync colors between plots.
self.class_selection = alt.selection_multi(fields=["class_name"]) # Used to sync selections between plots.

# Allow zooming via a "zoom_chart"
self.x_scale_interval = alt.selection_interval(encodings=["x"])
self.x_scale = alt.Scale(domain=self.x_scale_interval.ref())

# TPR
predictions = model_predictions.rename(columns={"tps": "indicator"})
tpr = self.make_composite_chart(predictions, "True Positive Rate", metric_name, subject=ChartSubject.TPR)
if tpr is None:
st.stop()
st.altair_chart(tpr.interactive(), use_container_width=True)

# FNR
fnr = self.make_composite_chart(
Expand All @@ -262,28 +248,6 @@ def build(
)

if fnr is None: # Label metric couldn't be matched to
zoom_chart = (
tpr.chart.encode(
alt.X("bin:Q"), alt.Y(title="TPR") # Avoid zooming the actual zoom selection view.
)
.properties(**zoom_view_args)
.add_selection(self.x_scale_interval)
)
st.altair_chart(zoom_chart & tpr.chart.properties(width=800))
st.stop()

# Zoom chart - get entire range of boths tprs and fnrs
xmin = min(tpr.xmin, fnr.xmin)
xmax = max(tpr.xmax, fnr.xmax)
zoom_chart = (
tpr.chart.encode(
alt.X(
"bin:Q", scale=alt.Scale(domain=[xmin, xmax])
), # Avoid zooming the actual zoom selection view.
alt.Y(title="TPR"),
)
.properties(**zoom_view_args)
.add_selection(self.x_scale_interval)
)

st.altair_chart(zoom_chart & (tpr.chart | fnr.chart))
st.altair_chart(fnr.interactive(), use_container_width=True)

0 comments on commit 8759144

Please sign in to comment.