Skip to content

Commit bffa32a

Browse files
authored
add tutuorial for cross-encoder model on sagemaker (opensearch-project#2607)
* add tutuorial for cross-encoder model on sagemaker Signed-off-by: Yaliang Wu <ylwu@amazon.com> * add connector helper doc link Signed-off-by: Yaliang Wu <ylwu@amazon.com> * remvoe title field Signed-off-by: Yaliang Wu <ylwu@amazon.com> * address commnets Signed-off-by: Yaliang Wu <ylwu@amazon.com> * use a better input format to invoke model Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 4a22eb8 commit bffa32a

File tree

1 file changed

+381
-0
lines changed

1 file changed

+381
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
# Topic
2+
3+
[Reranking pipeline](https://opensearch.org/docs/latest/search-plugins/search-relevance/reranking-search-results/) is a feature released in OpenSearch 2.12.
4+
It can rerank search results, providing a relevance score for each document in the search results with respect to the search query.
5+
The relevance score is calculated by a cross-encoder model.
6+
7+
This tutorial explains how to use the [Huggingface cross-encoder/ms-marco-MiniLM-L-6-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2) model in a reranking pipeline.
8+
9+
Note: Replace the placeholders that start with `your_` with your own values.
10+
11+
# Steps
12+
13+
## 0. Deploy the model on Amazon Sagemaker
14+
Use the following code to deploy the model on Amazon Sagemaker.
15+
You can find all supported instance type and price on [Amazon Sagemaker Pricing document](https://aws.amazon.com/sagemaker/pricing/). Suggest to use GPU for better performance.
16+
```python
17+
import sagemaker
18+
import boto3
19+
from sagemaker.huggingface import HuggingFaceModel
20+
sess = sagemaker.Session()
21+
role = sagemaker.get_execution_role()
22+
23+
hub = {
24+
'HF_MODEL_ID':'cross-encoder/ms-marco-MiniLM-L-6-v2',
25+
'HF_TASK':'text-classification'
26+
}
27+
huggingface_model = HuggingFaceModel(
28+
transformers_version='4.37.0',
29+
pytorch_version='2.1.0',
30+
py_version='py310',
31+
env=hub,
32+
role=role,
33+
)
34+
predictor = huggingface_model.deploy(
35+
initial_instance_count=1, # number of instances
36+
instance_type='ml.m5.xlarge' # ec2 instance type
37+
)
38+
```
39+
Note the model inference endpoint; you'll use it to create a connector in the next step.
40+
41+
## 1. Create a connector and register the model
42+
43+
To create a connector for the model, send the following request. If you are using self-managed OpenSearch, supply your AWS credentials:
44+
```json
45+
POST /_plugins/_ml/connectors/_create
46+
{
47+
"name": "Sagemakre cross-encoder model",
48+
"description": "Test connector for Sagemaker cross-encoder model",
49+
"version": 1,
50+
"protocol": "aws_sigv4",
51+
"credential": {
52+
"access_key": "your_access_key",
53+
"secret_key": "your_secret_key",
54+
"session_token": "your_session_token"
55+
},
56+
"parameters": {
57+
"region": "your_sagemkaer_model_region_like_us-west-2",
58+
"service_name": "sagemaker"
59+
},
60+
"actions": [
61+
{
62+
"action_type": "predict",
63+
"method": "POST",
64+
"url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
65+
"headers": {
66+
"content-type": "application/json"
67+
},
68+
"request_body": "{ \"inputs\": ${parameters.inputs} }",
69+
"pre_process_function": "\n String escape(def input) { \n if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n return input;\n }\n\n String query = params.query_text;\n StringBuilder builder = new StringBuilder('[');\n \n for (int i=0; i<params.text_docs.length; i ++) {\n builder.append('{\"text\":\"');\n builder.append(escape(query));\n builder.append('\", \"text_pair\":\"');\n builder.append(escape(params.text_docs[i]));\n builder.append('\"}');\n if (i<params.text_docs.length - 1) {\n builder.append(',');\n }\n }\n builder.append(']');\n \n def parameters = '{ \"inputs\": ' + builder + ' }';\n return '{\"parameters\": ' + parameters + '}';\n ",
70+
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n \n def resultBuilder = new StringBuilder('[ ');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
71+
}
72+
]
73+
}
74+
```
75+
76+
If you are using the AWS OpenSearch service, you can provide an IAM role ARN that allows access to the SageMaker model inference endpoint. For more information, see [AWS documentation](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html), [this tutorial](../aws/semantic_search_with_sagemaker_embedding_model.md), and [this connector helper notebook](../aws/AIConnectorHelper.ipynb):
77+
```json
78+
POST /_plugins/_ml/connectors/_create
79+
{
80+
"name": "Sagemakre cross-encoder model",
81+
"description": "Test connector for Sagemaker cross-encoder model",
82+
"version": 1,
83+
"protocol": "aws_sigv4",
84+
"credential": {
85+
"roleArn": "your_role_arn_which_allows_access_to_sagemaker_model_inference_endpoint"
86+
},
87+
"parameters": {
88+
"region": "your_sagemkaer_model_region_like_us-west-2",
89+
"service_name": "sagemaker"
90+
},
91+
"actions": [
92+
{
93+
"action_type": "predict",
94+
"method": "POST",
95+
"url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
96+
"headers": {
97+
"content-type": "application/json"
98+
},
99+
"request_body": "{ \"inputs\": ${parameters.inputs} }",
100+
"pre_process_function": "\n String escape(def input) { \n if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n return input;\n }\n\n String query = params.query_text;\n StringBuilder builder = new StringBuilder('[');\n \n for (int i=0; i<params.text_docs.length; i ++) {\n builder.append('{\"text\":\"');\n builder.append(escape(query));\n builder.append('\", \"text_pair\":\"');\n builder.append(escape(params.text_docs[i]));\n builder.append('\"}');\n if (i<params.text_docs.length - 1) {\n builder.append(',');\n }\n }\n builder.append(']');\n \n def parameters = '{ \"inputs\": ' + builder + ' }';\n return '{\"parameters\": ' + parameters + '}';\n ",
101+
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n \n def resultBuilder = new StringBuilder('[ ');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
102+
}
103+
]
104+
}
105+
```
106+
107+
Use the connector ID from the response to register and deploy the model:
108+
```json
109+
POST /_plugins/_ml/models/_register?deploy=true
110+
{
111+
"name": "Sagemaker Cross-Encoder model",
112+
"function_name": "remote",
113+
"description": "test rerank model",
114+
"connector_id": "your_connector_id"
115+
}
116+
```
117+
Note the model ID in the response; you'll use it in the following steps.
118+
119+
Test the model by using the Predict API:
120+
```json
121+
POST _plugins/_ml/models/your_model_id/_predict
122+
{
123+
"parameters": {
124+
"inputs": [
125+
{
126+
"text": "I like you",
127+
"text_pair": "I hate you"
128+
},
129+
{
130+
"text": "I like you",
131+
"text_pair": "I love you"
132+
}
133+
]
134+
}
135+
}
136+
```
137+
138+
Each item in the `inputs` array comprises a `query_text` and a `text_docs` string, separated by a ` . `
139+
140+
Alternatively, you can test the model as follows:
141+
```json
142+
POST _plugins/_ml/_predict/text_similarity/your_model_id
143+
{
144+
"query_text": "I like you",
145+
"text_docs": ["I hate you", "I love you"]
146+
}
147+
```
148+
The connector `pre_process_function` transforms the input into the format required by the `inputs` parameter shown previously.
149+
150+
By default, the SageMaker model output has the following format:
151+
```json
152+
[
153+
{
154+
"label": "LABEL_0",
155+
"score": 0.054037678986787796
156+
},
157+
{
158+
"label": "LABEL_0",
159+
"score": 0.5877784490585327
160+
}
161+
]
162+
```
163+
The connector `pre_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret. This adapted format is as follows:
164+
```json
165+
{
166+
"inference_results": [
167+
{
168+
"output": [
169+
{
170+
"name": "similarity",
171+
"data_type": "FLOAT32",
172+
"shape": [
173+
1
174+
],
175+
"data": [
176+
0.054037678986787796
177+
]
178+
},
179+
{
180+
"name": "similarity",
181+
"data_type": "FLOAT32",
182+
"shape": [
183+
1
184+
],
185+
"data": [
186+
0.5877784490585327
187+
]
188+
}
189+
],
190+
"status_code": 200
191+
}
192+
]
193+
}
194+
```
195+
196+
Explanation of the response:
197+
1. The response contains two `similarity` outputs. For each `similarity` output, the `data` array contains a relevance score of each document against the query.
198+
2. The `similarity` outputs are provided in the order of the input documents; the first result of similarity pertains to the first document.
199+
200+
201+
## 2. Reranking pipeline
202+
### 2.1 Ingest test data
203+
```json
204+
POST _bulk
205+
{ "index": { "_index": "my-test-data" } }
206+
{ "passage_text" : "Carson City is the capital city of the American state of Nevada." }
207+
{ "index": { "_index": "my-test-data" } }
208+
{ "passage_text" : "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan." }
209+
{ "index": { "_index": "my-test-data" } }
210+
{ "passage_text" : "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }
211+
{ "index": { "_index": "my-test-data" } }
212+
{ "passage_text" : "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." }
213+
214+
```
215+
### 2.2 Create a reranking pipeline
216+
```json
217+
PUT /_search/pipeline/rerank_pipeline_sagemaker
218+
{
219+
"description": "Pipeline for reranking with Sagemaker cross-encoder model",
220+
"response_processors": [
221+
{
222+
"rerank": {
223+
"ml_opensearch": {
224+
"model_id": "your_model_id_created_in_step1"
225+
},
226+
"context": {
227+
"document_fields": ["passage_text"]
228+
}
229+
}
230+
}
231+
]
232+
}
233+
```
234+
Note: if you provide multiple filed names in `document_fields`, the values of all fields are first concatenated and then reranking is performed.
235+
### 2.2 Test reranking
236+
237+
To return a different number of results, provide the `size` parameter. For example, set `size` to `4` to return the top four documents:
238+
239+
```json
240+
GET my-test-data/_search?search_pipeline=rerank_pipeline_sagemaker
241+
{
242+
"query": {
243+
"match_all": {}
244+
},
245+
"size": 4,
246+
"ext": {
247+
"rerank": {
248+
"query_context": {
249+
"query_text": "What is the capital of the United States?"
250+
}
251+
}
252+
}
253+
}
254+
```
255+
Response:
256+
```json
257+
{
258+
"took": 3,
259+
"timed_out": false,
260+
"_shards": {
261+
"total": 1,
262+
"successful": 1,
263+
"skipped": 0,
264+
"failed": 0
265+
},
266+
"hits": {
267+
"total": {
268+
"value": 4,
269+
"relation": "eq"
270+
},
271+
"max_score": 0.9997217,
272+
"hits": [
273+
{
274+
"_index": "my-test-data",
275+
"_id": "U0xye5AB9ZeWZdmDjWZn",
276+
"_score": 0.9997217,
277+
"_source": {
278+
"passage_text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
279+
}
280+
},
281+
{
282+
"_index": "my-test-data",
283+
"_id": "VExye5AB9ZeWZdmDjWZn",
284+
"_score": 0.55655104,
285+
"_source": {
286+
"passage_text": "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
287+
}
288+
},
289+
{
290+
"_index": "my-test-data",
291+
"_id": "UUxye5AB9ZeWZdmDjWZn",
292+
"_score": 0.115356825,
293+
"_source": {
294+
"passage_text": "Carson City is the capital city of the American state of Nevada."
295+
}
296+
},
297+
{
298+
"_index": "my-test-data",
299+
"_id": "Ukxye5AB9ZeWZdmDjWZn",
300+
"_score": 0.00021142483,
301+
"_source": {
302+
"passage_text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
303+
}
304+
}
305+
]
306+
},
307+
"profile": {
308+
"shards": []
309+
}
310+
}
311+
```
312+
Test the query without a reranking pipeline:
313+
```
314+
GET my-test-data/_search
315+
{
316+
"query": {
317+
"match_all": {}
318+
},
319+
"ext": {
320+
"rerank": {
321+
"query_context": {
322+
"query_text": "What is the capital of the United States?"
323+
}
324+
}
325+
}
326+
}
327+
```
328+
The first document in the response is `Carson City is the capital city of the American state of Nevada`, which is incorrect:
329+
```json
330+
{
331+
"took": 1,
332+
"timed_out": false,
333+
"_shards": {
334+
"total": 1,
335+
"successful": 1,
336+
"skipped": 0,
337+
"failed": 0
338+
},
339+
"hits": {
340+
"total": {
341+
"value": 4,
342+
"relation": "eq"
343+
},
344+
"max_score": 1,
345+
"hits": [
346+
{
347+
"_index": "my-test-data",
348+
"_id": "UUxye5AB9ZeWZdmDjWZn",
349+
"_score": 1,
350+
"_source": {
351+
"passage_text": "Carson City is the capital city of the American state of Nevada."
352+
}
353+
},
354+
{
355+
"_index": "my-test-data",
356+
"_id": "Ukxye5AB9ZeWZdmDjWZn",
357+
"_score": 1,
358+
"_source": {
359+
"passage_text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
360+
}
361+
},
362+
{
363+
"_index": "my-test-data",
364+
"_id": "U0xye5AB9ZeWZdmDjWZn",
365+
"_score": 1,
366+
"_source": {
367+
"passage_text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
368+
}
369+
},
370+
{
371+
"_index": "my-test-data",
372+
"_id": "VExye5AB9ZeWZdmDjWZn",
373+
"_score": 1,
374+
"_source": {
375+
"passage_text": "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
376+
}
377+
}
378+
]
379+
}
380+
}
381+
```

0 commit comments

Comments
 (0)