Skip to content

Commit 1e1ce9f

Browse files
author
Yuheng Tu
committed
fix linter and type checker
1 parent 378c26b commit 1e1ce9f

File tree

3 files changed

+42
-14
lines changed

3 files changed

+42
-14
lines changed

src/helm/benchmark/adaptation/adapter_spec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class AdapterSpec:
132132

133133
image_generation_parameters: Optional[ImageGenerationParameters] = None
134134
"""Parameters for image generation."""
135-
135+
136136
reeval_parameters: Optional[ReevalParameters] = None
137137
"""Parameters for reeval evaluation."""
138138

src/helm/benchmark/reeval_run.py

+6
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ def main():
168168
args = parser.parse_args()
169169
validate_args(args)
170170

171+
if args.max_eval_instances:
172+
hlog(
173+
"WARNING: In reeval mode, max-eval-instances will not be used to downsample the evaluation instances. "
174+
"Use --max-samples"
175+
)
176+
171177
register_builtin_configs_from_helm_package()
172178
register_configs_from_directory(args.local_path)
173179

src/helm/benchmark/reeval_runner.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,11 @@ def closure():
9494
optim = torch.optim.LBFGS([ability], lr=0.1, max_iter=20, history_size=10, line_search_fn="strong_wolfe")
9595

9696
for iteration in range(100):
97-
if iteration > 0:
98-
prev_ability = ability.clone()
99-
prev_loss = loss.clone()
100-
10197
loss = optim.step(closure)
10298

10399
if iteration > 0:
100+
prev_ability = ability.clone()
101+
prev_loss = loss.clone()
104102
d_loss = prev_loss - loss
105103
d_theta = torch.norm(prev_ability - ability, p=2)
106104
grad_norm = torch.norm(optim.param_groups[0]["params"][0].grad, p=2)
@@ -167,10 +165,10 @@ def run_one(self, run_spec: RunSpec):
167165
scenario_name = scenario.name
168166
try:
169167
difficulty_dataset = load_dataset("stair-lab/reeval-difficulty", split=scenario_name)
170-
lookup = {row["request.prompt"]: row["z"] for row in difficulty_dataset}
171-
except Exception as e:
168+
lookup: dict[str, float] = {row["request.prompt"]: row["z"] for row in difficulty_dataset}
169+
except Exception:
172170
hlog(f"WARNING: no available difficulty for {scenario_name}, skipping")
173-
lookup = {}
171+
return
174172

175173
unasked_request_states: List[RequestState] = []
176174
for request_state in unasked_request_states_without_z:
@@ -182,8 +180,17 @@ def run_one(self, run_spec: RunSpec):
182180
unasked_request_states.append(new_request_state)
183181

184182
# Execute the requests in an reeval manner
185-
model_ability = run_spec.adapter_spec.reeval_parameters.model_ability
186-
scenario_metric_name = run_spec.adapter_spec.reeval_parameters.metric_name
183+
# TODO: look for better way to fix the type-checker error
184+
# model_ability = run_spec.adapter_spec.reeval_parameters.model_ability
185+
# scenario_metric_name = run_spec.adapter_spec.reeval_parameters.metric_name
186+
# max_samples = run_spec.adapter_spec.reeval_parameters.max_samples
187+
if run_spec.adapter_spec.reeval_parameters:
188+
if run_spec.adapter_spec.reeval_parameters.model_ability:
189+
model_ability = run_spec.adapter_spec.reeval_parameters.model_ability
190+
if run_spec.adapter_spec.reeval_parameters.metric_name:
191+
scenario_metric_name = run_spec.adapter_spec.reeval_parameters.metric_name
192+
if run_spec.adapter_spec.reeval_parameters.max_samples:
193+
max_samples = run_spec.adapter_spec.reeval_parameters.max_samples
187194

188195
asked_request_states: List[RequestState] = []
189196
stats: List[Stat] = []
@@ -194,13 +201,24 @@ def run_one(self, run_spec: RunSpec):
194201
"instance_difficulties": [],
195202
}
196203

197-
for _ in tqdm(range(run_spec.adapter_spec.reeval_parameters.max_samples), desc="Reeval execution"):
204+
for _ in tqdm(range(max_samples), desc="Reeval execution"):
198205
if not unasked_request_states:
199206
break
200207

201-
selected_item = min(
202-
unasked_request_states, key=lambda item: abs(item.instance.extra_data["difficulty"] - model_ability)
203-
)
208+
# TODO: look for better way to fix the type-checker error
209+
# selected_item = min(
210+
# unasked_request_states, key=lambda item: abs(item.instance.extra_data["difficulty"] - model_ability)
211+
# )
212+
# unasked_request_states.remove(selected_item)
213+
selected_item = None
214+
min_diff = float("inf")
215+
for item in unasked_request_states:
216+
assert item.instance.extra_data
217+
diff = abs(item.instance.extra_data["difficulty"] - model_ability)
218+
if diff < min_diff:
219+
min_diff = diff
220+
selected_item = item
221+
assert type(selected_item) is RequestState
204222
unasked_request_states.remove(selected_item)
205223

206224
# Execute the request
@@ -256,7 +274,11 @@ def run_one(self, run_spec: RunSpec):
256274
scenario_metric_value = [s for s in per_instance_stat[0].stats if s.name.name == scenario_metric_name][
257275
0
258276
].mean
277+
278+
# TODO: look for better way to fix the type-checker error
279+
assert scenario_metric_value
259280
reeval_trajectory["response_correctness"].append(scenario_metric_value)
281+
assert selected_item.instance.extra_data
260282
reeval_trajectory["instance_difficulties"].append(selected_item.instance.extra_data["difficulty"])
261283

262284
# Estimate the model ability

0 commit comments

Comments
 (0)