40
40
from osbenchmark import exceptions
41
41
from osbenchmark .utils import io
42
42
from osbenchmark .utils .dataset import DataSet , get_data_set , Context
43
- from osbenchmark .utils .parse import parse_string_parameter , parse_int_parameter , parse_bool_parameter
43
+ from osbenchmark .utils .parse import parse_string_parameter , parse_int_parameter
44
44
from osbenchmark .workload import workload
45
45
46
46
__PARAM_SOURCES_BY_OP = {}
@@ -1127,9 +1127,9 @@ def _update_body_params(self, vector):
1127
1127
"[%s] param from body will be replaced with vector search query." , self .PARAMS_NAME_QUERY )
1128
1128
1129
1129
self .logger .info ("Here, we have query_params: %s " , self .query_params )
1130
- efficient_filter = self .query_params .get (self .PARAMS_NAME_FILTER )
1131
1130
filter_type = self .query_params .get (self .PARAMS_NAME_FILTER_TYPE )
1132
1131
filter_body = self .query_params .get (self .PARAMS_NAME_FILTER_BODY )
1132
+ efficient_filter = filter_body if filter_type == "efficient" else None
1133
1133
1134
1134
# override query params with vector search query
1135
1135
body_params [self .PARAMS_NAME_QUERY ] = self ._build_vector_search_query_body (vector , efficient_filter , filter_type , filter_body )
@@ -1262,7 +1262,7 @@ def __init__(self, workload, params, **kwargs):
1262
1262
self .id_field_name : str = parse_string_parameter (
1263
1263
self .PARAMS_NAME_ID_FIELD_NAME , params , self .DEFAULT_ID_FIELD_NAME
1264
1264
)
1265
- self .has_attributes = parse_bool_parameter ( "has_attributes " , params , False )
1265
+ self .filter_attributes : List [ Any ] = params . get ( "filter_attributes " , [] )
1266
1266
1267
1267
self .action_buffer = None
1268
1268
self .num_nested_vectors = 10
@@ -1294,7 +1294,7 @@ def partition(self, partition_index, total_partitions):
1294
1294
)
1295
1295
partition .parent_data_set .seek (partition .offset )
1296
1296
1297
- if self .has_attributes :
1297
+ if self .filter_attributes :
1298
1298
partition .attributes_data_set = get_data_set (
1299
1299
self .parent_data_set_format , self .parent_data_set_path , Context .ATTRIBUTES
1300
1300
)
@@ -1317,8 +1317,10 @@ def bulk_transform_add_attributes(self, partition: np.ndarray, action, attribute
1317
1317
partition .tolist (), attributes .tolist (), range (self .current , self .current + len (partition ))
1318
1318
):
1319
1319
row = {self .field_name : vec }
1320
- for idx , attribute_name , attribute_type in zip (range (3 ), ["taste" , "color" , "age" ], [str , str , int ]):
1321
- row .update ({attribute_name : attribute_type (attribute_list [idx ])})
1320
+ for idx , attribute_name in zip (range (len (self .filter_attributes )), self .filter_attributes ):
1321
+ attribute = attribute_list [idx ].decode ()
1322
+ if attribute != "None" :
1323
+ row .update ({attribute_name : attribute })
1322
1324
if add_id_field_to_body :
1323
1325
row .update ({self .id_field_name : identifier })
1324
1326
bulk_contents .append (row )
@@ -1369,11 +1371,11 @@ def bulk_transform(
1369
1371
An array of transformed vectors in bulk format.
1370
1372
"""
1371
1373
1372
- if not self .is_nested and not self .has_attributes :
1374
+ if not self .is_nested and not self .filter_attributes :
1373
1375
return self .bulk_transform_non_nested (partition , action )
1374
1376
1375
1377
# TODO: Assumption: we won't add attributes if we're also doing a nested query.
1376
- if self .has_attributes :
1378
+ if self .filter_attributes :
1377
1379
return self .bulk_transform_add_attributes (partition , action , attributes )
1378
1380
actions = []
1379
1381
@@ -1457,7 +1459,7 @@ def action(id_field_name, doc_id):
1457
1459
else :
1458
1460
parent_ids = None
1459
1461
1460
- if self .has_attributes :
1462
+ if self .filter_attributes :
1461
1463
attributes = self .attributes_data_set .read (bulk_size )
1462
1464
else :
1463
1465
attributes = None
0 commit comments