Skip to content

Commit e3821d9

Browse files
committed
Address Vijay offline feedback
Signed-off-by: Finn Roblin <finnrobl@amazon.com>
1 parent 7d99d80 commit e3821d9

File tree

3 files changed

+20
-66
lines changed

3 files changed

+20
-66
lines changed

osbenchmark/utils/dataset.py

+8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class Context(Enum):
2424
INDEX = 1
2525
QUERY = 2
2626
NEIGHBORS = 3
27+
MAX_DISTANCE_NEIGHBORS = 4
28+
MIN_SCORE_NEIGHBORS = 5
2729
ATTRIBUTES = 7
2830

2931

@@ -142,6 +144,12 @@ def parse_context(context: Context) -> str:
142144
if context == Context.QUERY:
143145
return "test"
144146

147+
if context == Context.MAX_DISTANCE_NEIGHBORS:
148+
return "max_distance_neighbors"
149+
150+
if context == Context.MIN_SCORE_NEIGHBORS:
151+
return "min_score_neighbors"
152+
145153
if context == Context.ATTRIBUTES:
146154
return "attributes"
147155

osbenchmark/workload/params.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from osbenchmark import exceptions
4141
from osbenchmark.utils import io
4242
from osbenchmark.utils.dataset import DataSet, get_data_set, Context
43-
from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter, parse_bool_parameter
43+
from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter
4444
from osbenchmark.workload import workload
4545

4646
__PARAM_SOURCES_BY_OP = {}
@@ -1127,9 +1127,9 @@ def _update_body_params(self, vector):
11271127
"[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)
11281128

11291129
self.logger.info("Here, we have query_params: %s ", self.query_params)
1130-
efficient_filter=self.query_params.get(self.PARAMS_NAME_FILTER)
11311130
filter_type=self.query_params.get(self.PARAMS_NAME_FILTER_TYPE)
11321131
filter_body=self.query_params.get(self.PARAMS_NAME_FILTER_BODY)
1132+
efficient_filter = filter_body if filter_type == "efficient" else None
11331133

11341134
# override query params with vector search query
11351135
body_params[self.PARAMS_NAME_QUERY] = self._build_vector_search_query_body(vector, efficient_filter, filter_type, filter_body)
@@ -1262,7 +1262,7 @@ def __init__(self, workload, params, **kwargs):
12621262
self.id_field_name: str = parse_string_parameter(
12631263
self.PARAMS_NAME_ID_FIELD_NAME, params, self.DEFAULT_ID_FIELD_NAME
12641264
)
1265-
self.has_attributes = parse_bool_parameter("has_attributes", params, False)
1265+
self.filter_attributes: List[Any] = params.get("filter_attributes", [])
12661266

12671267
self.action_buffer = None
12681268
self.num_nested_vectors = 10
@@ -1294,7 +1294,7 @@ def partition(self, partition_index, total_partitions):
12941294
)
12951295
partition.parent_data_set.seek(partition.offset)
12961296

1297-
if self.has_attributes:
1297+
if self.filter_attributes:
12981298
partition.attributes_data_set = get_data_set(
12991299
self.parent_data_set_format, self.parent_data_set_path, Context.ATTRIBUTES
13001300
)
@@ -1317,8 +1317,10 @@ def bulk_transform_add_attributes(self, partition: np.ndarray, action, attribute
13171317
partition.tolist(), attributes.tolist(), range(self.current, self.current + len(partition))
13181318
):
13191319
row = {self.field_name: vec}
1320-
for idx, attribute_name, attribute_type in zip(range(3), ["taste", "color", "age"], [str, str, int]):
1321-
row.update({attribute_name : attribute_type(attribute_list[idx])})
1320+
for idx, attribute_name in zip(range(len(self.filter_attributes)), self.filter_attributes):
1321+
attribute = attribute_list[idx].decode()
1322+
if attribute != "None":
1323+
row.update({attribute_name : attribute})
13221324
if add_id_field_to_body:
13231325
row.update({self.id_field_name: identifier})
13241326
bulk_contents.append(row)
@@ -1369,11 +1371,11 @@ def bulk_transform(
13691371
An array of transformed vectors in bulk format.
13701372
"""
13711373

1372-
if not self.is_nested and not self.has_attributes:
1374+
if not self.is_nested and not self.filter_attributes:
13731375
return self.bulk_transform_non_nested(partition, action)
13741376

13751377
# TODO: Assumption: we won't add attributes if we're also doing a nested query.
1376-
if self.has_attributes:
1378+
if self.filter_attributes:
13771379
return self.bulk_transform_add_attributes(partition, action, attributes)
13781380
actions = []
13791381

@@ -1457,7 +1459,7 @@ def action(id_field_name, doc_id):
14571459
else:
14581460
parent_ids = None
14591461

1460-
if self.has_attributes:
1462+
if self.filter_attributes:
14611463
attributes = self.attributes_data_set.read(bulk_size)
14621464
else:
14631465
attributes = None

tests/workload/params_test.py

+1-57
Original file line numberDiff line numberDiff line change
@@ -2900,62 +2900,6 @@ def test_params_default(self):
29002900
with self.assertRaises(StopIteration):
29012901
query_param_source_partition.params()
29022902

2903-
def test_params_custom_body(self):
2904-
# Create a data set
2905-
k = 12
2906-
data_set_path = create_data_set(
2907-
self.DEFAULT_NUM_VECTORS,
2908-
self.DEFAULT_DIMENSION,
2909-
self.DEFAULT_TYPE,
2910-
Context.QUERY,
2911-
self.data_set_dir
2912-
)
2913-
neighbors_data_set_path = create_data_set(
2914-
self.DEFAULT_NUM_VECTORS,
2915-
self.DEFAULT_DIMENSION,
2916-
self.DEFAULT_TYPE,
2917-
Context.NEIGHBORS,
2918-
self.data_set_dir
2919-
)
2920-
filter_body = {
2921-
"key": "value"
2922-
}
2923-
2924-
# Create a QueryVectorsFromDataSetParamSource with relevant params
2925-
test_param_source_params = {
2926-
"field": self.DEFAULT_FIELD_NAME,
2927-
"data_set_format": self.DEFAULT_TYPE,
2928-
"data_set_path": data_set_path,
2929-
"neighbors_data_set_path": neighbors_data_set_path,
2930-
"k": k,
2931-
"filter": filter_body,
2932-
}
2933-
query_param_source = VectorSearchPartitionParamSource(
2934-
workload.Workload(name="unit-test"),
2935-
test_param_source_params, {
2936-
"index": self.DEFAULT_INDEX_NAME,
2937-
"request-params": {},
2938-
"body": {
2939-
"size": 100,
2940-
}
2941-
}
2942-
)
2943-
query_param_source_partition = query_param_source.partition(0, 1)
2944-
2945-
# Check each
2946-
for _ in range(DEFAULT_NUM_VECTORS):
2947-
self._check_params(
2948-
query_param_source_partition.params(),
2949-
self.DEFAULT_FIELD_NAME,
2950-
self.DEFAULT_DIMENSION,
2951-
k,
2952-
100,
2953-
filter_body,
2954-
)
2955-
2956-
# Assert last call creates stop iteration
2957-
with self.assertRaises(StopIteration):
2958-
query_param_source_partition.params()
29592903
def test_post_filter(self):
29602904
# Create a data set
29612905
k = 12
@@ -3434,7 +3378,7 @@ def test_params_efficient_filter(
34343378
"data_set_path": data_set_path,
34353379
"bulk_size": bulk_size,
34363380
"id-field-name": self.DEFAULT_ID_FIELD_NAME,
3437-
"has_attributes": True
3381+
"filter_attributes": self.ATTRIBUTES_LIST
34383382
}
34393383
bulk_param_source = BulkVectorsFromDataSetParamSource(
34403384
workload.Workload(name="unit-test"), test_param_source_params

0 commit comments

Comments
 (0)