Skip to content

Commit ab01756

Browse files
committed
Add algorithm field; clean up dropdown displays;
Signed-off-by: Tyler Ohlsen <ohltyler@amazon.com>
1 parent 8994fbe commit ab01756

File tree

6 files changed

+47
-4
lines changed

6 files changed

+47
-4
lines changed

common/constants.ts

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55

66
import {
7+
MODEL_ALGORITHM,
78
PRETRAINED_MODEL_FORMAT,
89
PretrainedSentenceTransformer,
910
WORKFLOW_STATE,
@@ -68,6 +69,7 @@ export const ROBERTA_SENTENCE_TRANSFORMER = {
6869
shortenedName: 'all-distilroberta-v1',
6970
description: 'A sentence transformer from Hugging Face',
7071
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
72+
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
7173
version: '1.0.1',
7274
vectorDimensions: 768,
7375
} as PretrainedSentenceTransformer;
@@ -77,6 +79,7 @@ export const MPNET_SENTENCE_TRANSFORMER = {
7779
shortenedName: 'all-mpnet-base-v2',
7880
description: 'A sentence transformer from Hugging Face',
7981
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
82+
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
8083
version: '1.0.1',
8184
vectorDimensions: 768,
8285
} as PretrainedSentenceTransformer;
@@ -86,6 +89,7 @@ export const BERT_SENTENCE_TRANSFORMER = {
8689
shortenedName: 'msmarco-distilbert-base-tas-b',
8790
description: 'A sentence transformer from Hugging Face',
8891
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
92+
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
8993
version: '1.0.2',
9094
vectorDimensions: 768,
9195
} as PretrainedSentenceTransformer;

common/interfaces.ts

+25-1
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,28 @@ export enum MODEL_STATE {
152152
DEPLOY_FAILED = 'Deploy failed',
153153
}
154154

155+
// Based off of https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/FunctionName.java
156+
export enum MODEL_ALGORITHM {
157+
LINEAR_REGRESSION = 'Linear regression',
158+
KMEANS = 'K-means',
159+
AD_LIBSVM = 'AD LIBSVM',
160+
SAMPLE_ALGO = 'Sample algorithm',
161+
LOCAL_SAMPLE_CALCULATOR = 'Local sample calculator',
162+
FIT_RCF = 'Fit RCF',
163+
BATCH_RCF = 'Batch RCF',
164+
ANOMALY_LOCALIZATION = 'Anomaly localization',
165+
RCF_SUMMARIZE = 'RCF summarize',
166+
LOGISTIC_REGRESSION = 'Logistic regression',
167+
TEXT_EMBEDDING = 'Text embedding',
168+
METRICS_CORRELATION = 'Metrics correlation',
169+
REMOTE = 'Remote',
170+
SPARSE_ENCODING = 'Sparse encoding',
171+
SPARSE_TOKENIZE = 'Sparse tokenize',
172+
TEXT_SIMILARITY = 'Text similarity',
173+
QUESTION_ANSWERING = 'Question answering',
174+
AGENT = 'Agent',
175+
}
176+
155177
export enum MODEL_CATEGORY {
156178
DEPLOYED = 'Deployed',
157179
PRETRAINED = 'Pretrained',
@@ -166,6 +188,7 @@ export type PretrainedModel = {
166188
shortenedName: string;
167189
description: string;
168190
format: PRETRAINED_MODEL_FORMAT;
191+
algorithm: MODEL_ALGORITHM;
169192
version: string;
170193
};
171194

@@ -181,7 +204,7 @@ export type ModelConfig = {
181204
export type Model = {
182205
id: string;
183206
name: string;
184-
algorithm: string;
207+
algorithm: MODEL_ALGORITHM;
185208
state: MODEL_STATE;
186209
modelConfig?: ModelConfig;
187210
};
@@ -193,6 +216,7 @@ export type ModelDict = {
193216
export type ModelFormValue = {
194217
id: string;
195218
category?: MODEL_CATEGORY;
219+
algorithm?: MODEL_ALGORITHM;
196220
};
197221

198222
/**

public/component_types/transformer/text_embedding_transformer.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export class TextEmbeddingTransformer extends MLTransformer {
1919
this.inputs = [];
2020
this.createFields = [
2121
{
22-
label: 'Model',
22+
label: 'Text Embedding Model',
2323
id: 'model',
2424
type: 'model',
2525
helpText: 'A text embedding model for embedding text.',

public/pages/workflow_detail/component_details/input_fields/model_field.tsx

+13-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ export function ModelField(props: ModelFieldProps) {
8484
id: modelId,
8585
name: models[modelId].name,
8686
category: MODEL_CATEGORY.DEPLOYED,
87+
algorithm: models[modelId].algorithm,
8788
} as ModelItem);
8889
}
8990
});
@@ -98,16 +99,19 @@ export function ModelField(props: ModelFieldProps) {
9899
id: ROBERTA_SENTENCE_TRANSFORMER.name,
99100
name: ROBERTA_SENTENCE_TRANSFORMER.shortenedName,
100101
category: MODEL_CATEGORY.PRETRAINED,
102+
algorithm: ROBERTA_SENTENCE_TRANSFORMER.algorithm,
101103
},
102104
{
103105
id: MPNET_SENTENCE_TRANSFORMER.name,
104106
name: MPNET_SENTENCE_TRANSFORMER.shortenedName,
105107
category: MODEL_CATEGORY.PRETRAINED,
108+
algorithm: MPNET_SENTENCE_TRANSFORMER.algorithm,
106109
},
107110
{
108111
id: BERT_SENTENCE_TRANSFORMER.name,
109112
name: BERT_SENTENCE_TRANSFORMER.shortenedName,
110113
category: MODEL_CATEGORY.PRETRAINED,
114+
algorithm: BERT_SENTENCE_TRANSFORMER.algorithm,
111115
},
112116
];
113117
setPretrainedModels(modelItems);
@@ -170,10 +174,18 @@ export function ModelField(props: ModelFieldProps) {
170174
value: option.id,
171175
inputDisplay: (
172176
<>
173-
<EuiText size="xs">{option.name}</EuiText>
177+
<EuiText size="s">{option.name}</EuiText>
178+
</>
179+
),
180+
dropdownDisplay: (
181+
<>
182+
<EuiText size="s">{option.name}</EuiText>
174183
<EuiText size="xs" color="subdued">
175184
{option.category}
176185
</EuiText>
186+
<EuiText size="xs" color="subdued">
187+
{option.algorithm}
188+
</EuiText>
177189
</>
178190
),
179191
disabled: false,

public/utils/utils.ts

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export function getInitialValue(fieldType: FieldType): FieldValue {
102102
return {
103103
id: '',
104104
category: undefined,
105+
algorithm: undefined,
105106
} as ModelFormValue;
106107
}
107108
case 'json': {

server/routes/helpers.ts

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import {
77
DEFAULT_NEW_WORKFLOW_STATE_TYPE,
88
INDEX_NOT_FOUND_EXCEPTION,
9+
MODEL_ALGORITHM,
910
MODEL_STATE,
1011
Model,
1112
ModelDict,
@@ -94,7 +95,8 @@ export function getModelsFromResponses(modelHits: any[]): ModelDict {
9495
modelDict[modelId] = {
9596
id: modelId,
9697
name: modelHit._source?.name,
97-
algorithm: modelHit._source?.algorithm,
98+
// @ts-ignore
99+
algorithm: MODEL_ALGORITHM[modelHit._source?.algorithm],
98100
// @ts-ignore
99101
state: MODEL_STATE[modelHit._source?.model_state],
100102
modelConfig: {

0 commit comments

Comments
 (0)