Skip to content

Commit 76584f4

Browse files
authored
Add option to select and provision pretrained text embedding models (#137)
Signed-off-by: Tyler Ohlsen <ohltyler@amazon.com>
1 parent 16a0967 commit 76584f4

File tree

18 files changed

+527
-65
lines changed

18 files changed

+527
-65
lines changed

common/constants.ts

+42-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
import { TemplateNode, WORKFLOW_STATE } from './interfaces';
6+
import {
7+
MODEL_ALGORITHM,
8+
PRETRAINED_MODEL_FORMAT,
9+
PretrainedSentenceTransformer,
10+
WORKFLOW_STATE,
11+
} from './interfaces';
712

813
export const PLUGIN_ID = 'flow-framework';
914

@@ -52,6 +57,42 @@ export const SEARCH_MODELS_NODE_API_PATH = `${BASE_MODEL_NODE_API_PATH}/search`;
5257
*/
5358
export const CREATE_INGEST_PIPELINE_STEP_TYPE = 'create_ingest_pipeline';
5459
export const CREATE_INDEX_STEP_TYPE = 'create_index';
60+
export const REGISTER_LOCAL_PRETRAINED_MODEL_STEP_TYPE =
61+
'register_local_pretrained_model';
62+
63+
/**
64+
* ML PLUGIN PRETRAINED MODELS
65+
* (based off of https://opensearch.org/docs/latest/ml-commons-plugin/pretrained-models/#sentence-transformers)
66+
*/
67+
export const ROBERTA_SENTENCE_TRANSFORMER = {
68+
name: 'huggingface/sentence-transformers/all-distilroberta-v1',
69+
shortenedName: 'all-distilroberta-v1',
70+
description: 'A sentence transformer from Hugging Face',
71+
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
72+
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
73+
version: '1.0.1',
74+
vectorDimensions: 768,
75+
} as PretrainedSentenceTransformer;
76+
77+
export const MPNET_SENTENCE_TRANSFORMER = {
78+
name: 'huggingface/sentence-transformers/all-mpnet-base-v2',
79+
shortenedName: 'all-mpnet-base-v2',
80+
description: 'A sentence transformer from Hugging Face',
81+
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
82+
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
83+
version: '1.0.1',
84+
vectorDimensions: 768,
85+
} as PretrainedSentenceTransformer;
86+
87+
export const BERT_SENTENCE_TRANSFORMER = {
88+
name: 'huggingface/sentence-transformers/msmarco-distilbert-base-tas-b',
89+
shortenedName: 'msmarco-distilbert-base-tas-b',
90+
description: 'A sentence transformer from Hugging Face',
91+
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
92+
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
93+
version: '1.0.2',
94+
vectorDimensions: 768,
95+
} as PretrainedSentenceTransformer;
5596

5697
/**
5798
* MISCELLANEOUS

common/interfaces.ts

+85-5
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ export type CreateIndexNode = TemplateNode & {
8282
};
8383
};
8484

85+
export type RegisterPretrainedModelNode = TemplateNode & {
86+
user_inputs: {
87+
name: string;
88+
description: string;
89+
model_format: string;
90+
version: string;
91+
deploy: boolean;
92+
};
93+
};
94+
8595
export type TemplateEdge = {
8696
source: string;
8797
dest: string;
@@ -130,9 +140,83 @@ export enum USE_CASE {
130140
/**
131141
********** ML PLUGIN TYPES/INTERFACES **********
132142
*/
143+
144+
// Based off of https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java
145+
export enum MODEL_STATE {
146+
REGISTERED = 'Registered',
147+
REGISTERING = 'Registering',
148+
DEPLOYING = 'Deploying',
149+
DEPLOYED = 'Deployed',
150+
PARTIALLY_DEPLOYED = 'Partially deployed',
151+
UNDEPLOYED = 'Undeployed',
152+
DEPLOY_FAILED = 'Deploy failed',
153+
}
154+
155+
// Based off of https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/FunctionName.java
156+
export enum MODEL_ALGORITHM {
157+
LINEAR_REGRESSION = 'Linear regression',
158+
KMEANS = 'K-means',
159+
AD_LIBSVM = 'AD LIBSVM',
160+
SAMPLE_ALGO = 'Sample algorithm',
161+
LOCAL_SAMPLE_CALCULATOR = 'Local sample calculator',
162+
FIT_RCF = 'Fit RCF',
163+
BATCH_RCF = 'Batch RCF',
164+
ANOMALY_LOCALIZATION = 'Anomaly localization',
165+
RCF_SUMMARIZE = 'RCF summarize',
166+
LOGISTIC_REGRESSION = 'Logistic regression',
167+
TEXT_EMBEDDING = 'Text embedding',
168+
METRICS_CORRELATION = 'Metrics correlation',
169+
REMOTE = 'Remote',
170+
SPARSE_ENCODING = 'Sparse encoding',
171+
SPARSE_TOKENIZE = 'Sparse tokenize',
172+
TEXT_SIMILARITY = 'Text similarity',
173+
QUESTION_ANSWERING = 'Question answering',
174+
AGENT = 'Agent',
175+
}
176+
177+
export enum MODEL_CATEGORY {
178+
DEPLOYED = 'Deployed',
179+
PRETRAINED = 'Pretrained',
180+
}
181+
182+
export enum PRETRAINED_MODEL_FORMAT {
183+
TORCH_SCRIPT = 'TORCH_SCRIPT',
184+
}
185+
186+
export type PretrainedModel = {
187+
name: string;
188+
shortenedName: string;
189+
description: string;
190+
format: PRETRAINED_MODEL_FORMAT;
191+
algorithm: MODEL_ALGORITHM;
192+
version: string;
193+
};
194+
195+
export type PretrainedSentenceTransformer = PretrainedModel & {
196+
vectorDimensions: number;
197+
};
198+
199+
export type ModelConfig = {
200+
modelType?: string;
201+
embeddingDimension?: number;
202+
};
203+
133204
export type Model = {
134205
id: string;
135-
algorithm: string;
206+
name: string;
207+
algorithm: MODEL_ALGORITHM;
208+
state: MODEL_STATE;
209+
modelConfig?: ModelConfig;
210+
};
211+
212+
export type ModelDict = {
213+
[modelId: string]: Model;
214+
};
215+
216+
export type ModelFormValue = {
217+
id: string;
218+
category?: MODEL_CATEGORY;
219+
algorithm?: MODEL_ALGORITHM;
136220
};
137221

138222
/**
@@ -171,7 +255,3 @@ export enum WORKFLOW_RESOURCE_TYPE {
171255
export type WorkflowDict = {
172256
[workflowId: string]: Workflow;
173257
};
174-
175-
export type ModelDict = {
176-
[modelId: string]: Model;
177-
};

public/app.tsx

+13-4
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,21 @@ export const FlowFrameworkDashboardsApp = (props: Props) => {
6262
<Workflows {...routeProps} />
6363
)}
6464
/>
65-
{/* Defaulting to Workflows page */}
65+
{/*
66+
Defaulting to Workflows page. The pathname will need to be updated
67+
to handle the redirection and get the router props consistent.
68+
*/}
6669
<Route
6770
path={`${APP_PATH.HOME}`}
68-
render={(routeProps: RouteComponentProps<WorkflowsRouterProps>) => (
69-
<Workflows {...routeProps} />
70-
)}
71+
render={(routeProps: RouteComponentProps<WorkflowsRouterProps>) => {
72+
if (props.history.location.pathname !== APP_PATH.WORKFLOWS) {
73+
props.history.replace({
74+
...history,
75+
pathname: APP_PATH.WORKFLOWS,
76+
});
77+
}
78+
return <Workflows {...routeProps} />;
79+
}}
7180
/>
7281
</Switch>
7382
</EuiPageTemplate>

public/component_types/interfaces.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import { COMPONENT_CATEGORY, COMPONENT_CLASS } from '../utils';
1010
/**
1111
* ************ Types *************************
1212
*/
13-
export type FieldType = 'string' | 'json' | 'select';
13+
export type FieldType = 'string' | 'json' | 'select' | 'model';
1414
export type SelectType = 'model';
1515
export type FieldValue = string | {};
1616
export type ComponentFormValues = FormikValues;

public/component_types/transformer/text_embedding_transformer.ts

+4-6
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ export class TextEmbeddingTransformer extends MLTransformer {
1919
this.inputs = [];
2020
this.createFields = [
2121
{
22-
label: 'Model ID',
23-
id: 'modelId',
24-
type: 'select',
25-
selectType: 'model',
26-
helpText: 'The deployed text embedding model to use for embedding.',
22+
label: 'Text Embedding Model',
23+
id: 'model',
24+
type: 'model',
25+
helpText: 'A text embedding model for embedding text.',
2726
helpLink:
2827
'https://opensearch.org/docs/latest/ml-commons-plugin/integrating-ml-models/#choosing-a-model',
2928
},
@@ -36,7 +35,6 @@ export class TextEmbeddingTransformer extends MLTransformer {
3635
helpLink:
3736
'https://opensearch.org/docs/latest/ingest-pipelines/processors/text-embedding/',
3837
},
39-
4038
{
4139
label: 'Vector Field',
4240
id: 'vectorField',

public/pages/workflow_detail/component_details/component_details.tsx

+5-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ interface ComponentDetailsProps {
2828
export function ComponentDetails(props: ComponentDetailsProps) {
2929
return (
3030
<EuiPanel paddingSize="m">
31-
{props.isDeprovisionable ? (
31+
{/* TODO: determine if we need this view if we want the workspace to remain
32+
readonly once provisioned */}
33+
{/* {props.isDeprovisionable ? (
3234
<ProvisionedComponentInputs />
33-
) : props.selectedComponent ? (
35+
) : */}
36+
{props.selectedComponent ? (
3437
<ComponentInputs
3538
selectedComponent={props.selectedComponent}
3639
onFormChange={props.onFormChange}

public/pages/workflow_detail/component_details/component_inputs.tsx

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ export function ComponentInputs(props: ComponentInputsProps) {
5555
<EuiTitle size="m">
5656
<h2>{props.selectedComponent.data.label || ''}</h2>
5757
</EuiTitle>
58+
<EuiText color="subdued">
59+
{props.selectedComponent.data.description}
60+
</EuiText>
5861
<NewOrExistingTabs
5962
selectedTabId={selectedTabId}
6063
setSelectedTabId={setSelectedTabId}

public/pages/workflow_detail/component_details/input_field_list.tsx

+14-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import React from 'react';
77
import { EuiFlexItem, EuiSpacer } from '@elastic/eui';
8-
import { TextField, JsonField, SelectField } from './input_fields';
8+
import { TextField, JsonField, SelectField, ModelField } from './input_fields';
99
import { IComponentField } from '../../../../common';
1010

1111
/**
@@ -54,6 +54,19 @@ export function InputFieldList(props: InputFieldListProps) {
5454
);
5555
break;
5656
}
57+
case 'model': {
58+
el = (
59+
<EuiFlexItem key={idx}>
60+
<ModelField
61+
field={field}
62+
componentId={props.componentId}
63+
onFormChange={props.onFormChange}
64+
/>
65+
<EuiSpacer size={INPUT_FIELD_SPACER_SIZE} />
66+
</EuiFlexItem>
67+
);
68+
break;
69+
}
5770
case 'json': {
5871
el = (
5972
<EuiFlexItem key={idx}>

public/pages/workflow_detail/component_details/input_fields/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
export { TextField } from './text_field';
77
export { JsonField } from './json_field';
88
export { SelectField } from './select_field';
9+
export { ModelField } from './model_field';

0 commit comments

Comments
 (0)