37
37
from functools import total_ordering
38
38
from io import BytesIO
39
39
from os .path import commonprefix
40
+ from os import cpu_count as os_cpu_count
40
41
from typing import List , Optional
41
42
42
43
import ijson
@@ -1320,29 +1321,47 @@ def calculate_radial_search_recall(predictions, neighbors, enable_top_1_recall=F
1320
1321
1321
1322
return correct / min_num_of_results
1322
1323
1324
+ def _set_initial_recall_values (params : dict , result : dict ) -> None :
1325
+ # Add recall@k and recall@1 to the initial result only if k is present in the params and calculate_recall is true
1326
+ if "k" in params :
1327
+ result .update ({
1328
+ "recall@k" : 0 ,
1329
+ "recall@1" : 0
1330
+ })
1331
+ # Add recall@max_distance and recall@max_distance_1 to the initial result only if max_distance is present in the params
1332
+ elif "max_distance" in params :
1333
+ result .update ({
1334
+ "recall@max_distance" : 0 ,
1335
+ "recall@max_distance_1" : 0
1336
+ })
1337
+ # Add recall@min_score and recall@min_score_1 to the initial result only if min_score is present in the params
1338
+ elif "min_score" in params :
1339
+ result .update ({
1340
+ "recall@min_score" : 0 ,
1341
+ "recall@min_score_1" : 0
1342
+ })
1343
+
1344
+ def _get_should_calculate_recall (params : dict ) -> bool :
1345
+ num_clients = params .get ("num_clients" , 0 )
1346
+ if num_clients == 0 :
1347
+ self .logger .debug ("Expected num_clients to be specified but was not." )
1348
+ cpu_count = os_cpu_count ()
1349
+ if cpu_count < num_clients :
1350
+ self .logger .warning ("Number of clients, %s, specified is greater than the number of CPUs, %s, available." \
1351
+ "This will lead to unperformant context switching on load generation host. Performance " \
1352
+ "metrics may not be accurate. Skipping recall calculation." , num_clients , cpu_count )
1353
+ return False
1354
+ return params .get ("calculate-recall" , True )
1355
+
1323
1356
result = {
1324
1357
"weight" : 1 ,
1325
1358
"unit" : "ops" ,
1326
1359
"success" : True ,
1327
1360
}
1328
- # Add recall@k and recall@1 to the initial result only if k is present in the params
1329
- if "k" in params :
1330
- result .update ({
1331
- "recall@k" : 0 ,
1332
- "recall@1" : 0
1333
- })
1334
- # Add recall@max_distance and recall@max_distance_1 to the initial result only if max_distance is present in the params
1335
- elif "max_distance" in params :
1336
- result .update ({
1337
- "recall@max_distance" : 0 ,
1338
- "recall@max_distance_1" : 0
1339
- })
1340
- # Add recall@min_score and recall@min_score_1 to the initial result only if min_score is present in the params
1341
- elif "min_score" in params :
1342
- result .update ({
1343
- "recall@min_score" : 0 ,
1344
- "recall@min_score_1" : 0
1345
- })
1361
+ # deal with clients here. Need to get num_clients
1362
+ should_calculate_recall = _get_should_calculate_recall (params )
1363
+ if should_calculate_recall :
1364
+ _set_initial_recall_values (params , result )
1346
1365
1347
1366
doc_type = params .get ("type" )
1348
1367
response = await self ._raw_search (opensearch , doc_type , index , body , request_params , headers = headers )
@@ -1366,6 +1385,10 @@ def calculate_radial_search_recall(predictions, neighbors, enable_top_1_recall=F
1366
1385
if _is_empty_search_results (response_json ):
1367
1386
self .logger .info ("Vector search query returned no results." )
1368
1387
return result
1388
+
1389
+ if not should_calculate_recall :
1390
+ return result
1391
+
1369
1392
id_field = parse_string_parameter ("id-field-name" , params , "_id" )
1370
1393
candidates = []
1371
1394
for hit in response_json ['hits' ]['hits' ]:
0 commit comments