@@ -94,13 +94,11 @@ def closure():
94
94
optim = torch .optim .LBFGS ([ability ], lr = 0.1 , max_iter = 20 , history_size = 10 , line_search_fn = "strong_wolfe" )
95
95
96
96
for iteration in range (100 ):
97
- if iteration > 0 :
98
- prev_ability = ability .clone ()
99
- prev_loss = loss .clone ()
100
-
101
97
loss = optim .step (closure )
102
98
103
99
if iteration > 0 :
100
+ prev_ability = ability .clone ()
101
+ prev_loss = loss .clone ()
104
102
d_loss = prev_loss - loss
105
103
d_theta = torch .norm (prev_ability - ability , p = 2 )
106
104
grad_norm = torch .norm (optim .param_groups [0 ]["params" ][0 ].grad , p = 2 )
@@ -167,10 +165,10 @@ def run_one(self, run_spec: RunSpec):
167
165
scenario_name = scenario .name
168
166
try :
169
167
difficulty_dataset = load_dataset ("stair-lab/reeval-difficulty" , split = scenario_name )
170
- lookup = {row ["request.prompt" ]: row ["z" ] for row in difficulty_dataset }
171
- except Exception as e :
168
+ lookup : dict [ str , float ] = {row ["request.prompt" ]: row ["z" ] for row in difficulty_dataset }
169
+ except Exception :
172
170
hlog (f"WARNING: no available difficulty for { scenario_name } , skipping" )
173
- lookup = {}
171
+ return
174
172
175
173
unasked_request_states : List [RequestState ] = []
176
174
for request_state in unasked_request_states_without_z :
@@ -182,8 +180,17 @@ def run_one(self, run_spec: RunSpec):
182
180
unasked_request_states .append (new_request_state )
183
181
184
182
# Execute the requests in an reeval manner
185
- model_ability = run_spec .adapter_spec .reeval_parameters .model_ability
186
- scenario_metric_name = run_spec .adapter_spec .reeval_parameters .metric_name
183
+ # TODO: look for better way to fix the type-checker error
184
+ # model_ability = run_spec.adapter_spec.reeval_parameters.model_ability
185
+ # scenario_metric_name = run_spec.adapter_spec.reeval_parameters.metric_name
186
+ # max_samples = run_spec.adapter_spec.reeval_parameters.max_samples
187
+ if run_spec .adapter_spec .reeval_parameters :
188
+ if run_spec .adapter_spec .reeval_parameters .model_ability :
189
+ model_ability = run_spec .adapter_spec .reeval_parameters .model_ability
190
+ if run_spec .adapter_spec .reeval_parameters .metric_name :
191
+ scenario_metric_name = run_spec .adapter_spec .reeval_parameters .metric_name
192
+ if run_spec .adapter_spec .reeval_parameters .max_samples :
193
+ max_samples = run_spec .adapter_spec .reeval_parameters .max_samples
187
194
188
195
asked_request_states : List [RequestState ] = []
189
196
stats : List [Stat ] = []
@@ -194,13 +201,24 @@ def run_one(self, run_spec: RunSpec):
194
201
"instance_difficulties" : [],
195
202
}
196
203
197
- for _ in tqdm (range (run_spec . adapter_spec . reeval_parameters . max_samples ), desc = "Reeval execution" ):
204
+ for _ in tqdm (range (max_samples ), desc = "Reeval execution" ):
198
205
if not unasked_request_states :
199
206
break
200
207
201
- selected_item = min (
202
- unasked_request_states , key = lambda item : abs (item .instance .extra_data ["difficulty" ] - model_ability )
203
- )
208
+ # TODO: look for better way to fix the type-checker error
209
+ # selected_item = min(
210
+ # unasked_request_states, key=lambda item: abs(item.instance.extra_data["difficulty"] - model_ability)
211
+ # )
212
+ # unasked_request_states.remove(selected_item)
213
+ selected_item = None
214
+ min_diff = float ("inf" )
215
+ for item in unasked_request_states :
216
+ assert item .instance .extra_data
217
+ diff = abs (item .instance .extra_data ["difficulty" ] - model_ability )
218
+ if diff < min_diff :
219
+ min_diff = diff
220
+ selected_item = item
221
+ assert type (selected_item ) is RequestState
204
222
unasked_request_states .remove (selected_item )
205
223
206
224
# Execute the request
@@ -256,7 +274,11 @@ def run_one(self, run_spec: RunSpec):
256
274
scenario_metric_value = [s for s in per_instance_stat [0 ].stats if s .name .name == scenario_metric_name ][
257
275
0
258
276
].mean
277
+
278
+ # TODO: look for better way to fix the type-checker error
279
+ assert scenario_metric_value
259
280
reeval_trajectory ["response_correctness" ].append (scenario_metric_value )
281
+ assert selected_item .instance .extra_data
260
282
reeval_trajectory ["instance_difficulties" ].append (selected_item .instance .extra_data ["difficulty" ])
261
283
262
284
# Estimate the model ability
0 commit comments