@@ -38,8 +38,9 @@ def __init__(self, step_config: StepConfig):
38
38
default_port = 9200 if self .endpoint == 'localhost' else 80
39
39
self .port = parse_int_param ('port' , step_config .config ,
40
40
step_config .implicit_config , default_port )
41
+ self .timeout = parse_int_param ('timeout' , step_config .config , {}, 60 )
41
42
self .opensearch = get_opensearch_client (str (self .endpoint ),
42
- int (self .port ))
43
+ int (self .port ), int ( self . timeout ) )
43
44
44
45
45
46
class CreateIndexStep (OpenSearchStep ):
@@ -163,6 +164,25 @@ def _get_measures(self) -> List[str]:
163
164
return ['took' ]
164
165
165
166
167
+ class WarmupStep (OpenSearchStep ):
168
+ """See base class."""
169
+
170
+ label = 'warmup_operation'
171
+
172
+ def __init__ (self , step_config : StepConfig ):
173
+ super ().__init__ (step_config )
174
+ self .index_name = parse_string_param ('index_name' , step_config .config , {},
175
+ None )
176
+
177
+ def _action (self ):
178
+ """Performs warmup operation on an index."""
179
+ warmup_operation (self .endpoint , self .port , self .index_name )
180
+ return {}
181
+
182
+ def _get_measures (self ) -> List [str ]:
183
+ return ['took' ]
184
+
185
+
166
186
class TrainModelStep (OpenSearchStep ):
167
187
"""See base class."""
168
188
@@ -739,9 +759,6 @@ def get_body(self, vec):
739
759
}
740
760
}
741
761
742
- def get_exclude_fields (self ):
743
- return ['nested_field.' + self .field_name ]
744
-
745
762
class GetStatsStep (OpenSearchStep ):
746
763
"""See base class."""
747
764
@@ -841,6 +858,23 @@ def delete_model(endpoint, port, model_id):
841
858
return response .json ()
842
859
843
860
861
+ def warmup_operation (endpoint , port , index ):
862
+ """
863
+ Performs warmup operation on index to load native library files
864
+ of that index to reduce query latencies.
865
+ Args:
866
+ endpoint: Endpoint OpenSearch is running on
867
+ port: Port OpenSearch is running on
868
+ index: index name
869
+ Returns:
870
+ number of shards the plugin succeeded and failed to warm up.
871
+ """
872
+ response = requests .get ('http://' + endpoint + ':' + str (port ) +
873
+ '/_plugins/_knn/warmup/' + index ,
874
+ headers = {'content-type' : 'application/json' })
875
+ return response .json ()
876
+
877
+
844
878
def get_opensearch_client (endpoint : str , port : int , timeout = 60 ):
845
879
"""
846
880
Get an opensearch client from an endpoint and port
@@ -947,7 +981,7 @@ def query_index(opensearch: OpenSearch, index_name: str, body: dict,
947
981
948
982
949
983
def bulk_index (opensearch : OpenSearch , index_name : str , body : List ):
950
- return opensearch .bulk (index = index_name , body = body , timeout = '5m' )
984
+ return opensearch .bulk (index = index_name , body = body )
951
985
952
986
def get_segment_stats (opensearch : OpenSearch , index_name : str ):
953
987
return opensearch .indices .segments (index = index_name )
0 commit comments