Skip to content

Commit 87cd4d8

Browse files
MiguelAFHaunellsuhana13Miking98haoqiu1
authored
MedHELM V1 (#3403)
Co-authored-by: Alyssa Unell <alyunell9@gmail.com> Co-authored-by: suhana13 <suhana@stanford.edu> Co-authored-by: Michael Wornow <mwornow98@gmail.com> Co-authored-by: Suhana Bedi <57412795+suhana13@users.noreply.github.com> Co-authored-by: haoqiu1 <haoqiu@microsoft.com> Co-authored-by: HennyJie <cuihejie331771@gmail.com> Co-authored-by: Juan M. Banda <juan@jmbanda.com>
1 parent 4a1dd46 commit 87cd4d8

File tree

60 files changed

+8571
-669
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+8571
-669
lines changed

requirements.txt

+7-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ anthropic==0.38.0
1414
antlr4-python3-runtime==4.9.3
1515
anyio==4.8.0
1616
astunparse==1.6.3
17-
async-timeout==5.0.1
17+
async-timeout==4.0.3
1818
attrs==24.3.0
1919
audioread==3.0.1
2020
autokeras==1.0.20
2121
av==14.0.1
2222
awscli==1.33.44
2323
beautifulsoup4==4.12.3
24+
bert_score==0.3.13
2425
black==24.3.0
2526
blis==1.1.0
2627
boto3==1.34.162
@@ -131,6 +132,8 @@ keras==3.8.0
131132
keras-tuner==1.4.7
132133
kiwisolver==1.4.7
133134
kt-legacy==1.0.5
135+
langchain==0.3.9
136+
langchain-community==0.3.8
134137
langcodes==3.5.0
135138
langdetect==1.0.9
136139
language_data==1.3.0
@@ -223,6 +226,7 @@ pypinyin==0.49.0
223226
PySocks==1.7.1
224227
pytest==7.2.2
225228
python-dateutil==2.8.2
229+
python-docx==1.1.2
226230
python-utils==3.9.1
227231
pytorch-fid==0.3.0
228232
pytorch-lightning==2.0.9.post0
@@ -232,6 +236,8 @@ PyWavelets==1.6.0
232236
PyYAML==6.0.2
233237
qwen-vl-utils==0.0.8
234238
RapidFuzz==3.11.0
239+
rank_bm25==0.2.2
240+
referencing==0.35.1
235241
regex==2024.11.6
236242
reka-api==2.0.0
237243
requests==2.32.3

setup.cfg

+9
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,15 @@ heim =
279279
# Shared image dependencies
280280
crfm-helm[images]
281281

282+
medhelm =
283+
# Summarization metrics
284+
crfm-helm[summarization]
285+
286+
#MedHELM scenarios
287+
python-docx~=1.1.2
288+
langchain~=0.3.9
289+
lxml~=5.3.0
290+
282291
audiolm =
283292
crfm-helm[openai]
284293
crfm-helm[google]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import Any, List, Optional
2+
import os
3+
import re
4+
import sqlite3
5+
from helm.benchmark.adaptation.request_state import RequestState
6+
from helm.benchmark.annotation.annotator import Annotator
7+
from helm.common.hierarchical_logger import hlog
8+
from helm.benchmark.runner import get_benchmark_output_path
9+
10+
11+
class EhrSqlAnnotator(Annotator):
12+
"""
13+
Executes both ground truth and generated SQL queries on the eicu.sqlite database.
14+
"""
15+
16+
name = "ehr_sql"
17+
18+
def annotate(self, request_state: RequestState) -> Any:
19+
"""Evaluate SQL execution accuracy by running queries against the eicu.sqlite database."""
20+
21+
databases_root_path = os.path.join(get_benchmark_output_path(), "scenarios", "ehr_sql")
22+
database_path = os.path.join(databases_root_path, "eicu.sqlite")
23+
24+
assert len(request_state.instance.references) == 1
25+
ground_truth_sql = request_state.instance.references[0].output.text.strip()
26+
ground_truth_result: List[str] = []
27+
28+
# Execute the ground truth query
29+
try:
30+
with sqlite3.connect(database_path) as conn:
31+
cursor = conn.cursor()
32+
cursor.execute(ground_truth_sql)
33+
ground_truth_result = cursor.fetchall()
34+
except (sqlite3.OperationalError, sqlite3.Warning) as e:
35+
hlog(f"WARNING: Ground truth SQL failed with error: {e}")
36+
37+
# If ground truth SQL execution didn't return results, attempt to use extra_data["value"]
38+
if not ground_truth_result and request_state.instance.extra_data is not None:
39+
if "value" in request_state.instance.extra_data:
40+
extra_values = list(request_state.instance.extra_data["value"].values())
41+
42+
# Try inferring types from the database schema if possible
43+
with sqlite3.connect(database_path) as conn:
44+
cursor = conn.cursor()
45+
try:
46+
cursor.execute(ground_truth_sql)
47+
fetched_result = cursor.fetchone()
48+
if fetched_result:
49+
# Convert extra_values to match SQLite's expected types
50+
converted_values = [
51+
type(fetched_result[i])(extra_values[i]) for i in range(len(extra_values))
52+
]
53+
ground_truth_result = converted_values
54+
else:
55+
# If no rows were fetched, use `extra_values` as-is
56+
ground_truth_result = extra_values
57+
except sqlite3.OperationalError:
58+
# If query fails (syntax error, etc.), just use `extra_values` as-is
59+
ground_truth_result = extra_values
60+
61+
assert request_state.result is not None
62+
assert len(request_state.result.completions) == 1
63+
predicted_text = request_state.result.completions[0].text.strip()
64+
65+
predicted_sql_match = re.search(r"<\s*sql\s*>(.*?)<\/?\s*sql\s*>", predicted_text, re.DOTALL | re.IGNORECASE)
66+
predicted_sql = predicted_sql_match.group(1).strip() if predicted_sql_match else predicted_text.strip()
67+
68+
predicted_result: List[str] = []
69+
query_error: Optional[str] = None
70+
predicted_sql = predicted_sql.replace("`", "").strip()
71+
predicted_sql = re.sub(r"^sql\n", "", predicted_sql, flags=re.MULTILINE)
72+
if not predicted_sql:
73+
query_error = "No query generated"
74+
else:
75+
try:
76+
with sqlite3.connect(database_path) as conn:
77+
cursor = conn.cursor()
78+
cursor.execute(predicted_sql)
79+
predicted_result = cursor.fetchall()
80+
except (sqlite3.OperationalError, sqlite3.Warning) as e:
81+
query_error = str(e)
82+
83+
return {
84+
"predicted_result": predicted_result,
85+
"ground_truth_result": ground_truth_result,
86+
"query_error": query_error,
87+
}

src/helm/benchmark/metrics/basic_metrics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from urllib.parse import unquote
66

77
import numpy as np
8-
import scipy
9-
import calibration as cal
8+
import scipy # type: ignore
9+
import calibration as cal # type: ignore
1010
from helm.benchmark.adaptation.scenario_state import ScenarioState
1111
from helm.benchmark.metrics.evaluate_reference_metrics import compute_reference_metrics
1212
from helm.benchmark.metrics.efficiency_metrics import EfficiencyMetric
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import List
2+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
3+
from helm.benchmark.adaptation.request_state import RequestState
4+
from helm.benchmark.metrics.metric import Metric
5+
from helm.benchmark.metrics.metric_name import MetricName
6+
from helm.benchmark.metrics.metric_service import MetricService
7+
from helm.benchmark.metrics.statistic import Stat
8+
from helm.common.hierarchical_logger import hlog
9+
10+
11+
class EhrSqlMetric(Metric):
12+
"""
13+
Metric for evaluating the EHR SQL dataset, focusing on:
14+
1. Execution Accuracy – Whether the generated SQL query produces the same results as the ground truth.
15+
2. Query Validity – Whether the generated SQL executes without errors.
16+
3. Precision for Answerable Questions (Pans).
17+
4. Recall for Answerable Questions (Rans).
18+
"""
19+
20+
def evaluate_generation(
21+
self,
22+
adapter_spec: AdapterSpec,
23+
request_state: RequestState,
24+
metric_service: MetricService,
25+
eval_cache_path: str,
26+
) -> List[Stat]:
27+
"""
28+
Evaluate execution accuracy, query validity, and answerability metrics.
29+
"""
30+
31+
if not request_state.annotations:
32+
hlog(f"Warning: Request state missing annotations for instance {request_state.instance}")
33+
return []
34+
35+
if "ehr_sql" not in request_state.annotations:
36+
hlog(f"Warning: 'ehr_sql' key missing in annotations for instance {request_state.instance}")
37+
return []
38+
39+
# Extract execution results
40+
predicted_result = request_state.annotations["ehr_sql"].get("predicted_result", [])
41+
ground_truth_result = request_state.annotations["ehr_sql"].get("ground_truth_result", [])
42+
query_error = request_state.annotations["ehr_sql"].get("query_error", None)
43+
44+
# Extract predictions from the model output
45+
if request_state.result is None:
46+
predictions = []
47+
else:
48+
predictions = [completion.text.strip() for completion in request_state.result.completions]
49+
if not predictions:
50+
hlog(f"Warning: No predictions found in the completions for instance {request_state.instance}")
51+
return []
52+
53+
# Process the first prediction as the primary output
54+
prediction = predictions[0].strip()
55+
56+
# Extract references and input text
57+
references = getattr(request_state.instance, "references", None)
58+
59+
if not references or len(references) == 0:
60+
hlog(f"Warning: Missing references for instance {request_state.instance}")
61+
return []
62+
63+
# Check if the ground truth is answerable based on `is_impossible` flag
64+
ground_truth_query = references[0].output.text.strip() if references else None
65+
is_impossible = (
66+
request_state.instance.extra_data.get("is_impossible", False)
67+
if request_state.instance.extra_data
68+
else False
69+
)
70+
71+
is_answerable = not is_impossible and bool(ground_truth_query) # True if the ground truth is answerable
72+
is_predicted_answerable = bool(prediction) # True if the model generated a non-empty SQL query
73+
correct_answerable = int(is_answerable and is_predicted_answerable) # Correct if both are answerable
74+
75+
# **Execution Accuracy Fix:**
76+
execution_accuracy = 0
77+
78+
if ground_truth_query:
79+
if ground_truth_result and predicted_result:
80+
execution_accuracy = int(set(predicted_result) == set(ground_truth_result)) # Compare sets.
81+
elif not ground_truth_result and not predicted_result and not prediction:
82+
execution_accuracy = 1 # Both empty and no query was generated.
83+
elif not ground_truth_query and prediction:
84+
execution_accuracy = 0 # LLM generated a query when no gold query exists.
85+
86+
# **Query Validity Fix:**
87+
if not prediction: # No SQL query was generated
88+
query_validity = 0
89+
elif query_error is None:
90+
query_validity = 1 # Query executed successfully.
91+
else:
92+
query_validity = 0 # Execution error occurred.
93+
94+
return [
95+
# Execution-based Metrics
96+
Stat(MetricName("ehr_sql_execution_accuracy")).add(execution_accuracy),
97+
Stat(MetricName("ehr_sql_query_validity")).add(query_validity),
98+
# Answerability Metrics
99+
Stat(MetricName("ehr_sql_precision_answerable")).add(correct_answerable if is_predicted_answerable else 0),
100+
Stat(MetricName("ehr_sql_recall_answerable")).add(correct_answerable if is_answerable else 0),
101+
Stat(MetricName("ehr_sql_total_predicted_answerable")).add(int(is_predicted_answerable)),
102+
Stat(MetricName("ehr_sql_total_ground_truth_answerable")).add(int(is_answerable)),
103+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import re
2+
3+
from datetime import datetime
4+
from typing import List, Dict, Any
5+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
6+
from helm.benchmark.adaptation.request_state import RequestState
7+
from helm.benchmark.metrics.metric import Metric
8+
from helm.benchmark.metrics.metric_name import MetricName
9+
from helm.benchmark.metrics.metric_service import MetricService
10+
from helm.benchmark.metrics.statistic import Stat
11+
from helm.common.hierarchical_logger import hlog
12+
13+
14+
class MedCalcBenchMetric(Metric):
15+
"""
16+
Metric for evaluating the MedCalc Bench dataset, assessing the model's ability to
17+
be a clinical calculator.
18+
19+
Exact match based on category:
20+
1. Normal exact match: for categories "risk", "severity" or "diagnosis".
21+
2. Variant exact match: for other categories, if the number calculated by the model falls between the values
22+
in the Lower limit and Upper limit columns, we mark it as accurate.
23+
"""
24+
25+
def parse_duration(self, duration_str) -> int:
26+
"""Parses a duration tuple (weeks, days) from a string format like ('14 weeks', '2 days')."""
27+
match = re.match(r"\('(\d+) weeks', '(\d+) days'\)", duration_str)
28+
if match:
29+
weeks, days = map(int, match.groups())
30+
return weeks * 7 + days # Convert to total days
31+
else:
32+
raise ValueError(f"Invalid format: {duration_str}")
33+
34+
def is_within_range(self, lower_bound, upper_bound, prediction) -> int:
35+
"""
36+
Checks if a predicted duration falls within the given range.
37+
38+
Args:
39+
lower_bound (str): The lower bound in format "('X weeks', 'Y days')".
40+
upper_bound (str): The upper bound in format "('X weeks', 'Y days')".
41+
prediction (str): The predicted duration in the same format.
42+
43+
Returns:
44+
int: 1 if within range (inclusive), 0 otherwise.
45+
"""
46+
lower_days = self.parse_duration(lower_bound)
47+
upper_days = self.parse_duration(upper_bound)
48+
prediction_days = self.parse_duration(prediction)
49+
return 1 if lower_days <= prediction_days <= upper_days else 0
50+
51+
def check_date(self, prediction: str, reference: str, extra_data: Dict[str, Any]) -> int:
52+
"""Checks if prediction date is withing limits"""
53+
if re.match(r"\('(\d+) weeks', '(\d+) days'\)", reference):
54+
exact_match = self.is_within_range(extra_data["lower_limit"], extra_data["upper_limit"], prediction)
55+
else:
56+
prediction_date = self._str_to_date(prediction)
57+
upper_limit_date = self._str_to_date(extra_data["upper_limit"])
58+
lower_limit_date = self._str_to_date(extra_data["lower_limit"])
59+
exact_match = 1 if lower_limit_date <= prediction_date <= upper_limit_date else 0
60+
return exact_match
61+
62+
def _str_to_date(self, date_str: str) -> datetime:
63+
"""Convert string to datetime object."""
64+
return datetime.strptime(date_str, "%m/%d/%Y")
65+
66+
def check_in_range(self, prediction: str, reference: str, extra_data: Dict[str, Any], category: str) -> int:
67+
"""Check if the prediction falls within the range specified by the reference."""
68+
try:
69+
if category == "date":
70+
exact_match = self.check_date(prediction, reference, extra_data)
71+
elif category in ["dosage conversion", "physical"]:
72+
lower_limit = float(extra_data["lower_limit"])
73+
upper_limit = float(extra_data["upper_limit"])
74+
float_prediction = float(prediction)
75+
exact_match = 1 if lower_limit <= float_prediction <= upper_limit else 0
76+
else:
77+
raise ValueError(f"Category {category} not supported")
78+
except ValueError:
79+
return 0
80+
81+
return exact_match
82+
83+
def evaluate_generation(
84+
self,
85+
adapter_spec: AdapterSpec,
86+
request_state: RequestState,
87+
metric_service: MetricService,
88+
eval_cache_path: str,
89+
) -> List[Stat]:
90+
"""
91+
Evaluate a single generation against reference labels.
92+
"""
93+
# Extract predictions
94+
assert request_state.result, "request_state.result is unexpectedly None"
95+
predictions = [completion.text.strip() for completion in request_state.result.completions]
96+
97+
if not predictions:
98+
hlog("Warning: No predictions found in completions")
99+
return []
100+
101+
# Get the first prediction
102+
prediction = predictions[0]
103+
104+
# Get references
105+
references = getattr(request_state.instance, "references", None)
106+
107+
if not references or len(references) == 0:
108+
hlog(f"Warning: Missing references for instance {request_state.instance}")
109+
return []
110+
111+
reference = references[0].output.text
112+
113+
# Extract category, upper limit and lower limit
114+
assert request_state.instance.extra_data, "Extra data dict was expected but got None"
115+
category = request_state.instance.extra_data["category"]
116+
117+
if category in ["risk", "severity", "diagnosis"]:
118+
exact_match = 1 if prediction == reference else 0
119+
else:
120+
exact_match = self.check_in_range(prediction, reference, request_state.instance.extra_data, category)
121+
122+
return [
123+
Stat(MetricName("medcalc_bench_accuracy")).add(exact_match),
124+
]

0 commit comments

Comments
 (0)