From 539cd48ad14b5475cdc4fc83e415951d4ad4eb58 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Tue, 31 Jan 2023 14:02:54 +0800 Subject: [PATCH 01/75] Revert "feat: remove model registry related code and dependencies (#63)" This reverts commit 6bdd0d3ed6ab1bc828a2d59cccff96234bc99d54. Signed-off-by: Lin Wang --- common/constant.ts | 6 + common/index.ts | 1 + common/model.ts | 4 +- common/router.ts | 14 ++ common/router_paths.ts | 2 + package.json | 9 +- public/apis/api_provider.ts | 9 + public/apis/model.ts | 112 +++++++++- public/apis/model_aggregate.ts | 40 ++++ public/components/app.tsx | 8 +- public/components/common/custom.tsx | 35 +++ public/components/common/index.ts | 1 + public/components/model_drawer/index.tsx | 85 +++++++ .../components/model_drawer/version_table.tsx | 89 ++++++++ .../model_deployed_versions.test.tsx | 30 +++ .../__tests__/model_filter.test.tsx | 130 +++++++++++ .../__tests__/model_filter_item.test.tsx | 34 +++ .../__tests__/model_list_filter.test.tsx | 31 +++ .../model_list/__tests__/model_owner.test.tsx | 27 +++ .../model_list/__tests__/model_table.test.tsx | 155 +++++++++++++ .../model_table_uploading_cell.test.tsx | 85 +++++++ .../__tests__/owner_filter.test.tsx | 21 ++ .../__tests__/stage_filter.test.tsx | 21 ++ .../model_list/__tests__/tag_filter.test.tsx | 21 ++ public/components/model_list/index.tsx | 122 ++++++++++ .../model_list/model_confirm_delete_modal.tsx | 93 ++++++++ .../model_list/model_deployed_versions.tsx | 27 +++ public/components/model_list/model_filter.tsx | 98 ++++++++ .../model_list/model_filter_item.tsx | 24 ++ .../model_list/model_list_filter.tsx | 71 ++++++ public/components/model_list/model_owner.tsx | 15 ++ public/components/model_list/model_table.tsx | 167 ++++++++++++++ .../model_list/model_table_uploading_cell.tsx | 49 ++++ public/components/model_list/owner_filter.tsx | 19 ++ public/components/model_list/stage_filter.tsx | 19 ++ public/components/model_list/tag_filter.tsx | 19 ++ public/components/nav_panel.tsx | 32 +++ public/components/primitive_combo_box.tsx | 92 ++++++++ .../register_model_artifact.test.tsx | 56 +++++ .../__tests__/register_model_details.test.tsx | 52 +++++ .../__tests__/register_model_metrics.test.tsx | 119 ++++++++++ .../__tests__/register_model_tags.test.tsx | 96 ++++++++ .../register_model/__tests__/setup.tsx | 60 +++++ public/components/register_model/artifact.tsx | 69 ++++++ .../register_model/artifact_file.tsx | 42 ++++ .../register_model/artifact_url.tsx | 48 ++++ .../register_model/evaluation_metrics.tsx | 157 +++++++++++++ .../register_model/form_constants.ts | 6 + public/components/register_model/index.ts | 6 + .../register_model/model_configuration.tsx | 95 ++++++++ .../register_model/model_details.tsx | 109 +++++++++ .../components/register_model/model_tags.tsx | 53 +++++ .../register_model/register_model.hooks.ts | 49 ++++ .../register_model/register_model.tsx | 70 ++++++ .../register_model/register_model.types.ts | 36 +++ .../components/register_model/tag_field.tsx | 118 ++++++++++ public/hooks/tests/use_polling_unit.test.ts | 113 ++++++++++ public/hooks/use_polling_until.ts | 65 ++++++ public/utils/index.ts | 6 + public/utils/regex.ts | 6 + public/utils/table.ts | 14 ++ server/clusters/create_model_cluster.ts | 15 ++ server/clusters/model_plugin.ts | 90 ++++++++ server/plugin.ts | 15 +- server/routes/constants.ts | 5 + server/routes/index.ts | 1 + server/routes/model_aggregate_router.ts | 37 +++ server/routes/model_router.ts | 211 +++++++++++++++++- server/services/errors.ts | 6 + server/services/index.ts | 1 + server/services/model_aggregate_service.ts | 183 +++++++++++++++ server/services/model_service.ts | 170 +++++++++++++- server/services/utils/constants.ts | 10 + server/services/utils/model.ts | 64 ++++++ yarn.lock | 27 +++ 75 files changed, 4082 insertions(+), 15 deletions(-) create mode 100644 common/constant.ts create mode 100644 public/apis/model_aggregate.ts create mode 100644 public/components/common/custom.tsx create mode 100644 public/components/model_drawer/index.tsx create mode 100644 public/components/model_drawer/version_table.tsx create mode 100644 public/components/model_list/__tests__/model_deployed_versions.test.tsx create mode 100644 public/components/model_list/__tests__/model_filter.test.tsx create mode 100644 public/components/model_list/__tests__/model_filter_item.test.tsx create mode 100644 public/components/model_list/__tests__/model_list_filter.test.tsx create mode 100644 public/components/model_list/__tests__/model_owner.test.tsx create mode 100644 public/components/model_list/__tests__/model_table.test.tsx create mode 100644 public/components/model_list/__tests__/model_table_uploading_cell.test.tsx create mode 100644 public/components/model_list/__tests__/owner_filter.test.tsx create mode 100644 public/components/model_list/__tests__/stage_filter.test.tsx create mode 100644 public/components/model_list/__tests__/tag_filter.test.tsx create mode 100644 public/components/model_list/index.tsx create mode 100644 public/components/model_list/model_confirm_delete_modal.tsx create mode 100644 public/components/model_list/model_deployed_versions.tsx create mode 100644 public/components/model_list/model_filter.tsx create mode 100644 public/components/model_list/model_filter_item.tsx create mode 100644 public/components/model_list/model_list_filter.tsx create mode 100644 public/components/model_list/model_owner.tsx create mode 100644 public/components/model_list/model_table.tsx create mode 100644 public/components/model_list/model_table_uploading_cell.tsx create mode 100644 public/components/model_list/owner_filter.tsx create mode 100644 public/components/model_list/stage_filter.tsx create mode 100644 public/components/model_list/tag_filter.tsx create mode 100644 public/components/nav_panel.tsx create mode 100644 public/components/primitive_combo_box.tsx create mode 100644 public/components/register_model/__tests__/register_model_artifact.test.tsx create mode 100644 public/components/register_model/__tests__/register_model_details.test.tsx create mode 100644 public/components/register_model/__tests__/register_model_metrics.test.tsx create mode 100644 public/components/register_model/__tests__/register_model_tags.test.tsx create mode 100644 public/components/register_model/__tests__/setup.tsx create mode 100644 public/components/register_model/artifact.tsx create mode 100644 public/components/register_model/artifact_file.tsx create mode 100644 public/components/register_model/artifact_url.tsx create mode 100644 public/components/register_model/evaluation_metrics.tsx create mode 100644 public/components/register_model/form_constants.ts create mode 100644 public/components/register_model/index.ts create mode 100644 public/components/register_model/model_configuration.tsx create mode 100644 public/components/register_model/model_details.tsx create mode 100644 public/components/register_model/model_tags.tsx create mode 100644 public/components/register_model/register_model.hooks.ts create mode 100644 public/components/register_model/register_model.tsx create mode 100644 public/components/register_model/register_model.types.ts create mode 100644 public/components/register_model/tag_field.tsx create mode 100644 public/hooks/tests/use_polling_unit.test.ts create mode 100644 public/hooks/use_polling_until.ts create mode 100644 public/utils/index.ts create mode 100644 public/utils/regex.ts create mode 100644 public/utils/table.ts create mode 100644 server/clusters/create_model_cluster.ts create mode 100644 server/clusters/model_plugin.ts create mode 100644 server/routes/model_aggregate_router.ts create mode 100644 server/services/errors.ts create mode 100644 server/services/model_aggregate_service.ts diff --git a/common/constant.ts b/common/constant.ts new file mode 100644 index 00000000..8a66b751 --- /dev/null +++ b/common/constant.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const MAX_MODEL_CHUNK_SIZE = 10 * 1000 * 1000; diff --git a/common/index.ts b/common/index.ts index f9dadd1d..40f8005c 100644 --- a/common/index.ts +++ b/common/index.ts @@ -7,5 +7,6 @@ export const PLUGIN_ID = 'ml-commons-dashboards'; export const PLUGIN_NAME = 'Machine Learning'; export const PLUGIN_DESC = `ML Commons for OpenSearch eases the development of machine learning features by providing a set of common machine learning (ML) algorithms through transport and REST API calls. Those calls choose the right nodes and resources for each ML request and monitors ML tasks to ensure uptime. This allows you to leverage existing open-source ML algorithms and reduce the effort required to develop new ML features.`; +export * from './constant'; export * from './status'; export * from './model'; diff --git a/common/model.ts b/common/model.ts index 12389182..6a274c80 100644 --- a/common/model.ts +++ b/common/model.ts @@ -54,4 +54,6 @@ export type ModelSearchSort = | 'id-asc' | 'model_state-asc' | 'model_state-desc' - | 'id-desc'; + | 'id-desc' + | 'version-desc' + | 'version-asc'; diff --git a/common/router.ts b/common/router.ts index 4771c2e1..e8d2a37d 100644 --- a/common/router.ts +++ b/common/router.ts @@ -20,3 +20,17 @@ export const ROUTES: RouteConfig[] = [ label: 'Overview', }, ]; + +/* export const ROUTES1 = [ + { + path: routerPaths.modelList, + Component: ModelList, + label: 'Model List', + icon: 'createSingleMetricJob', + }, + { + path: routerPaths.registerModel, + label: 'Register Model', + Component: RegisterModelForm, + }, +];*/ diff --git a/common/router_paths.ts b/common/router_paths.ts index 5e4cd2cd..9ca63c68 100644 --- a/common/router_paths.ts +++ b/common/router_paths.ts @@ -6,4 +6,6 @@ export const routerPaths = { root: '/', overview: '/overview', + monitoring: '/monitoring', + registerModel: '/model-registry/register-model', }; diff --git a/package.json b/package.json index 2b61a071..41a94788 100644 --- a/package.json +++ b/package.json @@ -13,10 +13,15 @@ "test:jest": "../../node_modules/.bin/jest --config ./test/jest.config.js", "prepare": "husky install" }, - "dependencies": {}, + "dependencies": { + "hash-wasm": "^4.9.0", + "papaparse": "^5.3.2", + "react-hook-form": "^7.39.4" + }, "devDependencies": { "@testing-library/user-event": "^14.4.3", "husky": "^8.0.0", + "@types/papaparse": "^5.3.5", "lint-staged": "^10.0.0" } -} \ No newline at end of file +} diff --git a/public/apis/api_provider.ts b/public/apis/api_provider.ts index ab107b83..d1468d4d 100644 --- a/public/apis/api_provider.ts +++ b/public/apis/api_provider.ts @@ -5,20 +5,24 @@ import { Connector } from './connector'; import { Model } from './model'; +import { ModelAggregate } from './model_aggregate'; import { Profile } from './profile'; const apiInstanceStore: { model: Model | undefined; + modelAggregate: ModelAggregate | undefined; profile: Profile | undefined; connector: Connector | undefined; } = { model: undefined, + modelAggregate: undefined, profile: undefined, connector: undefined, }; export class APIProvider { public static getAPI(type: 'model'): Model; + public static getAPI(type: 'modelAggregate'): ModelAggregate; public static getAPI(type: 'profile'): Profile; public static getAPI(type: 'connector'): Connector; public static getAPI(type: keyof typeof apiInstanceStore) { @@ -31,6 +35,11 @@ export class APIProvider { apiInstanceStore.model = newInstance; return newInstance; } + case 'modelAggregate': { + const newInstance = new ModelAggregate(); + apiInstanceStore.modelAggregate = newInstance; + return newInstance; + } case 'profile': { const newInstance = new Profile(); apiInstanceStore.profile = newInstance; diff --git a/public/apis/model.ts b/public/apis/model.ts index 75d78191..1d82a7ad 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -3,8 +3,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { MODEL_API_ENDPOINT } from '../../server/routes/constants'; import { MODEL_STATE, ModelSearchSort } from '../../common'; +import { + MODEL_API_ENDPOINT, + MODEL_LOAD_API_ENDPOINT, + MODEL_UNLOAD_API_ENDPOINT, + MODEL_UPLOAD_API_ENDPOINT, + MODEL_PROFILE_API_ENDPOINT, +} from '../../server/routes/constants'; import { InnerHttpProvider } from './inner_http_provider'; export interface ModelSearchItem { @@ -23,13 +29,68 @@ export interface ModelSearchItem { }; } +export interface ModelDetail extends ModelSearchItem { + content: string; +} + export interface ModelSearchResponse { data: ModelSearchItem[]; total_models: number; } +export interface ModelLoadResponse { + task_id: string; + status: string; +} + +export interface ModelUnloadResponse { + [nodeId: string]: { + stats: { + [modelId: string]: string; + }; + }; +} + +export interface ModelProfileResponse { + nodes: { + [nodeId: string]: { + models: { + [modelId: string]: { + model_state: string; + predictor: string; + worker_nodes: string[]; + }; + }; + }; + }; +} + +interface UploadModelBase { + name: string; + version: string; + description: string; + modelFormat: string; + modelConfig: { + modelType: string; + embeddingDimension: number; + frameworkType: string; + }; +} + +export interface UploadModelByURL extends UploadModelBase { + url: string; +} + +export interface UploadModelByChunk extends UploadModelBase { + modelTaskType: string; + modelContentHashValue: string; + totalChunks: number; +} + export class Model { public search(query: { + algorithms?: string[]; + ids?: string[]; sort?: ModelSearchSort[]; from: number; size: number; @@ -45,4 +106,53 @@ export class Model { : { ...restQuery, data_source_id: dataSourceId }, }); } + + public delete(modelId: string) { + return InnerHttpProvider.getHttp().delete(`${MODEL_API_ENDPOINT}/${modelId}`); + } + + public getOne(modelId: string) { + return InnerHttpProvider.getHttp().get(`${MODEL_API_ENDPOINT}/${modelId}`); + } + + public load(modelId: string) { + return InnerHttpProvider.getHttp().post( + `${MODEL_LOAD_API_ENDPOINT}/${modelId}` + ); + } + + public unload(modelId: string) { + return InnerHttpProvider.getHttp().post( + `${MODEL_UNLOAD_API_ENDPOINT}/${modelId}` + ); + } + + public profile(modelId: string) { + return InnerHttpProvider.getHttp().get( + `${MODEL_PROFILE_API_ENDPOINT}/${modelId}` + ); + } + + public upload( + model: T + ): Promise< + T extends UploadModelByURL + ? { taskId: string } + : T extends UploadModelByChunk + ? { modelId: string } + : never + > { + return InnerHttpProvider.getHttp().post(MODEL_UPLOAD_API_ENDPOINT, { + body: JSON.stringify(model), + }); + } + + public uploadChunk(modelId: string, chunkId: string, chunkContent: Blob) { + return InnerHttpProvider.getHttp().post(`${MODEL_API_ENDPOINT}/${modelId}/chunk/${chunkId}`, { + body: chunkContent, + headers: { + 'Content-Type': 'application/octet-stream', + }, + }); + } } diff --git a/public/apis/model_aggregate.ts b/public/apis/model_aggregate.ts new file mode 100644 index 00000000..24ca58c5 --- /dev/null +++ b/public/apis/model_aggregate.ts @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { MODEL_AGGREGATE_API_ENDPOINT } from '../../server/routes/constants'; +import { InnerHttpProvider } from './inner_http_provider'; +import { MODEL_STATE } from '../../common/model'; + +export interface ModelAggregateSearchItem { + name: string; + description?: string; + latest_version: string; + latest_version_state: MODEL_STATE; + deployed_versions: string[]; + owner: string; + created_time?: number; +} + +interface ModelAggregateSearchResponse { + data: ModelAggregateSearchItem[]; + total_models: number; +} + +export class ModelAggregate { + public search(query: { + size: number; + from: number; + sort: 'created_time'; + order: 'desc' | 'asc'; + name?: string; + }) { + return InnerHttpProvider.getHttp().get( + MODEL_AGGREGATE_API_ENDPOINT, + { + query, + } + ); + } +} diff --git a/public/components/app.tsx b/public/components/app.tsx index 892d475a..dc6df48b 100644 --- a/public/components/app.tsx +++ b/public/components/app.tsx @@ -6,7 +6,7 @@ import React from 'react'; import { I18nProvider } from '@osd/i18n/react'; import { Redirect, Route, Switch } from 'react-router-dom'; -import { EuiPage, EuiPageBody } from '@elastic/eui'; +import { EuiPage, EuiPageBody, EuiPageSideBar } from '@elastic/eui'; import { useObservable } from 'react-use'; import { ROUTES } from '../../common/router'; import { routerPaths } from '../../common/router_paths'; @@ -25,6 +25,7 @@ import { DataSourceContextProvider } from '../contexts/data_source_context'; import { GlobalBreadcrumbs } from './global_breadcrumbs'; import { DataSourceTopNavMenu } from './data_source_top_nav_menu'; +import { NavPanel } from './nav_panel'; interface MlCommonsPluginAppDeps { basename: string; @@ -72,6 +73,11 @@ export const MlCommonsPluginApp = ({ > <> + {!useNewPageHeader && ( + + + + )} {ROUTES.map(({ path, Component, exact }) => ( diff --git a/public/components/common/custom.tsx b/public/components/common/custom.tsx new file mode 100644 index 00000000..43a8a02a --- /dev/null +++ b/public/components/common/custom.tsx @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiButton, EuiLink } from '@elastic/eui'; +import { useHistory } from 'react-router-dom'; + +type LinkProps = React.ComponentProps & { + to: string; +}; + +type ButtonProps = React.ComponentProps & { + to: string; +}; + +export const EuiCustomLink = ({ to, children, ...rest }: LinkProps) => { + const history = useHistory(); + + return ( + history.push(to)} {...rest}> + {children} + + ); +}; + +export const EuiLinkButton = ({ to, children, ...rest }: ButtonProps) => { + const history = useHistory(); + return ( + history.push(to)} {...rest}> + {children} + + ); +}; diff --git a/public/components/common/index.ts b/public/components/common/index.ts index 553d8444..8ece445d 100644 --- a/public/components/common/index.ts +++ b/public/components/common/index.ts @@ -3,4 +3,5 @@ * SPDX-License-Identifier: Apache-2.0 */ +export * from './custom'; export * from './copyable_text'; diff --git a/public/components/model_drawer/index.tsx b/public/components/model_drawer/index.tsx new file mode 100644 index 00000000..b7b4346b --- /dev/null +++ b/public/components/model_drawer/index.tsx @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useCallback, useState } from 'react'; +import { + EuiFlyout, + EuiFlyoutBody, + EuiFlyoutHeader, + EuiTitle, + EuiLink, + EuiSpacer, + EuiFlexGroup, + EuiFlexItem, + EuiDescriptionList, +} from '@elastic/eui'; +import { generatePath, useHistory } from 'react-router-dom'; +import { APIProvider } from '../../apis/api_provider'; +import { useFetcher } from '../../hooks/use_fetcher'; +import { routerPaths } from '../../../common/router_paths'; +import { VersionTable } from './version_table'; +import { EuiLinkButton, EuiCustomLink } from '../common'; + +export type VersionTableSort = 'version-desc' | 'version-asc'; + +interface Props { + onClose: () => void; + name: string; +} + +export const ModelDrawer = ({ onClose, name }: Props) => { + const [sort, setSort] = useState('version-desc'); + const { data: model } = useFetcher(APIProvider.getAPI('model').search, { + nameOrId: name, + from: 0, + size: 50, + sort: [sort], + }); + const latestVersion = useMemo(() => { + // TODO: currently assume that api will return versions in order + if (model?.data) { + const data = model.data; + return data[data.length - 1]; + } + return { id: '' }; + }, [model]); + + const handleTableChange = useCallback((criteria) => { + setSort(criteria.sort); + }, []); + + return ( + + + +

{name}

+
+ {latestVersion.id ? ( + <> + , + + View Full Details + + + ) : null} +
+ + {model && } + + + +

Versions

+
+
+ + + Register new version + +
+ +
+
+ ); +}; diff --git a/public/components/model_drawer/version_table.tsx b/public/components/model_drawer/version_table.tsx new file mode 100644 index 00000000..959f57af --- /dev/null +++ b/public/components/model_drawer/version_table.tsx @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useCallback, useRef } from 'react'; +import { generatePath, useHistory } from 'react-router-dom'; +import { EuiBasicTable, Direction, Criteria } from '@elastic/eui'; + +import { ModelSearchItem } from '../../apis/model'; +import { routerPaths } from '../../../common/router_paths'; +import { renderTime } from '../../utils'; +import type { VersionTableSort } from './'; + +export interface VersionTableCriteria { + sort?: VersionTableSort; +} + +export function VersionTable(props: { + models: ModelSearchItem[]; + sort: VersionTableSort; + onChange: (criteria: VersionTableCriteria) => void; +}) { + const { models, sort, onChange } = props; + const history = useHistory(); + const onChangeRef = useRef(onChange); + onChangeRef.current = onChange; + + const columns = useMemo( + () => [ + { + field: 'version', + name: 'Version', + sortable: true, + }, + { + field: 'state', + name: 'Stage', + }, + { + field: 'algorithm', + name: 'Algorithm', + }, + { + field: 'created_time', + name: 'Time', + render: renderTime, + sortable: true, + }, + ], + [] + ); + const rowProps = useCallback( + ({ id }) => ({ + onClick: () => { + history.push(generatePath(routerPaths.modelDetail, { id })); + }, + }), + [history] + ); + + const sorting = useMemo(() => { + const [field, direction] = sort.split('-'); + return { + sort: { + field: field as keyof ModelSearchItem, + direction: direction as Direction, + }, + }; + }, [sort]); + + const handleChange = useCallback(({ sort: newSort }: Criteria) => { + if (newSort) { + onChangeRef.current({ + sort: `${newSort.field}-${newSort.direction}` as VersionTableSort, + }); + } + }, []); + + return ( + + ); +} diff --git a/public/components/model_list/__tests__/model_deployed_versions.test.tsx b/public/components/model_list/__tests__/model_deployed_versions.test.tsx new file mode 100644 index 00000000..83fe182a --- /dev/null +++ b/public/components/model_list/__tests__/model_deployed_versions.test.tsx @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { ModelDeployedVersions } from '../model_deployed_versions'; +import { render, screen } from '../../../../test/test_utils'; + +describe('', () => { + it('should render "-" when pass empty versions', () => { + render(); + expect(screen.getByText('-')).toBeInTheDocument(); + }); + + it('should render displayed three versions', () => { + render(); + expect(screen.getByText('1, 2, 3')).toBeInTheDocument(); + }); + + it('should render displayed four versions', () => { + render(); + expect( + screen.getByText((_content, element) => { + return element?.tagName === 'SPAN' && element?.textContent === '1, 2, 3, + 2 more'; + }) + ).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_list/__tests__/model_filter.test.tsx b/public/components/model_list/__tests__/model_filter.test.tsx new file mode 100644 index 00000000..7a246725 --- /dev/null +++ b/public/components/model_list/__tests__/model_filter.test.tsx @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { ModelFilter } from '../model_filter'; +import { render, screen } from '../../../../test/test_utils'; + +describe('', () => { + afterEach(() => { + jest.resetAllMocks(); + }); + + it('should render "Tags" with 0 active filter', () => { + render( + {}} + /> + ); + expect(screen.getByText('Tags')).toBeInTheDocument(); + expect(screen.getByText('0')).toBeInTheDocument(); + }); + + it('should render Tags with 2 active filter', () => { + render( + {}} + /> + ); + expect(screen.getByText('Tags')).toBeInTheDocument(); + expect(screen.getByText('2')).toBeInTheDocument(); + }); + + it('should render options filter after click tags', async () => { + render( + {}} + /> + ); + expect(screen.queryByText('foo')).not.toBeInTheDocument(); + expect(screen.queryByPlaceholderText('Search Tags')).not.toBeInTheDocument(); + + await userEvent.click(screen.getByText('Tags')); + + expect(screen.getByText('foo')).toBeInTheDocument(); + expect(screen.getByPlaceholderText('Search Tags')).toBeInTheDocument(); + }); + + it('should only show "bar" after search', async () => { + render( + {}} + /> + ); + + await userEvent.click(screen.getByText('Tags')); + expect(screen.getByText('foo')).toBeInTheDocument(); + + await userEvent.type(screen.getByPlaceholderText('Search Tags'), 'bAr{enter}'); + expect(screen.queryByText('foo')).not.toBeInTheDocument(); + expect(screen.getByText('bar')).toBeInTheDocument(); + }); + + it('should call onChange with consistent value after option click', async () => { + const onChangeMock = jest.fn(); + const { rerender } = render( + + ); + + expect(onChangeMock).not.toHaveBeenCalled(); + + await userEvent.click(screen.getByText('Tags')); + await userEvent.click(screen.getByText('foo')); + expect(onChangeMock).toHaveBeenCalledWith(['foo']); + onChangeMock.mockClear(); + + rerender( + + ); + + await userEvent.click(screen.getByText('bar')); + expect(onChangeMock).toHaveBeenCalledWith(['foo', 'bar']); + onChangeMock.mockClear(); + + rerender( + + ); + + await userEvent.click(screen.getByText('bar')); + expect(onChangeMock).toHaveBeenCalledWith(['foo']); + onChangeMock.mockClear(); + }); +}); diff --git a/public/components/model_list/__tests__/model_filter_item.test.tsx b/public/components/model_list/__tests__/model_filter_item.test.tsx new file mode 100644 index 00000000..8b219960 --- /dev/null +++ b/public/components/model_list/__tests__/model_filter_item.test.tsx @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { ModelFilterItem } from '../model_filter_item'; + +import { render, screen } from '../../../../test/test_utils'; + +describe('', () => { + it('should render passed children and check icon', () => { + render( + {}}> + foo + + ); + expect(screen.getByText('foo')).toBeInTheDocument(); + expect(screen.getByRole('img', { hidden: true })).toBeInTheDocument(); + }); + + it('should call onClick with "foo" after click', async () => { + const onClickMock = jest.fn(); + render( + + foo + + ); + await userEvent.click(screen.getByRole('option')); + expect(onClickMock).toHaveBeenCalledWith('foo'); + }); +}); diff --git a/public/components/model_list/__tests__/model_list_filter.test.tsx b/public/components/model_list/__tests__/model_list_filter.test.tsx new file mode 100644 index 00000000..020c5cf7 --- /dev/null +++ b/public/components/model_list/__tests__/model_list_filter.test.tsx @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { ModelListFilter } from '../model_list_filter'; +import { render, screen } from '../../../../test/test_utils'; + +describe('', () => { + it('should render default search bar with tag, stage and owner filter', () => { + render( {}} />); + expect(screen.getByPlaceholderText('Search by name, person, or keyword')).toBeInTheDocument(); + expect(screen.getByText('Tags')).toBeInTheDocument(); + expect(screen.getByText('Stage')).toBeInTheDocument(); + expect(screen.getByText('Owner')).toBeInTheDocument(); + }); + + it('should render default search value and filter value', () => { + render( + {}} + /> + ); + expect(screen.getByDisplayValue('foo')).toBeInTheDocument(); + expect(screen.queryAllByText('1')).toHaveLength(3); + }); +}); diff --git a/public/components/model_list/__tests__/model_owner.test.tsx b/public/components/model_list/__tests__/model_owner.test.tsx new file mode 100644 index 00000000..2fe2823b --- /dev/null +++ b/public/components/model_list/__tests__/model_owner.test.tsx @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { ModelOwner } from '../model_owner'; +import { render, screen } from '../../../../test/test_utils'; + +describe('', () => { + it('should render avatar with name abbreviation', async () => { + render(); + expect(screen.getByText('FB')).toBeInTheDocument(); + }); + it('should show tooltip when hower', async () => { + jest.useFakeTimers(); + render(); + expect(screen.queryByText('Foo Bar')).not.toBeInTheDocument(); + + await userEvent.hover(screen.getByText('FB'), { delay: null }); + jest.advanceTimersByTime(1000); + expect(screen.getByText('Foo Bar')).toBeInTheDocument(); + jest.useRealTimers(); + }); +}); diff --git a/public/components/model_list/__tests__/model_table.test.tsx b/public/components/model_list/__tests__/model_table.test.tsx new file mode 100644 index 00000000..c2bb7e80 --- /dev/null +++ b/public/components/model_list/__tests__/model_table.test.tsx @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import moment from 'moment'; +import userEvent from '@testing-library/user-event'; + +import { ModelTable } from '../model_table'; +import { render, screen, within } from '../../../../test/test_utils'; +import { MODEL_STATE } from '../../../../common/model'; + +const tableData = [ + { + name: 'model1', + owner: 'foo', + latest_version: '5', + description: 'model 1 description', + latest_version_state: MODEL_STATE.loaded, + deployed_versions: ['1,2'], + created_time: Date.now(), + }, + { + name: 'model2', + owner: 'bar', + latest_version: '3', + description: 'model 2 description', + latest_version_state: MODEL_STATE.uploading, + deployed_versions: ['1,2'], + created_time: Date.now(), + }, +]; + +const setup = () => { + const onChangeMock = jest.fn(); + const onModelNameClickMock = jest.fn(); + const renderResult = render( + + ); + return { + renderResult, + onChangeMock, + onModelNameClickMock, + }; +}; + +describe('', () => { + it('should render consistent table header', () => { + setup(); + const tableHeaders = screen.queryAllByRole('columnheader'); + expect(within(tableHeaders[0]).getByText('Model Name')).toBeInTheDocument(); + expect(within(tableHeaders[1]).getByText('Latest version')).toBeInTheDocument(); + expect(within(tableHeaders[2]).getByText('Description')).toBeInTheDocument(); + expect(within(tableHeaders[3]).getByText('Owner')).toBeInTheDocument(); + expect(within(tableHeaders[4]).getByText('Deployed versions')).toBeInTheDocument(); + expect(within(tableHeaders[5]).getByText('Created at')).toBeInTheDocument(); + }); + + it('should render consistent table body', () => { + const { renderResult } = setup(); + const model1FirstCellContent = renderResult.getByText(tableData[0].name); + expect(model1FirstCellContent).toBeInTheDocument(); + const model1Cells = model1FirstCellContent.closest('tr')?.querySelectorAll('td'); + expect(model1Cells).not.toBeUndefined(); + expect(within(model1Cells!.item(1)).getByText(tableData[0].latest_version)).toBeInTheDocument(); + expect(within(model1Cells!.item(2)).getByText(tableData[0].description)).toBeInTheDocument(); + expect( + within(model1Cells!.item(3)).getByText(tableData[0].owner.slice(0, 1)) + ).toBeInTheDocument(); + expect( + within(model1Cells!.item(4)).getByText(tableData[0].deployed_versions.join(', ')) + ).toBeInTheDocument(); + expect( + within(model1Cells!.item(5)).getByText( + moment(tableData[0].created_time).format('MMM D, YYYY') + ) + ).toBeInTheDocument(); + + const model2FirstCellContent = renderResult.getByText('New model'); + expect(model2FirstCellContent).toBeInTheDocument(); + const model2Cells = model2FirstCellContent.closest('tr')?.querySelectorAll('td'); + expect(model2Cells).not.toBeUndefined(); + expect(within(model2Cells!.item(1)).getByRole('progressbar')).toBeInTheDocument(); + expect(within(model2Cells!.item(2)).getByText('...')).toBeInTheDocument(); + expect(within(model2Cells!.item(3)).getByRole('progressbar')).toBeInTheDocument(); + expect(within(model2Cells!.item(4)).getByText('updating')).toBeInTheDocument(); + expect(within(model2Cells!.item(5)).getByText('updating')).toBeInTheDocument(); + }); + + it('should call onChange with consistent params after pageSize change', async () => { + const { renderResult, onChangeMock } = setup(); + expect(onChangeMock).not.toHaveBeenCalled(); + await userEvent.click(renderResult.getByText(/Rows per page/)); + await userEvent.click(renderResult.getByText('50 rows')); + expect(onChangeMock).toHaveBeenCalledWith({ + pagination: { + currentPage: 1, + pageSize: 50, + }, + sort: { + field: 'created_time', + direction: 'desc', + }, + }); + }); + + it('should call onChange with consistent params after page change', async () => { + const { renderResult, onChangeMock } = setup(); + expect(onChangeMock).not.toHaveBeenCalled(); + await userEvent.click(renderResult.getByTestId('pagination-button-next')); + expect(onChangeMock).toHaveBeenCalledWith({ + pagination: { + currentPage: 2, + pageSize: 15, + }, + sort: { + field: 'created_time', + direction: 'desc', + }, + }); + }); + + it('should call onChange with consistent params after sort change', async () => { + const { renderResult, onChangeMock } = setup(); + expect(onChangeMock).not.toHaveBeenCalled(); + await userEvent.click(renderResult.getByTitle('Created at')); + expect(onChangeMock).toHaveBeenCalledWith({ + pagination: { + currentPage: 1, + pageSize: 15, + }, + sort: { + field: 'created_time', + direction: 'asc', + }, + }); + }); + + it('should call onModelNameClick with consistent params after model name click', async () => { + const { renderResult, onModelNameClickMock } = setup(); + expect(onModelNameClickMock).not.toHaveBeenCalled(); + await userEvent.click(renderResult.getByText('model1')); + expect(onModelNameClickMock).toHaveBeenCalledWith('model1'); + }); +}); diff --git a/public/components/model_list/__tests__/model_table_uploading_cell.test.tsx b/public/components/model_list/__tests__/model_table_uploading_cell.test.tsx new file mode 100644 index 00000000..45dc12a6 --- /dev/null +++ b/public/components/model_list/__tests__/model_table_uploading_cell.test.tsx @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { ModelTableUploadingCell } from '../model_table_uploading_cell'; +import { render, screen } from '../../../../test/test_utils'; +import { MODEL_STATE } from '../../../../common/model'; + +describe('', () => { + it('should render "updating" if column is deployedVersions or createdAt', () => { + const { rerender } = render( + } + latestVersionState={MODEL_STATE.uploading} + /> + ); + expect(screen.getByText('updating')).toBeInTheDocument(); + + rerender( + } + latestVersionState={MODEL_STATE.uploading} + /> + ); + expect(screen.getByText('updating')).toBeInTheDocument(); + }); + + it('should render loading spinner if column is latestVersion and owner', () => { + const { rerender } = render( + } + latestVersionState={MODEL_STATE.uploading} + /> + ); + expect(screen.getByRole('progressbar')).toBeInTheDocument(); + + rerender( + } + latestVersionState={MODEL_STATE.uploading} + /> + ); + expect(screen.getByRole('progressbar')).toBeInTheDocument(); + }); + + it('should render "New model" if column is name', () => { + render( + } + latestVersionState={MODEL_STATE.uploading} + /> + ); + expect(screen.getByText('New model')).toBeInTheDocument(); + }); + + it('should render "..." if column is description', () => { + render( + } + latestVersionState={MODEL_STATE.uploading} + /> + ); + expect(screen.getByText('...')).toBeInTheDocument(); + }); + + it('should render fallback if not uploading state', () => { + render( + Foo Bar} + latestVersionState={MODEL_STATE.loaded} + /> + ); + expect(screen.getByText('Foo Bar')).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_list/__tests__/owner_filter.test.tsx b/public/components/model_list/__tests__/owner_filter.test.tsx new file mode 100644 index 00000000..8252702f --- /dev/null +++ b/public/components/model_list/__tests__/owner_filter.test.tsx @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { render, screen } from '../../../../test/test_utils'; +import { OwnerFilter } from '../owner_filter'; + +describe('', () => { + afterEach(() => { + jest.resetAllMocks(); + }); + + it('should render "Owner" with 0 active filter for normal', () => { + render( {}} />); + expect(screen.getByText('Owner')).toBeInTheDocument(); + expect(screen.getByText('0')).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_list/__tests__/stage_filter.test.tsx b/public/components/model_list/__tests__/stage_filter.test.tsx new file mode 100644 index 00000000..9992b02c --- /dev/null +++ b/public/components/model_list/__tests__/stage_filter.test.tsx @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { render, screen } from '../../../../test/test_utils'; +import { StageFilter } from '../stage_filter'; + +describe('', () => { + afterEach(() => { + jest.resetAllMocks(); + }); + + it('should render "Stage" with 0 active filter for normal', () => { + render( {}} />); + expect(screen.getByText('Stage')).toBeInTheDocument(); + expect(screen.getByText('0')).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_list/__tests__/tag_filter.test.tsx b/public/components/model_list/__tests__/tag_filter.test.tsx new file mode 100644 index 00000000..9280b727 --- /dev/null +++ b/public/components/model_list/__tests__/tag_filter.test.tsx @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { render, screen } from '../../../../test/test_utils'; +import { TagFilter } from '../tag_filter'; + +describe('', () => { + afterEach(() => { + jest.resetAllMocks(); + }); + + it('should render "Tags" with 0 active filter for normal', () => { + render( {}} />); + expect(screen.queryByText('Tags')).toBeInTheDocument(); + expect(screen.queryByText('0')).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_list/index.tsx b/public/components/model_list/index.tsx new file mode 100644 index 00000000..c8d48dbe --- /dev/null +++ b/public/components/model_list/index.tsx @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useCallback, useMemo, useRef } from 'react'; +import { EuiPageHeader, EuiSpacer, EuiPanel } from '@elastic/eui'; + +import { CoreStart } from '../../../../../src/core/public'; +import { APIProvider } from '../../apis/api_provider'; +import { routerPaths } from '../../../common/router_paths'; +import { useFetcher } from '../../hooks/use_fetcher'; +import { ModelDrawer } from '../model_drawer'; +import { EuiLinkButton } from '../common'; + +import { ModelTable, ModelTableSort } from './model_table'; +import { ModelListFilter, ModelListFilterFilterValue } from './model_list_filter'; +import { + ModelConfirmDeleteModal, + ModelConfirmDeleteModalInstance, +} from './model_confirm_delete_modal'; + +export const ModelList = ({ notifications }: { notifications: CoreStart['notifications'] }) => { + const confirmModelDeleteRef = useRef(null); + const [params, setParams] = useState<{ + sort: ModelTableSort; + currentPage: number; + pageSize: number; + filterValue: ModelListFilterFilterValue; + }>({ + currentPage: 1, + pageSize: 15, + filterValue: { tag: [], owner: [], stage: [] }, + sort: { field: 'created_time', direction: 'desc' }, + }); + const [drawerModelName, setDrawerModelName] = useState(''); + + const { data, reload } = useFetcher(APIProvider.getAPI('modelAggregate').search, { + from: Math.max(0, (params.currentPage - 1) * params.pageSize), + size: params.pageSize, + sort: params.sort?.field, + order: params.sort?.direction, + name: params.filterValue.search, + }); + const models = useMemo(() => data?.data || [], [data]); + const totalModelCounts = data?.total_models || 0; + + const pagination = useMemo( + () => ({ + currentPage: params.currentPage, + pageSize: params.pageSize, + totalRecords: totalModelCounts, + }), + [totalModelCounts, params.currentPage, params.pageSize] + ); + + const handleModelDeleted = useCallback(async () => { + reload(); + notifications.toasts.addSuccess('Model has been deleted.'); + }, [reload, notifications.toasts]); + + const handleModelDelete = useCallback((modelId: string) => { + confirmModelDeleteRef.current?.show(modelId); + }, []); + + const handleViewModelDrawer = useCallback((name: string) => { + setDrawerModelName(name); + }, []); + + const handleTableChange = useCallback((criteria) => { + const { + pagination: { currentPage, pageSize }, + sort, + } = criteria; + setParams((previousValue) => { + if ( + currentPage === previousValue.currentPage && + pageSize === previousValue.pageSize && + (!sort || sort === previousValue.sort) + ) { + return previousValue; + } + return { + ...previousValue, + currentPage, + pageSize, + ...(sort ? { sort } : {}), + }; + }); + }, []); + + const handleFilterChange = useCallback((filterValue: ModelListFilterFilterValue) => { + setParams((prevValue) => ({ ...prevValue, filterValue, currentPage: 1 })); + }, []); + + return ( + + Models} + rightSideItems={[ + + Register new model + , + ]} + /> + + + + + + {drawerModelName && ( + setDrawerModelName('')} name={drawerModelName} /> + )} + + ); +}; diff --git a/public/components/model_list/model_confirm_delete_modal.tsx b/public/components/model_list/model_confirm_delete_modal.tsx new file mode 100644 index 00000000..a224699b --- /dev/null +++ b/public/components/model_list/model_confirm_delete_modal.tsx @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useImperativeHandle, useRef, useState } from 'react'; +import { EuiConfirmModal } from '@elastic/eui'; +import { APIProvider } from '../../apis/api_provider'; +import { usePollingUntil } from '../../hooks/use_polling_until'; + +export class NoIdProvideError {} + +export interface ModelConfirmDeleteModalInstance { + show: (modelId: string) => void; +} + +export const ModelConfirmDeleteModal = React.forwardRef< + ModelConfirmDeleteModalInstance, + { onDeleted: () => void } +>(({ onDeleted }, ref) => { + const deleteIdRef = useRef(); + const [visible, setVisible] = useState(false); + const [isDeleting, setIsDeleting] = useState(false); + const { start: startPolling } = usePollingUntil({ + continueChecker: async () => { + if (!deleteIdRef.current) { + throw new NoIdProvideError(); + } + return ( + ( + await APIProvider.getAPI('model').search({ + ids: [deleteIdRef.current], + from: 0, + size: 1, + }) + ).total_models === 1 + ); + }, + onGiveUp: () => { + setIsDeleting(false); + setVisible(false); + onDeleted(); + }, + onMaxRetries: () => { + setIsDeleting(false); + setVisible(false); + }, + }); + + const handleConfirm = useCallback( + async (e) => { + if (!deleteIdRef.current) { + throw new NoIdProvideError(); + } + e.stopPropagation(); + setIsDeleting(true); + await APIProvider.getAPI('model').delete(deleteIdRef.current); + startPolling(); + }, + [startPolling] + ); + + const handleCancel = useCallback(() => { + setVisible(false); + deleteIdRef.current = undefined; + }, []); + + useImperativeHandle( + ref, + () => ({ + show: (id: string) => { + deleteIdRef.current = id; + setVisible(true); + }, + }), + [] + ); + + if (!visible) { + return null; + } + + return ( + + ); +}); diff --git a/public/components/model_list/model_deployed_versions.tsx b/public/components/model_list/model_deployed_versions.tsx new file mode 100644 index 00000000..706a8c7f --- /dev/null +++ b/public/components/model_list/model_deployed_versions.tsx @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +const DISPLAY_VERSION = 3; + +export const ModelDeployedVersions = ({ versions }: { versions: string[] }) => { + if (versions.length === 0) { + return -; + } + const appendMore = versions.length > DISPLAY_VERSION; + + return ( + + + {versions.slice(0, DISPLAY_VERSION).join(', ')} + {appendMore ? ', ' : ''} + + {appendMore && ( + + {versions.length - DISPLAY_VERSION} more + )} + + ); +}; diff --git a/public/components/model_list/model_filter.tsx b/public/components/model_list/model_filter.tsx new file mode 100644 index 00000000..c01e7227 --- /dev/null +++ b/public/components/model_list/model_filter.tsx @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useMemo, useRef, useState } from 'react'; +import { EuiPopover, EuiPopoverTitle, EuiFieldSearch, EuiFilterButton } from '@elastic/eui'; +import { ModelFilterItem } from './model_filter_item'; + +export interface ModelFilterProps { + name: string; + searchPlaceholder: string; + options: Array; + value: string[]; + onChange: (value: string[]) => void; +} + +export const ModelFilter = ({ + name, + value, + options, + searchPlaceholder, + onChange, +}: ModelFilterProps) => { + const valueRef = useRef(value); + valueRef.current = value; + const onChangeRef = useRef(onChange); + onChangeRef.current = onChange; + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const [searchText, setSearchText] = useState(); + + const filteredOptions = useMemo( + () => + searchText + ? options.filter((option) => + (typeof option === 'string' ? option : option.name) + .toLowerCase() + .includes(searchText.toLowerCase()) + ) + : options, + [searchText, options] + ); + + const hadleButtonClick = useCallback(() => { + setIsPopoverOpen((prevState) => !prevState); + }, []); + + const closePopover = useCallback(() => { + setIsPopoverOpen(false); + }, []); + + const handleFilterItemClick = useCallback((clickItemValue: string) => { + onChangeRef.current( + valueRef.current.includes(clickItemValue) + ? valueRef.current.filter((item) => item !== clickItemValue) + : valueRef.current.concat(clickItemValue) + ); + }, []); + + return ( + 0} + numActiveFilters={value.length} + > + {name} + + } + isOpen={isPopoverOpen} + closePopover={closePopover} + panelPaddingSize="none" + > + + + + {filteredOptions.map((item, index) => { + const itemValue = typeof item === 'string' ? item : item.value; + const checked = value.includes(itemValue) ? 'on' : undefined; + return ( + + {typeof item === 'string' ? item : item.name} + + ); + })} + + ); +}; diff --git a/public/components/model_list/model_filter_item.tsx b/public/components/model_list/model_filter_item.tsx new file mode 100644 index 00000000..c3fbf809 --- /dev/null +++ b/public/components/model_list/model_filter_item.tsx @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback } from 'react'; +import { EuiFilterSelectItem, EuiFilterSelectItemProps } from '@elastic/eui'; + +export interface ModelFilterItemProps + extends Pick { + value: string; + onClick: (value: string) => void; +} + +export const ModelFilterItem = ({ checked, children, onClick, value }: ModelFilterItemProps) => { + const handleClick = useCallback(() => { + onClick(value); + }, [onClick, value]); + return ( + + {children} + + ); +}; diff --git a/public/components/model_list/model_list_filter.tsx b/public/components/model_list/model_list_filter.tsx new file mode 100644 index 00000000..fb621b03 --- /dev/null +++ b/public/components/model_list/model_list_filter.tsx @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EuiFlexItem, EuiFlexGroup, EuiFieldSearch, EuiFilterGroup } from '@elastic/eui'; +import React, { useCallback, useRef } from 'react'; + +import { TagFilter } from './tag_filter'; +import { OwnerFilter } from './owner_filter'; +import { StageFilter } from './stage_filter'; + +export interface ModelListFilterFilterValue { + search?: string; + tag: string[]; + owner: string[]; + stage: string[]; +} + +export const ModelListFilter = ({ + value, + onChange, + defaultSearch, +}: { + defaultSearch?: string; + value: Omit; + onChange: (value: ModelListFilterFilterValue) => void; +}) => { + const valueRef = useRef(value); + valueRef.current = value; + const onChangeRef = useRef(onChange); + onChangeRef.current = onChange; + + const handleSearch = useCallback((search) => { + onChangeRef.current({ ...valueRef.current, search }); + }, []); + + const handleTagChange = useCallback((tag: string[]) => { + onChangeRef.current({ ...valueRef.current, tag }); + }, []); + + const handleOwnerChange = useCallback((owner: string[]) => { + onChangeRef.current({ ...valueRef.current, owner }); + }, []); + + const handleStageChange = useCallback((stage: string[]) => { + onChangeRef.current({ ...valueRef.current, stage }); + }, []); + + return ( + <> + + + + + + + + + + + + + + ); +}; diff --git a/public/components/model_list/model_owner.tsx b/public/components/model_list/model_owner.tsx new file mode 100644 index 00000000..91c3d38c --- /dev/null +++ b/public/components/model_list/model_owner.tsx @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiAvatar, EuiToolTip } from '@elastic/eui'; + +export function ModelOwner({ name }: { name: string }) { + return ( + + + + ); +} diff --git a/public/components/model_list/model_table.tsx b/public/components/model_list/model_table.tsx new file mode 100644 index 00000000..efaf46c4 --- /dev/null +++ b/public/components/model_list/model_table.tsx @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useCallback, useRef } from 'react'; +import { + CriteriaWithPagination, + EuiBasicTable, + EuiBasicTableColumn, + EuiText, + Direction, +} from '@elastic/eui'; + +import { renderTime } from '../../utils'; +import { ModelOwner } from './model_owner'; +import { ModelDeployedVersions } from './model_deployed_versions'; +import { ModelTableUploadingCell } from './model_table_uploading_cell'; +import { ModelAggregateSearchItem } from '../../apis/model_aggregate'; + +export interface ModelTableSort { + field: 'created_time'; + direction: Direction; +} + +export interface ModelTableCriteria { + pagination: { currentPage: number; pageSize: number }; + sort?: ModelTableSort; +} + +export interface ModelTableProps { + models: ModelAggregateSearchItem[]; + pagination: { + currentPage: number; + pageSize: number; + totalRecords: number | undefined; + }; + sort: ModelTableSort; + onChange: (criteria: ModelTableCriteria) => void; + onModelNameClick: (name: string) => void; +} + +export function ModelTable(props: ModelTableProps) { + const { models, sort, onChange, onModelNameClick } = props; + const onChangeRef = useRef(onChange); + onChangeRef.current = onChange; + + const columns = useMemo>>( + () => [ + { + field: 'name', + name: 'Model Name', + width: '266px', + render: (name: string, record) => ( + { + onModelNameClick(name); + }} + style={{ color: '#006BB4' }} + > + {name} + + } + latestVersionState={record.latest_version_state} + column="name" + /> + ), + }, + { + field: 'latest_version', + name: 'Latest version', + width: '98px', + align: 'center', + render: (latestVersion: string, record) => ( + {latestVersion}} + latestVersionState={record.latest_version_state} + column="latestVersion" + /> + ), + }, + { + field: 'description', + name: 'Description', + render: (description: string, record) => ( + {description}} + latestVersionState={record.latest_version_state} + column="description" + /> + ), + }, + { + field: 'owner', + name: 'Owner', + width: '79px', + render: (owner: string, record) => ( + } + latestVersionState={record.latest_version_state} + column="owner" + /> + ), + align: 'center', + }, + { + field: 'deployed_versions', + name: 'Deployed versions', + render: (deployedVersions: string[], record) => ( + } + latestVersionState={record.latest_version_state} + column="deployedVersions" + /> + ), + }, + { + field: 'created_time', + name: 'Created at', + render: (createdTime: string, record) => ( + {renderTime(createdTime, 'MMM D, YYYY')}} + latestVersionState={record.latest_version_state} + column="createdAt" + /> + ), + sortable: true, + }, + ], + [onModelNameClick] + ); + + const pagination = useMemo( + () => ({ + pageIndex: props.pagination.currentPage - 1, + pageSize: props.pagination.pageSize, + totalItemCount: props.pagination.totalRecords || 0, + pageSizeOptions: [15, 30, 50, 100], + showPerPageOptions: true, + }), + [props.pagination] + ); + + const sorting = useMemo(() => ({ sort }), [sort]); + + const handleChange = useCallback((criteria: CriteriaWithPagination) => { + const newPagination = { currentPage: criteria.page.index + 1, pageSize: criteria.page.size }; + + onChangeRef.current({ + pagination: newPagination, + ...(criteria.sort ? { sort: criteria.sort as ModelTableSort } : {}), + }); + }, []); + + return ( + + columns={columns} + items={models} + pagination={pagination} + onChange={handleChange} + sorting={sorting} + hasActions + /> + ); +} diff --git a/public/components/model_list/model_table_uploading_cell.tsx b/public/components/model_list/model_table_uploading_cell.tsx new file mode 100644 index 00000000..a62b015f --- /dev/null +++ b/public/components/model_list/model_table_uploading_cell.tsx @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiLoadingSpinner, EuiText } from '@elastic/eui'; +import { MODEL_STATE } from '../../../common/model'; + +type ColumnType = + | 'name' + | 'latestVersion' + | 'description' + | 'owner' + | 'deployedVersions' + | 'createdAt'; + +const getUploadingText = (column: ColumnType) => { + switch (column) { + case 'name': + return 'New model'; + case 'description': + return '...'; + default: + return 'updating'; + } +}; + +export const ModelTableUploadingCell = ({ + column, + fallback, + latestVersionState, +}: { + column: ColumnType; + latestVersionState: MODEL_STATE; + fallback: JSX.Element; +}) => { + if (latestVersionState !== MODEL_STATE.uploading) { + return fallback; + } + if (column === 'latestVersion' || column === 'owner') { + return ; + } + return ( + + {getUploadingText(column)} + + ); +}; diff --git a/public/components/model_list/owner_filter.tsx b/public/components/model_list/owner_filter.tsx new file mode 100644 index 00000000..60254bf2 --- /dev/null +++ b/public/components/model_list/owner_filter.tsx @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { ModelFilter, ModelFilterProps } from './model_filter'; + +export const OwnerFilter = ({ value, onChange }: Pick) => { + return ( + + ); +}; diff --git a/public/components/model_list/stage_filter.tsx b/public/components/model_list/stage_filter.tsx new file mode 100644 index 00000000..9acb92e4 --- /dev/null +++ b/public/components/model_list/stage_filter.tsx @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { ModelFilter, ModelFilterProps } from './model_filter'; + +export const StageFilter = ({ value, onChange }: Pick) => { + return ( + + ); +}; diff --git a/public/components/model_list/tag_filter.tsx b/public/components/model_list/tag_filter.tsx new file mode 100644 index 00000000..7f7d679d --- /dev/null +++ b/public/components/model_list/tag_filter.tsx @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { ModelFilter, ModelFilterProps } from './model_filter'; + +export const TagFilter = ({ value, onChange }: Pick) => { + return ( + + ); +}; diff --git a/public/components/nav_panel.tsx b/public/components/nav_panel.tsx new file mode 100644 index 00000000..7a779d8c --- /dev/null +++ b/public/components/nav_panel.tsx @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useMemo } from 'react'; +import { EuiSideNav } from '@elastic/eui'; +import { generatePath, Link, matchPath, useLocation } from 'react-router-dom'; + +import { ROUTES } from '../../common/router'; + +export function NavPanel() { + const location = useLocation(); + const items = useMemo( + () => + ROUTES.filter((item) => !!item.label).map((item) => { + const href = generatePath(item.path); + return { + id: href, + name: item.label, + href, + isSelected: matchPath(location.pathname, { path: item.path, exact: item.exact }) !== null, + }; + }), + [location.pathname] + ); + const renderItem = useCallback( + ({ href, ...restProps }) => , + [] + ); + return ; +} diff --git a/public/components/primitive_combo_box.tsx b/public/components/primitive_combo_box.tsx new file mode 100644 index 00000000..91961563 --- /dev/null +++ b/public/components/primitive_combo_box.tsx @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useCallback } from 'react'; +import { CommonProps, EuiComboBox, EuiComboBoxProps } from '@elastic/eui'; + +export type OptionWithCommonProps = { value: T } & CommonProps; + +export type PrimitiveComboBoxProps = Omit< + EuiComboBoxProps, + 'options' | 'selectedOptions' | 'onChange' | 'singleSelection' +> & { + options: Array>; + attachOptionTestSubj?: boolean; +} & ( + | { + multi?: false; + value: T | undefined; + onChange: (value: T | undefined) => void; + } + | { + multi: true; + value: T[] | undefined; + onChange: (value: T[] | undefined) => void; + } + ); + +export const PrimitiveComboBox = ({ + multi, + value, + onChange, + options: optionsInProps, + attachOptionTestSubj, + 'data-test-subj': parentDataTestSubj, + ...restProps +}: PrimitiveComboBoxProps) => { + const options = useMemo( + () => + optionsInProps.map((option) => + typeof option === 'object' + ? { label: option.value.toString(), ...option } + : { + label: option.toString(), + value: option, + ...(attachOptionTestSubj + ? { + 'data-test-subj': `${ + parentDataTestSubj ? `${parentDataTestSubj}-` : '' + }${option.toString()}`, + } + : {}), + } + ), + [optionsInProps, attachOptionTestSubj, parentDataTestSubj] + ); + const selectedOptions = useMemo(() => { + if (multi) { + return options.filter((option) => value?.includes(option.value)); + } + return options.filter((option) => value === option.value); + }, [multi, value, options]); + + const handleChange = useCallback>['onChange']>( + (newOptions) => { + const result: T[] = []; + newOptions.forEach((item) => { + if (item.value !== undefined) { + result.push(item.value); + } + }); + if (multi) { + onChange(result.length === 0 ? undefined : result); + return; + } + onChange(result[0]); + }, + [multi, onChange] + ); + + return ( + + ); +}; diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx new file mode 100644 index 00000000..7ac40b71 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { screen } from '../../../../test/test_utils'; +import { setup } from './setup'; + +describe(' Artifact', () => { + it('should render an artifact panel', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + expect(result.modelFileInput).toBeInTheDocument(); + expect(screen.getByLabelText(/from computer/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/from url/i)).toBeInTheDocument(); + }); + + it('should submit the register model form', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + expect(onSubmitMock).not.toHaveBeenCalled(); + + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should NOT submit the register model form if model file is empty', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + // Empty model file selection by clicking the `Remove` button on EuiFilePicker + await result.user.click(screen.getByLabelText(/clear selected files/i)); + await result.user.click(result.submitButton); + + expect(result.modelFileInput).toBeInvalid(); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + + it('should NOT submit the register model form if model url is empty', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + // select option: From URL + await result.user.click(screen.getByLabelText(/from url/i)); + + const urlInput = screen.getByLabelText(/model url/i); + + // Empty URL input + await result.user.clear(urlInput); + await result.user.click(result.submitButton); + + expect(urlInput).toBeInvalid(); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); +}); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx new file mode 100644 index 00000000..eb010b65 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { setup } from './setup'; + +describe(' Details', () => { + it('should render a model details panel', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + expect(result.nameInput).toBeInTheDocument(); + // Model version is not editable + expect(result.versionInput).toBeDisabled(); + // Model Version should alway have a value + expect(result.versionInput.value).not.toBe(''); + expect(result.descriptionInput).toBeInTheDocument(); + expect(result.annotationsInput).toBeInTheDocument(); + }); + + it('should submit the register model form', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + expect(onSubmitMock).not.toHaveBeenCalled(); + + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should NOT submit the register model form if model name is empty', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + await result.user.clear(result.nameInput); + await result.user.click(result.submitButton); + + expect(result.nameInput).toBeInvalid(); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + + it('should NOT submit the register model form if model description is empty', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + await result.user.clear(result.descriptionInput); + await result.user.click(result.submitButton); + + expect(result.descriptionInput).toBeInvalid(); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); +}); diff --git a/public/components/register_model/__tests__/register_model_metrics.test.tsx b/public/components/register_model/__tests__/register_model_metrics.test.tsx new file mode 100644 index 00000000..1314de4e --- /dev/null +++ b/public/components/register_model/__tests__/register_model_metrics.test.tsx @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { screen } from '../../../../test/test_utils'; +import { setup } from './setup'; +import * as formHooks from '../register_model.hooks'; + +describe(' Evaluation Metrics', () => { + beforeEach(() => { + jest + .spyOn(formHooks, 'useMetricNames') + .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); + }); + + it('should render a evaluation metrics panel', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + expect(result.metricNameInput).toBeInTheDocument(); + expect(result.trainingMetricValueInput).toBeInTheDocument(); + expect(result.validationMetricValueInput).toBeInTheDocument(); + expect(result.testingMetricValueInput).toBeInTheDocument(); + }); + + it('should render metric value input as disabled by default', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + expect(result.trainingMetricValueInput).toBeDisabled(); + expect(result.validationMetricValueInput).toBeDisabled(); + expect(result.testingMetricValueInput).toBeDisabled(); + }); + + it('should render metric value input as enabled after selecting a metric name', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + await result.user.click(result.metricNameInput); + await result.user.click(screen.getByText('Metric 1')); + + expect(result.trainingMetricValueInput).toBeEnabled(); + expect(result.validationMetricValueInput).toBeEnabled(); + expect(result.testingMetricValueInput).toBeEnabled(); + }); + + it('should submit the form without selecting metric name', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should submit the form if metric name is selected but metric value are empty', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + await result.user.click(result.metricNameInput); + await result.user.click(screen.getByText('Metric 1')); + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should submit the form if metric name and all metric value are selected', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + await result.user.click(result.metricNameInput); + await result.user.click(screen.getByText('Metric 1')); + + await result.user.type(result.trainingMetricValueInput, '1'); + await result.user.type(result.validationMetricValueInput, '1'); + await result.user.type(result.testingMetricValueInput, '1'); + + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should submit the form if metric name is selected but metric value are partially selected', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + await result.user.click(result.metricNameInput); + await result.user.click(screen.getByText('Metric 1')); + + // Only input Training metric value + await result.user.type(result.trainingMetricValueInput, '1'); + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should NOT submit the form if metric value < 0', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + await result.user.click(result.metricNameInput); + await result.user.click(screen.getByText('Metric 1')); + + // Type an invalid value + await result.user.type(result.trainingMetricValueInput, '-.1'); + await result.user.click(result.submitButton); + + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + + it('should NOT submit the form if metric value > 1', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + await result.user.click(result.metricNameInput); + await result.user.click(screen.getByText('Metric 1')); + + // Type an invalid value + await result.user.type(result.trainingMetricValueInput, '1.1'); + await result.user.click(result.submitButton); + + expect(onSubmitMock).not.toHaveBeenCalled(); + }); +}); diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx new file mode 100644 index 00000000..7ccc22f7 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { screen } from '../../../../test/test_utils'; +import { setup } from './setup'; +import * as formHooks from '../register_model.hooks'; + +describe(' Tags', () => { + beforeEach(() => { + jest + .spyOn(formHooks, 'useModelTags') + .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + }); + + it('should render a tags panel', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + expect(result.tagKeyInput).toBeInTheDocument(); + expect(result.tagValueInput).toBeInTheDocument(); + }); + + it('should submit the form without selecting tags', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should submit the form with selected tags', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + await result.user.type(result.tagKeyInput, 'Key1'); + await result.user.type(result.tagValueInput, 'Value1'); + + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalledWith( + expect.objectContaining({ tags: [{ key: 'Key1', value: 'Value1' }] }) + ); + }); + + it('should allow to add multiple tags', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + // Add two tags + await result.user.click(screen.getByText(/add new tag/i)); + await result.user.click(screen.getByText(/add new tag/i)); + + expect( + screen.getAllByText(/select or add a key/i, { selector: '.euiComboBoxPlaceholder' }) + ).toHaveLength(3); + + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalledWith( + expect.objectContaining({ + tags: [ + { key: '', value: '' }, + { key: '', value: '' }, + { key: '', value: '' }, + ], + }) + ); + }); + + it('should allow to remove multiple tags', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + // Add two tags + await result.user.click(screen.getByText(/add new tag/i)); + await result.user.click(screen.getByText(/add new tag/i)); + + expect( + screen.getAllByText(/select or add a key/i, { selector: '.euiComboBoxPlaceholder' }) + ).toHaveLength(3); + + // Remove 2n tag, and 1st tag + await result.user.click(screen.getByLabelText(/remove tag at row 2/i)); + await result.user.click(screen.getByLabelText(/remove tag at row 1/i)); + + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalledWith( + expect.objectContaining({ + tags: [{ key: '', value: '' }], + }) + ); + }); +}); diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx new file mode 100644 index 00000000..806fb5a1 --- /dev/null +++ b/public/components/register_model/__tests__/setup.tsx @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { RegisterModelForm } from '../register_model'; +import type { RegisterModelFormProps } from '../register_model'; +import { render, screen } from '../../../../test/test_utils'; + +export async function setup({ onSubmit }: RegisterModelFormProps) { + render(); + const nameInput = screen.getByLabelText(/model name/i); + const versionInput = screen.getByLabelText(/version/i); + const descriptionInput = screen.getByLabelText(/model description/i); + const annotationsInput = screen.getByLabelText(/annotations\(optional\)/i); + const submitButton = screen.getByRole('button', { + name: /register model/i, + }); + const modelFileInput = screen.getByLabelText(/model file/i); + const configurationInput = screen.getByLabelText(/configuration object/i); + const metricNameInput = screen.getByLabelText(/metric name/i); + const trainingMetricValueInput = screen.getByLabelText(/training metric value/i); + const validationMetricValueInput = screen.getByLabelText(/validation metric value/i); + const testingMetricValueInput = screen.getByLabelText(/testing metric value/i); + const tagKeyInput = screen.getByLabelText(/^key$/i); + const tagValueInput = screen.getByLabelText(/^value$/i); + const form = screen.getByTestId('mlCommonsPlugin-registerModelForm'); + const user = userEvent.setup(); + + // fill model name + await user.type(nameInput, 'test model name'); + // fill model description + await user.type(descriptionInput, 'test model description'); + // fill model file + await user.upload( + modelFileInput, + new File(['test model file'], 'model.zip', { type: 'application/zip' }) + ); + + return { + nameInput, + versionInput, + descriptionInput, + annotationsInput, + configurationInput, + submitButton, + modelFileInput, + metricNameInput, + trainingMetricValueInput, + validationMetricValueInput, + testingMetricValueInput, + tagKeyInput, + tagValueInput, + form, + user, + }; +} diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx new file mode 100644 index 00000000..aa78435a --- /dev/null +++ b/public/components/register_model/artifact.tsx @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState } from 'react'; +import { + EuiFormRow, + EuiPanel, + EuiTitle, + EuiHorizontalRule, + htmlIdGenerator, + EuiSpacer, + EuiFlexGroup, + EuiFlexItem, + EuiCheckableCard, +} from '@elastic/eui'; +import type { Control } from 'react-hook-form'; + +import { FORM_ITEM_WIDTH } from './form_constants'; +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { ModelFileUploader } from './artifact_file'; +import { ArtifactUrl } from './artifact_url'; + +export const ArtifactPanel = (props: { + formControl: Control; +}) => { + const [selectedSource, setSelectedSource] = useState<'source_from_computer' | 'source_from_url'>( + 'source_from_computer' + ); + + return ( + + +

Artifact

+
+ + + + + setSelectedSource('source_from_computer')} + /> + + + setSelectedSource('source_from_url')} + /> + + + + + {selectedSource === 'source_from_computer' && ( + + )} + {selectedSource === 'source_from_url' && } +
+ ); +}; diff --git a/public/components/register_model/artifact_file.tsx b/public/components/register_model/artifact_file.tsx new file mode 100644 index 00000000..f342b21f --- /dev/null +++ b/public/components/register_model/artifact_file.tsx @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiFormRow, EuiFilePicker } from '@elastic/eui'; +import { useController } from 'react-hook-form'; +import type { Control } from 'react-hook-form'; + +import { FORM_ITEM_WIDTH } from './form_constants'; +import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; + +export const ModelFileUploader = (props: { + formControl: Control; +}) => { + const modelFileFieldController = useController({ + name: 'modelFile', + control: props.formControl, + rules: { required: true }, + shouldUnregister: true, + }); + + return ( + + { + modelFileFieldController.field.onChange(fileList?.item(0)); + }} + /> + + ); +}; diff --git a/public/components/register_model/artifact_url.tsx b/public/components/register_model/artifact_url.tsx new file mode 100644 index 00000000..47749d6a --- /dev/null +++ b/public/components/register_model/artifact_url.tsx @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiFormRow, htmlIdGenerator, EuiFieldText } from '@elastic/eui'; +import { useController } from 'react-hook-form'; +import type { Control } from 'react-hook-form'; + +import { FORM_ITEM_WIDTH } from './form_constants'; +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { URL_REGEX } from '../../utils/regex'; + +export const ArtifactUrl = (props: { + formControl: Control; +}) => { + const modelUrlFieldController = useController({ + name: 'modelURL', + control: props.formControl, + rules: { + required: true, + pattern: URL_REGEX, + }, + shouldUnregister: true, + }); + + return ( + + + + ); +}; diff --git a/public/components/register_model/evaluation_metrics.tsx b/public/components/register_model/evaluation_metrics.tsx new file mode 100644 index 00000000..3a19d2a8 --- /dev/null +++ b/public/components/register_model/evaluation_metrics.tsx @@ -0,0 +1,157 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useMemo } from 'react'; +import { + EuiFormRow, + EuiPanel, + EuiTitle, + EuiHorizontalRule, + EuiComboBox, + EuiComboBoxOptionOption, + EuiFlexItem, + EuiFlexGroup, + EuiFieldNumber, + EuiSpacer, +} from '@elastic/eui'; +import { useController } from 'react-hook-form'; +import type { Control } from 'react-hook-form'; + +import { FORM_ITEM_WIDTH } from './form_constants'; +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { useMetricNames } from './register_model.hooks'; + +const METRIC_VALUE_STEP = 0.01; + +export const EvaluationMetricsPanel = (props: { + formControl: Control; +}) => { + const [metricNamesLoading, metricNames] = useMetricNames(); + + // TODO: this has to be hooked with data from BE API + const options = useMemo(() => { + return metricNames.map((n) => ({ label: n })); + }, [metricNames]); + + const metricFieldController = useController({ + name: 'metricName', + control: props.formControl, + }); + + const trainingMetricFieldController = useController({ + name: 'trainingMetricValue', + control: props.formControl, + rules: { + max: 1, + min: 0, + }, + }); + + const validationMetricFieldController = useController({ + name: 'validationMetricValue', + control: props.formControl, + rules: { + max: 1, + min: 0, + }, + }); + + const testingMetricFieldController = useController({ + name: 'testingMetricValue', + control: props.formControl, + rules: { + max: 1, + min: 0, + }, + }); + + const onMetricNameChange = useCallback( + (data: EuiComboBoxOptionOption[]) => { + if (data.length === 0) { + trainingMetricFieldController.field.onChange(''); + validationMetricFieldController.field.onChange(''); + testingMetricFieldController.field.onChange(''); + metricFieldController.field.onChange(''); + } else { + metricFieldController.field.onChange(data[0].label); + } + }, + [ + metricFieldController, + trainingMetricFieldController, + validationMetricFieldController, + testingMetricFieldController, + ] + ); + + const onCreateMetricName = useCallback( + (metricName: string) => { + metricFieldController.field.onChange(metricName); + }, + [metricFieldController] + ); + + const metricValueFields = [ + { label: 'Training metric value', controller: trainingMetricFieldController }, + { label: 'Validation metric value', controller: validationMetricFieldController }, + { label: 'Testing metric value', controller: testingMetricFieldController }, + ]; + + return ( + + +

+ Evaluation Metrics - optional +

+
+ + + + + + + {metricValueFields.map(({ label, controller }) => ( + + + + + + ))} + +
+ ); +}; diff --git a/public/components/register_model/form_constants.ts b/public/components/register_model/form_constants.ts new file mode 100644 index 00000000..c8068177 --- /dev/null +++ b/public/components/register_model/form_constants.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const FORM_ITEM_WIDTH = 400; diff --git a/public/components/register_model/index.ts b/public/components/register_model/index.ts new file mode 100644 index 00000000..12ed6bea --- /dev/null +++ b/public/components/register_model/index.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export { RegisterModelForm } from './register_model'; diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx new file mode 100644 index 00000000..fcc4dd07 --- /dev/null +++ b/public/components/register_model/model_configuration.tsx @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState } from 'react'; +import { + EuiFormRow, + EuiPanel, + EuiTitle, + EuiHorizontalRule, + EuiCodeEditor, + EuiText, + EuiButtonEmpty, + EuiFlyout, + EuiFlyoutHeader, + EuiFlyoutBody, +} from '@elastic/eui'; +import { useController } from 'react-hook-form'; +import type { Control } from 'react-hook-form'; + +import '../../ace-themes/sql_console.js'; +import { FORM_ITEM_WIDTH } from './form_constants'; +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; + +function validateConfigurationObject(value: string) { + try { + JSON.parse(value.trim()); + } catch { + return false; + } + return true; +} + +export const ConfigurationPanel = (props: { + formControl: Control; +}) => { + const [isHelpVisible, setIsHelpVisible] = useState(false); + const configurationFieldController = useController({ + name: 'configuration', + control: props.formControl, + rules: { required: true, validate: validateConfigurationObject }, + }); + + return ( + + +

Configuration

+
+ + + setIsHelpVisible(true)} size="xs" color="primary"> + Help + + + } + > + configurationFieldController.field.onChange(value)} + setOptions={{ + fontSize: '14px', + enableBasicAutocompletion: true, + enableLiveAutocompletion: true, + }} + /> + + {isHelpVisible && ( + setIsHelpVisible(false)}> + + +

Help

+
+
+ + +

TODO

+
+
+
+ )} +
+ ); +}; diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx new file mode 100644 index 00000000..acd16561 --- /dev/null +++ b/public/components/register_model/model_details.tsx @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { + EuiFieldText, + EuiFieldNumber, + EuiFlexItem, + EuiFormRow, + EuiPanel, + EuiTitle, + EuiHorizontalRule, + EuiFlexGroup, + EuiTextArea, +} from '@elastic/eui'; +import { useController } from 'react-hook-form'; +import type { Control } from 'react-hook-form'; + +import { FORM_ITEM_WIDTH } from './form_constants'; +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; + +export const ModelDetailsPanel = (props: { + formControl: Control; +}) => { + const nameFieldController = useController({ + name: 'name', + control: props.formControl, + rules: { required: true }, + }); + + const versionFieldController = useController({ + name: 'version', + control: props.formControl, + rules: { required: true }, + }); + + const descriptionFieldController = useController({ + name: 'description', + control: props.formControl, + rules: { required: true }, + }); + + const annotationsFieldController = useController({ + name: 'annotations', + control: props.formControl, + }); + + const { ref: nameInputRef, ...nameField } = nameFieldController.field; + const { ref: versionInputRef, ...versionField } = versionFieldController.field; + const { ref: descriptionInputRef, ...descriptionField } = descriptionFieldController.field; + const { ref: annotationsInputRef, ...annotationsField } = annotationsFieldController.field; + + return ( + + +

Model Details

+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+ ); +}; diff --git a/public/components/register_model/model_tags.tsx b/public/components/register_model/model_tags.tsx new file mode 100644 index 00000000..e3faf9db --- /dev/null +++ b/public/components/register_model/model_tags.tsx @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback } from 'react'; +import { EuiButton, EuiPanel, EuiTitle, EuiHorizontalRule, EuiSpacer } from '@elastic/eui'; +import { useFieldArray } from 'react-hook-form'; +import type { Control } from 'react-hook-form'; + +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { ModelTagField } from './tag_field'; +import { useModelTags } from './register_model.hooks'; + +export const ModelTagsPanel = (props: { + formControl: Control; +}) => { + const [, { keys, values }] = useModelTags(); + const { fields, append, remove } = useFieldArray({ + name: 'tags', + control: props.formControl, + }); + + const addNewTag = useCallback(() => { + append({ key: '', value: '' }); + }, [append]); + + return ( + + +

+ Tags - optional +

+
+ + {fields.map((field, index) => { + return ( + + ); + })} + + Add new tag +
+ ); +}; diff --git a/public/components/register_model/register_model.hooks.ts b/public/components/register_model/register_model.hooks.ts new file mode 100644 index 00000000..487b2510 --- /dev/null +++ b/public/components/register_model/register_model.hooks.ts @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useEffect, useState } from 'react'; + +const metricNames = ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']; + +/** + * TODO: implement this function so that it retrieve metric names from BE + */ +export const useMetricNames = () => { + const [loading, setLoading] = useState(true); + + useEffect(() => { + const timeoutId = window.setTimeout(() => { + setLoading(false); + }, 1000); + + return () => { + window.clearTimeout(timeoutId); + }; + }, []); + + return [loading, metricNames] as const; +}; + +const keys = ['tag1', 'tag2']; +const values = ['value1', 'value2']; + +/** + * TODO: implement this function so that it retrieve tags from BE + */ +export const useModelTags = () => { + const [loading, setLoading] = useState(true); + + useEffect(() => { + const timeoutId = window.setTimeout(() => { + setLoading(false); + }, 1000); + + return () => { + window.clearTimeout(timeoutId); + }; + }, []); + + return [loading, { keys, values }] as const; +}; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx new file mode 100644 index 00000000..d44818b6 --- /dev/null +++ b/public/components/register_model/register_model.tsx @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback } from 'react'; +import { FieldErrors, useForm } from 'react-hook-form'; +import { EuiPageHeader, EuiSpacer, EuiForm, EuiButton } from '@elastic/eui'; + +import { ModelDetailsPanel } from './model_details'; +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { ArtifactPanel } from './artifact'; +import { ConfigurationPanel } from './model_configuration'; +import { EvaluationMetricsPanel } from './evaluation_metrics'; +import { ModelTagsPanel } from './model_tags'; + +export interface RegisterModelFormProps { + onSubmit?: (data: ModelFileFormData | ModelUrlFormData) => void; +} + +export const RegisterModelForm = (props: RegisterModelFormProps) => { + const { handleSubmit, control } = useForm({ + defaultValues: { + name: '', + description: '', + version: '1', + configuration: '{}', + tags: [{ key: '', value: '' }], + }, + }); + + const onSubmit = (data: ModelFileFormData | ModelUrlFormData) => { + if (props.onSubmit) { + props.onSubmit(data); + } + // TODO + // eslint-disable-next-line no-console + console.log(data); + }; + + const onError = useCallback((errors: FieldErrors) => { + // TODO + // eslint-disable-next-line no-console + console.log(errors); + }, []); + + return ( + + + + + + + + + + + + + + + Register model + + + ); +}; diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts new file mode 100644 index 00000000..6f271f3f --- /dev/null +++ b/public/components/register_model/register_model.types.ts @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export interface Tag { + key: string; + value: string; +} + +interface ModelFormBase { + name: string; + version: string; + description: string; + annotations: string; + configuration: string; + metricName?: string; + trainingMetricValue?: string; + validationMetricValue?: string; + testingMetricValue?: string; + tags?: Tag[]; +} + +/** + * The type of the register model form data via uploading a model file + */ +export interface ModelFileFormData extends ModelFormBase { + modelFile: File; +} + +/** + * The type of the register model form data via typing a model URL + */ +export interface ModelUrlFormData extends ModelFormBase { + modelURL: string; +} diff --git a/public/components/register_model/tag_field.tsx b/public/components/register_model/tag_field.tsx new file mode 100644 index 00000000..52e23acc --- /dev/null +++ b/public/components/register_model/tag_field.tsx @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + EuiButton, + EuiComboBox, + EuiComboBoxOptionOption, + EuiFlexGroup, + EuiFlexItem, + EuiFormRow, +} from '@elastic/eui'; +import React, { useCallback, useMemo } from 'react'; +import { Control, useController } from 'react-hook-form'; +import { FORM_ITEM_WIDTH } from './form_constants'; + +interface ModelTagFieldProps { + name: string; + index: number; + formControl: Control; + onDelete: (index: number) => void; + tagKeys: string[]; + tagValues: string[]; +} + +function getComboBoxValue(data: EuiComboBoxOptionOption[]) { + if (data.length === 0) { + return ''; + } else { + return data[0].label; + } +} + +export const ModelTagField = ({ + name, + formControl, + index, + tagKeys, + tagValues, + onDelete, +}: ModelTagFieldProps) => { + const tagKeyController = useController({ + name: `${name}.${index}.key`, + control: formControl, + }); + + const tagValueController = useController({ + name: `${name}.${index}.value`, + control: formControl, + }); + + const onKeyChange = useCallback( + (data: EuiComboBoxOptionOption[]) => { + tagKeyController.field.onChange(getComboBoxValue(data)); + }, + [tagKeyController.field] + ); + + const onValueChange = useCallback( + (data: EuiComboBoxOptionOption[]) => { + tagValueController.field.onChange(getComboBoxValue(data)); + }, + [tagValueController.field] + ); + + const keyOptions = useMemo(() => { + return tagKeys.map((key) => ({ label: key })); + }, [tagKeys]); + + const valueOptions = useMemo(() => { + return tagValues.map((value) => ({ label: value })); + }, [tagValues]); + + return ( + + + + + + + + + + + + + onDelete(index)}> + Remove + + + + ); +}; diff --git a/public/hooks/tests/use_polling_unit.test.ts b/public/hooks/tests/use_polling_unit.test.ts new file mode 100644 index 00000000..76e613b1 --- /dev/null +++ b/public/hooks/tests/use_polling_unit.test.ts @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { act, renderHook } from '@testing-library/react-hooks'; + +import { usePollingUntil } from '../use_polling_until'; + +describe('usePollingUntil', () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should call onGiveUp after continueChecker return false', async () => { + const onGiveUp = jest.fn(() => {}); + const continueChecker = () => Promise.resolve(false); + const { result } = renderHook(() => + usePollingUntil({ + continueChecker, + onGiveUp, + onMaxRetries: () => {}, + }) + ); + expect(onGiveUp).not.toHaveBeenCalled(); + await act(async () => { + result.current.start(); + jest.runOnlyPendingTimers(); + }); + expect(onGiveUp).toHaveBeenCalled(); + }); + + it('should call onMaxRetries after call continueChecker 3 times', async () => { + const onGiveUp = () => {}; + const onMaxRetries = jest.fn(() => {}); + const continueChecker = jest.fn(() => Promise.resolve(true)); + const { result } = renderHook(() => + usePollingUntil({ + maxRetries: 3, + continueChecker, + onGiveUp, + onMaxRetries, + }) + ); + expect(onMaxRetries).not.toHaveBeenCalled(); + await act(async () => { + result.current.start(); + jest.runOnlyPendingTimers(); + }); + await act(async () => { + jest.runOnlyPendingTimers(); + }); + await act(async () => { + jest.runOnlyPendingTimers(); + }); + expect(continueChecker).toHaveBeenCalledTimes(3); + expect(onMaxRetries).toHaveBeenCalled(); + }); + + it('should not call continueChecker after unmount', async () => { + const continueChecker = jest.fn(() => Promise.resolve(true)); + const { result, unmount } = renderHook(() => + usePollingUntil({ + continueChecker, + onGiveUp: () => {}, + onMaxRetries: () => {}, + }) + ); + await act(async () => { + result.current.start(); + jest.runOnlyPendingTimers(); + }); + expect(continueChecker).toHaveBeenCalledTimes(1); + await act(async () => { + unmount(); + jest.advanceTimersByTime(300); + }); + expect(continueChecker).toHaveBeenCalledTimes(1); + }); + + it('should not call onGiveUp after unmount', async () => { + let continueCheckerResolveFn: Function; + const onGiveUp = jest.fn(); + const continueChecker = () => + new Promise((resolve) => { + continueCheckerResolveFn = () => { + resolve(false); + }; + }); + const { result, unmount } = renderHook(() => + usePollingUntil({ + continueChecker, + onGiveUp, + onMaxRetries: () => {}, + }) + ); + await act(async () => { + result.current.start(); + jest.runOnlyPendingTimers(); + }); + await act(async () => { + unmount(); + }); + await act(async () => { + continueCheckerResolveFn(); + }); + expect(onGiveUp).not.toHaveBeenCalled(); + }); +}); diff --git a/public/hooks/use_polling_until.ts b/public/hooks/use_polling_until.ts new file mode 100644 index 00000000..4ac92148 --- /dev/null +++ b/public/hooks/use_polling_until.ts @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useCallback, useEffect, useRef } from 'react'; + +const delay = async (ms: number) => + await new Promise((resolve) => { + setTimeout(resolve, ms); + }); + +export const usePollingUntil = ({ + pollingGap = 300, + maxRetries = 100, + continueChecker, + onMaxRetries, + onGiveUp, +}: { + pollingGap?: number; + maxRetries?: number; + continueChecker: () => Promise; + onGiveUp: () => void; + onMaxRetries: () => void; +}) => { + const mountedRef = useRef(true); + const continueCheckerRef = useRef(continueChecker); + continueCheckerRef.current = continueChecker; + const pollingTimes = useRef(0); + const onMaxRetiresRef = useRef(onMaxRetries); + onMaxRetiresRef.current = onMaxRetries; + const onGiveUpRef = useRef(onGiveUp); + onGiveUpRef.current = onGiveUp; + + const start = useCallback(async () => { + if (pollingTimes.current >= maxRetries) { + onMaxRetiresRef.current(); + return; + } + await delay(pollingGap); + if (!mountedRef.current) { + return; + } + pollingTimes.current += 1; + const flag = await continueCheckerRef.current(); + if (!mountedRef.current) { + return; + } + if (!flag) { + onGiveUpRef.current(); + return; + } + start(); + }, [pollingGap, maxRetries]); + + useEffect(() => { + return () => { + mountedRef.current = false; + }; + }, []); + + return { + start, + }; +}; diff --git a/public/utils/index.ts b/public/utils/index.ts new file mode 100644 index 00000000..636354ea --- /dev/null +++ b/public/utils/index.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from './table'; diff --git a/public/utils/regex.ts b/public/utils/regex.ts new file mode 100644 index 00000000..a5719ea9 --- /dev/null +++ b/public/utils/regex.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const URL_REGEX = /^(https?|ftp|file):\/\/[-a-zA-Z0-9+&@#/%?=~_\|!:,.;]*[-a-zA-Z0-9+&@#/%=~_\|]/; diff --git a/public/utils/table.ts b/public/utils/table.ts new file mode 100644 index 00000000..ece9a5b9 --- /dev/null +++ b/public/utils/table.ts @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import moment from 'moment'; + +export const DEFAULT_EMPTY_DATA = '-'; + +export const renderTime = (time: string | number, format = 'MM/DD/YY h:mm a') => { + const momentTime = moment(time); + if (time && momentTime.isValid()) return momentTime.format(format); + return DEFAULT_EMPTY_DATA; +}; diff --git a/server/clusters/create_model_cluster.ts b/server/clusters/create_model_cluster.ts new file mode 100644 index 00000000..78b1cd2c --- /dev/null +++ b/server/clusters/create_model_cluster.ts @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { CoreSetup } from '../../../../src/core/server'; + +import modelPlugin from './model_plugin'; +import { CLUSTER } from '../services/utils/constants'; + +export const createModelCluster = (core: CoreSetup) => { + return core.opensearch.legacy.createClient(CLUSTER.MODEL, { + plugins: [modelPlugin], + }); +}; diff --git a/server/clusters/model_plugin.ts b/server/clusters/model_plugin.ts new file mode 100644 index 00000000..40e7bb97 --- /dev/null +++ b/server/clusters/model_plugin.ts @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { API_ROUTE_PREFIX, MODEL_BASE_API, MODEL_PROFILE_API } from '../services/utils/constants'; + +// eslint-disable-next-line import/no-default-export +export default function (Client: any, config: any, components: any) { + const ca = components.clientAction.factory; + + if (!Client.prototype.mlCommonsModel) { + Client.prototype.mlCommonsModel = components.clientAction.namespaceFactory(); + } + + const mlCommonsModel = Client.prototype.mlCommonsModel.prototype; + + mlCommonsModel.search = ca({ + method: 'POST', + url: { + fmt: `${MODEL_BASE_API}/_search`, + }, + needBody: true, + }); + + mlCommonsModel.getOne = ca({ + method: 'GET', + url: { + fmt: `${MODEL_BASE_API}/<%=modelId%>`, + req: { + modelId: { + type: 'string', + required: true, + }, + }, + }, + }); + + mlCommonsModel.delete = ca({ + method: 'DELETE', + url: { + fmt: `${MODEL_BASE_API}/<%=modelId%>`, + req: { + modelId: { + type: 'string', + required: true, + }, + }, + }, + }); + + mlCommonsModel.load = ca({ + method: 'POST', + url: { + fmt: `${MODEL_BASE_API}/<%=modelId%>/_load`, + req: { + modelId: { + type: 'string', + required: true, + }, + }, + }, + }); + + mlCommonsModel.unload = ca({ + method: 'POST', + url: { + fmt: `${MODEL_BASE_API}/<%=modelId%>/_unload`, + req: { + modelId: { + type: 'string', + required: true, + }, + }, + }, + }); + + mlCommonsModel.profile = ca({ + method: 'GET', + url: { + fmt: `${MODEL_PROFILE_API}/<%=modelId%>`, + req: { + modelId: { + type: 'string', + required: true, + }, + }, + }, + }); +} diff --git a/server/plugin.ts b/server/plugin.ts index 5e968077..c6338020 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -11,8 +11,10 @@ import { Logger, } from '../../../src/core/server'; +import { createModelCluster } from './clusters/create_model_cluster'; import { MlCommonsPluginSetup, MlCommonsPluginStart } from './types'; -import { connectorRouter, modelRouter, profileRouter } from './routes'; +import { connectorRouter, modelRouter, profileRouter, modelAggregateRouter } from './routes'; +import { ModelService } from './services'; export class MlCommonsPlugin implements Plugin { private readonly logger: Logger; @@ -25,7 +27,16 @@ export class MlCommonsPlugin implements Plugin { + router.get( + { + path: MODEL_AGGREGATE_API_ENDPOINT, + validate: { + query: schema.object({ + from: schema.number(), + size: schema.number(), + sort: schema.literal('created_time'), + order: schema.oneOf([schema.literal('asc'), schema.literal('desc')]), + name: schema.maybe(schema.string()), + }), + }, + }, + async (context, request) => { + try { + const payload = await ModelAggregateService.search({ + client: context.core.opensearch.client, + ...request.query, + }); + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); +}; diff --git a/server/routes/model_router.ts b/server/routes/model_router.ts index 4a352dd6..ba247721 100644 --- a/server/routes/model_router.ts +++ b/server/routes/model_router.ts @@ -4,13 +4,21 @@ */ import { schema } from '@osd/config-schema'; -import { MODEL_STATE } from '../../common'; -import { IRouter } from '../../../../src/core/server'; -import { ModelService } from '../services'; -import { MODEL_API_ENDPOINT } from './constants'; +import { MAX_MODEL_CHUNK_SIZE, MODEL_STATE } from '../../common'; +import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { ModelService, RecordNotFoundError } from '../services'; +import { + MODEL_API_ENDPOINT, + MODEL_LOAD_API_ENDPOINT, + MODEL_UNLOAD_API_ENDPOINT, + MODEL_UPLOAD_API_ENDPOINT, + MODEL_PROFILE_API_ENDPOINT, +} from './constants'; import { getOpenSearchClientTransport } from './utils'; const modelSortQuerySchema = schema.oneOf([ + schema.literal('version-desc'), + schema.literal('version-asc'), schema.literal('name-asc'), schema.literal('name-desc'), schema.literal('model_state-asc'), @@ -30,7 +38,33 @@ const modelStateSchema = schema.oneOf([ schema.literal(MODEL_STATE.uploading), ]); -export const modelRouter = (router: IRouter) => { +const modelUploadBaseSchema = { + name: schema.string(), + version: schema.string(), + description: schema.string(), + modelFormat: schema.string(), + modelConfig: schema.object({ + modelType: schema.string(), + embeddingDimension: schema.number(), + frameworkType: schema.string(), + }), +}; + +const modelUploadByURLSchema = schema.object({ + ...modelUploadBaseSchema, + url: schema.string(), +}); + +const modelUploadByChunkSchema = schema.object({ + ...modelUploadBaseSchema, + modelTaskType: schema.string(), + modelContentHashValue: schema.string(), + totalChunks: schema.number(), +}); + +export const modelRouter = (services: { modelService: ModelService }, router: IRouter) => { + const { modelService } = services; + router.get( { path: MODEL_API_ENDPOINT, @@ -77,4 +111,171 @@ export const modelRouter = (router: IRouter) => { } } ); + + router.get( + { + path: `${MODEL_API_ENDPOINT}/{modelId}`, + validate: { + params: schema.object({ + modelId: schema.string(), + }), + }, + }, + async (_context, request) => { + try { + const model = await modelService.getOne({ + request, + modelId: request.params.modelId, + }); + return opensearchDashboardsResponseFactory.ok({ body: model }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.delete( + { + path: `${MODEL_API_ENDPOINT}/{modelId}`, + validate: { + params: schema.object({ + modelId: schema.string(), + }), + }, + }, + async (_context, request) => { + try { + await modelService.delete({ + request, + modelId: request.params.modelId, + }); + return opensearchDashboardsResponseFactory.ok(); + } catch (err) { + if (err instanceof RecordNotFoundError) { + return opensearchDashboardsResponseFactory.notFound(); + } + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.post( + { + path: `${MODEL_LOAD_API_ENDPOINT}/{modelId}`, + validate: { + params: schema.object({ + modelId: schema.string(), + }), + }, + }, + async (_context, request) => { + try { + const result = await modelService.load({ + request, + modelId: request.params.modelId, + }); + return opensearchDashboardsResponseFactory.ok({ body: result }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.post( + { + path: `${MODEL_UNLOAD_API_ENDPOINT}/{modelId}`, + validate: { + params: schema.object({ + modelId: schema.string(), + }), + }, + }, + async (_context, request) => { + try { + const result = await modelService.unload({ + request, + modelId: request.params.modelId, + }); + return opensearchDashboardsResponseFactory.ok({ body: result }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.get( + { + path: `${MODEL_PROFILE_API_ENDPOINT}/{modelId}`, + validate: { + params: schema.object({ + modelId: schema.string(), + }), + }, + }, + async (_context, request) => { + try { + const result = await modelService.profile({ + request, + modelId: request.params.modelId, + }); + return opensearchDashboardsResponseFactory.ok({ body: result }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.post( + { + path: MODEL_UPLOAD_API_ENDPOINT, + validate: { + body: schema.oneOf([modelUploadByURLSchema, modelUploadByChunkSchema]), + }, + }, + async (context, request) => { + try { + const body = await ModelService.upload({ + client: context.core.opensearch.client, + model: request.body, + }); + + return opensearchDashboardsResponseFactory.ok({ + body, + }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.post( + { + path: `${MODEL_API_ENDPOINT}/{modelId}/chunk/{chunkId}`, + validate: { + params: schema.object({ + modelId: schema.string(), + chunkId: schema.string(), + }), + body: schema.buffer(), + }, + options: { + body: { + maxBytes: MAX_MODEL_CHUNK_SIZE, + }, + }, + }, + async (context, request) => { + try { + await ModelService.uploadModelChunk({ + client: context.core.opensearch.client, + modelId: request.params.modelId, + chunkId: request.params.chunkId, + chunk: request.body, + }); + return opensearchDashboardsResponseFactory.ok(); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest(err.message); + } + } + ); }; diff --git a/server/services/errors.ts b/server/services/errors.ts new file mode 100644 index 00000000..71cc7d8a --- /dev/null +++ b/server/services/errors.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export class RecordNotFoundError extends Error {} diff --git a/server/services/index.ts b/server/services/index.ts index 26b3a585..249e61ea 100644 --- a/server/services/index.ts +++ b/server/services/index.ts @@ -4,3 +4,4 @@ */ export { ModelService } from './model_service'; +export { RecordNotFoundError } from './errors'; diff --git a/server/services/model_aggregate_service.ts b/server/services/model_aggregate_service.ts new file mode 100644 index 00000000..0c773c09 --- /dev/null +++ b/server/services/model_aggregate_service.ts @@ -0,0 +1,183 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Copyright OpenSearch Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +import { IScopedClusterClient } from '../../../../src/core/server'; +import { MODEL_STATE, OpenSearchModelBase } from '../../common/model'; + +import { MODEL_SEARCH_API } from './utils/constants'; + +const MAX_MODEL_BUCKET_NUM = 10000; + +interface GetAggregateModelsParams { + client: IScopedClusterClient; + from: number; + size: number; + name?: string; + sort: 'created_time'; + order: 'desc' | 'asc'; +} + +export class ModelAggregateService { + public static async getAggregateModels({ + client, + from, + size, + sort, + name, + order, + }: GetAggregateModelsParams) { + const aggregateResult = await client.asCurrentUser.transport.request({ + method: 'GET', + path: MODEL_SEARCH_API, + body: { + size: 0, + query: { + bool: { + must: [...(name ? [{ match: { name } }] : [])], + must_not: { + exists: { + field: 'chunk_number', + }, + }, + }, + }, + aggs: { + models: { + terms: { + field: 'name.keyword', + size: MAX_MODEL_BUCKET_NUM, + }, + aggs: { + latest_version_hits: { + top_hits: { + sort: [ + { + created_time: { + order: 'desc', + }, + }, + ], + size: 1, + _source: ['model_version', 'model_state', 'description', 'created_time'], + }, + }, + }, + }, + }, + }, + }); + const models = aggregateResult.body.aggregations.models.buckets as Array<{ + key: string; + doc_count: number; + latest_version_hits: { + hits: { + hits: [ + { + _source: Pick & { + created_time: number; + description?: string; + }; + } + ]; + }; + }; + }>; + + return { + models: models + .sort( + (a, b) => + ((a.latest_version_hits.hits.hits[0]._source.created_time ?? 0) - + (b.latest_version_hits.hits.hits[0]._source.created_time ?? 0)) * + (sort === 'created_time' && order === 'asc' ? 1 : -1) + ) + .slice(from, from + size), + total_models: models.length, + }; + } + + public static async search(params: GetAggregateModelsParams) { + const { client } = params; + const { models, total_models: totalModels } = await ModelAggregateService.getAggregateModels( + params + ); + const { names, count } = models.reduce<{ names: string[]; count: number }>( + (previous, { key, doc_count: docCount }: { key: string; doc_count: number }) => ({ + names: previous.names.concat(key), + count: docCount + previous.count, + }), + { names: [], count: 0 } + ); + const versionResult = await client.asCurrentUser.transport.request({ + method: 'GET', + path: MODEL_SEARCH_API, + body: { + size: count, + query: { + bool: { + should: names.map((name) => ({ term: { 'name.keyword': name } })), + must_not: { + exists: { + field: 'chunk_number', + }, + }, + }, + }, + _source: ['name', 'model_version', 'model_state', 'model_id'], + }, + }); + const versionResultMap = (versionResult.body.hits.hits as Array<{ + _id: string; + _source: OpenSearchModelBase; + }>).reduce<{ + [key: string]: Array>; + }>( + (pValue, { _source: { name, ...resetProperties } }) => ({ + ...pValue, + [name]: (pValue[name] ?? []).concat(resetProperties), + }), + {} + ); + return { + data: models.map( + ({ + key, + latest_version_hits: { + hits: { hits }, + }, + }) => { + const latestVersion = hits[0]._source; + return { + name: key, + deployed_versions: (versionResultMap[key] ?? []) + .filter((item) => item.model_state === MODEL_STATE.loaded) + .map((item) => item.model_version), + // TODO: Change to the real model owner + owner: key, + latest_version: latestVersion.model_version, + latest_version_state: latestVersion.model_state, + created_time: latestVersion.created_time, + }; + } + ), + total_models: totalModels, + }; + } +} diff --git a/server/services/model_service.ts b/server/services/model_service.ts index 730240f2..ba172bbd 100644 --- a/server/services/model_service.ts +++ b/server/services/model_service.ts @@ -18,18 +18,66 @@ * permissions and limitations under the License. */ -import { OpenSearchClient } from '../../../../src/core/server'; +import { + IScopedClusterClient, + OpenSearchClient, + ScopeableRequest, + ILegacyClusterClient, +} from '../../../../src/core/server'; import { MODEL_STATE, ModelSearchSort } from '../../common'; -import { generateModelSearchQuery } from './utils/model'; -import { MODEL_BASE_API } from './utils/constants'; +import { convertModelSource, generateModelSearchQuery } from './utils/model'; +import { MODEL_BASE_API, MODEL_META_API, MODEL_UPLOAD_API } from './utils/constants'; +import { RecordNotFoundError } from './errors'; const modelSortFieldMapping: { [key: string]: string } = { + version: 'model_version', name: 'name.keyword', id: '_id', }; +interface UploadModelBase { + name: string; + version: string; + description: string; + modelFormat: string; + modelConfig: { + modelType: string; + embeddingDimension: number; + frameworkType: string; + }; +} + +interface UploadModelByURL extends UploadModelBase { + url: string; +} + +interface UploadModelByChunk extends UploadModelBase { + modelTaskType: string; + modelContentHashValue: string; + totalChunks: number; +} + +type UploadResultInner< + T extends UploadModelByURL | UploadModelByChunk +> = T extends UploadModelByChunk + ? { modelId: string; status: string } + : T extends UploadModelByURL + ? { taskId: string; status: string } + : never; + +type UploadResult = Promise>; + +const isUploaModelByURL = (test: UploadModelByURL | UploadModelByChunk): test is UploadModelByURL => + (test as UploadModelByURL).url !== undefined; + export class ModelService { + private osClient: ILegacyClusterClient; + + constructor(osClient: ILegacyClusterClient) { + this.osClient = osClient; + } + public static async search({ from, size, @@ -41,6 +89,7 @@ export class ModelService { from: number; size: number; sort?: ModelSearchSort[]; + name?: string; states?: MODEL_STATE[]; extraQuery?: Record; nameOrId?: string; @@ -75,4 +124,119 @@ export class ModelService { total_models: hits.total.value, }; } + + public async getOne({ request, modelId }: { request: ScopeableRequest; modelId: string }) { + const modelSource = await this.osClient + .asScoped(request) + .callAsCurrentUser('mlCommonsModel.getOne', { + modelId, + }); + return { + id: modelId, + ...convertModelSource(modelSource), + }; + } + + public async delete({ request, modelId }: { request: ScopeableRequest; modelId: string }) { + const { result } = await this.osClient + .asScoped(request) + .callAsCurrentUser('mlCommonsModel.delete', { + modelId, + }); + if (result === 'not_found') { + throw new RecordNotFoundError(); + } + return true; + } + + public async load({ request, modelId }: { request: ScopeableRequest; modelId: string }) { + const result = await this.osClient.asScoped(request).callAsCurrentUser('mlCommonsModel.load', { + modelId, + }); + return result; + } + + public async unload({ request, modelId }: { request: ScopeableRequest; modelId: string }) { + const result = await this.osClient + .asScoped(request) + .callAsCurrentUser('mlCommonsModel.unload', { + modelId, + }); + return result; + } + + public async profile({ request, modelId }: { request: ScopeableRequest; modelId: string }) { + const result = await this.osClient + .asScoped(request) + .callAsCurrentUser('mlCommonsModel.profile', { + modelId, + }); + return result; + } + + public static async upload({ + client, + model, + }: { + client: IScopedClusterClient; + model: T; + }): UploadResult { + const { name, version, description, modelFormat, modelConfig } = model; + const uploadModelBase = { + name, + version, + description, + model_format: modelFormat, + model_config: { + model_type: modelConfig.modelType, + embedding_dimension: modelConfig.embeddingDimension, + framework_type: modelConfig.frameworkType, + }, + }; + if (isUploaModelByURL(model)) { + const { task_id: taskId, status } = ( + await client.asCurrentUser.transport.request({ + method: 'POST', + path: MODEL_UPLOAD_API, + body: { + ...uploadModelBase, + url: model.url, + }, + }) + ).body; + return { taskId, status } as UploadResultInner; + } + + const { model_id: modelId, status } = ( + await client.asCurrentUser.transport.request({ + method: 'POST', + path: MODEL_META_API, + body: { + ...uploadModelBase, + model_task_type: model.modelTaskType, + model_content_hash_value: model.modelContentHashValue, + total_chunks: model.totalChunks, + }, + }) + ).body; + return { modelId, status } as UploadResultInner; + } + + public static async uploadModelChunk({ + client, + modelId, + chunkId, + chunk, + }: { + client: IScopedClusterClient; + modelId: string; + chunkId: string; + chunk: Buffer; + }) { + return client.asCurrentUser.transport.request({ + method: 'POST', + path: `${MODEL_BASE_API}/${modelId}/chunk/${chunkId}`, + body: chunk, + }); + } } diff --git a/server/services/utils/constants.ts b/server/services/utils/constants.ts index bda18b7e..f5129b17 100644 --- a/server/services/utils/constants.ts +++ b/server/services/utils/constants.ts @@ -22,6 +22,16 @@ export const API_ROUTE_PREFIX = '/_plugins/_ml'; export const PROFILE_BASE_API = `${API_ROUTE_PREFIX}/profile`; export const MODEL_BASE_API = `${API_ROUTE_PREFIX}/models`; export const MODEL_SEARCH_API = `${MODEL_BASE_API}/_search`; +export const MODEL_UPLOAD_API = `${MODEL_BASE_API}/_upload`; +export const MODEL_META_API = `${MODEL_BASE_API}/meta`; +export const MODEL_PROFILE_API = `${PROFILE_BASE_API}/models`; + +export const CLUSTER = { + TRAIN: 'opensearch_mlCommonsTrain', + MODEL: 'opensearch_mlCommonsModel', + TASK: 'opensearch_mlCommonsTask', + PREDICT: 'opensearch_mlCommonsPredict', +}; export const CONNECTOR_BASE_API = `${API_ROUTE_PREFIX}/connectors`; export const CONNECTOR_SEARCH_API = `${CONNECTOR_BASE_API}/_search`; diff --git a/server/services/utils/model.ts b/server/services/utils/model.ts index 5196445a..0643f97a 100644 --- a/server/services/utils/model.ts +++ b/server/services/utils/model.ts @@ -6,17 +6,48 @@ import { MODEL_STATE } from '../../../common'; import { generateTermQuery } from './query'; +export const convertModelSource = (source: { + model_content: string; + name: string; + algorithm: string; + model_state: string; + model_version: string; +}) => ({ + content: source.model_content, + name: source.name, + algorithm: source.algorithm, + state: source.model_state, + version: source.model_version, +}); + export const generateModelSearchQuery = ({ + ids, + algorithms, + name, states, nameOrId, extraQuery, }: { + ids?: string[]; + algorithms?: string[]; + name?: string; states?: MODEL_STATE[]; nameOrId?: string; extraQuery?: Record; }) => ({ bool: { must: [ + ...(ids ? [{ ids: { values: ids } }] : []), + ...(algorithms ? [generateTermQuery('algorithm', algorithms)] : []), + ...(name + ? [ + { + term: { + 'name.keyword': name, + }, + }, + ] + : []), ...(states ? [generateTermQuery('model_state', states)] : []), ...(nameOrId ? [ @@ -43,3 +74,36 @@ export const generateModelSearchQuery = ({ }, }, }); + +export interface UploadModel { + name: string; + version: string; + description: string; + modelFormat: string; + modelConfig: { + modelType: string; + membeddingDimension: number; + frameworkType: string; + }; + url?: string; +} + +export const convertUploadModel = ({ + name, + version, + description, + modelFormat, + modelConfig, + url, +}: UploadModel) => ({ + name, + version, + description, + model_format: modelFormat, + model_config: { + model_type: modelConfig.modelType, + embedding_dimension: modelConfig.membeddingDimension, + framework_type: modelConfig.frameworkType, + }, + url, +}); diff --git a/yarn.lock b/yarn.lock index be3d6867..fd5e7963 100644 --- a/yarn.lock +++ b/yarn.lock @@ -28,6 +28,18 @@ resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.4.3.tgz#af975e367743fa91989cd666666aec31a8f50591" integrity sha512-kCUc5MEwaEMakkO5x7aoD+DLi02ehmEM2QCGWvNqAS1dV/fAvORWEjnjsEIvml59M7Y5kCkWN6fCCyPOe8OL6Q== +"@types/node@*": + version "18.7.14" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.7.14.tgz#0fe081752a3333392d00586d815485a17c2cf3c9" + integrity sha512-6bbDaETVi8oyIARulOE9qF1/Qdi/23z6emrUh0fNJRUmjznqrixD4MpGDdgOFk5Xb0m2H6Xu42JGdvAxaJR/wA== + +"@types/papaparse@^5.3.5": + version "5.3.5" + resolved "https://registry.yarnpkg.com/@types/papaparse/-/papaparse-5.3.5.tgz#e5ad94b1fe98e2a8ea0b03284b83d2cb252bbf39" + integrity sha512-R1icl/hrJPFRpuYj9PVG03WBAlghJj4JW9Py5QdR8FFSxaLmZRyu7xYDCCBZIJNfUv3MYaeBbhBoX958mUTAaw== + dependencies: + "@types/node" "*" + "@types/parse-json@^4.0.0": version "4.0.0" resolved "https://registry.yarnpkg.com/@types/parse-json/-/parse-json-4.0.0.tgz#2f8bb441434d163b35fb8ffdccd7138927ffb8c0" @@ -267,6 +279,11 @@ has-flag@^4.0.0: resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== +hash-wasm@^4.9.0: + version "4.9.0" + resolved "https://registry.yarnpkg.com/hash-wasm/-/hash-wasm-4.9.0.tgz#7e9dcc9f7d6bd0cc802f2a58f24edce999744206" + integrity sha512-7SW7ejyfnRxuOc7ptQHSf4LDoZaWOivfzqw+5rpcQku0nHfmicPKE51ra9BiRLAmT8+gGLestr1XroUkqdjL6w== + human-signals@^1.1.1: version "1.1.1" resolved "https://registry.yarnpkg.com/human-signals/-/human-signals-1.1.1.tgz#c5b1cd14f50aeae09ab6c59fe63ba3395fe4dfa3" @@ -454,6 +471,11 @@ p-map@^4.0.0: dependencies: aggregate-error "^3.0.0" +papaparse@^5.3.2: + version "5.3.2" + resolved "https://registry.yarnpkg.com/papaparse/-/papaparse-5.3.2.tgz#d1abed498a0ee299f103130a6109720404fbd467" + integrity sha512-6dNZu0Ki+gyV0eBsFKJhYr+MdQYAzFUGlBMNj3GNrmHxmz1lfRa24CjFObPXtjcetlOv5Ad299MhIK0znp3afw== + parent-module@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/parent-module/-/parent-module-1.0.1.tgz#691d2709e78c79fae3a156622452d00762caaaa2" @@ -501,6 +523,11 @@ pump@^3.0.0: end-of-stream "^1.1.0" once "^1.3.1" +react-hook-form@^7.39.4: + version "7.39.4" + resolved "https://registry.yarnpkg.com/react-hook-form/-/react-hook-form-7.39.4.tgz#7d9edf4e778a0cec4383f0119cd0699e3826a14a" + integrity sha512-B0e78r9kR9L2M4A4AXGbHoA/vyv34sB/n8QWJAw33TFz8f5t9helBbYAeqnbvcQf1EYzJxKX/bGQQh9K+evCyQ== + resolve-from@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/resolve-from/-/resolve-from-4.0.0.tgz#4abcd852ad32dd7baabfe9b40e00a36db5f392e6" From b5ef886720605b3ddc6887a7cf5dc3f0c68a9c59 Mon Sep 17 00:00:00 2001 From: wanglam Date: Mon, 30 Jan 2023 09:50:19 +0800 Subject: [PATCH 02/75] Feat add show my models in owner filter (#29) * feat: add security account API Signed-off-by: Lin Wang * feat: add Show Only My Models button to OwnerFilter Signed-off-by: Lin Wang * test: split owner filter select test cases Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- common/security.ts | 20 ++++++ public/apis/__mocks__/security.ts | 10 +++ public/apis/api_provider.ts | 15 ++++- public/apis/security.ts | 16 +++++ .../__tests__/model_filter.test.tsx | 20 +++++- .../__tests__/owner_filter.test.tsx | 63 +++++++++++++++++-- public/components/model_list/model_filter.tsx | 11 +++- public/components/model_list/owner_filter.tsx | 35 ++++++++++- server/clusters/model_plugin.ts | 2 +- server/plugin.ts | 9 ++- server/routes/constants.ts | 3 + server/routes/index.ts | 1 + server/routes/security_router.ts | 27 ++++++++ server/services/security_service.ts | 19 ++++++ server/services/utils/constants.ts | 12 ++-- 15 files changed, 246 insertions(+), 17 deletions(-) create mode 100644 common/security.ts create mode 100644 public/apis/__mocks__/security.ts create mode 100644 public/apis/security.ts create mode 100644 server/routes/security_router.ts create mode 100644 server/services/security_service.ts diff --git a/common/security.ts b/common/security.ts new file mode 100644 index 00000000..06ff5570 --- /dev/null +++ b/common/security.ts @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export interface OpenSearchSecurityAccount { + user_name: string; + is_reserved: boolean; + is_hidden: boolean; + is_interval_user: boolean; + user_required_tenant: null; + backed_roles: string[]; + custom_attribute_names: string[]; + tenants: { + global_tenant: boolean; + admin_tenant: true; + admin: true; + }; + roles: string[]; +} diff --git a/public/apis/__mocks__/security.ts b/public/apis/__mocks__/security.ts new file mode 100644 index 00000000..a91be589 --- /dev/null +++ b/public/apis/__mocks__/security.ts @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export class Security { + public getAccount() { + return Promise.resolve({ user_name: 'admin' }); + } +} diff --git a/public/apis/api_provider.ts b/public/apis/api_provider.ts index d1468d4d..a5a7d1b3 100644 --- a/public/apis/api_provider.ts +++ b/public/apis/api_provider.ts @@ -7,17 +7,20 @@ import { Connector } from './connector'; import { Model } from './model'; import { ModelAggregate } from './model_aggregate'; import { Profile } from './profile'; +import { Security } from './security'; const apiInstanceStore: { model: Model | undefined; modelAggregate: ModelAggregate | undefined; profile: Profile | undefined; connector: Connector | undefined; + security: Security | undefined; } = { model: undefined, modelAggregate: undefined, profile: undefined, connector: undefined, + security: undefined, }; export class APIProvider { @@ -25,6 +28,7 @@ export class APIProvider { public static getAPI(type: 'modelAggregate'): ModelAggregate; public static getAPI(type: 'profile'): Profile; public static getAPI(type: 'connector'): Connector; + public static getAPI(type: 'security'): Security; public static getAPI(type: keyof typeof apiInstanceStore) { if (apiInstanceStore[type]) { return apiInstanceStore[type]!; @@ -50,11 +54,16 @@ export class APIProvider { apiInstanceStore.connector = newInstance; return newInstance; } + case 'security': { + const newInstance = new Security(); + apiInstanceStore.security = newInstance; + return newInstance; + } } } public static clear() { - apiInstanceStore.model = undefined; - apiInstanceStore.profile = undefined; - apiInstanceStore.connector = undefined; + Object.keys(apiInstanceStore).forEach((key) => { + apiInstanceStore[key as keyof typeof apiInstanceStore] = undefined; + }); } } diff --git a/public/apis/security.ts b/public/apis/security.ts new file mode 100644 index 00000000..bfa6b5a8 --- /dev/null +++ b/public/apis/security.ts @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { InnerHttpProvider } from './inner_http_provider'; +import { OpenSearchSecurityAccount } from '../../common/security'; +import { SECURITY_ACCOUNT_API_ENDPOINT } from '../../server/routes/constants'; + +export class Security { + public getAccount() { + return InnerHttpProvider.getHttp().get( + SECURITY_ACCOUNT_API_ENDPOINT + ); + } +} diff --git a/public/components/model_list/__tests__/model_filter.test.tsx b/public/components/model_list/__tests__/model_filter.test.tsx index 7a246725..e4b63c33 100644 --- a/public/components/model_list/__tests__/model_filter.test.tsx +++ b/public/components/model_list/__tests__/model_filter.test.tsx @@ -2,6 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ +jest.mock('../../../apis/security'); import React from 'react'; import userEvent from '@testing-library/user-event'; @@ -42,7 +43,7 @@ describe('', () => { expect(screen.getByText('2')).toBeInTheDocument(); }); - it('should render options filter after click tags', async () => { + it('should render options filter after filter button clicked', async () => { render( ', () => { expect(screen.getByPlaceholderText('Search Tags')).toBeInTheDocument(); }); + it('should render passed footer after filter button clicked', async () => { + const { getByText, queryByText } = render( + {}} + footer="footer" + /> + ); + expect(queryByText('footer')).not.toBeInTheDocument(); + + await userEvent.click(screen.getByText('Tags')); + expect(getByText('footer')).toBeInTheDocument(); + }); + it('should only show "bar" after search', async () => { render( ', () => { @@ -13,9 +16,61 @@ describe('', () => { jest.resetAllMocks(); }); - it('should render "Owner" with 0 active filter for normal', () => { + it('should render "Owner" with 3 filter for normal', async () => { + const { getByText, findByText } = render( {}} />); + expect(getByText('Owner')).toBeInTheDocument(); + expect(await findByText('3')).toBeInTheDocument(); + }); + + it('should render three options with 1 checked option and 1 active filter', async () => { + render( {}} />); + expect(screen.getByText('1')).toBeInTheDocument(); + await userEvent.click(screen.getByText('Owner')); + const allOptions = screen.getAllByRole('option'); + expect(allOptions.length).toBe(3); + expect(await within(allOptions[0]).getByText('admin (Me)')).toBeInTheDocument(); + expect( + within(allOptions[0]).getByRole('img', { hidden: true }).querySelector('path') + ).toBeInTheDocument(); + expect(within(allOptions[1]).getByText('owner-1')).toBeInTheDocument(); + expect(within(allOptions[2]).getByText('owner-2')).toBeInTheDocument(); + }); + + it('should render "Show Only My Models" button with (Me) option', async () => { + render( {}} />); + await userEvent.click(screen.getByText('Owner')); + expect(screen.getByText('Show Only My Models')).toBeInTheDocument(); + expect(screen.getByText('admin (Me)')).toBeInTheDocument(); + }); + + it('should call onChange with user selection', async () => { + const onChangeMock = jest.fn(); + const { rerender } = render(); + await userEvent.click(screen.getByText('Owner')); + await userEvent.click(screen.getByText('owner-1')); + expect(onChangeMock).toHaveBeenCalledWith(['owner-1']); + onChangeMock.mockClear(); + + rerender(); + await userEvent.click(screen.getByText('owner-1')); + expect(onChangeMock).toHaveBeenCalledWith([]); + }); + + it('should call onChange with current user', async () => { + const onChangeMock = jest.fn(); + render(); + await userEvent.click(screen.getByText('Owner')); + await userEvent.click(screen.getByText('Show Only My Models')); + expect(onChangeMock).toHaveBeenCalledWith(['admin']); + }); + + it('should NOT render "Show Only My Models" button and "(Me)" option after error fetch account', async () => { + jest + .spyOn(APIProvider.getAPI('security'), 'getAccount') + .mockRejectedValue(new Error('Failed to fetch account')); render( {}} />); - expect(screen.getByText('Owner')).toBeInTheDocument(); - expect(screen.getByText('0')).toBeInTheDocument(); + await userEvent.click(screen.getByText('Owner')); + expect(screen.queryByText('Show Only My Models')).not.toBeInTheDocument(); + expect(screen.queryByText('admin (Me)')).not.toBeInTheDocument(); }); }); diff --git a/public/components/model_list/model_filter.tsx b/public/components/model_list/model_filter.tsx index c01e7227..d8506891 100644 --- a/public/components/model_list/model_filter.tsx +++ b/public/components/model_list/model_filter.tsx @@ -4,7 +4,13 @@ */ import React, { useCallback, useMemo, useRef, useState } from 'react'; -import { EuiPopover, EuiPopoverTitle, EuiFieldSearch, EuiFilterButton } from '@elastic/eui'; +import { + EuiPopover, + EuiPopoverTitle, + EuiFieldSearch, + EuiFilterButton, + EuiPopoverFooter, +} from '@elastic/eui'; import { ModelFilterItem } from './model_filter_item'; export interface ModelFilterProps { @@ -13,11 +19,13 @@ export interface ModelFilterProps { options: Array; value: string[]; onChange: (value: string[]) => void; + footer?: React.ReactNode; } export const ModelFilter = ({ name, value, + footer, options, searchPlaceholder, onChange, @@ -93,6 +101,7 @@ export const ModelFilter = ({
); })} + {footer && {footer}} ); }; diff --git a/public/components/model_list/owner_filter.tsx b/public/components/model_list/owner_filter.tsx index 60254bf2..67f7eeb4 100644 --- a/public/components/model_list/owner_filter.tsx +++ b/public/components/model_list/owner_filter.tsx @@ -3,17 +3,48 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React from 'react'; +import React, { useCallback, useMemo } from 'react'; +import { EuiButton } from '@elastic/eui'; import { ModelFilter, ModelFilterProps } from './model_filter'; +import { useFetcher } from '../../hooks/use_fetcher'; +import { APIProvider } from '../../apis/api_provider'; + +const ownerFetcher = () => Promise.resolve(['admin', 'owner-1', 'owner-2']); export const OwnerFilter = ({ value, onChange }: Pick) => { + const { data: accountData } = useFetcher(APIProvider.getAPI('security').getAccount); + const { data: ownerData } = useFetcher(ownerFetcher); + const currentAccountName = accountData?.user_name; + const options = useMemo( + () => + (ownerData ?? []).map((owner) => ({ + name: owner === currentAccountName ? `${owner} (Me)` : owner, + value: owner, + })), + [ownerData, currentAccountName] + ); + + const handleOnlyMyModelsClick = useCallback(() => { + if (!currentAccountName) { + return; + } + onChange([currentAccountName]); + }, [currentAccountName, onChange]); + return ( + Show Only My Models + + ) + } /> ); }; diff --git a/server/clusters/model_plugin.ts b/server/clusters/model_plugin.ts index 40e7bb97..40151a31 100644 --- a/server/clusters/model_plugin.ts +++ b/server/clusters/model_plugin.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { API_ROUTE_PREFIX, MODEL_BASE_API, MODEL_PROFILE_API } from '../services/utils/constants'; +import { MODEL_BASE_API, MODEL_PROFILE_API } from '../services/utils/constants'; // eslint-disable-next-line import/no-default-export export default function (Client: any, config: any, components: any) { diff --git a/server/plugin.ts b/server/plugin.ts index c6338020..e106437b 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -13,7 +13,13 @@ import { import { createModelCluster } from './clusters/create_model_cluster'; import { MlCommonsPluginSetup, MlCommonsPluginStart } from './types'; -import { connectorRouter, modelRouter, profileRouter, modelAggregateRouter } from './routes'; +import { + connectorRouter, + modelRouter, + profileRouter, + modelAggregateRouter, + securityRouter, +} from './routes'; import { ModelService } from './services'; export class MlCommonsPlugin implements Plugin { @@ -39,6 +45,7 @@ export class MlCommonsPlugin implements Plugin { + router.get( + { + path: SECURITY_ACCOUNT_API_ENDPOINT, + validate: false, + }, + async (context) => { + try { + const body = await SecurityService.getAccount({ + client: context.core.opensearch.client, + }); + return opensearchDashboardsResponseFactory.ok({ body }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ body: error as Error }); + } + } + ); +}; diff --git a/server/services/security_service.ts b/server/services/security_service.ts new file mode 100644 index 00000000..9721fd00 --- /dev/null +++ b/server/services/security_service.ts @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { OpenSearchSecurityAccount } from '../../common/security'; +import { IScopedClusterClient } from '../../../../src/core/server/opensearch/client'; +import { SECURITY_ACCOUNT_API } from './utils/constants'; + +export class SecurityService { + public static async getAccount({ client }: { client: IScopedClusterClient }) { + return ( + await client.asCurrentUser.transport.request({ + method: 'GET', + path: SECURITY_ACCOUNT_API, + }) + ).body as OpenSearchSecurityAccount; + } +} diff --git a/server/services/utils/constants.ts b/server/services/utils/constants.ts index f5129b17..bc7deab9 100644 --- a/server/services/utils/constants.ts +++ b/server/services/utils/constants.ts @@ -18,13 +18,17 @@ * permissions and limitations under the License. */ -export const API_ROUTE_PREFIX = '/_plugins/_ml'; -export const PROFILE_BASE_API = `${API_ROUTE_PREFIX}/profile`; -export const MODEL_BASE_API = `${API_ROUTE_PREFIX}/models`; +export const ML_COMMONS_API_PREFIX = '/_plugins/_ml'; +export const PROFILE_BASE_API = `${ML_COMMONS_API_PREFIX}/profile`; +export const MODEL_BASE_API = `${ML_COMMONS_API_PREFIX}/models`; export const MODEL_SEARCH_API = `${MODEL_BASE_API}/_search`; export const MODEL_UPLOAD_API = `${MODEL_BASE_API}/_upload`; export const MODEL_META_API = `${MODEL_BASE_API}/meta`; export const MODEL_PROFILE_API = `${PROFILE_BASE_API}/models`; +export const PREDICT_BASE_API = `${ML_COMMONS_API_PREFIX}/_predict`; + +export const SECURITY_API_PREFIX = '/_plugins/_security/api'; +export const SECURITY_ACCOUNT_API = `${SECURITY_API_PREFIX}/account`; export const CLUSTER = { TRAIN: 'opensearch_mlCommonsTrain', @@ -33,7 +37,7 @@ export const CLUSTER = { PREDICT: 'opensearch_mlCommonsPredict', }; -export const CONNECTOR_BASE_API = `${API_ROUTE_PREFIX}/connectors`; +export const CONNECTOR_BASE_API = `${ML_COMMONS_API_PREFIX}/connectors`; export const CONNECTOR_SEARCH_API = `${CONNECTOR_BASE_API}/_search`; export const MODEL_INDEX = '.plugins-ml-model'; From 3a6000801a7b5315e4e9f87422458700008d89af Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Tue, 31 Jan 2023 15:41:35 +0800 Subject: [PATCH 03/75] bring in-app navigation bar back (#72) And also fix an issue that refresh the page always redirect to root path Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- common/router.ts | 12 ++ common/router_paths.ts | 1 + public/ace-themes/sql_console.js | 18 +++ public/index.scss | 183 +++++++++++++++++++++++++++++++ 4 files changed, 214 insertions(+) create mode 100644 public/ace-themes/sql_console.js create mode 100644 public/index.scss diff --git a/common/router.ts b/common/router.ts index e8d2a37d..7e87ac0f 100644 --- a/common/router.ts +++ b/common/router.ts @@ -3,7 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { ModelList } from '../public/components/model_list'; import { Monitoring } from '../public/components/monitoring'; +import { RegisterModelForm } from '../public/components/register_model'; import { routerPaths } from './router_paths'; interface RouteConfig { @@ -19,6 +21,16 @@ export const ROUTES: RouteConfig[] = [ Component: Monitoring, label: 'Overview', }, + { + path: routerPaths.registerModel, + label: 'Register Model', + Component: RegisterModelForm, + }, + { + path: routerPaths.modelList, + label: 'Model List', + Component: ModelList, + }, ]; /* export const ROUTES1 = [ diff --git a/common/router_paths.ts b/common/router_paths.ts index 9ca63c68..e171781b 100644 --- a/common/router_paths.ts +++ b/common/router_paths.ts @@ -8,4 +8,5 @@ export const routerPaths = { overview: '/overview', monitoring: '/monitoring', registerModel: '/model-registry/register-model', + modelList: '/model-registry/model-list', }; diff --git a/public/ace-themes/sql_console.js b/public/ace-themes/sql_console.js new file mode 100644 index 00000000..fa7b3668 --- /dev/null +++ b/public/ace-themes/sql_console.js @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as ace from 'brace'; + +ace.define('ace/theme/sql_console', ['require', 'exports', 'module', 'ace/lib/dom'], function ( + acequire, + exports +) { + exports.isDark = false; + exports.cssClass = 'ace-sql-console'; + exports.cssText = require('../index.scss'); + + const dom = acequire('../lib/dom'); + dom.importCssString(exports.cssText, exports.cssClass); +}); diff --git a/public/index.scss b/public/index.scss new file mode 100644 index 00000000..945af8c8 --- /dev/null +++ b/public/index.scss @@ -0,0 +1,183 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* stylelint-disable no-empty-source */ +.ace-sql-console { + .ace_gutter { + background: rgb(237, 239, 243); + color: rgb(0, 0, 0); + } + + .ace_print-margin { + width: 1px; + background: #e8e8e8; + } + + .ace_fold { + background-color: #6b72e6; + } + + background-color: rgb(245, 247, 250); + color: black; + + .ace_marker-layer { + .ace_active-line.ace_active-line { + background-color: rgb(211, 218, 230); + } + + .ace_selection { + background: rgb(181, 213, 255); + } + + .ace_step { + background: rgb(252, 255, 0); + } + + .ace_stack { + background: rgb(164, 229, 101); + } + + .ace_bracket { + margin: -1px 0 0 -1px; + border: 1px solid rgb(192, 192, 192); + } + + .ace_active-line { + background: rgba(0, 0, 0, 0.07); + } + + .ace_selected-word { + background: rgb(250, 250, 255); + border: 1px solid rgb(200, 200, 250); + } + } + + .ace_cursor { + color: black; + } + + .ace_invisible { + color: rgb(191, 191, 191); + } + + .ace_storage { + color: rgb(157, 106, 242); + } + + .ace_keyword { + color: rgb(157, 106, 242); + } + + .ace_constant { + color: rgb(197, 6, 11); + } + + .ace_constant.ace_buildin { + color: rgb(88, 72, 246); + } + + .ace_constant.ace_language { + color: rgb(88, 92, 246); + } + + .ace_constant.ace_library { + color: rgb(6, 150, 14); + } + + .ace_invalid { + background-color: rgba(255, 0, 0, 0.1); + color: red; + } + + .ace_support.ace_function { + color: rgb(60, 76, 114); + } + + .ace_support.ace_constant { + color: rgb(6, 150, 14); + } + + .ace_support.ace_type { + color: rgb(109, 121, 222); + } + + .ace_support.ace_class { + color: rgb(109, 121, 222); + } + + .ace_keyword.ace_operator { + color: rgb(104, 118, 135); + } + + .ace_string { + color: rgb(3, 106, 7); + } + + .ace_comment { + color: rgb(76, 136, 107); + } + + .ace_comment.ace_doc { + color: rgb(0, 102, 255); + } + + .ace_comment.ace_doc.ace_tag { + color: rgb(128, 159, 191); + } + + .ace_constant.ace_numeric { + color: rgb(0, 0, 205); + } + + .ace_variable { + color: rgb(49, 132, 149); + } + + .ace_xml-pe { + color: rgb(104, 104, 91); + } + + .ace_entity.ace_name.ace_function { + color: #0000a2; + } + + .ace_heading { + color: rgb(12, 7, 255); + } + + .ace_list { + color: rgb(185, 6, 144); + } + + .ace_meta.ace_tag { + color: rgb(0, 22, 142); + } + + .ace_string.ace_regex { + color: rgb(255, 0, 0); + } + + .ace_gutter-active-line { + background-color: rgb(211, 218, 230); + } + + .ace_indent-guide { + background: url('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAACCAYAAACZgbYnAAAAE0lEQVQImWP4////f4bLly//BwAmVgd1/w11/gAAAABJRU5ErkJggg==') + right repeat-y; + } +} + +.ace-sql-console.ace_multiselect { + .ace_selection.ace_start { + box-shadow: 0 0 3px 0px white; + } +} + +.ace_editor { + .ace-sql-console { + height: 200px; + } +} + From 7048cfcc13ee3ee4237a93754351abbe19cd731c Mon Sep 17 00:00:00 2001 From: raintygao Date: Wed, 1 Feb 2023 10:20:59 +0800 Subject: [PATCH 04/75] feat: add more actions in model list (#77) * feat: add more actions in model list Signed-off-by: raintygao * fix: define constant variable without using useMemo Signed-off-by: raintygao --------- Signed-off-by: raintygao Signed-off-by: Lin Wang --- common/model.ts | 1 + public/components/model_drawer/index.tsx | 13 ++--- .../components/model_drawer/version_table.tsx | 57 +++++++++---------- public/components/model_list/model_table.tsx | 17 +++++- 4 files changed, 50 insertions(+), 38 deletions(-) diff --git a/common/model.ts b/common/model.ts index 6a274c80..590dd497 100644 --- a/common/model.ts +++ b/common/model.ts @@ -27,6 +27,7 @@ export interface OpenSearchSelfTrainedModel extends OpenSearchModelBase { } export interface OpenSearchCustomerModel extends OpenSearchModelBase { + algorithm: string; chunk_number: number; created_time: number; description: string; diff --git a/public/components/model_drawer/index.tsx b/public/components/model_drawer/index.tsx index b7b4346b..e6b3c4a7 100644 --- a/public/components/model_drawer/index.tsx +++ b/public/components/model_drawer/index.tsx @@ -9,18 +9,16 @@ import { EuiFlyoutBody, EuiFlyoutHeader, EuiTitle, - EuiLink, EuiSpacer, EuiFlexGroup, EuiFlexItem, EuiDescriptionList, } from '@elastic/eui'; -import { generatePath, useHistory } from 'react-router-dom'; import { APIProvider } from '../../apis/api_provider'; import { useFetcher } from '../../hooks/use_fetcher'; import { routerPaths } from '../../../common/router_paths'; import { VersionTable } from './version_table'; -import { EuiLinkButton, EuiCustomLink } from '../common'; +import { EuiLinkButton } from '../common'; export type VersionTableSort = 'version-desc' | 'version-asc'; @@ -58,10 +56,11 @@ export const ModelDrawer = ({ onClose, name }: Props) => { {latestVersion.id ? ( <> - , - + + {/* TODO: update after exsiting detail page */} + {/* View Full Details - + */} ) : null} @@ -74,7 +73,7 @@ export const ModelDrawer = ({ onClose, name }: Props) => { - + Register new version diff --git a/public/components/model_drawer/version_table.tsx b/public/components/model_drawer/version_table.tsx index 959f57af..f3ed69ea 100644 --- a/public/components/model_drawer/version_table.tsx +++ b/public/components/model_drawer/version_table.tsx @@ -4,11 +4,9 @@ */ import React, { useMemo, useCallback, useRef } from 'react'; -import { generatePath, useHistory } from 'react-router-dom'; -import { EuiBasicTable, Direction, Criteria } from '@elastic/eui'; +import { EuiBasicTable, Direction, Criteria, EuiBasicTableColumn } from '@elastic/eui'; import { ModelSearchItem } from '../../apis/model'; -import { routerPaths } from '../../../common/router_paths'; import { renderTime } from '../../utils'; import type { VersionTableSort } from './'; @@ -22,41 +20,40 @@ export function VersionTable(props: { onChange: (criteria: VersionTableCriteria) => void; }) { const { models, sort, onChange } = props; - const history = useHistory(); const onChangeRef = useRef(onChange); onChangeRef.current = onChange; - const columns = useMemo( - () => [ - { - field: 'version', - name: 'Version', - sortable: true, + const columns: Array> = [ + { + field: 'model_version', + name: 'Version', + sortable: false, + }, + { + field: 'model_state', + name: 'Stage', + }, + { + field: 'algorithm', + name: 'Algorithm', + }, + { + field: 'created_time', + name: 'Time', + render: (time: string) => { + return renderTime(time); }, - { - field: 'state', - name: 'Stage', - }, - { - field: 'algorithm', - name: 'Algorithm', - }, - { - field: 'created_time', - name: 'Time', - render: renderTime, - sortable: true, - }, - ], - [] - ); + sortable: false, + }, + ]; const rowProps = useCallback( ({ id }) => ({ onClick: () => { - history.push(generatePath(routerPaths.modelDetail, { id })); + // TODO: update after exsiting detail page + // history.push(generatePath(routerPaths.modelDetail, { id })); }, }), - [history] + [] ); const sorting = useMemo(() => { @@ -78,7 +75,7 @@ export function VersionTable(props: { }, []); return ( - columns={columns} items={models} rowProps={rowProps} diff --git a/public/components/model_list/model_table.tsx b/public/components/model_list/model_table.tsx index efaf46c4..16fee0bc 100644 --- a/public/components/model_list/model_table.tsx +++ b/public/components/model_list/model_table.tsx @@ -128,6 +128,21 @@ export function ModelTable(props: ModelTableProps) { ), sortable: true, }, + { + name: 'Actions', + actions: [ + // TODO: add a new task to update after design completed + { + name: 'Prevew', + description: 'Preview model group', + type: 'icon', + icon: 'boxesHorizontal', + onClick: ({ name }) => { + onModelNameClick(name); + }, + }, + ], + }, ], [onModelNameClick] ); @@ -155,7 +170,7 @@ export function ModelTable(props: ModelTableProps) { }, []); return ( - + columns={columns} items={models} pagination={pagination} From 67a806627f15bab4cb6ab39202857f74154ecb3d Mon Sep 17 00:00:00 2001 From: raintygao Date: Thu, 2 Feb 2023 18:52:26 +0800 Subject: [PATCH 05/75] populate register form with existing model version (#82) Signed-off-by: Lin Wang --- common/router_paths.ts | 2 +- public/apis/model.ts | 6 ++ public/components/model_drawer/index.tsx | 14 +++-- .../register_model_artifact.test.tsx | 7 +++ .../__tests__/register_model_details.test.tsx | 7 +++ .../__tests__/register_model_form.test.tsx | 60 +++++++++++++++++++ .../__tests__/register_model_metrics.test.tsx | 7 +++ .../__tests__/register_model_tags.test.tsx | 7 +++ .../register_model/register_model.tsx | 29 ++++++++- public/utils/index.ts | 1 + public/utils/version.ts | 16 +++++ 11 files changed, 148 insertions(+), 8 deletions(-) create mode 100644 public/components/register_model/__tests__/register_model_form.test.tsx create mode 100644 public/utils/version.ts diff --git a/common/router_paths.ts b/common/router_paths.ts index e171781b..3282976f 100644 --- a/common/router_paths.ts +++ b/common/router_paths.ts @@ -7,6 +7,6 @@ export const routerPaths = { root: '/', overview: '/overview', monitoring: '/monitoring', - registerModel: '/model-registry/register-model', + registerModel: '/model-registry/register-model/:id?', modelList: '/model-registry/model-list', }; diff --git a/public/apis/model.ts b/public/apis/model.ts index 1d82a7ad..36748370 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -27,6 +27,12 @@ export interface ModelSearchItem { name: string; description?: string; }; + model_config?: { + all_config?: string; + embedding_dimension: number; + framework_type: string; + model_type: string; + }; } export interface ModelDetail extends ModelSearchItem { diff --git a/public/components/model_drawer/index.tsx b/public/components/model_drawer/index.tsx index e6b3c4a7..106b146f 100644 --- a/public/components/model_drawer/index.tsx +++ b/public/components/model_drawer/index.tsx @@ -14,6 +14,7 @@ import { EuiFlexItem, EuiDescriptionList, } from '@elastic/eui'; +import { generatePath } from 'react-router-dom'; import { APIProvider } from '../../apis/api_provider'; import { useFetcher } from '../../hooks/use_fetcher'; import { routerPaths } from '../../../common/router_paths'; @@ -39,7 +40,7 @@ export const ModelDrawer = ({ onClose, name }: Props) => { // TODO: currently assume that api will return versions in order if (model?.data) { const data = model.data; - return data[data.length - 1]; + return data[0]; } return { id: '' }; }, [model]); @@ -73,9 +74,14 @@ export const ModelDrawer = ({ onClose, name }: Props) => { - - Register new version - + {latestVersion?.id && ( + + Register new version + + )} diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 7ac40b71..c4cf18c0 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -6,6 +6,13 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; +jest.mock('react-router-dom', () => ({ + ...jest.requireActual('react-router-dom'), + useParams: () => ({ + id: '', + }), +})); + describe(' Artifact', () => { it('should render an artifact panel', async () => { const onSubmitMock = jest.fn(); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index eb010b65..6d6e4c1c 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -5,6 +5,13 @@ import { setup } from './setup'; +jest.mock('react-router-dom', () => ({ + ...jest.requireActual('react-router-dom'), + useParams: () => ({ + id: '', + }), +})); + describe(' Details', () => { it('should render a model details panel', async () => { const onSubmitMock = jest.fn(); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx new file mode 100644 index 00000000..0eba41e8 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { render, screen, waitFor } from '../../../../test/test_utils'; +import { RegisterModelForm } from '../register_model'; +import { APIProvider } from '../../../apis/api_provider'; + +jest.mock('react-router-dom', () => ({ + ...jest.requireActual('react-router-dom'), + useParams: () => ({ + id: 'test_model_id', + }), +})); + +describe(' Form', () => { + it('should init form when id param in url route', async () => { + const request = jest.spyOn(APIProvider.getAPI('model'), 'search'); + const mockResult = { + data: [ + { + id: 'C7jN0YQBjgpeQQ_RmiDE', + model_version: '1.0.7', + created_time: 1669967223491, + model_config: { + all_config: + '{"_name_or_path":"nreimers/MiniLM-L6-H384-uncased","architectures":["BertModel"],"attention_probs_dropout_prob":0.1,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":384,"initializer_range":0.02,"intermediate_size":1536,"layer_norm_eps":1e-12,"max_position_embeddings":512,"model_type":"bert","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":0,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":2,"use_cache":true,"vocab_size":30522}', + model_type: 'bert', + embedding_dimension: 384, + framework_type: 'SENTENCE_TRANSFORMERS', + }, + last_loaded_time: 1672895017422, + model_format: 'TORCH_SCRIPT', + last_uploaded_time: 1669967226531, + name: 'all-MiniLM-L6-v2', + model_state: 'LOADED', + total_chunks: 9, + model_content_size_in_bytes: 83408741, + algorithm: 'TEXT_EMBEDDING', + model_content_hash_value: + '9376c2ebd7c83f99ec2526323786c348d2382e6d86576f750c89ea544d6bbb14', + current_worker_node_count: 1, + planning_worker_node_count: 1, + }, + ], + pagination: { currentPage: 1, pageSize: 1, totalRecords: 1, totalPages: 1 }, + }; + request.mockResolvedValue(mockResult); + render(); + + const { name } = mockResult.data[0]; + + await waitFor(() => { + const nameInput = screen.getByLabelText(/model name/i); + expect(nameInput.value).toBe(name); + }); + }); +}); diff --git a/public/components/register_model/__tests__/register_model_metrics.test.tsx b/public/components/register_model/__tests__/register_model_metrics.test.tsx index 1314de4e..0fb11b3c 100644 --- a/public/components/register_model/__tests__/register_model_metrics.test.tsx +++ b/public/components/register_model/__tests__/register_model_metrics.test.tsx @@ -7,6 +7,13 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; +jest.mock('react-router-dom', () => ({ + ...jest.requireActual('react-router-dom'), + useParams: () => ({ + id: '', + }), +})); + describe(' Evaluation Metrics', () => { beforeEach(() => { jest diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index 7ccc22f7..dea20588 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -7,6 +7,13 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; +jest.mock('react-router-dom', () => ({ + ...jest.requireActual('react-router-dom'), + useParams: () => ({ + id: '', + }), +})); + describe(' Tags', () => { beforeEach(() => { jest diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index d44818b6..e192b32f 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -3,23 +3,26 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback } from 'react'; +import React, { useCallback, useEffect } from 'react'; import { FieldErrors, useForm } from 'react-hook-form'; import { EuiPageHeader, EuiSpacer, EuiForm, EuiButton } from '@elastic/eui'; - +import { useParams } from 'react-router-dom'; import { ModelDetailsPanel } from './model_details'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { ArtifactPanel } from './artifact'; import { ConfigurationPanel } from './model_configuration'; import { EvaluationMetricsPanel } from './evaluation_metrics'; import { ModelTagsPanel } from './model_tags'; +import { APIProvider } from '../../apis/api_provider'; +import { upgradeModelVersion } from '../../utils'; export interface RegisterModelFormProps { onSubmit?: (data: ModelFileFormData | ModelUrlFormData) => void; } export const RegisterModelForm = (props: RegisterModelFormProps) => { - const { handleSubmit, control } = useForm({ + const { id: latestVersioinId } = useParams<{ id: string | undefined }>(); + const { handleSubmit, control, setValue } = useForm({ defaultValues: { name: '', description: '', @@ -38,6 +41,26 @@ export const RegisterModelForm = (props: RegisterModelFormProps) => { console.log(data); }; + useEffect(() => { + if (!latestVersioinId) return; + const initializeForm = async () => { + const { data } = await APIProvider.getAPI('model').search({ + ids: [latestVersioinId], + from: 0, + size: 1, + }); + if (data?.[0]) { + // TODO: clarify which fields to pre-populate + const { model_version: modelVersion, name, model_config: modelConfig } = data?.[0]; + const newVersion = upgradeModelVersion(modelVersion); + setValue('name', name); + setValue('version', newVersion); + setValue('configuration', modelConfig?.all_config ?? ''); + } + }; + initializeForm(); + }, [latestVersioinId, setValue]); + const onError = useCallback((errors: FieldErrors) => { // TODO // eslint-disable-next-line no-console diff --git a/public/utils/index.ts b/public/utils/index.ts index 636354ea..44ae52e4 100644 --- a/public/utils/index.ts +++ b/public/utils/index.ts @@ -4,3 +4,4 @@ */ export * from './table'; +export * from './version'; diff --git a/public/utils/version.ts b/public/utils/version.ts new file mode 100644 index 00000000..98cb2c15 --- /dev/null +++ b/public/utils/version.ts @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const upgradeModelVersion = (version: string) => { + // TODO:determine whether BE version follows semver + // const num = Number(version.split('.').reduce((prev, i) => prev + i, '')); + // return String(num + 1) + // .split('') + // .toString() + // .replaceAll(',', '.'); + + // TODO:determine whether user can input version + return version; +}; From 110dc6466db79fbbb3ad15c455d8bedc7519b7e2 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Wed, 8 Feb 2023 11:04:33 +0800 Subject: [PATCH 06/75] Feature/add model upload logic (#83) * feat: add task getOne API Signed-off-by: Lin Wang * feat: update model upload API Signed-off-by: Lin Wang * feat: add model upload hook Signed-off-by: Lin Wang * feat: add model upload hook to model register form Signed-off-by: Lin Wang * feat: run workflow when branch name like feature/* Signed-off-by: Lin Wang * ci: update to recursive match run workflow Signed-off-by: Lin Wang * fix: es-lint error fixup Signed-off-by: Lin Wang * test: add model task API mock for model register form Signed-off-by: Lin Wang * chore: update model upload modelConfig type Signed-off-by: Lin Wang * refactor: remove isUploadByURL and change to 'modelURL' in model Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .github/workflows/lint-workflow.yml | 4 +- .github/workflows/unit-tests-workflow.yml | 4 +- public/apis/__mocks__/model.ts | 10 + public/apis/__mocks__/task.ts | 10 + public/apis/api_provider.ts | 9 + public/apis/model.ts | 11 +- public/apis/task.ts | 25 +++ public/components/model_drawer/index.tsx | 3 +- .../get_model_content_hash_value.test.ts | 14 ++ .../__tests__/register_model.hooks.test.ts | 188 ++++++++++++++++++ .../register_model/__tests__/setup.tsx | 3 + .../get_model_content_hash_value.ts | 32 +++ .../register_model/register_model.hooks.ts | 81 +++++++- .../register_model/register_model.tsx | 30 ++- server/plugin.ts | 4 +- server/routes/constants.ts | 2 + server/routes/index.ts | 1 + server/routes/model_router.ts | 7 +- server/routes/task_router.ts | 33 +++ server/services/index.ts | 1 + server/services/model_service.ts | 29 +-- server/services/task_service.ts | 33 +++ server/services/utils/constants.ts | 2 + server/services/utils/model.ts | 33 --- 24 files changed, 487 insertions(+), 82 deletions(-) create mode 100644 public/apis/__mocks__/task.ts create mode 100644 public/apis/task.ts create mode 100644 public/components/register_model/__tests__/get_model_content_hash_value.test.ts create mode 100644 public/components/register_model/__tests__/register_model.hooks.test.ts create mode 100644 public/components/register_model/get_model_content_hash_value.ts create mode 100644 server/routes/task_router.ts create mode 100644 server/services/task_service.ts diff --git a/.github/workflows/lint-workflow.yml b/.github/workflows/lint-workflow.yml index 2d71f79c..99b36274 100644 --- a/.github/workflows/lint-workflow.yml +++ b/.github/workflows/lint-workflow.yml @@ -2,10 +2,10 @@ name: Lint workflow on: push: branches: - - '*' + - '**' pull_request: branches: - - '*' + - '**' env: OPENSEARCH_DASHBOARDS_VERSION: 'main' jobs: diff --git a/.github/workflows/unit-tests-workflow.yml b/.github/workflows/unit-tests-workflow.yml index f28c00b8..13f2a3c7 100644 --- a/.github/workflows/unit-tests-workflow.yml +++ b/.github/workflows/unit-tests-workflow.yml @@ -2,10 +2,10 @@ name: Unit tests workflow on: push: branches: - - '*' + - '**' pull_request: branches: - - '*' + - '**' env: OPENSEARCH_DASHBOARDS_VERSION: 'main' jobs: diff --git a/public/apis/__mocks__/model.ts b/public/apis/__mocks__/model.ts index d206d888..1435b938 100644 --- a/public/apis/__mocks__/model.ts +++ b/public/apis/__mocks__/model.ts @@ -18,4 +18,14 @@ export class Model { total_models: 1, }); } + + public upload({ url }: { url?: string }) { + return Promise.resolve( + url === undefined ? { model_id: 'model-id-1' } : { task_id: 'task-id-1' } + ); + } + + public uploadChunk() { + return Promise.resolve(); + } } diff --git a/public/apis/__mocks__/task.ts b/public/apis/__mocks__/task.ts new file mode 100644 index 00000000..becf33a4 --- /dev/null +++ b/public/apis/__mocks__/task.ts @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export class Task { + public getOne() { + return Promise.resolve({}); + } +} diff --git a/public/apis/api_provider.ts b/public/apis/api_provider.ts index a5a7d1b3..dfc9bbe6 100644 --- a/public/apis/api_provider.ts +++ b/public/apis/api_provider.ts @@ -8,6 +8,7 @@ import { Model } from './model'; import { ModelAggregate } from './model_aggregate'; import { Profile } from './profile'; import { Security } from './security'; +import { Task } from './task'; const apiInstanceStore: { model: Model | undefined; @@ -15,15 +16,18 @@ const apiInstanceStore: { profile: Profile | undefined; connector: Connector | undefined; security: Security | undefined; + task: Task | undefined; } = { model: undefined, modelAggregate: undefined, profile: undefined, connector: undefined, security: undefined, + task: undefined, }; export class APIProvider { + public static getAPI(type: 'task'): Task; public static getAPI(type: 'model'): Model; public static getAPI(type: 'modelAggregate'): ModelAggregate; public static getAPI(type: 'profile'): Profile; @@ -59,6 +63,11 @@ export class APIProvider { apiInstanceStore.security = newInstance; return newInstance; } + case 'task': { + const newInstance = new Task(); + apiInstanceStore.task = newInstance; + return newInstance; + } } } public static clear() { diff --git a/public/apis/model.ts b/public/apis/model.ts index 36748370..6ea9be65 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -76,11 +76,7 @@ interface UploadModelBase { version: string; description: string; modelFormat: string; - modelConfig: { - modelType: string; - embeddingDimension: number; - frameworkType: string; - }; + modelConfig: Record; } export interface UploadModelByURL extends UploadModelBase { @@ -88,7 +84,6 @@ export interface UploadModelByURL extends UploadModelBase { } export interface UploadModelByChunk extends UploadModelBase { - modelTaskType: string; modelContentHashValue: string; totalChunks: number; } @@ -143,9 +138,9 @@ export class Model { model: T ): Promise< T extends UploadModelByURL - ? { taskId: string } + ? { task_id: string } : T extends UploadModelByChunk - ? { modelId: string } + ? { model_id: string } : never > { return InnerHttpProvider.getHttp().post(MODEL_UPLOAD_API_ENDPOINT, { diff --git a/public/apis/task.ts b/public/apis/task.ts new file mode 100644 index 00000000..e0a42014 --- /dev/null +++ b/public/apis/task.ts @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { TASK_API_ENDPOINT } from '../../server/routes/constants'; +import { InnerHttpProvider } from './inner_http_provider'; + +export interface TaskGetOneResponse { + error?: string; + last_update_time: number; + create_time: number; + is_async: boolean; + function_name: string; + state: string; + model_id?: string; + task_type: string; + worker_node: string[]; +} + +export class Task { + public getOne(taskId: string) { + return InnerHttpProvider.getHttp().get(`${TASK_API_ENDPOINT}/${taskId}`); + } +} diff --git a/public/components/model_drawer/index.tsx b/public/components/model_drawer/index.tsx index 106b146f..4383fe0b 100644 --- a/public/components/model_drawer/index.tsx +++ b/public/components/model_drawer/index.tsx @@ -18,9 +18,10 @@ import { generatePath } from 'react-router-dom'; import { APIProvider } from '../../apis/api_provider'; import { useFetcher } from '../../hooks/use_fetcher'; import { routerPaths } from '../../../common/router_paths'; -import { VersionTable } from './version_table'; import { EuiLinkButton } from '../common'; +import { VersionTable } from './version_table'; + export type VersionTableSort = 'version-desc' | 'version-asc'; interface Props { diff --git a/public/components/register_model/__tests__/get_model_content_hash_value.test.ts b/public/components/register_model/__tests__/get_model_content_hash_value.test.ts new file mode 100644 index 00000000..b4e27cbb --- /dev/null +++ b/public/components/register_model/__tests__/get_model_content_hash_value.test.ts @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { getModelContentHashValue } from '../get_model_content_hash_value'; + +describe('getModelContentHashValue', () => { + it('should return consistent sha256 value', async () => { + expect(await getModelContentHashValue(new Blob(new Array(10000).fill(1)))).toBe( + 'd25f01257c9890622d78cf9cf3362457ef75712bd187ae33b383f80d618d0f06' + ); + }); +}); diff --git a/public/components/register_model/__tests__/register_model.hooks.test.ts b/public/components/register_model/__tests__/register_model.hooks.test.ts new file mode 100644 index 00000000..9493eac5 --- /dev/null +++ b/public/components/register_model/__tests__/register_model.hooks.test.ts @@ -0,0 +1,188 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { renderHook } from '@testing-library/react-hooks'; +import { Task, TaskGetOneResponse } from '../../../../public/apis/task'; +import { Model } from '../../../../public/apis/model'; +import * as getModelContentHashValueExports from '../get_model_content_hash_value'; + +import { useModelUpload } from '../register_model.hooks'; + +const modelBaseData = { + name: 'foo', + version: '1', + description: 'foo bar', + annotations: '', + configuration: `{ + "foo":"bar" + }`, +}; + +const modelUrlFormData = { + ...modelBaseData, + modelURL: 'http://localhost/', +}; + +const modelFileFormData = { + ...modelBaseData, + modelFile: new File(new Array(10000).fill(1), 'test-model'), +}; +// Make file size to 30MB, so we can split 3 chunks to upload. +Object.defineProperty(modelFileFormData.modelFile, 'size', { + get() { + return 3 * 10000000; + }, +}); + +describe('useModelUpload', () => { + describe('upload by url', () => { + beforeEach(() => { + jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ task_id: 'task-id-1' }); + jest + .spyOn(Task.prototype, 'getOne') + .mockResolvedValue({ model_id: 'new-model-1' } as TaskGetOneResponse); + }); + + afterEach(() => { + jest.spyOn(Model.prototype, 'upload').mockClear(); + jest.spyOn(Task.prototype, 'getOne').mockClear(); + }); + + it('should call model upload with consistent params', () => { + const { result } = renderHook(() => useModelUpload()); + const modelUploadMock = jest.spyOn(Model.prototype, 'upload'); + expect(modelUploadMock).not.toHaveBeenCalled(); + + result.current(modelUrlFormData); + + expect(modelUploadMock).toHaveBeenCalledWith({ + name: 'foo', + version: '1', + description: 'foo bar', + modelFormat: 'TORCH_SCRIPT', + modelConfig: { + foo: 'bar', + }, + url: 'http://localhost/', + }); + }); + + it('should call get task cycling and resolved with model id when upload by url', async () => { + jest.useFakeTimers(); + const { result } = renderHook(() => useModelUpload()); + + const taskGetOneMock = jest + .spyOn(Task.prototype, 'getOne') + .mockResolvedValueOnce({} as TaskGetOneResponse); + expect(taskGetOneMock).not.toHaveBeenCalled(); + + const uploadPromise = result.current(modelUrlFormData); + + await jest.spyOn(Model.prototype, 'upload').mock.results[0].value; + expect(taskGetOneMock).toHaveBeenCalledWith('task-id-1'); + + await taskGetOneMock.mock.results[0].value; + taskGetOneMock.mockResolvedValueOnce({ model_id: 'new-model-1' } as TaskGetOneResponse); + jest.advanceTimersByTime(1000); + + expect(taskGetOneMock).toHaveBeenCalledTimes(2); + expect(await uploadPromise).toBe('new-model-1'); + + jest.useRealTimers(); + }); + + it('should NOT call get task if component unmount', () => { + const { result, unmount } = renderHook(() => useModelUpload()); + let uploadAPIResolveFn: Function; + const uploadAPIPromise = new Promise<{ task_id: string }>((resolve, reject) => { + uploadAPIResolveFn = () => { + resolve({ task_id: 'task-id-1' }); + }; + }); + jest.spyOn(Model.prototype, 'upload').mockReturnValue(uploadAPIPromise); + + const uploadPromise = result.current(modelUrlFormData); + unmount(); + uploadAPIResolveFn!(); + + expect(jest.spyOn(Task.prototype, 'getOne')).not.toHaveBeenCalled(); + + expect(uploadPromise).rejects.toMatch('component unmounted'); + }); + + it('should NOT cycling call get task after component unmount', async () => { + jest.useFakeTimers(); + const { result, unmount } = renderHook(() => useModelUpload()); + + const taskGetOneMock = jest + .spyOn(Task.prototype, 'getOne') + .mockResolvedValue({} as TaskGetOneResponse); + expect(taskGetOneMock).not.toHaveBeenCalled(); + + const uploadPromise = result.current(modelUrlFormData); + + await jest.spyOn(Model.prototype, 'upload').mock.results[0].value; + + await taskGetOneMock.mock.results[0].value; + expect(taskGetOneMock).toHaveBeenCalledTimes(1); + unmount(); + + jest.advanceTimersByTime(1000); + + expect(taskGetOneMock).toHaveBeenCalledTimes(1); + expect(uploadPromise).rejects.toMatch('component unmounted'); + + jest.useRealTimers(); + }); + }); + + describe('upload by file', () => { + beforeEach(() => { + jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ model_id: 'model-id-1' }); + jest.spyOn(Model.prototype, 'uploadChunk').mockResolvedValue({}); + jest + .spyOn(getModelContentHashValueExports, 'getModelContentHashValue') + .mockResolvedValue('file-hash'); + }); + + afterEach(() => { + jest.spyOn(Model.prototype, 'upload').mockClear(); + jest.spyOn(Model.prototype, 'uploadChunk').mockClear(); + jest.spyOn(getModelContentHashValueExports, 'getModelContentHashValue').mockClear(); + }); + + it('should call model upload with consistent params', async () => { + const { result } = renderHook(() => useModelUpload()); + + expect(jest.spyOn(Model.prototype, 'upload')).not.toHaveBeenCalled(); + + result.current(modelFileFormData); + + await jest.spyOn(getModelContentHashValueExports, 'getModelContentHashValue').mock.results[0] + .value; + await jest.spyOn(Model.prototype, 'upload').mock.results[0].value; + + expect(jest.spyOn(Model.prototype, 'upload')).toHaveBeenCalledWith({ + name: 'foo', + version: '1', + description: 'foo bar', + modelFormat: 'TORCH_SCRIPT', + modelConfig: { + foo: 'bar', + }, + modelContentHashValue: 'file-hash', + totalChunks: 3, + }); + }); + + it('should call model uploadChunk for 3 times', async () => { + const { result } = renderHook(() => useModelUpload()); + + await result.current(modelFileFormData); + + expect(jest.spyOn(Model.prototype, 'uploadChunk')).toHaveBeenCalledTimes(3); + }); + }); +}); diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index 806fb5a1..ab0f5cb9 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -10,6 +10,9 @@ import { RegisterModelForm } from '../register_model'; import type { RegisterModelFormProps } from '../register_model'; import { render, screen } from '../../../../test/test_utils'; +jest.mock('../../../apis/model'); +jest.mock('../../../apis/task'); + export async function setup({ onSubmit }: RegisterModelFormProps) { render(); const nameInput = screen.getByLabelText(/model name/i); diff --git a/public/components/register_model/get_model_content_hash_value.ts b/public/components/register_model/get_model_content_hash_value.ts new file mode 100644 index 00000000..d0e33db5 --- /dev/null +++ b/public/components/register_model/get_model_content_hash_value.ts @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { createSHA256 } from 'hash-wasm'; + +const readChunkFile = (chunk: Blob) => { + return new Promise((resolve) => { + const fileReader = new FileReader(); + fileReader.onload = async (e) => { + if (e.target?.result) { + resolve(new Uint8Array(e.target.result as ArrayBuffer)); + } + }; + fileReader.readAsArrayBuffer(chunk); + }); +}; + +export const getModelContentHashValue = async (file: Blob) => { + const hasher = await createSHA256(); + const chunkSize = 64 * 1024 * 1024; + + const chunkNumber = Math.floor(file.size / chunkSize); + + for (let i = 0; i <= chunkNumber; i++) { + const chunk = file.slice(chunkSize * i, Math.min(chunkSize * (i + 1), file.size)); + hasher.update(await readChunkFile(chunk)); + } + + return hasher.digest(); +}; diff --git a/public/components/register_model/register_model.hooks.ts b/public/components/register_model/register_model.hooks.ts index 487b2510..91c4608e 100644 --- a/public/components/register_model/register_model.hooks.ts +++ b/public/components/register_model/register_model.hooks.ts @@ -3,7 +3,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useEffect, useState } from 'react'; +import { useCallback, useEffect, useRef, useState } from 'react'; +import { APIProvider } from '../../apis/api_provider'; +import { getModelContentHashValue } from './get_model_content_hash_value'; +import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; const metricNames = ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']; @@ -47,3 +50,79 @@ export const useModelTags = () => { return [loading, { keys, values }] as const; }; + +export const useModelUpload = () => { + const timeoutIdRef = useRef(-1); + const mountedRef = useRef(true); + + useEffect(() => { + return () => { + mountedRef.current = false; + window.clearTimeout(timeoutIdRef.current); + }; + }, []); + + return useCallback(async (model: ModelFileFormData | ModelUrlFormData) => { + const modelUploadBase = { + name: model.name, + version: model.version, + description: model.description, + // TODO: Need to confirm if we have the model format input + modelFormat: 'TORCH_SCRIPT', + modelConfig: JSON.parse(model.configuration), + }; + if ('modelURL' in model) { + const { task_id: taskId } = await APIProvider.getAPI('model').upload({ + ...modelUploadBase, + url: model.modelURL, + }); + return new Promise((resolve, reject) => { + const refreshTaskStatus = () => { + APIProvider.getAPI('task') + .getOne(taskId) + .then(({ model_id: modelId, error }) => { + if (error) { + reject(error); + return; + } + if (modelId === undefined) { + if (!mountedRef.current) { + reject('component unmounted'); + return; + } + timeoutIdRef.current = window.setTimeout(refreshTaskStatus, 1000); + return; + } + resolve(modelId); + }); + }; + if (!mountedRef.current) { + reject('component unmounted'); + return; + } + refreshTaskStatus(); + }); + } + const { modelFile } = model; + const MAX_CHUNK_SIZE = 10 * 1000 * 1000; + const totalChunks = Math.ceil(modelFile.size / MAX_CHUNK_SIZE); + const modelContentHashValue = await getModelContentHashValue(modelFile); + + const modelId = ( + await APIProvider.getAPI('model').upload({ + ...modelUploadBase, + totalChunks, + modelContentHashValue, + }) + ).model_id; + + for (let i = 0; i < totalChunks; i++) { + const chunk = modelFile.slice( + MAX_CHUNK_SIZE * i, + Math.min(MAX_CHUNK_SIZE * (i + 1), modelFile.size) + ); + await APIProvider.getAPI('model').uploadChunk(modelId, `${i}`, chunk); + } + return modelId; + }, []); +}; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index e192b32f..eb0887c9 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -7,22 +7,27 @@ import React, { useCallback, useEffect } from 'react'; import { FieldErrors, useForm } from 'react-hook-form'; import { EuiPageHeader, EuiSpacer, EuiForm, EuiButton } from '@elastic/eui'; import { useParams } from 'react-router-dom'; + +import { APIProvider } from '../../apis/api_provider'; +import { upgradeModelVersion } from '../../utils'; + import { ModelDetailsPanel } from './model_details'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { ArtifactPanel } from './artifact'; import { ConfigurationPanel } from './model_configuration'; import { EvaluationMetricsPanel } from './evaluation_metrics'; import { ModelTagsPanel } from './model_tags'; -import { APIProvider } from '../../apis/api_provider'; -import { upgradeModelVersion } from '../../utils'; +import { useModelUpload } from './register_model.hooks'; export interface RegisterModelFormProps { onSubmit?: (data: ModelFileFormData | ModelUrlFormData) => void; } export const RegisterModelForm = (props: RegisterModelFormProps) => { - const { id: latestVersioinId } = useParams<{ id: string | undefined }>(); - const { handleSubmit, control, setValue } = useForm({ + const { id: latestVersionId } = useParams<{ id: string | undefined }>(); + const { handleSubmit, control, setValue, formState } = useForm< + ModelFileFormData | ModelUrlFormData + >({ defaultValues: { name: '', description: '', @@ -31,21 +36,23 @@ export const RegisterModelForm = (props: RegisterModelFormProps) => { tags: [{ key: '', value: '' }], }, }); + const submitModel = useModelUpload(); - const onSubmit = (data: ModelFileFormData | ModelUrlFormData) => { + const onSubmit = async (data: ModelFileFormData | ModelUrlFormData) => { if (props.onSubmit) { props.onSubmit(data); } + await submitModel(data); // TODO // eslint-disable-next-line no-console console.log(data); }; useEffect(() => { - if (!latestVersioinId) return; + if (!latestVersionId) return; const initializeForm = async () => { const { data } = await APIProvider.getAPI('model').search({ - ids: [latestVersioinId], + ids: [latestVersionId], from: 0, size: 1, }); @@ -59,7 +66,7 @@ export const RegisterModelForm = (props: RegisterModelFormProps) => { } }; initializeForm(); - }, [latestVersioinId, setValue]); + }, [latestVersionId, setValue]); const onError = useCallback((errors: FieldErrors) => { // TODO @@ -85,7 +92,12 @@ export const RegisterModelForm = (props: RegisterModelFormProps) => { - + Register model diff --git a/server/plugin.ts b/server/plugin.ts index e106437b..352eabde 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -16,9 +16,10 @@ import { MlCommonsPluginSetup, MlCommonsPluginStart } from './types'; import { connectorRouter, modelRouter, - profileRouter, modelAggregateRouter, + profileRouter, securityRouter, + taskRouter, } from './routes'; import { ModelService } from './services'; @@ -46,6 +47,7 @@ export class MlCommonsPlugin implements Plugin { + router.get( + { + path: `${TASK_API_ENDPOINT}/{taskId}`, + validate: { + params: schema.object({ + taskId: schema.string(), + }), + }, + }, + async (context, request) => { + try { + const body = await TaskService.getOne({ + client: context.core.opensearch.client, + taskId: request.params.taskId, + }); + return opensearchDashboardsResponseFactory.ok({ body }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); +}; diff --git a/server/services/index.ts b/server/services/index.ts index 249e61ea..865e2b7e 100644 --- a/server/services/index.ts +++ b/server/services/index.ts @@ -4,4 +4,5 @@ */ export { ModelService } from './model_service'; +export { TaskService } from './task_service'; export { RecordNotFoundError } from './errors'; diff --git a/server/services/model_service.ts b/server/services/model_service.ts index ba172bbd..f757c3fd 100644 --- a/server/services/model_service.ts +++ b/server/services/model_service.ts @@ -41,11 +41,7 @@ interface UploadModelBase { version: string; description: string; modelFormat: string; - modelConfig: { - modelType: string; - embeddingDimension: number; - frameworkType: string; - }; + modelConfig: Record; } interface UploadModelByURL extends UploadModelBase { @@ -53,7 +49,6 @@ interface UploadModelByURL extends UploadModelBase { } interface UploadModelByChunk extends UploadModelBase { - modelTaskType: string; modelContentHashValue: string; totalChunks: number; } @@ -61,15 +56,16 @@ interface UploadModelByChunk extends UploadModelBase { type UploadResultInner< T extends UploadModelByURL | UploadModelByChunk > = T extends UploadModelByChunk - ? { modelId: string; status: string } + ? { model_id: string; status: string } : T extends UploadModelByURL - ? { taskId: string; status: string } + ? { task_id: string; status: string } : never; type UploadResult = Promise>; -const isUploaModelByURL = (test: UploadModelByURL | UploadModelByChunk): test is UploadModelByURL => - (test as UploadModelByURL).url !== undefined; +const isUploadModelByURL = ( + test: UploadModelByURL | UploadModelByChunk +): test is UploadModelByURL => (test as UploadModelByURL).url !== undefined; export class ModelService { private osClient: ILegacyClusterClient; @@ -187,13 +183,9 @@ export class ModelService { version, description, model_format: modelFormat, - model_config: { - model_type: modelConfig.modelType, - embedding_dimension: modelConfig.embeddingDimension, - framework_type: modelConfig.frameworkType, - }, + model_config: modelConfig, }; - if (isUploaModelByURL(model)) { + if (isUploadModelByURL(model)) { const { task_id: taskId, status } = ( await client.asCurrentUser.transport.request({ method: 'POST', @@ -204,7 +196,7 @@ export class ModelService { }, }) ).body; - return { taskId, status } as UploadResultInner; + return { task_id: taskId, status } as UploadResultInner; } const { model_id: modelId, status } = ( @@ -213,13 +205,12 @@ export class ModelService { path: MODEL_META_API, body: { ...uploadModelBase, - model_task_type: model.modelTaskType, model_content_hash_value: model.modelContentHashValue, total_chunks: model.totalChunks, }, }) ).body; - return { modelId, status } as UploadResultInner; + return { model_id: modelId, status } as UploadResultInner; } public static async uploadModelChunk({ diff --git a/server/services/task_service.ts b/server/services/task_service.ts new file mode 100644 index 00000000..45476da9 --- /dev/null +++ b/server/services/task_service.ts @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Copyright OpenSearch Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +import { IScopedClusterClient } from '../../../../src/core/server'; +import { TASK_BASE_API } from './utils/constants'; + +export class TaskService { + public static async getOne({ client, taskId }: { client: IScopedClusterClient; taskId: string }) { + return ( + await client.asCurrentUser.transport.request({ + method: 'GET', + path: `${TASK_BASE_API}/${taskId}`, + }) + ).body; + } +} diff --git a/server/services/utils/constants.ts b/server/services/utils/constants.ts index bc7deab9..847d0ead 100644 --- a/server/services/utils/constants.ts +++ b/server/services/utils/constants.ts @@ -30,6 +30,8 @@ export const PREDICT_BASE_API = `${ML_COMMONS_API_PREFIX}/_predict`; export const SECURITY_API_PREFIX = '/_plugins/_security/api'; export const SECURITY_ACCOUNT_API = `${SECURITY_API_PREFIX}/account`; +export const TASK_BASE_API = `${ML_COMMONS_API_PREFIX}/tasks`; + export const CLUSTER = { TRAIN: 'opensearch_mlCommonsTrain', MODEL: 'opensearch_mlCommonsModel', diff --git a/server/services/utils/model.ts b/server/services/utils/model.ts index 0643f97a..708fdfc6 100644 --- a/server/services/utils/model.ts +++ b/server/services/utils/model.ts @@ -74,36 +74,3 @@ export const generateModelSearchQuery = ({ }, }, }); - -export interface UploadModel { - name: string; - version: string; - description: string; - modelFormat: string; - modelConfig: { - modelType: string; - membeddingDimension: number; - frameworkType: string; - }; - url?: string; -} - -export const convertUploadModel = ({ - name, - version, - description, - modelFormat, - modelConfig, - url, -}: UploadModel) => ({ - name, - version, - description, - model_format: modelFormat, - model_config: { - model_type: modelConfig.modelType, - embedding_dimension: modelConfig.membeddingDimension, - framework_type: modelConfig.frameworkType, - }, - url, -}); From f8c0f5703adfb1a423f1c97be35bc7a7d0be9391 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Wed, 8 Feb 2023 12:16:36 +0800 Subject: [PATCH 07/75] feat: update register model form ui according to the new design (#85) Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- common/router.ts | 2 +- public/apis/inner_http_provider.ts | 9 +- .../register_model_artifact.test.tsx | 41 +++-- .../__tests__/register_model_details.test.tsx | 59 ++++++-- .../__tests__/register_model_form.test.tsx | 19 +-- .../__tests__/register_model_metrics.test.tsx | 55 +++---- .../__tests__/register_model_tags.test.tsx | 7 - .../register_model/__tests__/setup.tsx | 37 ++--- public/components/register_model/artifact.tsx | 62 ++++---- .../register_model/artifact_file.tsx | 25 +++- .../register_model/artifact_url.tsx | 9 +- .../register_model/evaluation_metrics.tsx | 34 +++-- .../register_model/model_configuration.tsx | 29 ++-- .../register_model/model_details.tsx | 141 +++++++++--------- .../components/register_model/model_tags.tsx | 9 +- .../register_model/register_model.tsx | 84 +++++++---- .../register_model/register_model.types.ts | 2 +- public/components/register_model/utils.ts | 12 ++ public/hooks/use_search_params.ts | 12 ++ test/test_utils.tsx | 38 ++++- 20 files changed, 412 insertions(+), 274 deletions(-) create mode 100644 public/components/register_model/utils.ts create mode 100644 public/hooks/use_search_params.ts diff --git a/common/router.ts b/common/router.ts index 7e87ac0f..c16ef73a 100644 --- a/common/router.ts +++ b/common/router.ts @@ -5,7 +5,7 @@ import { ModelList } from '../public/components/model_list'; import { Monitoring } from '../public/components/monitoring'; -import { RegisterModelForm } from '../public/components/register_model'; +import { RegisterModelForm } from '../public/components/register_model/register_model'; import { routerPaths } from './router_paths'; interface RouteConfig { diff --git a/public/apis/inner_http_provider.ts b/public/apis/inner_http_provider.ts index 915538d1..a2a4060f 100644 --- a/public/apis/inner_http_provider.ts +++ b/public/apis/inner_http_provider.ts @@ -5,14 +5,17 @@ import { CoreStart } from '../../../../src/core/public'; -let httpSotre: CoreStart['http'] | undefined; +let httpClient: CoreStart['http'] | undefined; export class InnerHttpProvider { public static setHttp(http: CoreStart['http'] | undefined) { - httpSotre = http; + httpClient = http; } public static getHttp() { - return httpSotre!; + if (!httpClient) { + throw Error('Http Client not set'); + } + return httpClient; } } diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index c4cf18c0..2d405f73 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -6,22 +6,21 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; -jest.mock('react-router-dom', () => ({ - ...jest.requireActual('react-router-dom'), - useParams: () => ({ - id: '', - }), -})); - describe(' Artifact', () => { it('should render an artifact panel', async () => { const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); - expect(result.modelFileInput).toBeInTheDocument(); + await setup({ onSubmit: onSubmitMock }); + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInTheDocument(); expect(screen.getByLabelText(/from computer/i)).toBeInTheDocument(); expect(screen.getByLabelText(/from url/i)).toBeInTheDocument(); }); + it('should not render an artifact panel if importing an opensearch defined model', async () => { + const onSubmitMock = jest.fn(); + await setup({ onSubmit: onSubmitMock }, { route: '/?type=import' }); + expect(screen.queryByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeNull(); + }); + it('should submit the register model form', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); @@ -32,6 +31,24 @@ describe(' Artifact', () => { expect(onSubmitMock).toHaveBeenCalled(); }); + it('should NOT submit the register model form if model file size exceed 80MB', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + // Empty model file selection by clicking the `Remove` button on EuiFilePicker + await result.user.click(screen.getByLabelText(/clear selected files/i)); + await result.user.click(result.submitButton); + + const modelFileInput = screen.getByLabelText(/file/i); + // File size can not exceed 80MB + const invalidFile = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + Object.defineProperty(invalidFile, 'size', { value: 81 * 1000 * 1000 }); + await result.user.upload(modelFileInput, invalidFile); + + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInvalid(); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + it('should NOT submit the register model form if model file is empty', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); @@ -40,7 +57,7 @@ describe(' Artifact', () => { await result.user.click(screen.getByLabelText(/clear selected files/i)); await result.user.click(result.submitButton); - expect(result.modelFileInput).toBeInvalid(); + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInvalid(); expect(onSubmitMock).not.toHaveBeenCalled(); }); @@ -51,7 +68,9 @@ describe(' Artifact', () => { // select option: From URL await result.user.click(screen.getByLabelText(/from url/i)); - const urlInput = screen.getByLabelText(/model url/i); + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); // Empty URL input await result.user.clear(urlInput); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index 6d6e4c1c..fa5583aa 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -5,22 +5,11 @@ import { setup } from './setup'; -jest.mock('react-router-dom', () => ({ - ...jest.requireActual('react-router-dom'), - useParams: () => ({ - id: '', - }), -})); - describe(' Details', () => { it('should render a model details panel', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); expect(result.nameInput).toBeInTheDocument(); - // Model version is not editable - expect(result.versionInput).toBeDisabled(); - // Model Version should alway have a value - expect(result.versionInput.value).not.toBe(''); expect(result.descriptionInput).toBeInTheDocument(); expect(result.annotationsInput).toBeInTheDocument(); }); @@ -46,6 +35,22 @@ describe(' Details', () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); + it('should NOT submit the register model form if model name length exceeded 60', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + await result.user.clear(result.nameInput); + await result.user.type(result.nameInput, 'x'.repeat(60)); + expect(result.nameInput).toBeValid(); + + await result.user.clear(result.nameInput); + await result.user.type(result.nameInput, 'x'.repeat(61)); + expect(result.nameInput).toBeInvalid(); + + await result.user.click(result.submitButton); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + it('should NOT submit the register model form if model description is empty', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); @@ -56,4 +61,36 @@ describe(' Details', () => { expect(result.descriptionInput).toBeInvalid(); expect(onSubmitMock).not.toHaveBeenCalled(); }); + + it('should NOT submit the register model form if model description length exceed 200', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + await result.user.clear(result.descriptionInput); + await result.user.type(result.descriptionInput, 'x'.repeat(200)); + expect(result.descriptionInput).toBeValid(); + + await result.user.clear(result.descriptionInput); + await result.user.type(result.descriptionInput, 'x'.repeat(201)); + expect(result.descriptionInput).toBeInvalid(); + + await result.user.click(result.submitButton); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + + it('annotation text length should not exceed 200', async () => { + const onSubmitMock = jest.fn(); + const result = await setup({ onSubmit: onSubmitMock }); + + await result.user.clear(result.annotationsInput); + await result.user.type(result.annotationsInput, 'x'.repeat(200)); + expect(result.annotationsInput).toBeValid(); + + await result.user.clear(result.annotationsInput); + await result.user.type(result.annotationsInput, 'x'.repeat(201)); + expect(result.annotationsInput).toBeInvalid(); + + await result.user.click(result.submitButton); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); }); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 0eba41e8..6687c78a 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -4,16 +4,12 @@ */ import React from 'react'; +import { Route } from 'react-router-dom'; + import { render, screen, waitFor } from '../../../../test/test_utils'; import { RegisterModelForm } from '../register_model'; import { APIProvider } from '../../../apis/api_provider'; - -jest.mock('react-router-dom', () => ({ - ...jest.requireActual('react-router-dom'), - useParams: () => ({ - id: 'test_model_id', - }), -})); +import { routerPaths } from '../../../../common/router_paths'; describe(' Form', () => { it('should init form when id param in url route', async () => { @@ -48,12 +44,17 @@ describe(' Form', () => { pagination: { currentPage: 1, pageSize: 1, totalRecords: 1, totalPages: 1 }, }; request.mockResolvedValue(mockResult); - render(); + render( + + + , + { route: '/model-registry/register-model/test_model_id' } + ); const { name } = mockResult.data[0]; await waitFor(() => { - const nameInput = screen.getByLabelText(/model name/i); + const nameInput = screen.getByLabelText(/^name$/i); expect(nameInput.value).toBe(name); }); }); diff --git a/public/components/register_model/__tests__/register_model_metrics.test.tsx b/public/components/register_model/__tests__/register_model_metrics.test.tsx index 0fb11b3c..cea70198 100644 --- a/public/components/register_model/__tests__/register_model_metrics.test.tsx +++ b/public/components/register_model/__tests__/register_model_metrics.test.tsx @@ -7,13 +7,6 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; -jest.mock('react-router-dom', () => ({ - ...jest.requireActual('react-router-dom'), - useParams: () => ({ - id: '', - }), -})); - describe(' Evaluation Metrics', () => { beforeEach(() => { jest @@ -23,33 +16,33 @@ describe(' Evaluation Metrics', () => { it('should render a evaluation metrics panel', async () => { const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + await setup({ onSubmit: onSubmitMock }); - expect(result.metricNameInput).toBeInTheDocument(); - expect(result.trainingMetricValueInput).toBeInTheDocument(); - expect(result.validationMetricValueInput).toBeInTheDocument(); - expect(result.testingMetricValueInput).toBeInTheDocument(); + expect(screen.getByLabelText(/^metric$/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/training value/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/validation value/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/testing value/i)).toBeInTheDocument(); }); it('should render metric value input as disabled by default', async () => { const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + await setup({ onSubmit: onSubmitMock }); - expect(result.trainingMetricValueInput).toBeDisabled(); - expect(result.validationMetricValueInput).toBeDisabled(); - expect(result.testingMetricValueInput).toBeDisabled(); + expect(screen.getByLabelText(/training value/i)).toBeDisabled(); + expect(screen.getByLabelText(/validation value/i)).toBeDisabled(); + expect(screen.getByLabelText(/testing value/i)).toBeDisabled(); }); it('should render metric value input as enabled after selecting a metric name', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); - await result.user.click(result.metricNameInput); + await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); - expect(result.trainingMetricValueInput).toBeEnabled(); - expect(result.validationMetricValueInput).toBeEnabled(); - expect(result.testingMetricValueInput).toBeEnabled(); + expect(screen.getByLabelText(/training value/i)).toBeEnabled(); + expect(screen.getByLabelText(/validation value/i)).toBeEnabled(); + expect(screen.getByLabelText(/testing value/i)).toBeEnabled(); }); it('should submit the form without selecting metric name', async () => { @@ -63,7 +56,7 @@ describe(' Evaluation Metrics', () => { it('should submit the form if metric name is selected but metric value are empty', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); - await result.user.click(result.metricNameInput); + await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); await result.user.click(result.submitButton); @@ -73,12 +66,12 @@ describe(' Evaluation Metrics', () => { it('should submit the form if metric name and all metric value are selected', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); - await result.user.click(result.metricNameInput); + await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); - await result.user.type(result.trainingMetricValueInput, '1'); - await result.user.type(result.validationMetricValueInput, '1'); - await result.user.type(result.testingMetricValueInput, '1'); + await result.user.type(screen.getByLabelText(/training value/i), '1'); + await result.user.type(screen.getByLabelText(/validation value/i), '1'); + await result.user.type(screen.getByLabelText(/testing value/i), '1'); await result.user.click(result.submitButton); @@ -88,11 +81,11 @@ describe(' Evaluation Metrics', () => { it('should submit the form if metric name is selected but metric value are partially selected', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); - await result.user.click(result.metricNameInput); + await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); // Only input Training metric value - await result.user.type(result.trainingMetricValueInput, '1'); + await result.user.type(screen.getByLabelText(/training value/i), '1'); await result.user.click(result.submitButton); expect(onSubmitMock).toHaveBeenCalled(); @@ -101,11 +94,11 @@ describe(' Evaluation Metrics', () => { it('should NOT submit the form if metric value < 0', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); - await result.user.click(result.metricNameInput); + await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); // Type an invalid value - await result.user.type(result.trainingMetricValueInput, '-.1'); + await result.user.type(screen.getByLabelText(/training value/i), '-.1'); await result.user.click(result.submitButton); expect(onSubmitMock).not.toHaveBeenCalled(); @@ -114,11 +107,11 @@ describe(' Evaluation Metrics', () => { it('should NOT submit the form if metric value > 1', async () => { const onSubmitMock = jest.fn(); const result = await setup({ onSubmit: onSubmitMock }); - await result.user.click(result.metricNameInput); + await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); // Type an invalid value - await result.user.type(result.trainingMetricValueInput, '1.1'); + await result.user.type(screen.getByLabelText(/training value/i), '1.1'); await result.user.click(result.submitButton); expect(onSubmitMock).not.toHaveBeenCalled(); diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index dea20588..7ccc22f7 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -7,13 +7,6 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; -jest.mock('react-router-dom', () => ({ - ...jest.requireActual('react-router-dom'), - useParams: () => ({ - id: '', - }), -})); - describe(' Tags', () => { beforeEach(() => { jest diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index ab0f5cb9..db7ec12a 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -8,26 +8,20 @@ import userEvent from '@testing-library/user-event'; import { RegisterModelForm } from '../register_model'; import type { RegisterModelFormProps } from '../register_model'; -import { render, screen } from '../../../../test/test_utils'; +import { render, RenderWithRouteProps, screen } from '../../../../test/test_utils'; jest.mock('../../../apis/model'); jest.mock('../../../apis/task'); -export async function setup({ onSubmit }: RegisterModelFormProps) { - render(); - const nameInput = screen.getByLabelText(/model name/i); - const versionInput = screen.getByLabelText(/version/i); - const descriptionInput = screen.getByLabelText(/model description/i); - const annotationsInput = screen.getByLabelText(/annotations\(optional\)/i); +export async function setup({ onSubmit }: RegisterModelFormProps, options?: RenderWithRouteProps) { + render(, { route: options?.route ?? '/' }); + const nameInput = screen.getByLabelText(/^name$/i); + const descriptionInput = screen.getByLabelText(/description/i); + const annotationsInput = screen.getByLabelText(/annotation/i); const submitButton = screen.getByRole('button', { name: /register model/i, }); - const modelFileInput = screen.getByLabelText(/model file/i); - const configurationInput = screen.getByLabelText(/configuration object/i); - const metricNameInput = screen.getByLabelText(/metric name/i); - const trainingMetricValueInput = screen.getByLabelText(/training metric value/i); - const validationMetricValueInput = screen.getByLabelText(/validation metric value/i); - const testingMetricValueInput = screen.getByLabelText(/testing metric value/i); + const modelFileInput = screen.queryByLabelText(/file/i); const tagKeyInput = screen.getByLabelText(/^key$/i); const tagValueInput = screen.getByLabelText(/^value$/i); const form = screen.getByTestId('mlCommonsPlugin-registerModelForm'); @@ -38,23 +32,18 @@ export async function setup({ onSubmit }: RegisterModelFormProps) { // fill model description await user.type(descriptionInput, 'test model description'); // fill model file - await user.upload( - modelFileInput, - new File(['test model file'], 'model.zip', { type: 'application/zip' }) - ); + if (modelFileInput) { + await user.upload( + modelFileInput, + new File(['test model file'], 'model.zip', { type: 'application/zip' }) + ); + } return { nameInput, - versionInput, descriptionInput, annotationsInput, - configurationInput, submitButton, - modelFileInput, - metricNameInput, - trainingMetricValueInput, - validationMetricValueInput, - testingMetricValueInput, tagKeyInput, tagValueInput, form, diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index aa78435a..1cfd189c 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -6,7 +6,6 @@ import React, { useState } from 'react'; import { EuiFormRow, - EuiPanel, EuiTitle, EuiHorizontalRule, htmlIdGenerator, @@ -14,6 +13,9 @@ import { EuiFlexGroup, EuiFlexItem, EuiCheckableCard, + EuiText, + EuiRadio, + EuiLink, } from '@elastic/eui'; import type { Control } from 'react-hook-form'; @@ -24,46 +26,48 @@ import { ArtifactUrl } from './artifact_url'; export const ArtifactPanel = (props: { formControl: Control; + ordinalNumber: number; }) => { const [selectedSource, setSelectedSource] = useState<'source_from_computer' | 'source_from_url'>( 'source_from_computer' ); return ( - +
-

Artifact

+

{props.ordinalNumber}. Artifact

- - - - - setSelectedSource('source_from_computer')} - /> - - - setSelectedSource('source_from_url')} - /> - - - + + + Provide the model artifact for upload. If uploading from local file, keep your browser + open until the upload is complete.{' '} + + Learn more + + + + + setSelectedSource('source_from_computer')} + /> + setSelectedSource('source_from_url')} + /> {selectedSource === 'source_from_computer' && ( )} {selectedSource === 'source_from_url' && } - +
); }; diff --git a/public/components/register_model/artifact_file.tsx b/public/components/register_model/artifact_file.tsx index f342b21f..c34acf95 100644 --- a/public/components/register_model/artifact_file.tsx +++ b/public/components/register_model/artifact_file.tsx @@ -8,30 +8,41 @@ import { EuiFormRow, EuiFilePicker } from '@elastic/eui'; import { useController } from 'react-hook-form'; import type { Control } from 'react-hook-form'; -import { FORM_ITEM_WIDTH } from './form_constants'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +const ONE_MB = 1000 * 1000; +const MAX_FILE_SIZE = 80 * ONE_MB; + +function validateFile(file: File) { + if (file.size > MAX_FILE_SIZE) { + return 'Maximum file size exceeded. Add a smaller file.'; + } + return true; +} + export const ModelFileUploader = (props: { formControl: Control; }) => { const modelFileFieldController = useController({ name: 'modelFile', control: props.formControl, - rules: { required: true }, + rules: { + required: { value: true, message: 'A file is required. Add a file.' }, + validate: validateFile, + }, shouldUnregister: true, }); return ( { modelFileFieldController.field.onChange(fileList?.item(0)); diff --git a/public/components/register_model/artifact_url.tsx b/public/components/register_model/artifact_url.tsx index 47749d6a..89bcb6df 100644 --- a/public/components/register_model/artifact_url.tsx +++ b/public/components/register_model/artifact_url.tsx @@ -19,24 +19,23 @@ export const ArtifactUrl = (props: { name: 'modelURL', control: props.formControl, rules: { - required: true, - pattern: URL_REGEX, + required: { value: true, message: 'URL is required. Enter a URL.' }, + pattern: { value: URL_REGEX, message: 'URL is invalid. Enter a valid URL.' }, }, shouldUnregister: true, }); return ( ; + ordinalNumber: number; }) => { const [metricNamesLoading, metricNames] = useMetricNames(); @@ -94,28 +93,33 @@ export const EvaluationMetricsPanel = (props: { ); const metricValueFields = [ - { label: 'Training metric value', controller: trainingMetricFieldController }, - { label: 'Validation metric value', controller: validationMetricFieldController }, - { label: 'Testing metric value', controller: testingMetricFieldController }, + { label: 'Training value', controller: trainingMetricFieldController }, + { label: 'Validation value', controller: validationMetricFieldController }, + { label: 'Testing value', controller: testingMetricFieldController }, ]; return ( - +

- Evaluation Metrics - optional + {props.ordinalNumber}. Evaluation Metrics - optional

- + + + Track training, validation, and testing metrics to compare across versions and models. + + + - + {metricValueFields.map(({ label, controller }) => ( ))} - +
); }; diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index fcc4dd07..99965d54 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -6,15 +6,15 @@ import React, { useState } from 'react'; import { EuiFormRow, - EuiPanel, EuiTitle, - EuiHorizontalRule, EuiCodeEditor, EuiText, EuiButtonEmpty, EuiFlyout, EuiFlyoutHeader, EuiFlyoutBody, + EuiLink, + EuiSpacer, } from '@elastic/eui'; import { useController } from 'react-hook-form'; import type { Control } from 'react-hook-form'; @@ -27,31 +27,42 @@ function validateConfigurationObject(value: string) { try { JSON.parse(value.trim()); } catch { - return false; + return 'JSON is invalid. Enter a valid JSON'; } return true; } export const ConfigurationPanel = (props: { formControl: Control; + ordinalNumber: number; }) => { const [isHelpVisible, setIsHelpVisible] = useState(false); const configurationFieldController = useController({ name: 'configuration', control: props.formControl, - rules: { required: true, validate: validateConfigurationObject }, + rules: { + required: { value: true, message: 'Configuration is required.' }, + validate: validateConfigurationObject, + }, }); return ( - +
-

Configuration

+

{props.ordinalNumber}. Configuration

- + + + The model configuration JSON object.{' '} + Help. + + + setIsHelpVisible(true)} size="xs" color="primary"> @@ -90,6 +101,6 @@ export const ConfigurationPanel = (props: { )} - +
); }; diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index acd16561..f5076215 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -4,106 +4,105 @@ */ import React from 'react'; -import { - EuiFieldText, - EuiFieldNumber, - EuiFlexItem, - EuiFormRow, - EuiPanel, - EuiTitle, - EuiHorizontalRule, - EuiFlexGroup, - EuiTextArea, -} from '@elastic/eui'; +import { EuiFieldText, EuiFormRow, EuiTitle, EuiTextArea, EuiText } from '@elastic/eui'; import { useController } from 'react-hook-form'; import type { Control } from 'react-hook-form'; -import { FORM_ITEM_WIDTH } from './form_constants'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +const NAME_MAX_LENGTH = 60; +const DESCRIPTION_MAX_LENGTH = 200; +const ANNOTATION_MAX_LENGTH = 200; + export const ModelDetailsPanel = (props: { formControl: Control; + ordinalNumber: number; }) => { const nameFieldController = useController({ name: 'name', control: props.formControl, - rules: { required: true }, - }); - - const versionFieldController = useController({ - name: 'version', - control: props.formControl, - rules: { required: true }, + rules: { + required: { value: true, message: 'Name can not be empty' }, + maxLength: { value: NAME_MAX_LENGTH, message: 'Text exceed max length' }, + }, }); const descriptionFieldController = useController({ name: 'description', control: props.formControl, - rules: { required: true }, + rules: { + required: { value: true, message: 'Description can not be empty' }, + maxLength: { value: DESCRIPTION_MAX_LENGTH, message: 'Text exceed max length' }, + }, }); const annotationsFieldController = useController({ name: 'annotations', control: props.formControl, + rules: { maxLength: { value: ANNOTATION_MAX_LENGTH, message: 'Text exceed max length' } }, }); const { ref: nameInputRef, ...nameField } = nameFieldController.field; - const { ref: versionInputRef, ...versionField } = versionFieldController.field; const { ref: descriptionInputRef, ...descriptionField } = descriptionFieldController.field; const { ref: annotationsInputRef, ...annotationsField } = annotationsFieldController.field; return ( - +
-

Model Details

+

{props.ordinalNumber}. Model Details

- - - - - - - - - - - - - - - - - - - - - - - - - - + + {Math.max(NAME_MAX_LENGTH - nameField.value.length, 0)} characters allowed. +
+ Use a unique for the model. + + } + > + +
+ + + + + Annotation - Optional + + } + > + + +
); }; diff --git a/public/components/register_model/model_tags.tsx b/public/components/register_model/model_tags.tsx index e3faf9db..91365b3f 100644 --- a/public/components/register_model/model_tags.tsx +++ b/public/components/register_model/model_tags.tsx @@ -4,7 +4,7 @@ */ import React, { useCallback } from 'react'; -import { EuiButton, EuiPanel, EuiTitle, EuiHorizontalRule, EuiSpacer } from '@elastic/eui'; +import { EuiButton, EuiTitle, EuiHorizontalRule, EuiSpacer } from '@elastic/eui'; import { useFieldArray } from 'react-hook-form'; import type { Control } from 'react-hook-form'; @@ -14,6 +14,7 @@ import { useModelTags } from './register_model.hooks'; export const ModelTagsPanel = (props: { formControl: Control; + ordinalNumber: number; }) => { const [, { keys, values }] = useModelTags(); const { fields, append, remove } = useFieldArray({ @@ -26,10 +27,10 @@ export const ModelTagsPanel = (props: { }, [append]); return ( - +

- Tags - optional + {props.ordinalNumber}. Tags - optional

@@ -48,6 +49,6 @@ export const ModelTagsPanel = (props: { })} Add new tag - +
); }; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index eb0887c9..062e4291 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -5,11 +5,8 @@ import React, { useCallback, useEffect } from 'react'; import { FieldErrors, useForm } from 'react-hook-form'; -import { EuiPageHeader, EuiSpacer, EuiForm, EuiButton } from '@elastic/eui'; import { useParams } from 'react-router-dom'; - -import { APIProvider } from '../../apis/api_provider'; -import { upgradeModelVersion } from '../../utils'; +import { EuiPageHeader, EuiSpacer, EuiForm, EuiButton, EuiPanel, EuiText } from '@elastic/eui'; import { ModelDetailsPanel } from './model_details'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; @@ -18,23 +15,44 @@ import { ConfigurationPanel } from './model_configuration'; import { EvaluationMetricsPanel } from './evaluation_metrics'; import { ModelTagsPanel } from './model_tags'; import { useModelUpload } from './register_model.hooks'; +import { APIProvider } from '../../apis/api_provider'; +import { upgradeModelVersion } from '../../utils'; +import { useSearchParams } from '../../hooks/use_search_params'; +import { isValidModelRegisterFormType } from './utils'; export interface RegisterModelFormProps { onSubmit?: (data: ModelFileFormData | ModelUrlFormData) => void; } +const DEFAULT_VALUES = { + name: '', + description: '', + version: '1', + configuration: '{}', + tags: [{ key: '', value: '' }], +}; + export const RegisterModelForm = (props: RegisterModelFormProps) => { const { id: latestVersionId } = useParams<{ id: string | undefined }>(); + const typeParams = useSearchParams().get('type'); + + const formType = isValidModelRegisterFormType(typeParams) ? typeParams : 'upload'; + const partials = + formType === 'import' + ? [ModelDetailsPanel, ModelTagsPanel] + : [ + ModelDetailsPanel, + ArtifactPanel, + ConfigurationPanel, + EvaluationMetricsPanel, + ModelTagsPanel, + ]; + const { handleSubmit, control, setValue, formState } = useForm< ModelFileFormData | ModelUrlFormData >({ - defaultValues: { - name: '', - description: '', - version: '1', - configuration: '{}', - tags: [{ key: '', value: '' }], - }, + mode: 'onChange', + defaultValues: DEFAULT_VALUES, }); const submitModel = useModelUpload(); @@ -80,26 +98,30 @@ export const RegisterModelForm = (props: RegisterModelFormProps) => { onSubmit={handleSubmit(onSubmit, onError)} component="form" > - - - - - - - - - - - - - - Register model - + + + + + Register your model to collaboratively manage its life cycle, and facilitate model + discovery across your organization. + + + + {partials.map((FormPartial, i) => ( + + + + + ))} + + Register model + + ); }; diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index 6f271f3f..1a60d101 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -12,7 +12,7 @@ interface ModelFormBase { name: string; version: string; description: string; - annotations: string; + annotations?: string; configuration: string; metricName?: string; trainingMetricValue?: string; diff --git a/public/components/register_model/utils.ts b/public/components/register_model/utils.ts new file mode 100644 index 00000000..a9f9e765 --- /dev/null +++ b/public/components/register_model/utils.ts @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * import - import a OpenSearch pre-defined model + * upload - user upload a model by himself/herself by register a new model or register a new version + */ +export function isValidModelRegisterFormType(type: string | null): type is 'upload' | 'import' { + return type === 'upload' || type === 'import'; +} diff --git a/public/hooks/use_search_params.ts b/public/hooks/use_search_params.ts new file mode 100644 index 00000000..fd9d95d0 --- /dev/null +++ b/public/hooks/use_search_params.ts @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { useLocation } from 'react-router-dom'; + +export function useSearchParams() { + const { search } = useLocation(); + return React.useMemo(() => new URLSearchParams(search), [search]); +} diff --git a/test/test_utils.tsx b/test/test_utils.tsx index fadc47f0..e8aa475d 100644 --- a/test/test_utils.tsx +++ b/test/test_utils.tsx @@ -6,18 +6,46 @@ import React, { FC, ReactElement } from 'react'; import { I18nProvider } from '@osd/i18n/react'; import { render, RenderOptions } from '@testing-library/react'; +import { createBrowserHistory } from 'history'; +import { Router } from 'react-router-dom'; import { DataSourceContextProvider } from '../public/contexts'; +export interface RenderWithRouteProps { + route: string; +} + +const history = { + current: createBrowserHistory(), +}; + const AllTheProviders: FC<{ children: React.ReactNode }> = ({ children }) => { return ( - - {children} - + + + {children} + + ); }; -const customRender = (ui: ReactElement, options?: Omit) => - render(ui, { wrapper: AllTheProviders, ...options }); +/** + * Example 1: render with a route + * customRender(, {route: '/app'}) + * + */ +const customRender = ( + ui: ReactElement, + options?: Omit & RenderWithRouteProps +) => { + const currentHistory = createBrowserHistory(); + history.current = currentHistory; + + if (options?.route) { + currentHistory.push(options?.route); + } + + return render(ui, { wrapper: AllTheProviders, ...options }); +}; export * from '@testing-library/react'; export { customRender as render }; From 115752fa54f93c50ee5eb550d0b916f18b3743bf Mon Sep 17 00:00:00 2001 From: raintygao Date: Thu, 9 Feb 2023 18:43:30 +0800 Subject: [PATCH 08/75] update help button location and flyout content according to updated design and fix ci failed (#90) * feat: update help location and flyout content Signed-off-by: raintygao * chore: add slient param to unit test Signed-off-by: raintygao * chore: add silent param to unit test Signed-off-by: raintygao * chore: move silent to workflow Signed-off-by: raintygao * chore: add watch mode for unit test Signed-off-by: raintygao --------- Signed-off-by: raintygao Signed-off-by: Lin Wang --- package.json | 6 +++ .../register_model_configuration.test.tsx | 17 ++++++ .../components/register_model/help_flyout.tsx | 53 +++++++++++++++++++ .../register_model/model_configuration.tsx | 36 ++++--------- 4 files changed, 86 insertions(+), 26 deletions(-) create mode 100644 public/components/register_model/__tests__/register_model_configuration.test.tsx create mode 100644 public/components/register_model/help_flyout.tsx diff --git a/package.json b/package.json index 41a94788..cf7a726c 100644 --- a/package.json +++ b/package.json @@ -11,8 +11,14 @@ "osd": "node ../../scripts/osd", "lint:es": "node ../../scripts/eslint", "test:jest": "../../node_modules/.bin/jest --config ./test/jest.config.js", + "test:watch": "../../node_modules/.bin/jest --config ./test/jest.config.js --watch", "prepare": "husky install" }, + "husky": { + "hooks": { + "pre-commit": "lint-staged" + } + }, "dependencies": { "hash-wasm": "^4.9.0", "papaparse": "^5.3.2", diff --git a/public/components/register_model/__tests__/register_model_configuration.test.tsx b/public/components/register_model/__tests__/register_model_configuration.test.tsx new file mode 100644 index 00000000..8ff8d618 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_configuration.test.tsx @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { screen } from '../../../../test/test_utils'; +import { setup } from './setup'; + +describe(' Configuration', () => { + it('should render a help flyout when click help button', async () => { + const { user } = await setup({}); + + expect(screen.getByLabelText('Configuration in JSON')).toBeInTheDocument(); + await user.click(screen.getByTestId('model-configuration-help-button')); + expect(screen.getByRole('dialog')).toBeInTheDocument(); + }); +}); diff --git a/public/components/register_model/help_flyout.tsx b/public/components/register_model/help_flyout.tsx new file mode 100644 index 00000000..b545c36f --- /dev/null +++ b/public/components/register_model/help_flyout.tsx @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { + EuiTitle, + EuiSpacer, + EuiText, + EuiFlyout, + EuiFlyoutHeader, + EuiFlyoutBody, + EuiCodeBlock, +} from '@elastic/eui'; + +interface Props { + onClose: () => void; +} + +export const HelpFlyout = ({ onClose }: Props) => { + const jsonCode = ` + "model_type": "bert", + "embedding_dimension": 384, + "framework_type": "sentence_transformers" +`; + + return ( + + + +

Help

+
+
+ + +

Example

+
+ + +

+ For consistency across the many flyouts, please utilize the following code for + implementing the flyout with a header. +

+
+ + + {jsonCode} + +
+
+ ); +}; diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index 99965d54..e4bc8384 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -10,10 +10,6 @@ import { EuiCodeEditor, EuiText, EuiButtonEmpty, - EuiFlyout, - EuiFlyoutHeader, - EuiFlyoutBody, - EuiLink, EuiSpacer, } from '@elastic/eui'; import { useController } from 'react-hook-form'; @@ -22,6 +18,7 @@ import type { Control } from 'react-hook-form'; import '../../ace-themes/sql_console.js'; import { FORM_ITEM_WIDTH } from './form_constants'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { HelpFlyout } from './help_flyout'; function validateConfigurationObject(value: string) { try { @@ -54,7 +51,14 @@ export const ConfigurationPanel = (props: { The model configuration JSON object.{' '} - Help. + setIsHelpVisible(true)} + size="xs" + color="primary" + data-test-subj="model-configuration-help-button" + > + Help. + @@ -63,13 +67,6 @@ export const ConfigurationPanel = (props: { label="Configuration in JSON" isInvalid={Boolean(configurationFieldController.fieldState.error)} error={configurationFieldController.fieldState.error?.message} - labelAppend={ - - setIsHelpVisible(true)} size="xs" color="primary"> - Help - - - } >
- {isHelpVisible && ( - setIsHelpVisible(false)}> - - -

Help

-
-
- - -

TODO

-
-
-
- )} + {isHelpVisible && setIsHelpVisible(false)} />} ); }; From 04fad6f8c3b4e9d51900e336d5b3f41592579dd2 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Mon, 13 Feb 2023 09:54:51 +0800 Subject: [PATCH 09/75] feat: add multiple validation rules on tag field (#93) + Only allow to add maximum 25 tags + Not allow to add tag with empty key or value from UI, but tags with both key and value empty are allowed from UI(such tag will be just be ignored) + Not allow to add duplicated tags which have the same key-value pairs + Tag key length and tag value length cannot exceed 80 characters --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../register_model_artifact.test.tsx | 35 +++-- .../__tests__/register_model_details.test.tsx | 38 +++-- .../__tests__/register_model_metrics.test.tsx | 37 ++--- .../__tests__/register_model_tags.test.tsx | 142 +++++++++++++++-- .../register_model/__tests__/setup.tsx | 9 +- public/components/register_model/artifact.tsx | 28 +--- .../register_model/artifact_file.tsx | 10 +- .../register_model/artifact_url.tsx | 10 +- .../register_model/evaluation_metrics.tsx | 17 +- .../register_model/model_configuration.tsx | 11 +- .../register_model/model_details.tsx | 18 +-- .../components/register_model/model_tags.tsx | 25 +-- .../register_model/register_model.tsx | 88 +++++------ .../components/register_model/tag_field.tsx | 147 +++++++++++++++--- 14 files changed, 401 insertions(+), 214 deletions(-) diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 2d405f73..752a01b3 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -5,25 +5,39 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; +import * as formHooks from '../register_model.hooks'; describe(' Artifact', () => { + const onSubmitMock = jest.fn(); + + beforeEach(() => { + jest + .spyOn(formHooks, 'useMetricNames') + .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); + jest + .spyOn(formHooks, 'useModelTags') + .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + it('should render an artifact panel', async () => { - const onSubmitMock = jest.fn(); - await setup({ onSubmit: onSubmitMock }); + await setup(); expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInTheDocument(); expect(screen.getByLabelText(/from computer/i)).toBeInTheDocument(); expect(screen.getByLabelText(/from url/i)).toBeInTheDocument(); }); it('should not render an artifact panel if importing an opensearch defined model', async () => { - const onSubmitMock = jest.fn(); - await setup({ onSubmit: onSubmitMock }, { route: '/?type=import' }); + await setup({ route: '/?type=import' }); expect(screen.queryByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeNull(); }); it('should submit the register model form', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); expect(onSubmitMock).not.toHaveBeenCalled(); await result.user.click(result.submitButton); @@ -32,8 +46,7 @@ describe(' Artifact', () => { }); it('should NOT submit the register model form if model file size exceed 80MB', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); // Empty model file selection by clicking the `Remove` button on EuiFilePicker await result.user.click(screen.getByLabelText(/clear selected files/i)); @@ -50,8 +63,7 @@ describe(' Artifact', () => { }); it('should NOT submit the register model form if model file is empty', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); // Empty model file selection by clicking the `Remove` button on EuiFilePicker await result.user.click(screen.getByLabelText(/clear selected files/i)); @@ -62,8 +74,7 @@ describe(' Artifact', () => { }); it('should NOT submit the register model form if model url is empty', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); // select option: From URL await result.user.click(screen.getByLabelText(/from url/i)); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index fa5583aa..eb2f924c 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -4,19 +4,34 @@ */ import { setup } from './setup'; +import * as formHooks from '../register_model.hooks'; describe(' Details', () => { + const onSubmitMock = jest.fn(); + + beforeEach(() => { + jest + .spyOn(formHooks, 'useMetricNames') + .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); + jest + .spyOn(formHooks, 'useModelTags') + .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + it('should render a model details panel', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); expect(result.nameInput).toBeInTheDocument(); expect(result.descriptionInput).toBeInTheDocument(); expect(result.annotationsInput).toBeInTheDocument(); }); it('should submit the register model form', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); expect(onSubmitMock).not.toHaveBeenCalled(); await result.user.click(result.submitButton); @@ -25,8 +40,7 @@ describe(' Details', () => { }); it('should NOT submit the register model form if model name is empty', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.clear(result.nameInput); await result.user.click(result.submitButton); @@ -36,8 +50,7 @@ describe(' Details', () => { }); it('should NOT submit the register model form if model name length exceeded 60', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.clear(result.nameInput); await result.user.type(result.nameInput, 'x'.repeat(60)); @@ -52,8 +65,7 @@ describe(' Details', () => { }); it('should NOT submit the register model form if model description is empty', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.clear(result.descriptionInput); await result.user.click(result.submitButton); @@ -63,8 +75,7 @@ describe(' Details', () => { }); it('should NOT submit the register model form if model description length exceed 200', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.clear(result.descriptionInput); await result.user.type(result.descriptionInput, 'x'.repeat(200)); @@ -79,8 +90,7 @@ describe(' Details', () => { }); it('annotation text length should not exceed 200', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.clear(result.annotationsInput); await result.user.type(result.annotationsInput, 'x'.repeat(200)); diff --git a/public/components/register_model/__tests__/register_model_metrics.test.tsx b/public/components/register_model/__tests__/register_model_metrics.test.tsx index cea70198..88874596 100644 --- a/public/components/register_model/__tests__/register_model_metrics.test.tsx +++ b/public/components/register_model/__tests__/register_model_metrics.test.tsx @@ -8,15 +8,24 @@ import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; describe(' Evaluation Metrics', () => { + const onSubmitMock = jest.fn(); + beforeEach(() => { jest .spyOn(formHooks, 'useMetricNames') .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); + jest + .spyOn(formHooks, 'useModelTags') + .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + }); + + afterEach(() => { + jest.clearAllMocks(); }); it('should render a evaluation metrics panel', async () => { - const onSubmitMock = jest.fn(); - await setup({ onSubmit: onSubmitMock }); + await setup(); expect(screen.getByLabelText(/^metric$/i)).toBeInTheDocument(); expect(screen.getByLabelText(/training value/i)).toBeInTheDocument(); @@ -25,8 +34,7 @@ describe(' Evaluation Metrics', () => { }); it('should render metric value input as disabled by default', async () => { - const onSubmitMock = jest.fn(); - await setup({ onSubmit: onSubmitMock }); + await setup(); expect(screen.getByLabelText(/training value/i)).toBeDisabled(); expect(screen.getByLabelText(/validation value/i)).toBeDisabled(); @@ -34,8 +42,7 @@ describe(' Evaluation Metrics', () => { }); it('should render metric value input as enabled after selecting a metric name', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); @@ -46,16 +53,14 @@ describe(' Evaluation Metrics', () => { }); it('should submit the form without selecting metric name', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.click(result.submitButton); expect(onSubmitMock).toHaveBeenCalled(); }); it('should submit the form if metric name is selected but metric value are empty', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); await result.user.click(result.submitButton); @@ -64,8 +69,7 @@ describe(' Evaluation Metrics', () => { }); it('should submit the form if metric name and all metric value are selected', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); @@ -79,8 +83,7 @@ describe(' Evaluation Metrics', () => { }); it('should submit the form if metric name is selected but metric value are partially selected', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); @@ -92,8 +95,7 @@ describe(' Evaluation Metrics', () => { }); it('should NOT submit the form if metric value < 0', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); @@ -105,8 +107,7 @@ describe(' Evaluation Metrics', () => { }); it('should NOT submit the form if metric value > 1', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index 7ccc22f7..5d65aa77 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -3,39 +3,55 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { screen } from '../../../../test/test_utils'; +import { screen, waitFor, within } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; describe(' Tags', () => { + const onSubmitMock = jest.fn(); + beforeEach(() => { + jest + .spyOn(formHooks, 'useMetricNames') + .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + }); + + afterEach(() => { + jest.clearAllMocks(); }); it('should render a tags panel', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + await setup(); - expect(result.tagKeyInput).toBeInTheDocument(); - expect(result.tagValueInput).toBeInTheDocument(); + const keyContainer = screen.queryByTestId('ml-tagKey1'); + const valueContainer = screen.queryByTestId('ml-tagValue1'); + + expect(keyContainer).toBeInTheDocument(); + expect(valueContainer).toBeInTheDocument(); }); it('should submit the form without selecting tags', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); await result.user.click(result.submitButton); expect(onSubmitMock).toHaveBeenCalled(); }); it('should submit the form with selected tags', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); - await result.user.type(result.tagKeyInput, 'Key1'); - await result.user.type(result.tagValueInput, 'Value1'); + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + + await result.user.type(keyInput, 'Key1'); + await result.user.type(valueInput, 'Value1'); await result.user.click(result.submitButton); @@ -45,8 +61,7 @@ describe(' Tags', () => { }); it('should allow to add multiple tags', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); // Add two tags await result.user.click(screen.getByText(/add new tag/i)); @@ -69,9 +84,103 @@ describe(' Tags', () => { ); }); + it('should NOT allow to submit tag which does NOT have value', async () => { + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + // only input key, but NOT value + await result.user.type(keyInput, 'key 1'); + await result.user.click(result.submitButton); + + // tag value input should be invalid + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).queryByText('A value is required. Enter a value.'); + expect(valueInput).toBeInTheDocument(); + + // it should not submit the form + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + + it('should NOT allow to submit tag which does NOT have key', async () => { + const result = await setup(); + + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + // only input value, but NOT key + await result.user.type(valueInput, 'Value 1'); + await result.user.click(result.submitButton); + + // tag key input should be invalid + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).queryByText('A key is required. Enter a key.'); + expect(keyInput).toBeInTheDocument(); + + // it should not submit the form + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + + it('should NOT allow to submit if it has duplicate tags', async () => { + const result = await setup(); + + // input tag key: 'Key 1' + const keyContainer1 = screen.getByTestId('ml-tagKey1'); + const keyInput1 = within(keyContainer1).getByRole('textbox'); + await result.user.type(keyInput1, 'Key 1'); + + // input tag key: 'Value 1' + const valueContainer1 = screen.getByTestId('ml-tagValue1'); + const valueInput1 = within(valueContainer1).getByRole('textbox'); + await result.user.type(valueInput1, 'Value 1'); + + // Add a new tag, and input the same tag key and value + await result.user.click(screen.getByText(/add new tag/i)); + // input tag key: 'Key 1' + const keyContainer2 = screen.getByTestId('ml-tagKey2'); + const keyInput2 = within(keyContainer2).getByRole('textbox'); + await result.user.type(keyInput2, 'Key 1'); + + // input tag key: 'Value 1' + const valueContainer2 = screen.getByTestId('ml-tagValue2'); + const valueInput2 = within(valueContainer2).getByRole('textbox'); + await result.user.type(valueInput2, 'Value 1'); + + await result.user.click(result.submitButton); + + // Display error message + expect( + within(keyContainer2).queryByText( + 'This tag has already been added. Remove the duplicate tag.' + ) + ).toBeInTheDocument(); + // it should not submit the form + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + + it( + 'should only allow to add maximum 25 tags', + async () => { + const result = await setup(); + const MAX_TAG_NUM = 25; + + // It has one tag by default, we can add 24 more tags + const addNewTagButton = screen.getByText(/add new tag/i); + for (let i = 1; i < MAX_TAG_NUM; i++) { + await result.user.click(addNewTagButton); + } + + // 25 tags are displayed + waitFor(() => expect(screen.queryAllByTestId(/ml-tagKey/i)).toHaveLength(25)); + // add new tag button should not be displayed + waitFor(() => expect(screen.getByRole('button', { name: /add new tag/i })).toBeDisabled()); + }, + // The test will fail due to timeout as we interact with the page a lot(24 button click to add new tags) + // So we try to increase test running timeout to 60000ms to mitigate the timeout issue + 60 * 1000 + ); + it('should allow to remove multiple tags', async () => { - const onSubmitMock = jest.fn(); - const result = await setup({ onSubmit: onSubmitMock }); + const result = await setup(); // Add two tags await result.user.click(screen.getByText(/add new tag/i)); @@ -85,6 +194,9 @@ describe(' Tags', () => { await result.user.click(screen.getByLabelText(/remove tag at row 2/i)); await result.user.click(screen.getByLabelText(/remove tag at row 1/i)); + // should have only one tag left + await waitFor(() => expect(screen.queryAllByTestId(/ml-tagKey/i)).toHaveLength(1)); + await result.user.click(result.submitButton); expect(onSubmitMock).toHaveBeenCalledWith( diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index db7ec12a..849b0063 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -7,14 +7,13 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; import { RegisterModelForm } from '../register_model'; -import type { RegisterModelFormProps } from '../register_model'; import { render, RenderWithRouteProps, screen } from '../../../../test/test_utils'; jest.mock('../../../apis/model'); jest.mock('../../../apis/task'); -export async function setup({ onSubmit }: RegisterModelFormProps, options?: RenderWithRouteProps) { - render(, { route: options?.route ?? '/' }); +export async function setup(options?: RenderWithRouteProps) { + render(, { route: options?.route ?? '/' }); const nameInput = screen.getByLabelText(/^name$/i); const descriptionInput = screen.getByLabelText(/description/i); const annotationsInput = screen.getByLabelText(/annotation/i); @@ -22,8 +21,6 @@ export async function setup({ onSubmit }: RegisterModelFormProps, options?: Rend name: /register model/i, }); const modelFileInput = screen.queryByLabelText(/file/i); - const tagKeyInput = screen.getByLabelText(/^key$/i); - const tagValueInput = screen.getByLabelText(/^value$/i); const form = screen.getByTestId('mlCommonsPlugin-registerModelForm'); const user = userEvent.setup(); @@ -44,8 +41,6 @@ export async function setup({ onSubmit }: RegisterModelFormProps, options?: Rend descriptionInput, annotationsInput, submitButton, - tagKeyInput, - tagValueInput, form, user, }; diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index 1cfd189c..96c80fe2 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -4,30 +4,12 @@ */ import React, { useState } from 'react'; -import { - EuiFormRow, - EuiTitle, - EuiHorizontalRule, - htmlIdGenerator, - EuiSpacer, - EuiFlexGroup, - EuiFlexItem, - EuiCheckableCard, - EuiText, - EuiRadio, - EuiLink, -} from '@elastic/eui'; -import type { Control } from 'react-hook-form'; +import { EuiTitle, htmlIdGenerator, EuiSpacer, EuiText, EuiRadio, EuiLink } from '@elastic/eui'; -import { FORM_ITEM_WIDTH } from './form_constants'; -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { ModelFileUploader } from './artifact_file'; import { ArtifactUrl } from './artifact_url'; -export const ArtifactPanel = (props: { - formControl: Control; - ordinalNumber: number; -}) => { +export const ArtifactPanel = (props: { ordinalNumber: number }) => { const [selectedSource, setSelectedSource] = useState<'source_from_computer' | 'source_from_url'>( 'source_from_computer' ); @@ -64,10 +46,8 @@ export const ArtifactPanel = (props: { onChange={() => setSelectedSource('source_from_url')} /> - {selectedSource === 'source_from_computer' && ( - - )} - {selectedSource === 'source_from_url' && } + {selectedSource === 'source_from_computer' && } + {selectedSource === 'source_from_url' && } ); }; diff --git a/public/components/register_model/artifact_file.tsx b/public/components/register_model/artifact_file.tsx index c34acf95..7dca17de 100644 --- a/public/components/register_model/artifact_file.tsx +++ b/public/components/register_model/artifact_file.tsx @@ -5,8 +5,7 @@ import React from 'react'; import { EuiFormRow, EuiFilePicker } from '@elastic/eui'; -import { useController } from 'react-hook-form'; -import type { Control } from 'react-hook-form'; +import { useController, useFormContext } from 'react-hook-form'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; @@ -20,12 +19,11 @@ function validateFile(file: File) { return true; } -export const ModelFileUploader = (props: { - formControl: Control; -}) => { +export const ModelFileUploader = () => { + const { control } = useFormContext(); const modelFileFieldController = useController({ name: 'modelFile', - control: props.formControl, + control, rules: { required: { value: true, message: 'A file is required. Add a file.' }, validate: validateFile, diff --git a/public/components/register_model/artifact_url.tsx b/public/components/register_model/artifact_url.tsx index 89bcb6df..cde55f5c 100644 --- a/public/components/register_model/artifact_url.tsx +++ b/public/components/register_model/artifact_url.tsx @@ -5,19 +5,17 @@ import React from 'react'; import { EuiFormRow, htmlIdGenerator, EuiFieldText } from '@elastic/eui'; -import { useController } from 'react-hook-form'; -import type { Control } from 'react-hook-form'; +import { useController, useFormContext } from 'react-hook-form'; import { FORM_ITEM_WIDTH } from './form_constants'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { URL_REGEX } from '../../utils/regex'; -export const ArtifactUrl = (props: { - formControl: Control; -}) => { +export const ArtifactUrl = () => { + const { control } = useFormContext(); const modelUrlFieldController = useController({ name: 'modelURL', - control: props.formControl, + control, rules: { required: { value: true, message: 'URL is required. Enter a URL.' }, pattern: { value: URL_REGEX, message: 'URL is invalid. Enter a valid URL.' }, diff --git a/public/components/register_model/evaluation_metrics.tsx b/public/components/register_model/evaluation_metrics.tsx index 74aa9751..3748033f 100644 --- a/public/components/register_model/evaluation_metrics.tsx +++ b/public/components/register_model/evaluation_metrics.tsx @@ -15,18 +15,15 @@ import { EuiSpacer, EuiText, } from '@elastic/eui'; -import { useController } from 'react-hook-form'; -import type { Control } from 'react-hook-form'; +import { useController, useFormContext } from 'react-hook-form'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { useMetricNames } from './register_model.hooks'; const METRIC_VALUE_STEP = 0.01; -export const EvaluationMetricsPanel = (props: { - formControl: Control; - ordinalNumber: number; -}) => { +export const EvaluationMetricsPanel = (props: { ordinalNumber: number }) => { + const { control } = useFormContext(); const [metricNamesLoading, metricNames] = useMetricNames(); // TODO: this has to be hooked with data from BE API @@ -36,12 +33,12 @@ export const EvaluationMetricsPanel = (props: { const metricFieldController = useController({ name: 'metricName', - control: props.formControl, + control, }); const trainingMetricFieldController = useController({ name: 'trainingMetricValue', - control: props.formControl, + control, rules: { max: 1, min: 0, @@ -50,7 +47,7 @@ export const EvaluationMetricsPanel = (props: { const validationMetricFieldController = useController({ name: 'validationMetricValue', - control: props.formControl, + control, rules: { max: 1, min: 0, @@ -59,7 +56,7 @@ export const EvaluationMetricsPanel = (props: { const testingMetricFieldController = useController({ name: 'testingMetricValue', - control: props.formControl, + control, rules: { max: 1, min: 0, diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index e4bc8384..5926a26b 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -12,8 +12,7 @@ import { EuiButtonEmpty, EuiSpacer, } from '@elastic/eui'; -import { useController } from 'react-hook-form'; -import type { Control } from 'react-hook-form'; +import { useController, useFormContext } from 'react-hook-form'; import '../../ace-themes/sql_console.js'; import { FORM_ITEM_WIDTH } from './form_constants'; @@ -29,14 +28,12 @@ function validateConfigurationObject(value: string) { return true; } -export const ConfigurationPanel = (props: { - formControl: Control; - ordinalNumber: number; -}) => { +export const ConfigurationPanel = (props: { ordinalNumber: number }) => { + const { control } = useFormContext(); const [isHelpVisible, setIsHelpVisible] = useState(false); const configurationFieldController = useController({ name: 'configuration', - control: props.formControl, + control, rules: { required: { value: true, message: 'Configuration is required.' }, validate: validateConfigurationObject, diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index f5076215..103c4742 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -5,22 +5,18 @@ import React from 'react'; import { EuiFieldText, EuiFormRow, EuiTitle, EuiTextArea, EuiText } from '@elastic/eui'; -import { useController } from 'react-hook-form'; -import type { Control } from 'react-hook-form'; - -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { useController, useFormContext } from 'react-hook-form'; +import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; const NAME_MAX_LENGTH = 60; const DESCRIPTION_MAX_LENGTH = 200; const ANNOTATION_MAX_LENGTH = 200; -export const ModelDetailsPanel = (props: { - formControl: Control; - ordinalNumber: number; -}) => { +export const ModelDetailsPanel = (props: { ordinalNumber: number }) => { + const { control } = useFormContext(); const nameFieldController = useController({ name: 'name', - control: props.formControl, + control, rules: { required: { value: true, message: 'Name can not be empty' }, maxLength: { value: NAME_MAX_LENGTH, message: 'Text exceed max length' }, @@ -29,7 +25,7 @@ export const ModelDetailsPanel = (props: { const descriptionFieldController = useController({ name: 'description', - control: props.formControl, + control, rules: { required: { value: true, message: 'Description can not be empty' }, maxLength: { value: DESCRIPTION_MAX_LENGTH, message: 'Text exceed max length' }, @@ -38,7 +34,7 @@ export const ModelDetailsPanel = (props: { const annotationsFieldController = useController({ name: 'annotations', - control: props.formControl, + control, rules: { maxLength: { value: ANNOTATION_MAX_LENGTH, message: 'Text exceed max length' } }, }); diff --git a/public/components/register_model/model_tags.tsx b/public/components/register_model/model_tags.tsx index 91365b3f..6ca71030 100644 --- a/public/components/register_model/model_tags.tsx +++ b/public/components/register_model/model_tags.tsx @@ -4,22 +4,21 @@ */ import React, { useCallback } from 'react'; -import { EuiButton, EuiTitle, EuiHorizontalRule, EuiSpacer } from '@elastic/eui'; -import { useFieldArray } from 'react-hook-form'; -import type { Control } from 'react-hook-form'; +import { EuiButton, EuiTitle, EuiHorizontalRule, EuiSpacer, EuiText } from '@elastic/eui'; +import { useFieldArray, useFormContext } from 'react-hook-form'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { ModelTagField } from './tag_field'; import { useModelTags } from './register_model.hooks'; -export const ModelTagsPanel = (props: { - formControl: Control; - ordinalNumber: number; -}) => { +const MAX_TAG_NUM = 25; + +export const ModelTagsPanel = (props: { ordinalNumber: number }) => { + const { control } = useFormContext(); const [, { keys, values }] = useModelTags(); const { fields, append, remove } = useFieldArray({ name: 'tags', - control: props.formControl, + control, }); const addNewTag = useCallback(() => { @@ -38,9 +37,7 @@ export const ModelTagsPanel = (props: { return ( - Add new tag + = MAX_TAG_NUM} onClick={addNewTag}> + Add new tag + + + + {`You can add up to ${MAX_TAG_NUM - fields.length} more tags.`} + ); }; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 062e4291..319fdd03 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -4,7 +4,7 @@ */ import React, { useCallback, useEffect } from 'react'; -import { FieldErrors, useForm } from 'react-hook-form'; +import { FieldErrors, useForm, FormProvider } from 'react-hook-form'; import { useParams } from 'react-router-dom'; import { EuiPageHeader, EuiSpacer, EuiForm, EuiButton, EuiPanel, EuiText } from '@elastic/eui'; @@ -20,10 +20,6 @@ import { upgradeModelVersion } from '../../utils'; import { useSearchParams } from '../../hooks/use_search_params'; import { isValidModelRegisterFormType } from './utils'; -export interface RegisterModelFormProps { - onSubmit?: (data: ModelFileFormData | ModelUrlFormData) => void; -} - const DEFAULT_VALUES = { name: '', description: '', @@ -32,7 +28,7 @@ const DEFAULT_VALUES = { tags: [{ key: '', value: '' }], }; -export const RegisterModelForm = (props: RegisterModelFormProps) => { +export const RegisterModelForm = () => { const { id: latestVersionId } = useParams<{ id: string | undefined }>(); const typeParams = useSearchParams().get('type'); @@ -48,22 +44,14 @@ export const RegisterModelForm = (props: RegisterModelFormProps) => { ModelTagsPanel, ]; - const { handleSubmit, control, setValue, formState } = useForm< - ModelFileFormData | ModelUrlFormData - >({ + const form = useForm({ mode: 'onChange', defaultValues: DEFAULT_VALUES, }); const submitModel = useModelUpload(); const onSubmit = async (data: ModelFileFormData | ModelUrlFormData) => { - if (props.onSubmit) { - props.onSubmit(data); - } await submitModel(data); - // TODO - // eslint-disable-next-line no-console - console.log(data); }; useEffect(() => { @@ -78,13 +66,13 @@ export const RegisterModelForm = (props: RegisterModelFormProps) => { // TODO: clarify which fields to pre-populate const { model_version: modelVersion, name, model_config: modelConfig } = data?.[0]; const newVersion = upgradeModelVersion(modelVersion); - setValue('name', name); - setValue('version', newVersion); - setValue('configuration', modelConfig?.all_config ?? ''); + form.setValue('name', name); + form.setValue('version', newVersion); + form.setValue('configuration', modelConfig?.all_config ?? ''); } }; initializeForm(); - }, [latestVersionId, setValue]); + }, [latestVersionId, form]); const onError = useCallback((errors: FieldErrors) => { // TODO @@ -93,35 +81,37 @@ export const RegisterModelForm = (props: RegisterModelFormProps) => { }, []); return ( - - - - - - Register your model to collaboratively manage its life cycle, and facilitate model - discovery across your organization. - - - - {partials.map((FormPartial, i) => ( - - - - - ))} - - Register model - - - + + + + + + + Register your model to collaboratively manage its life cycle, and facilitate model + discovery across your organization. + + + + {partials.map((FormPartial, i) => ( + + + + + ))} + + Register model + + + + ); }; diff --git a/public/components/register_model/tag_field.tsx b/public/components/register_model/tag_field.tsx index 52e23acc..ed47ca03 100644 --- a/public/components/register_model/tag_field.tsx +++ b/public/components/register_model/tag_field.tsx @@ -11,19 +11,20 @@ import { EuiFlexItem, EuiFormRow, } from '@elastic/eui'; -import React, { useCallback, useMemo } from 'react'; -import { Control, useController } from 'react-hook-form'; +import React, { useCallback, useMemo, useRef } from 'react'; +import { useController, useWatch, useFormContext } from 'react-hook-form'; import { FORM_ITEM_WIDTH } from './form_constants'; +import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; interface ModelTagFieldProps { - name: string; index: number; - formControl: Control; onDelete: (index: number) => void; tagKeys: string[]; tagValues: string[]; } +const MAX_TAG_LENGTH = 80; + function getComboBoxValue(data: EuiComboBoxOptionOption[]) { if (data.length === 0) { return ''; @@ -32,22 +33,67 @@ function getComboBoxValue(data: EuiComboBoxOptionOption[]) { } } -export const ModelTagField = ({ - name, - formControl, - index, - tagKeys, - tagValues, - onDelete, -}: ModelTagFieldProps) => { +export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagFieldProps) => { + const rowEleRef = useRef(null); + const { trigger, control } = useFormContext(); + const tags = useWatch({ + control, + name: 'tags', + }); + const tagKeyController = useController({ - name: `${name}.${index}.key`, - control: formControl, + name: `tags.${index}.key` as const, + control, + rules: { + validate: (tagKey) => { + if (tags) { + const tag = tags[index]; + // If it has value, key cannot be empty + if (!tagKey && tag.value) { + return 'A key is required. Enter a key.'; + } + // If a tag has both key and value, validate if the same tag was added before + if (tagKey && tag.value) { + // Find if the same tag appears before the current tag + for (let i = 0; i < index; i++) { + // If found the same tag, then the current tag is invalid + if (tags[i].key === tagKey && tags[i].value === tag.value) { + return 'This tag has already been added. Remove the duplicate tag.'; + } + } + } + } + return true; + }, + }, }); const tagValueController = useController({ - name: `${name}.${index}.value`, - control: formControl, + name: `tags.${index}.value` as const, + control, + rules: { + validate: (tagValue) => { + if (tags) { + const tag = tags[index]; + // If it has key, value cannot be empty + if (!tagValue && tag.key) { + return 'A value is required. Enter a value.'; + } + // If a tag has both key and value, validate if the same tag was added before + if (tag.key && tagValue) { + // Find if the same tag appears before the current tag + for (let i = 0; i < index; i++) { + // If found the same tag, then the current tag is invalid + if (tags[i].key === tag.key && tags[i].value === tagValue) { + // return `false` instead of error message because we don't show error message on value field + return false; + } + } + } + } + return true; + }, + }, }); const onKeyChange = useCallback( @@ -64,6 +110,26 @@ export const ModelTagField = ({ [tagValueController.field] ); + const onKeyCreate = useCallback( + (value: string) => { + if (value.length > MAX_TAG_LENGTH) { + return; + } + tagKeyController.field.onChange(value); + }, + [tagKeyController.field] + ); + + const onValueCreate = useCallback( + (value: string) => { + if (value.length > MAX_TAG_LENGTH) { + return; + } + tagValueController.field.onChange(value); + }, + [tagValueController.field] + ); + const keyOptions = useMemo(() => { return tagKeys.map((key) => ({ label: key })); }, [tagKeys]); @@ -72,43 +138,76 @@ export const ModelTagField = ({ return tagValues.map((value) => ({ label: value })); }, [tagValues]); + const onBlur = useCallback( + (e: React.FocusEvent) => { + // If the blur event was stolen by the child element of the current row + // We don't want to validate the form field yet + if (e.relatedTarget && rowEleRef.current && rowEleRef.current.contains(e.relatedTarget)) { + return; + } + // The blur could happen when selecting combo box dropdown + // But we don't want to trigger form validation in this case + if ( + e.relatedTarget?.getAttribute('role') === 'option' && + e.relatedTarget?.tagName === 'BUTTON' + ) { + return; + } + // Validate the form only when the current tag row blurred + trigger('tags'); + }, + [trigger] + ); + return ( - + - + - + - + onDelete(index)}> Remove From c16f225c4a8712cfe21789fe6013a00724beee9f Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Mon, 13 Feb 2023 18:43:39 +0800 Subject: [PATCH 10/75] feat: align ui with the latest design changes (#95) + model name max length 60 characters -> 80 characters + remove ordinal number from form section title + add missing tag description to form tag section Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../__tests__/register_model_details.test.tsx | 6 +++--- public/components/register_model/artifact.tsx | 4 ++-- .../components/register_model/evaluation_metrics.tsx | 4 ++-- .../components/register_model/model_configuration.tsx | 4 ++-- public/components/register_model/model_details.tsx | 6 +++--- public/components/register_model/model_tags.tsx | 11 +++++++---- public/components/register_model/register_model.tsx | 2 +- 7 files changed, 20 insertions(+), 17 deletions(-) diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index eb2f924c..01343dcf 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -49,15 +49,15 @@ describe(' Details', () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); - it('should NOT submit the register model form if model name length exceeded 60', async () => { + it('should NOT submit the register model form if model name length exceeded 80', async () => { const result = await setup(); await result.user.clear(result.nameInput); - await result.user.type(result.nameInput, 'x'.repeat(60)); + await result.user.type(result.nameInput, 'x'.repeat(80)); expect(result.nameInput).toBeValid(); await result.user.clear(result.nameInput); - await result.user.type(result.nameInput, 'x'.repeat(61)); + await result.user.type(result.nameInput, 'x'.repeat(81)); expect(result.nameInput).toBeInvalid(); await result.user.click(result.submitButton); diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index 96c80fe2..737b6eaf 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -9,7 +9,7 @@ import { EuiTitle, htmlIdGenerator, EuiSpacer, EuiText, EuiRadio, EuiLink } from import { ModelFileUploader } from './artifact_file'; import { ArtifactUrl } from './artifact_url'; -export const ArtifactPanel = (props: { ordinalNumber: number }) => { +export const ArtifactPanel = () => { const [selectedSource, setSelectedSource] = useState<'source_from_computer' | 'source_from_url'>( 'source_from_computer' ); @@ -17,7 +17,7 @@ export const ArtifactPanel = (props: { ordinalNumber: number }) => { return (
-

{props.ordinalNumber}. Artifact

+

Artifact

diff --git a/public/components/register_model/evaluation_metrics.tsx b/public/components/register_model/evaluation_metrics.tsx index 3748033f..d80d2d4f 100644 --- a/public/components/register_model/evaluation_metrics.tsx +++ b/public/components/register_model/evaluation_metrics.tsx @@ -22,7 +22,7 @@ import { useMetricNames } from './register_model.hooks'; const METRIC_VALUE_STEP = 0.01; -export const EvaluationMetricsPanel = (props: { ordinalNumber: number }) => { +export const EvaluationMetricsPanel = () => { const { control } = useFormContext(); const [metricNamesLoading, metricNames] = useMetricNames(); @@ -99,7 +99,7 @@ export const EvaluationMetricsPanel = (props: { ordinalNumber: number }) => {

- {props.ordinalNumber}. Evaluation Metrics - optional + Evaluation Metrics - optional

diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index 5926a26b..2e47a721 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -28,7 +28,7 @@ function validateConfigurationObject(value: string) { return true; } -export const ConfigurationPanel = (props: { ordinalNumber: number }) => { +export const ConfigurationPanel = () => { const { control } = useFormContext(); const [isHelpVisible, setIsHelpVisible] = useState(false); const configurationFieldController = useController({ @@ -43,7 +43,7 @@ export const ConfigurationPanel = (props: { ordinalNumber: number }) => { return (
-

{props.ordinalNumber}. Configuration

+

Configuration

diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 103c4742..8d962981 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -8,11 +8,11 @@ import { EuiFieldText, EuiFormRow, EuiTitle, EuiTextArea, EuiText } from '@elast import { useController, useFormContext } from 'react-hook-form'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -const NAME_MAX_LENGTH = 60; +const NAME_MAX_LENGTH = 80; const DESCRIPTION_MAX_LENGTH = 200; const ANNOTATION_MAX_LENGTH = 200; -export const ModelDetailsPanel = (props: { ordinalNumber: number }) => { +export const ModelDetailsPanel = () => { const { control } = useFormContext(); const nameFieldController = useController({ name: 'name', @@ -45,7 +45,7 @@ export const ModelDetailsPanel = (props: { ordinalNumber: number }) => { return (
-

{props.ordinalNumber}. Model Details

+

Model Details

{ +export const ModelTagsPanel = () => { const { control } = useFormContext(); const [, { keys, values }] = useModelTags(); const { fields, append, remove } = useFieldArray({ @@ -29,10 +29,13 @@ export const ModelTagsPanel = (props: { ordinalNumber: number }) => {

- {props.ordinalNumber}. Tags - optional + Tags - optional

- + + Add tags to facilitate model discovery and tracking across your organization. + + {fields.map((field, index) => { return ( { {partials.map((FormPartial, i) => ( - + ))} From 1eb46886b73e9d1860eaf4f3e5e641130c917374 Mon Sep 17 00:00:00 2001 From: raintygao Date: Tue, 14 Feb 2023 17:40:15 +0800 Subject: [PATCH 11/75] add metrics validation (#97) * feat: add metrics validation Signed-off-by: raintygao * test: add UT Signed-off-by: raintygao --------- Signed-off-by: raintygao Signed-off-by: Lin Wang --- .../__tests__/register_model_metrics.test.tsx | 14 ++- .../register_model/evaluation_metrics.tsx | 94 ++++++++++++++----- .../register_model/register_model.types.ts | 12 ++- public/utils/index.ts | 1 + public/utils/number.ts | 8 ++ 5 files changed, 102 insertions(+), 27 deletions(-) create mode 100644 public/utils/number.ts diff --git a/public/components/register_model/__tests__/register_model_metrics.test.tsx b/public/components/register_model/__tests__/register_model_metrics.test.tsx index 88874596..26759411 100644 --- a/public/components/register_model/__tests__/register_model_metrics.test.tsx +++ b/public/components/register_model/__tests__/register_model_metrics.test.tsx @@ -59,13 +59,14 @@ describe(' Evaluation Metrics', () => { expect(onSubmitMock).toHaveBeenCalled(); }); - it('should submit the form if metric name is selected but metric value are empty', async () => { + it('should NOT submit the form if metric name is selected but metric value are empty and error message in screen', async () => { const result = await setup(); await result.user.click(screen.getByLabelText(/^metric$/i)); await result.user.click(screen.getByText('Metric 1')); await result.user.click(result.submitButton); - expect(onSubmitMock).toHaveBeenCalled(); + expect(onSubmitMock).not.toHaveBeenCalled(); + expect(screen.getByText('At least one value is required. Enter a value')).toBeInTheDocument(); }); it('should submit the form if metric name and all metric value are selected', async () => { @@ -117,4 +118,13 @@ describe(' Evaluation Metrics', () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); + + it('should keep metric value not more than 2 decimal point', async () => { + const result = await setup(); + await result.user.click(screen.getByLabelText(/^metric$/i)); + await result.user.click(screen.getByText('Metric 1')); + + await result.user.type(screen.getByLabelText(/training value/i), '1.111'); + expect(screen.getByLabelText(/training value/i)).toHaveValue(1.11); + }); }); diff --git a/public/components/register_model/evaluation_metrics.tsx b/public/components/register_model/evaluation_metrics.tsx index d80d2d4f..a46c26ab 100644 --- a/public/components/register_model/evaluation_metrics.tsx +++ b/public/components/register_model/evaluation_metrics.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useMemo } from 'react'; +import React, { useCallback, useMemo, useState } from 'react'; import { EuiFormRow, EuiTitle, @@ -15,15 +15,18 @@ import { EuiSpacer, EuiText, } from '@elastic/eui'; -import { useController, useFormContext } from 'react-hook-form'; +import { useController, useFormContext, useWatch } from 'react-hook-form'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { useMetricNames } from './register_model.hooks'; +import { fixTwoDecimalPoint } from '../../utils'; const METRIC_VALUE_STEP = 0.01; +const MAX_METRIC_NAME_LENGTH = 50; export const EvaluationMetricsPanel = () => { - const { control } = useFormContext(); + const { trigger, control } = useFormContext(); + const [isRequiredValueText, setIsRequiredValueText] = useState(false); const [metricNamesLoading, metricNames] = useMetricNames(); // TODO: this has to be hooked with data from BE API @@ -31,35 +34,56 @@ export const EvaluationMetricsPanel = () => { return metricNames.map((n) => ({ label: n })); }, [metricNames]); - const metricFieldController = useController({ - name: 'metricName', + const metricKeyController = useController({ + name: 'metric.key', control, }); + const metric = useWatch({ + control, + name: 'metric', + }); + + const valueValidateFn = () => { + if (metric) { + const { trainingValue, validationValue, testingValue, key } = metric; + if (key && !trainingValue && !validationValue && !testingValue) { + setIsRequiredValueText(true); + return false; + } else { + setIsRequiredValueText(false); + return true; + } + } + return true; + }; const trainingMetricFieldController = useController({ - name: 'trainingMetricValue', + name: 'metric.trainingValue', control, rules: { max: 1, min: 0, + validate: valueValidateFn, }, }); const validationMetricFieldController = useController({ - name: 'validationMetricValue', + name: 'metric.validationValue', control, rules: { max: 1, min: 0, + validate: valueValidateFn, }, }); const testingMetricFieldController = useController({ - name: 'testingMetricValue', + name: 'metric.testingValue', control, rules: { max: 1, min: 0, + validate: valueValidateFn, }, }); @@ -69,13 +93,13 @@ export const EvaluationMetricsPanel = () => { trainingMetricFieldController.field.onChange(''); validationMetricFieldController.field.onChange(''); testingMetricFieldController.field.onChange(''); - metricFieldController.field.onChange(''); + metricKeyController.field.onChange(''); } else { - metricFieldController.field.onChange(data[0].label); + metricKeyController.field.onChange(data[0].label); } }, [ - metricFieldController, + metricKeyController, trainingMetricFieldController, validationMetricFieldController, testingMetricFieldController, @@ -84,9 +108,12 @@ export const EvaluationMetricsPanel = () => { const onCreateMetricName = useCallback( (metricName: string) => { - metricFieldController.field.onChange(metricName); + if (metricName.length > MAX_METRIC_NAME_LENGTH) { + return; + } + metricKeyController.field.onChange(metricName); }, - [metricFieldController] + [metricKeyController.field] ); const metricValueFields = [ @@ -95,8 +122,25 @@ export const EvaluationMetricsPanel = () => { { label: 'Testing value', controller: testingMetricFieldController }, ]; + const onBlur = useCallback( + (e: React.FocusEvent) => { + // The blur could happen when selecting combo box dropdown + // But we don't want to trigger form validation in this case + if ( + (e.relatedTarget?.getAttribute('role') === 'option' && + e.relatedTarget?.tagName === 'BUTTON') || + e.relatedTarget?.getAttribute('role') === 'textbox' + ) { + return; + } + // Validate the form only when the current tag row blurred + trigger('metric'); + }, + [trigger] + ); + return ( -
+

Evaluation Metrics - optional @@ -111,22 +155,23 @@ export const EvaluationMetricsPanel = () => { @@ -141,11 +186,13 @@ export const EvaluationMetricsPanel = () => { + controller.field.onChange(fixTwoDecimalPoint(value.target.value)) + } onBlur={controller.field.onBlur} inputRef={controller.field.ref} /> @@ -153,6 +200,11 @@ export const EvaluationMetricsPanel = () => { ))} + {isRequiredValueText && ( + + At least one value is required. Enter a value + + )}

); }; diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index 1a60d101..648cae3e 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -8,16 +8,20 @@ export interface Tag { value: string; } +interface Metric { + key: string; + trainingValue: string; + validationValue: string; + testingValue: string; +} + interface ModelFormBase { name: string; version: string; description: string; annotations?: string; configuration: string; - metricName?: string; - trainingMetricValue?: string; - validationMetricValue?: string; - testingMetricValue?: string; + metric?: Metric; tags?: Tag[]; } diff --git a/public/utils/index.ts b/public/utils/index.ts index 44ae52e4..62da79ee 100644 --- a/public/utils/index.ts +++ b/public/utils/index.ts @@ -5,3 +5,4 @@ export * from './table'; export * from './version'; +export * from './number'; diff --git a/public/utils/number.ts b/public/utils/number.ts new file mode 100644 index 00000000..0b720e7b --- /dev/null +++ b/public/utils/number.ts @@ -0,0 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const fixTwoDecimalPoint = (value: string) => { + return value.replace(/^(\-)*(\d+)\.(\d\d).*$/, '$1$2.$3'); +}; From 52f873dea077ba945ce1493fa1ce72e1c10b5f04 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Tue, 14 Feb 2023 20:54:56 +0800 Subject: [PATCH 12/75] feat: add register form submission footer (#99) + Form footer will stick at the bottom of the screen + Display error counts in the footer Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../__tests__/register_model_form.test.tsx | 106 +++++++++++++----- .../register_model/register_model.tsx | 63 +++++++++-- test/jest.config.js | 2 +- 3 files changed, 132 insertions(+), 39 deletions(-) diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 6687c78a..2dfbb1ce 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -10,39 +10,41 @@ import { render, screen, waitFor } from '../../../../test/test_utils'; import { RegisterModelForm } from '../register_model'; import { APIProvider } from '../../../apis/api_provider'; import { routerPaths } from '../../../../common/router_paths'; +import { setup } from './setup'; + +const MOCKED_DATA = { + data: [ + { + id: 'C7jN0YQBjgpeQQ_RmiDE', + model_version: '1.0.7', + created_time: 1669967223491, + model_config: { + all_config: + '{"_name_or_path":"nreimers/MiniLM-L6-H384-uncased","architectures":["BertModel"],"attention_probs_dropout_prob":0.1,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":384,"initializer_range":0.02,"intermediate_size":1536,"layer_norm_eps":1e-12,"max_position_embeddings":512,"model_type":"bert","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":0,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":2,"use_cache":true,"vocab_size":30522}', + model_type: 'bert', + embedding_dimension: 384, + framework_type: 'SENTENCE_TRANSFORMERS', + }, + last_loaded_time: 1672895017422, + model_format: 'TORCH_SCRIPT', + last_uploaded_time: 1669967226531, + name: 'all-MiniLM-L6-v2', + model_state: 'LOADED', + total_chunks: 9, + model_content_size_in_bytes: 83408741, + algorithm: 'TEXT_EMBEDDING', + model_content_hash_value: '9376c2ebd7c83f99ec2526323786c348d2382e6d86576f750c89ea544d6bbb14', + current_worker_node_count: 1, + planning_worker_node_count: 1, + }, + ], + pagination: { currentPage: 1, pageSize: 1, totalRecords: 1, totalPages: 1 }, +}; describe(' Form', () => { it('should init form when id param in url route', async () => { const request = jest.spyOn(APIProvider.getAPI('model'), 'search'); - const mockResult = { - data: [ - { - id: 'C7jN0YQBjgpeQQ_RmiDE', - model_version: '1.0.7', - created_time: 1669967223491, - model_config: { - all_config: - '{"_name_or_path":"nreimers/MiniLM-L6-H384-uncased","architectures":["BertModel"],"attention_probs_dropout_prob":0.1,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":384,"initializer_range":0.02,"intermediate_size":1536,"layer_norm_eps":1e-12,"max_position_embeddings":512,"model_type":"bert","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":0,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":2,"use_cache":true,"vocab_size":30522}', - model_type: 'bert', - embedding_dimension: 384, - framework_type: 'SENTENCE_TRANSFORMERS', - }, - last_loaded_time: 1672895017422, - model_format: 'TORCH_SCRIPT', - last_uploaded_time: 1669967226531, - name: 'all-MiniLM-L6-v2', - model_state: 'LOADED', - total_chunks: 9, - model_content_size_in_bytes: 83408741, - algorithm: 'TEXT_EMBEDDING', - model_content_hash_value: - '9376c2ebd7c83f99ec2526323786c348d2382e6d86576f750c89ea544d6bbb14', - current_worker_node_count: 1, - planning_worker_node_count: 1, - }, - ], - pagination: { currentPage: 1, pageSize: 1, totalRecords: 1, totalPages: 1 }, - }; + const mockResult = MOCKED_DATA; request.mockResolvedValue(mockResult); render( @@ -58,4 +60,50 @@ describe(' Form', () => { expect(nameInput.value).toBe(name); }); }); + + it('submit button label should be `Register version` when register new version', async () => { + const request = jest.spyOn(APIProvider.getAPI('model'), 'search'); + const mockResult = MOCKED_DATA; + request.mockResolvedValue(mockResult); + + render( + + + , + { route: '/model-registry/register-model/test_model_id' } + ); + + expect(screen.getByRole('button', { name: /register version/i })).toBeInTheDocument(); + }); + + it('submit button label should be `Register model` when import a model', async () => { + render( + + + , + { route: '/model-registry/register-model?type=import' } + ); + expect(screen.getByRole('button', { name: /register model/i })).toBeInTheDocument(); + }); + + it('submit button label should be `Register model` when register new model', async () => { + render( + + + , + { route: '/model-registry/register-model' } + ); + expect(screen.getByRole('button', { name: /register model/i })).toBeInTheDocument(); + }); + + it('should display number of form errors in form footer', async () => { + const { user, nameInput, descriptionInput } = await setup(); + await user.clear(nameInput); + await user.clear(descriptionInput); + await user.click(screen.getByRole('button', { name: /register model/i })); + expect(screen.queryByText(/2 form errors/i)).toBeInTheDocument(); + + await user.type(nameInput, 'test model name'); + expect(screen.queryByText(/1 form error/i)).toBeInTheDocument(); + }); }); diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 0b84d27e..bd2919af 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -6,7 +6,19 @@ import React, { useCallback, useEffect } from 'react'; import { FieldErrors, useForm, FormProvider } from 'react-hook-form'; import { useParams } from 'react-router-dom'; -import { EuiPageHeader, EuiSpacer, EuiForm, EuiButton, EuiPanel, EuiText } from '@elastic/eui'; +import { + EuiPageHeader, + EuiSpacer, + EuiForm, + EuiButton, + EuiPanel, + EuiText, + EuiBottomBar, + EuiFlexGroup, + EuiFlexItem, +} from '@elastic/eui'; +import useObservable from 'react-use/lib/useObservable'; +import { from } from 'rxjs'; import { ModelDetailsPanel } from './model_details'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; @@ -19,6 +31,7 @@ import { APIProvider } from '../../apis/api_provider'; import { upgradeModelVersion } from '../../utils'; import { useSearchParams } from '../../hooks/use_search_params'; import { isValidModelRegisterFormType } from './utils'; +import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; const DEFAULT_VALUES = { name: '', @@ -28,10 +41,17 @@ const DEFAULT_VALUES = { tags: [{ key: '', value: '' }], }; +const FORM_ID = 'mlModelUploadForm'; + export const RegisterModelForm = () => { const { id: latestVersionId } = useParams<{ id: string | undefined }>(); const typeParams = useSearchParams().get('type'); + const { + services: { chrome }, + } = useOpenSearchDashboards(); + const isLocked = useObservable(chrome?.getIsNavDrawerLocked$() ?? from([false])); + const formType = isValidModelRegisterFormType(typeParams) ? typeParams : 'upload'; const partials = formType === 'import' @@ -80,9 +100,12 @@ export const RegisterModelForm = () => { console.log(errors); }, []); + const errorCount = Object.keys(form.formState.errors).length; + return ( { ))} - - Register model - + + + + + {errorCount > 0 && ( + + + {errorCount} form {errorCount > 1 ? 'errors' : 'error'} + + + )} + + + {latestVersionId ? 'Register version' : 'Register model'} + + + + ); diff --git a/test/jest.config.js b/test/jest.config.js index ee1ca887..065a2f01 100644 --- a/test/jest.config.js +++ b/test/jest.config.js @@ -9,7 +9,7 @@ module.exports = { setupFilesAfterEnv: ['/test/setup.jest.ts'], roots: [''], moduleNameMapper: { - '\\.(css|less|scss)$': '/test/mocks/styleMock.ts', + '\\.(css|less|scss|svg)$': '/test/mocks/styleMock.ts', '^ui/(.*)': '/../../src/legacy/ui/public/$1/', }, testMatch: ['**/*.test.{js,mjs,ts,tsx}'], From 622a31cf26409668eb7f1f4767abaae679ab7b59 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Mon, 20 Feb 2023 15:03:48 +0800 Subject: [PATCH 13/75] feat: show notifications if form submit success or fail (#110) Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../__tests__/register_model_form.test.tsx | 49 +++++++++++++++++++ .../register_model/register_model.tsx | 48 ++++++++++++++++-- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 2dfbb1ce..0eddd38f 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -11,6 +11,18 @@ import { RegisterModelForm } from '../register_model'; import { APIProvider } from '../../../apis/api_provider'; import { routerPaths } from '../../../../common/router_paths'; import { setup } from './setup'; +import * as formHooks from '../register_model.hooks'; +import * as PluginContext from '../../../../../../src/plugins/opensearch_dashboards_react/public'; + +// Cannot spyOn(PluginContext, 'useOpenSearchDashboards') directly as it results in error: +// TypeError: Cannot redefine property: useOpenSearchDashboards +// So we have to mock the entire module first as a workaround +jest.mock('../../../../../../src/plugins/opensearch_dashboards_react/public', () => { + return { + __esModule: true, + ...jest.requireActual('../../../../../../src/plugins/opensearch_dashboards_react/public'), + }; +}); const MOCKED_DATA = { data: [ @@ -42,6 +54,28 @@ const MOCKED_DATA = { }; describe(' Form', () => { + const addDangerMock = jest.fn(); + const addSuccessMock = jest.fn(); + const onSubmitMock = jest.fn(); + + beforeEach(() => { + jest.spyOn(PluginContext, 'useOpenSearchDashboards').mockReturnValue({ + services: { + notifications: { + toasts: { + addDanger: addDangerMock, + addSuccess: addSuccessMock, + }, + }, + }, + }); + jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + it('should init form when id param in url route', async () => { const request = jest.spyOn(APIProvider.getAPI('model'), 'search'); const mockResult = MOCKED_DATA; @@ -106,4 +140,19 @@ describe(' Form', () => { await user.type(nameInput, 'test model name'); expect(screen.queryByText(/1 form error/i)).toBeInTheDocument(); }); + + it('should call addSuccess to display a success toast', async () => { + const { user } = await setup(); + await user.click(screen.getByRole('button', { name: /register model/i })); + expect(addSuccessMock).toHaveBeenCalled(); + }); + + it('should call addDanger to display an error toast', async () => { + jest + .spyOn(formHooks, 'useModelUpload') + .mockReturnValue(jest.fn().mockRejectedValue(new Error('error'))); + const { user } = await setup(); + await user.click(screen.getByRole('button', { name: /register model/i })); + expect(addDangerMock).toHaveBeenCalled(); + }); }); diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index bd2919af..38f9c3c0 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -16,6 +16,7 @@ import { EuiBottomBar, EuiFlexGroup, EuiFlexItem, + EuiTextColor, } from '@elastic/eui'; import useObservable from 'react-use/lib/useObservable'; import { from } from 'rxjs'; @@ -32,6 +33,7 @@ import { upgradeModelVersion } from '../../utils'; import { useSearchParams } from '../../hooks/use_search_params'; import { isValidModelRegisterFormType } from './utils'; import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; +import { mountReactNode } from '../../../../../src/core/public/utils'; const DEFAULT_VALUES = { name: '', @@ -48,7 +50,7 @@ export const RegisterModelForm = () => { const typeParams = useSearchParams().get('type'); const { - services: { chrome }, + services: { chrome, notifications }, } = useOpenSearchDashboards(); const isLocked = useObservable(chrome?.getIsNavDrawerLocked$() ?? from([false])); @@ -70,9 +72,47 @@ export const RegisterModelForm = () => { }); const submitModel = useModelUpload(); - const onSubmit = async (data: ModelFileFormData | ModelUrlFormData) => { - await submitModel(data); - }; + const onSubmit = useCallback( + async (data: ModelFileFormData | ModelUrlFormData) => { + try { + await submitModel(data); + if (latestVersionId) { + notifications?.toasts.addSuccess({ + title: mountReactNode( + + A model artifact for{' '} + {form.getValues('name')} is uploading + + ), + text: 'Once it uploads, a new version will be created.', + }); + } else { + notifications?.toasts.addSuccess({ + title: mountReactNode( + + {form.getValues('name')} was created + + ), + text: + 'The model artifact is uploading. Once it uploads, a new version will be created.', + }); + } + } catch (e) { + if (e instanceof Error) { + notifications?.toasts.addDanger({ + title: 'Model creation failed', + text: e.message, + }); + } else { + notifications?.toasts.addDanger({ + title: 'Model creation failed', + text: 'Unknown error', + }); + } + } + }, + [submitModel, notifications, form, latestVersionId] + ); useEffect(() => { if (!latestVersionId) return; From 0be9346839db2e478495f0ceccde8b88e02575d7 Mon Sep 17 00:00:00 2001 From: xyinshen Date: Tue, 21 Feb 2023 12:41:16 +0800 Subject: [PATCH 14/75] feat: add model_register_button to model_list (#80) * feat: add model_register_button to model_list Signed-off-by: xyinshen * fix: fix some code standards Signed-off-by: xyinshen * fix: fix code with the review comment Signed-off-by: xyinshen * fix: resolve modal scroll bar Signed-off-by: xyinshen * fix: update to only one EuiModalBody Signed-off-by: xyinshen * fix: update EuiCombobox to Euisearchable Signed-off-by: xyinshen * fix: update searchfiled with Euisearchable and update Euitext Signed-off-by: xyinshen * fix: update searchfiled with Euisearchable and update Euitext Signed-off-by: xyinshen * fix: fix conflict and update unit-test Signed-off-by: xyinshen * fix: resolve conflict and update unit test Signed-off-by: xyinshen * fix: restore an incorrect change Signed-off-by: xyinshen * fix: update unit test error Signed-off-by: xyinshen * fix: update to newest UI Signed-off-by: xyinshen * fix: fix test_utils lint error Signed-off-by: xyinshen * fix: add description and update code with review comment Signed-off-by: xyinshen * fix: update code with newest code standards and update register form header descriptions Signed-off-by: xyinshen * fix: update unit_test of register_model_type_modal Signed-off-by: xyinshen * fix: remove test/test_utils.tsx change Signed-off-by: xyinshen * fix: Let the modal dialog load the model list Signed-off-by: xyinshen * fix: update register_modal unit-test Signed-off-by: xyinshen * fix: delete customProps and use values in jsx Signed-off-by: xyinshen * refactor: remove common/registry_modal_option.ts Signed-off-by: xyinshen * fix: remove monitoring/model_status_filter.tsx change Signed-off-by: xyinshen --------- Signed-off-by: xyinshen Signed-off-by: Lin Wang --- public/components/model_list/index.tsx | 17 +- .../model_list/regsister_new_model_button.tsx | 22 ++ .../__tests__/index.test.tsx | 86 +++++++ .../register_model_type_modal/index.tsx | 236 ++++++++++++++++++ 4 files changed, 346 insertions(+), 15 deletions(-) create mode 100644 public/components/model_list/regsister_new_model_button.tsx create mode 100644 public/components/register_model_type_modal/__tests__/index.test.tsx create mode 100644 public/components/register_model_type_modal/index.tsx diff --git a/public/components/model_list/index.tsx b/public/components/model_list/index.tsx index c8d48dbe..37a4a775 100644 --- a/public/components/model_list/index.tsx +++ b/public/components/model_list/index.tsx @@ -2,24 +2,19 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - import React, { useState, useCallback, useMemo, useRef } from 'react'; import { EuiPageHeader, EuiSpacer, EuiPanel } from '@elastic/eui'; - import { CoreStart } from '../../../../../src/core/public'; import { APIProvider } from '../../apis/api_provider'; -import { routerPaths } from '../../../common/router_paths'; import { useFetcher } from '../../hooks/use_fetcher'; import { ModelDrawer } from '../model_drawer'; -import { EuiLinkButton } from '../common'; - import { ModelTable, ModelTableSort } from './model_table'; import { ModelListFilter, ModelListFilterFilterValue } from './model_list_filter'; +import { RegisterNewModelButton } from './regsister_new_model_button'; import { ModelConfirmDeleteModal, ModelConfirmDeleteModalInstance, } from './model_confirm_delete_modal'; - export const ModelList = ({ notifications }: { notifications: CoreStart['notifications'] }) => { const confirmModelDeleteRef = useRef(null); const [params, setParams] = useState<{ @@ -92,17 +87,9 @@ export const ModelList = ({ notifications }: { notifications: CoreStart['notific const handleFilterChange = useCallback((filterValue: ModelListFilterFilterValue) => { setParams((prevValue) => ({ ...prevValue, filterValue, currentPage: 1 })); }, []); - return ( - Models} - rightSideItems={[ - - Register new model - , - ]} - /> + Models} rightSideItems={[]} /> diff --git a/public/components/model_list/regsister_new_model_button.tsx b/public/components/model_list/regsister_new_model_button.tsx new file mode 100644 index 00000000..ad0fa9d2 --- /dev/null +++ b/public/components/model_list/regsister_new_model_button.tsx @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +import React, { useState, useCallback } from 'react'; +import { EuiButton } from '@elastic/eui'; +import { RegisterModelTypeModal } from '../register_model_type_modal'; +export function RegisterNewModelButton() { + const [isModalVisible, setIsModalVisible] = useState(false); + const showModal = useCallback(() => { + setIsModalVisible(true); + }, []); + const closeModal = useCallback(() => { + setIsModalVisible(false); + }, []); + return ( + <> + Register new model + {isModalVisible && } + + ); +} diff --git a/public/components/register_model_type_modal/__tests__/index.test.tsx b/public/components/register_model_type_modal/__tests__/index.test.tsx new file mode 100644 index 00000000..d9e3fdc1 --- /dev/null +++ b/public/components/register_model_type_modal/__tests__/index.test.tsx @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; +import { RegisterModelTypeModal } from '../index'; +import { render, screen } from '../../../../test/test_utils'; + +const mockOffsetMethods = () => { + const originalOffsetHeight = Object.getOwnPropertyDescriptor( + HTMLElement.prototype, + 'offsetHeight' + ); + const originalOffsetWidth = Object.getOwnPropertyDescriptor(HTMLElement.prototype, 'offsetWidth'); + Object.defineProperty(HTMLElement.prototype, 'offsetHeight', { + configurable: true, + value: 600, + }); + Object.defineProperty(HTMLElement.prototype, 'offsetWidth', { + configurable: true, + value: 600, + }); + return () => { + Object.defineProperty( + HTMLElement.prototype, + 'offsetHeight', + originalOffsetHeight as PropertyDescriptor + ); + Object.defineProperty( + HTMLElement.prototype, + 'offsetWidth', + originalOffsetWidth as PropertyDescriptor + ); + }; +}; + +describe('', () => { + it('should render two checkablecard', () => { + render( {}} />); + expect(screen.getByLabelText('Opensearch model repository')).toBeInTheDocument(); + expect(screen.getByLabelText('Add your own model')).toBeInTheDocument(); + }); + + it('should render select with Opensearch model repository', () => { + render( {}} />); + expect(screen.getByLabelText('Opensearch model repository')).toBeInTheDocument(); + expect(screen.getByLabelText('OpenSearch model repository models')).toBeInTheDocument(); + }); + + it('should call onCloseModal after click "cancel"', async () => { + const onClickMock = jest.fn(); + render(); + await userEvent.click(screen.getByTestId('cancelRegister')); + expect(onClickMock).toHaveBeenCalled(); + }); + + it('should call opensearch model repository model list and link to url with selected option after click "Find model" and continue', async () => { + const mockReset = mockOffsetMethods(); + render( {}} />); + await userEvent.click(screen.getByLabelText('Opensearch model repository')); + expect(screen.getByTestId('findModel')).toBeInTheDocument(); + expect(screen.getByTestId('opensearchModelList')).toBeInTheDocument(); + expect(screen.getByText('tapas-tiny')).toBeInTheDocument(); + await userEvent.click(screen.getByText('tapas-tiny')); + await userEvent.click(screen.getByTestId('continueRegister')); + expect(document.URL).toContain( + 'model-registry/register-model/?type=import&name=tapas-tiny&version=tapas-tiny' + ); + mockReset(); + }); + + it('should render no model found when input a invalid text to search model', async () => { + render( {}} />); + await userEvent.click(screen.getByLabelText('Opensearch model repository')); + await userEvent.type(screen.getByLabelText('OpenSearch model repository models'), '1'); + expect(screen.getByText('No model found')).toBeInTheDocument(); + }); + + it('should link href after selecting "add your own model" and continue ', async () => { + render( {}} />); + await userEvent.click(screen.getByTestId('continueRegister')); + expect(document.URL).toEqual('http://localhost/model-registry/register-model/?type=upload'); + }); +}); diff --git a/public/components/register_model_type_modal/index.tsx b/public/components/register_model_type_modal/index.tsx new file mode 100644 index 00000000..f40c5580 --- /dev/null +++ b/public/components/register_model_type_modal/index.tsx @@ -0,0 +1,236 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +import { EuiSpacer } from '@elastic/eui'; +import React, { useState, useCallback, Fragment } from 'react'; +import { useHistory } from 'react-router-dom'; +import { + EuiButton, + EuiModal, + EuiModalBody, + EuiModalFooter, + EuiModalHeader, + EuiModalHeaderTitle, + EuiFlexGroup, + EuiFlexItem, + EuiCheckableCard, + EuiText, + EuiSelectable, + EuiTextColor, + EuiLink, + EuiSelectableOption, + EuiHighlight, +} from '@elastic/eui'; +import { htmlIdGenerator } from '@elastic/eui'; +import { generatePath } from 'react-router-dom'; +import { routerPaths } from '../../../common/router_paths'; +enum ModelSource { + USER_MODEL = 'UserModel', + PRE_TRAINED_MODEL = 'PreTrainedModel', +} +interface Props { + onCloseModal: () => void; +} +interface IItem { + label: string; + checked?: 'on' | undefined; + description: string; +} +const MODEL_LIST = [ + { + name: 'tapas-tiny', + description: + 'TAPAS is a BERT-like transformers model pretrained on a large corpus of English data from Wikipedia in a self-supervised fashion', + checked: undefined, + }, + { + name: 'electra-small-generator', + description: 'ELECTRA is a new method for self-supervised language representation learning', + checked: undefined, + }, + { + name: 'flan-T5-large-grammer-synthesis', + description: + 'A fine-tuned version of google/flan-t5-large for grammer correction on an expanded version of the JFLEG dataset', + checked: undefined, + }, + { + name: 'BEiT', + description: + 'The BEiT model is a version Transformer(ViT),which is a transformer encoder model(BERT-like)', + checked: undefined, + }, +]; +const renderModelOption = (option: IItem, searchValue: string) => { + return ( + <> + {option.label} +
+ + + {option.description} + + + + ); +}; +export function RegisterModelTypeModal({ onCloseModal }: Props) { + const [modelRepoSelection, setModelRepoSelection] = useState>>( + () => + MODEL_LIST.map((item) => ({ + checked: item.checked, + label: item.name, + description: item.description, + })) + ); + const history = useHistory(); + const [modelSource, setModelSource] = useState(ModelSource.PRE_TRAINED_MODEL); + const onChange = useCallback((modelSelection: Array>) => { + setModelRepoSelection(modelSelection); + }, []); + const handleContinue = useCallback( + (selectedOption) => { + selectedOption = onChange(modelRepoSelection); + switch (modelSource) { + case ModelSource.USER_MODEL: + selectedOption = modelRepoSelection.find((option) => option.checked === 'on'); + if (selectedOption?.label) { + history.push( + `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ + selectedOption?.label + }&version=${selectedOption?.label}` + ); + } + break; + case ModelSource.PRE_TRAINED_MODEL: + history.push( + `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=upload` + ); + break; + } + }, + [history, modelSource, modelRepoSelection, onChange] + ); + return ( +
+ onCloseModal()} maxWidth="1000px"> + + +

Register model

+
+
+ +
+ + Model source + + + + + Opensearch model repository + + + + Select from a curated list of pre-trained models for search use cases. + + +
+ } + aria-label="Opensearch model repository" + checked={modelSource === ModelSource.USER_MODEL} + onChange={() => setModelSource(ModelSource.USER_MODEL)} + /> + + + + Add your own model + + + + Upload your own model in Torchscript format, as a local file via URL. + + +
+ } + aria-label="Add your own model" + checked={modelSource === ModelSource.PRE_TRAINED_MODEL} + onChange={() => setModelSource(ModelSource.PRE_TRAINED_MODEL)} + /> + + +
+ + +
+ + Model + + +
+ + For more information on each model, see + + + + OpenSearch model repository documentation + + +
+ + + {(list, search) => ( + + {search} + {list} + + )} + +
+ + + + Cancel + + + Continue + + + +
+ ); +} From 531b6037b6c9a0bd995668cd10a94f1b2509d316 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Wed, 22 Feb 2023 15:38:18 +0800 Subject: [PATCH 15/75] feat: add upload callout (#113) Signed-off-by: Lin Wang --- .../__tests__/upload_callout.test.tsx | 17 ++++++++++++++ public/components/model_list/index.tsx | 4 ++++ .../components/model_list/upload_callout.tsx | 23 +++++++++++++++++++ 3 files changed, 44 insertions(+) create mode 100644 public/components/model_list/__tests__/upload_callout.test.tsx create mode 100644 public/components/model_list/upload_callout.tsx diff --git a/public/components/model_list/__tests__/upload_callout.test.tsx b/public/components/model_list/__tests__/upload_callout.test.tsx new file mode 100644 index 00000000..610eaaa0 --- /dev/null +++ b/public/components/model_list/__tests__/upload_callout.test.tsx @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { UploadCallout } from '../upload_callout'; + +import { render, screen } from '../../../../test/test_utils'; + +describe('', () => { + it('should display consistent call title and content', () => { + render(); + expect(screen.getByText('1 upload in progress')); + expect(screen.getByText('image-classifier is uploading to the model registry.')); + }); +}); diff --git a/public/components/model_list/index.tsx b/public/components/model_list/index.tsx index 37a4a775..f8d0d17f 100644 --- a/public/components/model_list/index.tsx +++ b/public/components/model_list/index.tsx @@ -15,6 +15,8 @@ import { ModelConfirmDeleteModal, ModelConfirmDeleteModalInstance, } from './model_confirm_delete_modal'; +import { UploadCallout } from './upload_callout'; + export const ModelList = ({ notifications }: { notifications: CoreStart['notifications'] }) => { const confirmModelDeleteRef = useRef(null); const [params, setParams] = useState<{ @@ -93,6 +95,8 @@ export const ModelList = ({ notifications }: { notifications: CoreStart['notific + + { + return ( + + {models.join(', ')} is uploading to the model registry. + + ); +}; From a302d489f1be0cb835e33291b2845b2394725d5a Mon Sep 17 00:00:00 2001 From: xyinshen Date: Wed, 22 Feb 2023 15:39:23 +0800 Subject: [PATCH 16/75] fix: update-register-form-hearder-descriptions (#114) * fix: update-register-form-hearder-descriptions Signed-off-by: xyinshen * fix: restore unnecessary remove Signed-off-by: xyinshen * fix: remove wrap with div and update code structure Signed-off-by: xyinshen --------- Signed-off-by: xyinshen Signed-off-by: Lin Wang --- .../register_model/register_model.tsx | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 38f9c3c0..1926e97a 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -17,6 +17,7 @@ import { EuiFlexGroup, EuiFlexItem, EuiTextColor, + EuiLink, } from '@elastic/eui'; import useObservable from 'react-use/lib/useObservable'; import { from } from 'rxjs'; @@ -151,11 +152,34 @@ export const RegisterModelForm = () => { component="form" > - + - Register your model to collaboratively manage its life cycle, and facilitate model - discovery across your organization. + {latestVersionId && ( + <> + Register a new version of Image-classifiar.The version number will be + automatically incremented. For more information on versioning, see{' '} + + Model Registry Documentation + + . + + )} + {formType === 'import' && !latestVersionId && ( + <> + Register a pre-trained model. For more information, see{' '} + + OpenSearch model repository documentation + + . + + )} + {formType === 'upload' && !latestVersionId && ( + <> + Register your model to collaboratively manage its life cycle, and facilitate model + discovery across your organization. + + )} From a493f781fb2d48e79471007a8abc9fba86a27726 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Wed, 22 Feb 2023 20:39:33 +0800 Subject: [PATCH 17/75] feat: upload file after register form submitted (#117) + display notification if file upload failed/succeed + add rxjs Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- package.json | 3 +- .../model_file_uploader_manager.test.ts | 88 +++++++++++++++++++ .../__tests__/register_model.hooks.test.ts | 8 -- .../register_model_artifact.test.tsx | 9 ++ .../__tests__/register_model_form.test.tsx | 2 + public/components/register_model/constants.ts | 6 ++ .../model_file_upload_manager.ts | 86 ++++++++++++++++++ .../register_model/register_model.hooks.ts | 6 +- .../register_model/register_model.tsx | 44 +++++++++- yarn.lock | 12 +++ 10 files changed, 249 insertions(+), 15 deletions(-) create mode 100644 public/components/register_model/__tests__/model_file_uploader_manager.test.ts create mode 100644 public/components/register_model/constants.ts create mode 100644 public/components/register_model/model_file_upload_manager.ts diff --git a/package.json b/package.json index cf7a726c..50ef67f0 100644 --- a/package.json +++ b/package.json @@ -22,7 +22,8 @@ "dependencies": { "hash-wasm": "^4.9.0", "papaparse": "^5.3.2", - "react-hook-form": "^7.39.4" + "react-hook-form": "^7.39.4", + "rxjs": "^6.5.5" }, "devDependencies": { "@testing-library/user-event": "^14.4.3", diff --git a/public/components/register_model/__tests__/model_file_uploader_manager.test.ts b/public/components/register_model/__tests__/model_file_uploader_manager.test.ts new file mode 100644 index 00000000..b0f95186 --- /dev/null +++ b/public/components/register_model/__tests__/model_file_uploader_manager.test.ts @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { waitFor } from '@testing-library/dom'; +import { Model } from '../../../../public/apis/model'; +import { ModelFileUploadManager } from '../model_file_upload_manager'; + +describe('ModelFileUploadManager', () => { + const uploadChunkMock = jest.fn(); + + beforeEach(() => { + jest.spyOn(Model.prototype, 'uploadChunk').mockImplementation(uploadChunkMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should upload file by chunk', async () => { + const uploader = new ModelFileUploadManager(); + const file = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + Object.defineProperty(file, 'size', { value: 30 * 1000 * 1000 }); + + uploader.upload({ + modelId: 'test model id', + file, + chunkSize: 10 * 1000 * 1000, + }); + await waitFor(() => expect(uploadChunkMock).toHaveBeenCalledTimes(3)); + }); + + it('should call onUpdate', async () => { + const onUpdateMock = jest.fn(); + const uploader = new ModelFileUploadManager(); + const file = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + Object.defineProperty(file, 'size', { value: 30 * 1000 * 1000 }); + + uploader.upload({ + modelId: 'test model id', + file, + chunkSize: 10 * 1000 * 1000, + onUpdate: onUpdateMock, + }); + await waitFor(() => { + expect(onUpdateMock).toHaveBeenNthCalledWith(1, { total: 3, current: 1 }); + expect(onUpdateMock).toHaveBeenNthCalledWith(2, { total: 3, current: 2 }); + expect(onUpdateMock).toHaveBeenNthCalledWith(3, { total: 3, current: 3 }); + }); + }); + + it('should call onComplete', async () => { + const functionCallOrder: string[] = []; + const onCompleteMock = jest.fn().mockImplementation(() => functionCallOrder.push('onComplete')); + const onUpdateMock = jest.fn().mockImplementation(() => functionCallOrder.push('onUpdate')); + + const uploader = new ModelFileUploadManager(); + const file = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + Object.defineProperty(file, 'size', { value: 30 * 1000 * 1000 }); + + uploader.upload({ + modelId: 'test model id', + file, + chunkSize: 10 * 1000 * 1000, + onComplete: onCompleteMock, + onUpdate: onUpdateMock, + }); + await waitFor(() => expect(onCompleteMock).toHaveBeenCalled()); + expect(functionCallOrder).toEqual(['onUpdate', 'onUpdate', 'onUpdate', 'onComplete']); + }); + + it('should call onError', async () => { + jest.spyOn(Model.prototype, 'uploadChunk').mockRejectedValue(new Error()); + const onErrorMock = jest.fn(); + const uploader = new ModelFileUploadManager(); + const file = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + Object.defineProperty(file, 'size', { value: 30 * 1000 * 1000 }); + + uploader.upload({ + modelId: 'test model id', + file, + chunkSize: 10 * 1000 * 1000, + onError: onErrorMock, + }); + await waitFor(() => expect(onErrorMock).toHaveBeenCalled()); + }); +}); diff --git a/public/components/register_model/__tests__/register_model.hooks.test.ts b/public/components/register_model/__tests__/register_model.hooks.test.ts index 9493eac5..de611618 100644 --- a/public/components/register_model/__tests__/register_model.hooks.test.ts +++ b/public/components/register_model/__tests__/register_model.hooks.test.ts @@ -176,13 +176,5 @@ describe('useModelUpload', () => { totalChunks: 3, }); }); - - it('should call model uploadChunk for 3 times', async () => { - const { result } = renderHook(() => useModelUpload()); - - await result.current(modelFileFormData); - - expect(jest.spyOn(Model.prototype, 'uploadChunk')).toHaveBeenCalledTimes(3); - }); }); }); diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 752a01b3..2d079257 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -6,9 +6,11 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; +import { ModelFileUploadManager } from '../model_file_upload_manager'; describe(' Artifact', () => { const onSubmitMock = jest.fn(); + const uploadMock = jest.fn(); beforeEach(() => { jest @@ -18,6 +20,7 @@ describe(' Artifact', () => { .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); }); afterEach(() => { @@ -41,8 +44,14 @@ describe(' Artifact', () => { expect(onSubmitMock).not.toHaveBeenCalled(); await result.user.click(result.submitButton); + expect(onSubmitMock).toHaveBeenCalled(); + }); + it('should upload the model file', async () => { + const result = await setup(); + await result.user.click(result.submitButton); expect(onSubmitMock).toHaveBeenCalled(); + expect(uploadMock).toHaveBeenCalled(); }); it('should NOT submit the register model form if model file size exceed 80MB', async () => { diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 0eddd38f..f99c79e5 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -11,6 +11,7 @@ import { RegisterModelForm } from '../register_model'; import { APIProvider } from '../../../apis/api_provider'; import { routerPaths } from '../../../../common/router_paths'; import { setup } from './setup'; +import { Model } from '../../../../public/apis/model'; import * as formHooks from '../register_model.hooks'; import * as PluginContext from '../../../../../../src/plugins/opensearch_dashboards_react/public'; @@ -70,6 +71,7 @@ describe(' Form', () => { }, }); jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + jest.spyOn(Model.prototype, 'uploadChunk').mockResolvedValue({}); }); afterEach(() => { diff --git a/public/components/register_model/constants.ts b/public/components/register_model/constants.ts new file mode 100644 index 00000000..9ec2d9d6 --- /dev/null +++ b/public/components/register_model/constants.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const MAX_CHUNK_SIZE = 10 * 1000 * 1000; diff --git a/public/components/register_model/model_file_upload_manager.ts b/public/components/register_model/model_file_upload_manager.ts new file mode 100644 index 00000000..249c0ac9 --- /dev/null +++ b/public/components/register_model/model_file_upload_manager.ts @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { BehaviorSubject, Observable, range } from 'rxjs'; +import { concatMap } from 'rxjs/operators'; + +import { APIProvider } from '../../apis/api_provider'; + +interface FileUploadStatus { + current: number; + total: number; +} + +interface UploadOptions { + file: File; + chunkSize: number; + modelId: string; + onUpdate?: (status: FileUploadStatus) => void; + onError?: () => void; + onComplete?: () => void; +} + +const MIN_CHUNK_SIZE = 10 * 1000 * 1000; + +export class ModelFileUploadManager { + uploads = new BehaviorSubject>>([]); + + constructor() {} + + upload(options: UploadOptions) { + const chunkSize = options.chunkSize < MIN_CHUNK_SIZE ? MIN_CHUNK_SIZE : options.chunkSize; + const totalChunks = Math.ceil(options.file.size / chunkSize); + + const observable = range(1, totalChunks).pipe( + concatMap(async (i) => { + const chunk = options.file.slice( + chunkSize * (i - 1), + Math.min(chunkSize * i, options.file.size) + ); + await APIProvider.getAPI('model').uploadChunk(options.modelId, `${i - 1}`, chunk); + return { total: totalChunks, current: i }; + }) + ); + + this.uploads.next(this.uploads.getValue().concat(observable)); + + observable.subscribe({ + next: (v) => { + if (options.onUpdate) { + options.onUpdate(v); + } + }, + error: () => { + this.uploads.next(this.uploads.getValue().filter((obs) => obs !== observable)); + + if (options.onError) { + options.onError(); + } + }, + complete: () => { + this.uploads.next(this.uploads.getValue().filter((obs) => obs !== observable)); + + if (options.onComplete) { + options.onComplete(); + } + }, + }); + } + + /** + * Get the running uploads + */ + getUploads$() { + return this.uploads.asObservable(); + } +} + +export const modelFileUploadManager = new ModelFileUploadManager(); + +window.onbeforeunload = () => { + if (modelFileUploadManager.uploads.getValue().length > 0) { + return 'File upload will be terminated if you leave the page, are you sure?'; + } +}; diff --git a/public/components/register_model/register_model.hooks.ts b/public/components/register_model/register_model.hooks.ts index 91c4608e..6781e318 100644 --- a/public/components/register_model/register_model.hooks.ts +++ b/public/components/register_model/register_model.hooks.ts @@ -5,6 +5,7 @@ import { useCallback, useEffect, useRef, useState } from 'react'; import { APIProvider } from '../../apis/api_provider'; +import { MAX_CHUNK_SIZE } from './constants'; import { getModelContentHashValue } from './get_model_content_hash_value'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; @@ -104,7 +105,6 @@ export const useModelUpload = () => { }); } const { modelFile } = model; - const MAX_CHUNK_SIZE = 10 * 1000 * 1000; const totalChunks = Math.ceil(modelFile.size / MAX_CHUNK_SIZE); const modelContentHashValue = await getModelContentHashValue(modelFile); @@ -116,13 +116,13 @@ export const useModelUpload = () => { }) ).model_id; - for (let i = 0; i < totalChunks; i++) { + /* for (let i = 0; i < totalChunks; i++) { const chunk = modelFile.slice( MAX_CHUNK_SIZE * i, Math.min(MAX_CHUNK_SIZE * (i + 1), modelFile.size) ); await APIProvider.getAPI('model').uploadChunk(modelId, `${i}`, chunk); - } + }*/ return modelId; }, []); }; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 1926e97a..5d78efb3 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -5,7 +5,7 @@ import React, { useCallback, useEffect } from 'react'; import { FieldErrors, useForm, FormProvider } from 'react-hook-form'; -import { useParams } from 'react-router-dom'; +import { useHistory, useParams } from 'react-router-dom'; import { EuiPageHeader, EuiSpacer, @@ -35,6 +35,9 @@ import { useSearchParams } from '../../hooks/use_search_params'; import { isValidModelRegisterFormType } from './utils'; import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; import { mountReactNode } from '../../../../../src/core/public/utils'; +import { modelFileUploadManager } from './model_file_upload_manager'; +import { MAX_CHUNK_SIZE } from './constants'; +import { routerPaths } from '../../../common/router_paths'; const DEFAULT_VALUES = { name: '', @@ -47,6 +50,7 @@ const DEFAULT_VALUES = { const FORM_ID = 'mlModelUploadForm'; export const RegisterModelForm = () => { + const history = useHistory(); const { id: latestVersionId } = useParams<{ id: string | undefined }>(); const typeParams = useSearchParams().get('type'); @@ -76,7 +80,41 @@ export const RegisterModelForm = () => { const onSubmit = useCallback( async (data: ModelFileFormData | ModelUrlFormData) => { try { - await submitModel(data); + const modelId = await submitModel(data); + // Navigate to model list if form submit successfully + history.push(routerPaths.modelList); + + // Upload model artifact + if ('modelFile' in data) { + modelFileUploadManager.upload({ + file: data.modelFile, + modelId, + chunkSize: MAX_CHUNK_SIZE, + onComplete: () => { + notifications?.toasts.addSuccess({ + title: mountReactNode( + + Artifact for{' '} + {form.getValues('name')} uploaded + + ), + text: `The artifact for ${form.getValues('name')} uploaded successfully`, + }); + }, + onError: () => { + notifications?.toasts.addDanger({ + title: mountReactNode( + + {form.getValues('name')} artifact + upload failed. + + ), + text: 'The new version was not created.', + }); + }, + }); + } + if (latestVersionId) { notifications?.toasts.addSuccess({ title: mountReactNode( @@ -112,7 +150,7 @@ export const RegisterModelForm = () => { } } }, - [submitModel, notifications, form, latestVersionId] + [submitModel, notifications, form, latestVersionId, history] ); useEffect(() => { diff --git a/yarn.lock b/yarn.lock index fd5e7963..27e1f2ed 100644 --- a/yarn.lock +++ b/yarn.lock @@ -546,6 +546,13 @@ rfdc@^1.3.0: resolved "https://registry.yarnpkg.com/rfdc/-/rfdc-1.3.0.tgz#d0b7c441ab2720d05dc4cf26e01c89631d9da08b" integrity sha512-V2hovdzFbOi77/WajaSMXk2OLm+xNIeQdMMuB7icj7bk6zi2F8GGAxigcnDFpJHbNyNcgyJDiP+8nOrY5cZGrA== +rxjs@^6.5.5: + version "6.6.7" + resolved "https://registry.yarnpkg.com/rxjs/-/rxjs-6.6.7.tgz#90ac018acabf491bf65044235d5863c4dab804c9" + integrity sha512-hTdwr+7yYNIT5n4AMYp85KA6yw2Va0FLa3Rguvbpa4W3I5xynaBZo41cM3XM+4Q6fRMj3sBYIR1VAmZMXYJvRQ== + dependencies: + tslib "^1.9.0" + rxjs@^7.5.1: version "7.5.7" resolved "https://registry.yarnpkg.com/rxjs/-/rxjs-7.5.7.tgz#2ec0d57fdc89ece220d2e702730ae8f1e49def39" @@ -654,6 +661,11 @@ to-regex-range@^5.0.1: dependencies: is-number "^7.0.0" +tslib@^1.9.0: + version "1.14.1" + resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.14.1.tgz#cf2d38bdc34a134bcaf1091c41f6619e2f672d00" + integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg== + tslib@^2.1.0: version "2.4.0" resolved "https://registry.yarnpkg.com/tslib/-/tslib-2.4.0.tgz#7cecaa7f073ce680a05847aa77be941098f36dc3" From e0bca757cf84eeb64d8f5b394ff7c5cf4329d0d4 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Fri, 24 Feb 2023 11:31:23 +0800 Subject: [PATCH 18/75] feat: update artifact file validation rules (#118) + increase model file size limit to 4GB + add model file extension validation to only allow ZIP file --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- common/constant.ts | 5 ++- .../register_model_artifact.test.tsx | 40 +++++++++++++++++-- public/components/register_model/artifact.tsx | 18 +++++++-- .../register_model/artifact_file.tsx | 12 ++++-- 4 files changed, 64 insertions(+), 11 deletions(-) diff --git a/common/constant.ts b/common/constant.ts index 8a66b751..742b1475 100644 --- a/common/constant.ts +++ b/common/constant.ts @@ -3,4 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -export const MAX_MODEL_CHUNK_SIZE = 10 * 1000 * 1000; +export const ONE_MB = 1000 * 1000; +export const ONE_GB = 1000 * ONE_MB; + +export const MAX_MODEL_CHUNK_SIZE = 10 * ONE_MB; diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 2d079257..33859ef9 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -7,6 +7,7 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; import { ModelFileUploadManager } from '../model_file_upload_manager'; +import { ONE_GB } from '../../../../common/constant'; describe(' Artifact', () => { const onSubmitMock = jest.fn(); @@ -54,20 +55,53 @@ describe(' Artifact', () => { expect(uploadMock).toHaveBeenCalled(); }); - it('should NOT submit the register model form if model file size exceed 80MB', async () => { + it('should submit the register model form if model file size is 4GB', async () => { const result = await setup(); // Empty model file selection by clicking the `Remove` button on EuiFilePicker await result.user.click(screen.getByLabelText(/clear selected files/i)); + + const modelFileInput = screen.getByLabelText(/file/i); + // User select a file with maximum accepted size + const validFile = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + Object.defineProperty(validFile, 'size', { value: 4 * ONE_GB }); + await result.user.upload(modelFileInput, validFile); + + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeValid(); await result.user.click(result.submitButton); + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should NOT submit the register model form if model file size exceed 4GB', async () => { + const result = await setup(); + + // Empty model file selection by clicking the `Remove` button on EuiFilePicker + await result.user.click(screen.getByLabelText(/clear selected files/i)); const modelFileInput = screen.getByLabelText(/file/i); - // File size can not exceed 80MB + // File size can not exceed 4GB const invalidFile = new File(['test model file'], 'model.zip', { type: 'application/zip' }); - Object.defineProperty(invalidFile, 'size', { value: 81 * 1000 * 1000 }); + Object.defineProperty(invalidFile, 'size', { value: 4 * ONE_GB + 1 }); + await result.user.upload(modelFileInput, invalidFile); + + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInvalid(); + await result.user.click(result.submitButton); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + + it('should NOT submit the register model form if model file is not ZIP', async () => { + const result = await setup(); + + // Empty model file selection by clicking the `Remove` button on EuiFilePicker + await result.user.click(screen.getByLabelText(/clear selected files/i)); + + const modelFileInput = screen.getByLabelText(/file/i); + // Only ZIP(.zip) file is allowed + const invalidFile = new File(['test model file'], 'model.json', { type: 'application/json' }); await result.user.upload(modelFileInput, invalidFile); expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInvalid(); + await result.user.click(result.submitButton); expect(onSubmitMock).not.toHaveBeenCalled(); }); diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index 737b6eaf..bb7cb4fa 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -6,8 +6,9 @@ import React, { useState } from 'react'; import { EuiTitle, htmlIdGenerator, EuiSpacer, EuiText, EuiRadio, EuiLink } from '@elastic/eui'; -import { ModelFileUploader } from './artifact_file'; +import { MAX_MODEL_FILE_SIZE, ModelFileUploader } from './artifact_file'; import { ArtifactUrl } from './artifact_url'; +import { ONE_GB } from '../../../common/constant'; export const ArtifactPanel = () => { const [selectedSource, setSelectedSource] = useState<'source_from_computer' | 'source_from_url'>( @@ -21,8 +22,8 @@ export const ArtifactPanel = () => { - Provide the model artifact for upload. If uploading from local file, keep your browser - open until the upload is complete.{' '} + The zipped artifact must include a model file and a tokenizer file. If uploading with a + local file, keep this browser open util the upload completes.{' '} Learn more @@ -48,6 +49,17 @@ export const ArtifactPanel = () => { {selectedSource === 'source_from_computer' && } {selectedSource === 'source_from_url' && } + + + Accepted file format: ZIP (.zip). Maximum size, {MAX_MODEL_FILE_SIZE / ONE_GB}GB. + + + The ZIP mush include the following contents: +
    +
  • Model File, accepted formats: Torchscript(.pt), ONNX(.onnx)
  • +
  • Tokenizer file, accepted format: JSON(.json)
  • +
+
); }; diff --git a/public/components/register_model/artifact_file.tsx b/public/components/register_model/artifact_file.tsx index 7dca17de..3d4c9e9d 100644 --- a/public/components/register_model/artifact_file.tsx +++ b/public/components/register_model/artifact_file.tsx @@ -8,14 +8,18 @@ import { EuiFormRow, EuiFilePicker } from '@elastic/eui'; import { useController, useFormContext } from 'react-hook-form'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { ONE_GB } from '../../../common/constant'; -const ONE_MB = 1000 * 1000; -const MAX_FILE_SIZE = 80 * ONE_MB; +// 4GB +export const MAX_MODEL_FILE_SIZE = 4 * ONE_GB; function validateFile(file: File) { - if (file.size > MAX_FILE_SIZE) { + if (file.size > MAX_MODEL_FILE_SIZE) { return 'Maximum file size exceeded. Add a smaller file.'; } + if (!file.name.endsWith('.zip')) { + return 'Invalid file format. Add a ZIP(.zip) file.'; + } return true; } @@ -34,12 +38,12 @@ export const ModelFileUploader = () => { return ( { From 9509ceb31b1fa97c70cd507824dd551f857ff716 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Mon, 27 Feb 2023 14:36:17 +0800 Subject: [PATCH 19/75] feat: display notification when upload model by URL (#126) + polling model upload task status and show notifications if task failed/successful + removed useModelUpload hook and refactor it with rxjs --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../__tests__/model_task_manager.test.ts | 94 +++++++++ .../__tests__/register_model.hooks.test.ts | 180 ------------------ .../register_model_artifact.test.tsx | 43 ++++- .../__tests__/register_model_details.test.tsx | 3 +- .../__tests__/register_model_form.test.tsx | 8 +- .../__tests__/register_model_metrics.test.tsx | 3 +- .../__tests__/register_model_tags.test.tsx | 3 +- .../register_model/model_task_manager.ts | 87 +++++++++ .../register_model/register_model.hooks.ts | 81 +------- .../register_model/register_model.tsx | 64 ++++--- .../register_model/register_model_api.ts | 51 +++++ 11 files changed, 310 insertions(+), 307 deletions(-) create mode 100644 public/components/register_model/__tests__/model_task_manager.test.ts delete mode 100644 public/components/register_model/__tests__/register_model.hooks.test.ts create mode 100644 public/components/register_model/model_task_manager.ts create mode 100644 public/components/register_model/register_model_api.ts diff --git a/public/components/register_model/__tests__/model_task_manager.test.ts b/public/components/register_model/__tests__/model_task_manager.test.ts new file mode 100644 index 00000000..7e9b1b4f --- /dev/null +++ b/public/components/register_model/__tests__/model_task_manager.test.ts @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { waitFor } from '../../../../test/test_utils'; +import { Task } from '../../../../public/apis/task'; + +import { ModelTaskManager } from '../model_task_manager'; + +describe('ModelTaskManager', () => { + const getOneMock = jest.fn(); + + beforeEach(() => { + jest.spyOn(Task.prototype, 'getOne').mockImplementation(getOneMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should call onComplete if model task complete successfully', async () => { + const onCompleteMock = jest.fn(); + const onErrorMock = jest.fn(); + const onUpdateMock = jest.fn(); + const res = { model_id: 'model_id' }; + getOneMock.mockResolvedValue(res); + + const taskManager = new ModelTaskManager(); + taskManager.query({ + taskId: 'task_id', + onComplete: onCompleteMock, + onError: onErrorMock, + onUpdate: onUpdateMock, + }); + + await waitFor(() => { + expect(onCompleteMock).toHaveBeenCalledWith(res); + expect(onUpdateMock).toHaveBeenCalledWith(res); + expect(onErrorMock).not.toHaveBeenCalled(); + }); + }); + + it('should call onError if model task complete with error', async () => { + const onCompleteMock = jest.fn(); + const onErrorMock = jest.fn(); + const onUpdateMock = jest.fn(); + const res = { error: 'error msg' }; + getOneMock.mockResolvedValue(res); + + const taskManager = new ModelTaskManager(); + taskManager.query({ + taskId: 'task_id', + onComplete: onCompleteMock, + onError: onErrorMock, + onUpdate: onUpdateMock, + }); + + await waitFor(() => { + expect(onCompleteMock).not.toHaveBeenCalled(); + expect(onUpdateMock).not.toHaveBeenCalled(); + expect(onErrorMock).toHaveBeenCalled(); + }); + }); + + it('should poll get task API util model is created', async () => { + const onCompleteMock = jest.fn(); + const onErrorMock = jest.fn(); + const onUpdateMock = jest.fn(); + const res = { model_id: 'model_id' }; + // 1st call -> {}, model is not created + // 2nd call -> {}, model is not created + // 3rd call -> { model_id: 'model_id' }, model is created + getOneMock.mockResolvedValue(res).mockResolvedValueOnce({}).mockResolvedValueOnce({}); + + const taskManager = new ModelTaskManager(); + taskManager.query({ + taskId: 'task_id', + onComplete: onCompleteMock, + onError: onErrorMock, + onUpdate: onUpdateMock, + }); + + await waitFor( + () => { + expect(onCompleteMock).toHaveBeenCalledWith(res); + expect(onUpdateMock).toHaveBeenCalledWith(res); + expect(onErrorMock).not.toHaveBeenCalled(); + }, + { timeout: 8000 } + ); + expect(getOneMock).toHaveBeenCalledTimes(3); + }); +}); diff --git a/public/components/register_model/__tests__/register_model.hooks.test.ts b/public/components/register_model/__tests__/register_model.hooks.test.ts deleted file mode 100644 index de611618..00000000 --- a/public/components/register_model/__tests__/register_model.hooks.test.ts +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { renderHook } from '@testing-library/react-hooks'; -import { Task, TaskGetOneResponse } from '../../../../public/apis/task'; -import { Model } from '../../../../public/apis/model'; -import * as getModelContentHashValueExports from '../get_model_content_hash_value'; - -import { useModelUpload } from '../register_model.hooks'; - -const modelBaseData = { - name: 'foo', - version: '1', - description: 'foo bar', - annotations: '', - configuration: `{ - "foo":"bar" - }`, -}; - -const modelUrlFormData = { - ...modelBaseData, - modelURL: 'http://localhost/', -}; - -const modelFileFormData = { - ...modelBaseData, - modelFile: new File(new Array(10000).fill(1), 'test-model'), -}; -// Make file size to 30MB, so we can split 3 chunks to upload. -Object.defineProperty(modelFileFormData.modelFile, 'size', { - get() { - return 3 * 10000000; - }, -}); - -describe('useModelUpload', () => { - describe('upload by url', () => { - beforeEach(() => { - jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ task_id: 'task-id-1' }); - jest - .spyOn(Task.prototype, 'getOne') - .mockResolvedValue({ model_id: 'new-model-1' } as TaskGetOneResponse); - }); - - afterEach(() => { - jest.spyOn(Model.prototype, 'upload').mockClear(); - jest.spyOn(Task.prototype, 'getOne').mockClear(); - }); - - it('should call model upload with consistent params', () => { - const { result } = renderHook(() => useModelUpload()); - const modelUploadMock = jest.spyOn(Model.prototype, 'upload'); - expect(modelUploadMock).not.toHaveBeenCalled(); - - result.current(modelUrlFormData); - - expect(modelUploadMock).toHaveBeenCalledWith({ - name: 'foo', - version: '1', - description: 'foo bar', - modelFormat: 'TORCH_SCRIPT', - modelConfig: { - foo: 'bar', - }, - url: 'http://localhost/', - }); - }); - - it('should call get task cycling and resolved with model id when upload by url', async () => { - jest.useFakeTimers(); - const { result } = renderHook(() => useModelUpload()); - - const taskGetOneMock = jest - .spyOn(Task.prototype, 'getOne') - .mockResolvedValueOnce({} as TaskGetOneResponse); - expect(taskGetOneMock).not.toHaveBeenCalled(); - - const uploadPromise = result.current(modelUrlFormData); - - await jest.spyOn(Model.prototype, 'upload').mock.results[0].value; - expect(taskGetOneMock).toHaveBeenCalledWith('task-id-1'); - - await taskGetOneMock.mock.results[0].value; - taskGetOneMock.mockResolvedValueOnce({ model_id: 'new-model-1' } as TaskGetOneResponse); - jest.advanceTimersByTime(1000); - - expect(taskGetOneMock).toHaveBeenCalledTimes(2); - expect(await uploadPromise).toBe('new-model-1'); - - jest.useRealTimers(); - }); - - it('should NOT call get task if component unmount', () => { - const { result, unmount } = renderHook(() => useModelUpload()); - let uploadAPIResolveFn: Function; - const uploadAPIPromise = new Promise<{ task_id: string }>((resolve, reject) => { - uploadAPIResolveFn = () => { - resolve({ task_id: 'task-id-1' }); - }; - }); - jest.spyOn(Model.prototype, 'upload').mockReturnValue(uploadAPIPromise); - - const uploadPromise = result.current(modelUrlFormData); - unmount(); - uploadAPIResolveFn!(); - - expect(jest.spyOn(Task.prototype, 'getOne')).not.toHaveBeenCalled(); - - expect(uploadPromise).rejects.toMatch('component unmounted'); - }); - - it('should NOT cycling call get task after component unmount', async () => { - jest.useFakeTimers(); - const { result, unmount } = renderHook(() => useModelUpload()); - - const taskGetOneMock = jest - .spyOn(Task.prototype, 'getOne') - .mockResolvedValue({} as TaskGetOneResponse); - expect(taskGetOneMock).not.toHaveBeenCalled(); - - const uploadPromise = result.current(modelUrlFormData); - - await jest.spyOn(Model.prototype, 'upload').mock.results[0].value; - - await taskGetOneMock.mock.results[0].value; - expect(taskGetOneMock).toHaveBeenCalledTimes(1); - unmount(); - - jest.advanceTimersByTime(1000); - - expect(taskGetOneMock).toHaveBeenCalledTimes(1); - expect(uploadPromise).rejects.toMatch('component unmounted'); - - jest.useRealTimers(); - }); - }); - - describe('upload by file', () => { - beforeEach(() => { - jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ model_id: 'model-id-1' }); - jest.spyOn(Model.prototype, 'uploadChunk').mockResolvedValue({}); - jest - .spyOn(getModelContentHashValueExports, 'getModelContentHashValue') - .mockResolvedValue('file-hash'); - }); - - afterEach(() => { - jest.spyOn(Model.prototype, 'upload').mockClear(); - jest.spyOn(Model.prototype, 'uploadChunk').mockClear(); - jest.spyOn(getModelContentHashValueExports, 'getModelContentHashValue').mockClear(); - }); - - it('should call model upload with consistent params', async () => { - const { result } = renderHook(() => useModelUpload()); - - expect(jest.spyOn(Model.prototype, 'upload')).not.toHaveBeenCalled(); - - result.current(modelFileFormData); - - await jest.spyOn(getModelContentHashValueExports, 'getModelContentHashValue').mock.results[0] - .value; - await jest.spyOn(Model.prototype, 'upload').mock.results[0].value; - - expect(jest.spyOn(Model.prototype, 'upload')).toHaveBeenCalledWith({ - name: 'foo', - version: '1', - description: 'foo bar', - modelFormat: 'TORCH_SCRIPT', - modelConfig: { - foo: 'bar', - }, - modelContentHashValue: 'file-hash', - totalChunks: 3, - }); - }); - }); -}); diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 33859ef9..422136d7 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -7,10 +7,12 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; import { ModelFileUploadManager } from '../model_file_upload_manager'; +import * as formAPI from '../register_model_api'; import { ONE_GB } from '../../../../common/constant'; describe(' Artifact', () => { - const onSubmitMock = jest.fn(); + const onSubmitWithFileMock = jest.fn(); + const onSubmitWithURLMock = jest.fn(); const uploadMock = jest.fn(); beforeEach(() => { @@ -20,7 +22,8 @@ describe(' Artifact', () => { jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); - jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitWithFileMock); + jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitWithURLMock); jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); }); @@ -42,19 +45,39 @@ describe(' Artifact', () => { it('should submit the register model form', async () => { const result = await setup(); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); await result.user.click(result.submitButton); - expect(onSubmitMock).toHaveBeenCalled(); + expect(onSubmitWithFileMock).toHaveBeenCalled(); }); it('should upload the model file', async () => { const result = await setup(); await result.user.click(result.submitButton); - expect(onSubmitMock).toHaveBeenCalled(); + expect(onSubmitWithFileMock).toHaveBeenCalled(); expect(uploadMock).toHaveBeenCalled(); }); + it('should upload with model url', async () => { + const result = await setup(); + + // select option: From URL + await result.user.click(screen.getByLabelText(/from url/i)); + + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + + await result.user.clear(urlInput); + await result.user.type( + urlInput, + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/torch_script/sentence-transformers_all-MiniLM-L6-v2-1.0.1-torch_script.zip' + ); + await result.user.click(result.submitButton); + + expect(onSubmitWithURLMock).toHaveBeenCalled(); + }); + it('should submit the register model form if model file size is 4GB', async () => { const result = await setup(); @@ -69,7 +92,7 @@ describe(' Artifact', () => { expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeValid(); await result.user.click(result.submitButton); - expect(onSubmitMock).toHaveBeenCalled(); + expect(onSubmitWithFileMock).toHaveBeenCalled(); }); it('should NOT submit the register model form if model file size exceed 4GB', async () => { @@ -86,7 +109,7 @@ describe(' Artifact', () => { expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInvalid(); await result.user.click(result.submitButton); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); }); it('should NOT submit the register model form if model file is not ZIP', async () => { @@ -102,7 +125,7 @@ describe(' Artifact', () => { expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInvalid(); await result.user.click(result.submitButton); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); }); it('should NOT submit the register model form if model file is empty', async () => { @@ -113,7 +136,7 @@ describe(' Artifact', () => { await result.user.click(result.submitButton); expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInvalid(); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); }); it('should NOT submit the register model form if model url is empty', async () => { @@ -131,6 +154,6 @@ describe(' Artifact', () => { await result.user.click(result.submitButton); expect(urlInput).toBeInvalid(); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(onSubmitWithURLMock).not.toHaveBeenCalled(); }); }); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index 01343dcf..1eba64bc 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -5,6 +5,7 @@ import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; +import * as formAPI from '../register_model_api'; describe(' Details', () => { const onSubmitMock = jest.fn(); @@ -16,7 +17,7 @@ describe(' Details', () => { jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); - jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); }); afterEach(() => { diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index f99c79e5..415ec09d 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -12,8 +12,8 @@ import { APIProvider } from '../../../apis/api_provider'; import { routerPaths } from '../../../../common/router_paths'; import { setup } from './setup'; import { Model } from '../../../../public/apis/model'; -import * as formHooks from '../register_model.hooks'; import * as PluginContext from '../../../../../../src/plugins/opensearch_dashboards_react/public'; +import * as formAPI from '../register_model_api'; // Cannot spyOn(PluginContext, 'useOpenSearchDashboards') directly as it results in error: // TypeError: Cannot redefine property: useOpenSearchDashboards @@ -70,7 +70,7 @@ describe(' Form', () => { }, }, }); - jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); jest.spyOn(Model.prototype, 'uploadChunk').mockResolvedValue({}); }); @@ -150,9 +150,7 @@ describe(' Form', () => { }); it('should call addDanger to display an error toast', async () => { - jest - .spyOn(formHooks, 'useModelUpload') - .mockReturnValue(jest.fn().mockRejectedValue(new Error('error'))); + jest.spyOn(formAPI, 'submitModelWithFile').mockRejectedValue(new Error('error')); const { user } = await setup(); await user.click(screen.getByRole('button', { name: /register model/i })); expect(addDangerMock).toHaveBeenCalled(); diff --git a/public/components/register_model/__tests__/register_model_metrics.test.tsx b/public/components/register_model/__tests__/register_model_metrics.test.tsx index 26759411..0e9c6653 100644 --- a/public/components/register_model/__tests__/register_model_metrics.test.tsx +++ b/public/components/register_model/__tests__/register_model_metrics.test.tsx @@ -6,6 +6,7 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; +import * as formAPI from '../register_model_api'; describe(' Evaluation Metrics', () => { const onSubmitMock = jest.fn(); @@ -17,7 +18,7 @@ describe(' Evaluation Metrics', () => { jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); - jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); }); afterEach(() => { diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index 5d65aa77..e0f14355 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -6,6 +6,7 @@ import { screen, waitFor, within } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; +import * as formAPI from '../register_model_api'; describe(' Tags', () => { const onSubmitMock = jest.fn(); @@ -17,7 +18,7 @@ describe(' Tags', () => { jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); - jest.spyOn(formHooks, 'useModelUpload').mockReturnValue(onSubmitMock); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); }); afterEach(() => { diff --git a/public/components/register_model/model_task_manager.ts b/public/components/register_model/model_task_manager.ts new file mode 100644 index 00000000..625a4201 --- /dev/null +++ b/public/components/register_model/model_task_manager.ts @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Observable, timer, BehaviorSubject } from 'rxjs'; +import { takeWhile, switchMap } from 'rxjs/operators'; +import { APIProvider } from '../../apis/api_provider'; +import { TaskGetOneResponse } from '../../apis/task'; + +interface TaskQueryOptions { + taskId: string; + onUpdate?: (status: TaskGetOneResponse) => void; + onError?: (err: Error) => void; + onComplete?: (status: TaskGetOneResponse) => void; +} + +// Model download task is still running if +// 1. model id doesn't exist +// 2. the current task is running fine without error +function isTaskRunning(res: TaskGetOneResponse) { + return !Boolean(res.model_id) && !Boolean(res.error); +} + +export class ModelTaskManager { + /** + * The model download tasks which are still running in BE + */ + tasks = new BehaviorSubject>>(new Map()); + + constructor() {} + + remove(taskId: string) { + this.tasks.getValue().delete(taskId); + this.tasks.next(this.tasks.getValue()); + } + + add(taskId: string, taskObservable: Observable) { + if (!this.tasks.getValue().has(taskId)) { + this.tasks.next(this.tasks.getValue().set(taskId, taskObservable)); + } + } + + query(options: TaskQueryOptions) { + if (!this.tasks.getValue().has(options.taskId)) { + const observable = timer(0, 2000) + .pipe(switchMap((_) => APIProvider.getAPI('task').getOne(options.taskId))) + // TODO: should it also check res.state? + // The intention here is to stop polling once a model is created + .pipe(takeWhile((res) => !Boolean(res.model_id) && !Boolean(res.error), true)); + + observable.subscribe({ + next: (res) => { + if (options.onUpdate && !res.error) { + options.onUpdate(res); + } + + if (isTaskRunning(res)) { + this.add(options.taskId, observable); + } else { + this.remove(options.taskId); + } + // Model download task is complete if model id exists + if (res.model_id && options.onComplete) { + options.onComplete(res); + } + + if (res.error && options.onError) { + options.onError(new Error(res.error)); + } + }, + error: (err: Error) => { + this.remove(options.taskId); + if (options.onError) { + options.onError(err); + } + }, + }); + } + } + + getTasks$() { + return this.tasks.asObservable(); + } +} + +export const modelTaskManager = new ModelTaskManager(); diff --git a/public/components/register_model/register_model.hooks.ts b/public/components/register_model/register_model.hooks.ts index 6781e318..487b2510 100644 --- a/public/components/register_model/register_model.hooks.ts +++ b/public/components/register_model/register_model.hooks.ts @@ -3,11 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useCallback, useEffect, useRef, useState } from 'react'; -import { APIProvider } from '../../apis/api_provider'; -import { MAX_CHUNK_SIZE } from './constants'; -import { getModelContentHashValue } from './get_model_content_hash_value'; -import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { useEffect, useState } from 'react'; const metricNames = ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']; @@ -51,78 +47,3 @@ export const useModelTags = () => { return [loading, { keys, values }] as const; }; - -export const useModelUpload = () => { - const timeoutIdRef = useRef(-1); - const mountedRef = useRef(true); - - useEffect(() => { - return () => { - mountedRef.current = false; - window.clearTimeout(timeoutIdRef.current); - }; - }, []); - - return useCallback(async (model: ModelFileFormData | ModelUrlFormData) => { - const modelUploadBase = { - name: model.name, - version: model.version, - description: model.description, - // TODO: Need to confirm if we have the model format input - modelFormat: 'TORCH_SCRIPT', - modelConfig: JSON.parse(model.configuration), - }; - if ('modelURL' in model) { - const { task_id: taskId } = await APIProvider.getAPI('model').upload({ - ...modelUploadBase, - url: model.modelURL, - }); - return new Promise((resolve, reject) => { - const refreshTaskStatus = () => { - APIProvider.getAPI('task') - .getOne(taskId) - .then(({ model_id: modelId, error }) => { - if (error) { - reject(error); - return; - } - if (modelId === undefined) { - if (!mountedRef.current) { - reject('component unmounted'); - return; - } - timeoutIdRef.current = window.setTimeout(refreshTaskStatus, 1000); - return; - } - resolve(modelId); - }); - }; - if (!mountedRef.current) { - reject('component unmounted'); - return; - } - refreshTaskStatus(); - }); - } - const { modelFile } = model; - const totalChunks = Math.ceil(modelFile.size / MAX_CHUNK_SIZE); - const modelContentHashValue = await getModelContentHashValue(modelFile); - - const modelId = ( - await APIProvider.getAPI('model').upload({ - ...modelUploadBase, - totalChunks, - modelContentHashValue, - }) - ).model_id; - - /* for (let i = 0; i < totalChunks; i++) { - const chunk = modelFile.slice( - MAX_CHUNK_SIZE * i, - Math.min(MAX_CHUNK_SIZE * (i + 1), modelFile.size) - ); - await APIProvider.getAPI('model').uploadChunk(modelId, `${i}`, chunk); - }*/ - return modelId; - }, []); -}; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 5d78efb3..6702ca2b 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -28,7 +28,7 @@ import { ArtifactPanel } from './artifact'; import { ConfigurationPanel } from './model_configuration'; import { EvaluationMetricsPanel } from './evaluation_metrics'; import { ModelTagsPanel } from './model_tags'; -import { useModelUpload } from './register_model.hooks'; +import { submitModelWithFile, submitModelWithURL } from './register_model_api'; import { APIProvider } from '../../apis/api_provider'; import { upgradeModelVersion } from '../../utils'; import { useSearchParams } from '../../hooks/use_search_params'; @@ -38,6 +38,7 @@ import { mountReactNode } from '../../../../../src/core/public/utils'; import { modelFileUploadManager } from './model_file_upload_manager'; import { MAX_CHUNK_SIZE } from './constants'; import { routerPaths } from '../../../common/router_paths'; +import { modelTaskManager } from './model_task_manager'; const DEFAULT_VALUES = { name: '', @@ -75,46 +76,51 @@ export const RegisterModelForm = () => { mode: 'onChange', defaultValues: DEFAULT_VALUES, }); - const submitModel = useModelUpload(); const onSubmit = useCallback( async (data: ModelFileFormData | ModelUrlFormData) => { try { - const modelId = await submitModel(data); - // Navigate to model list if form submit successfully - history.push(routerPaths.modelList); + const onComplete = () => { + notifications?.toasts.addSuccess({ + title: mountReactNode( + + Artifact for {form.getValues('name')}{' '} + uploaded + + ), + text: `The artifact for ${form.getValues('name')} uploaded successfully`, + }); + }; + + const onError = () => { + notifications?.toasts.addDanger({ + title: mountReactNode( + + {form.getValues('name')} artifact + upload failed. + + ), + text: 'The new version was not created.', + }); + }; - // Upload model artifact if ('modelFile' in data) { + const modelId = await submitModelWithFile(data); modelFileUploadManager.upload({ file: data.modelFile, modelId, chunkSize: MAX_CHUNK_SIZE, - onComplete: () => { - notifications?.toasts.addSuccess({ - title: mountReactNode( - - Artifact for{' '} - {form.getValues('name')} uploaded - - ), - text: `The artifact for ${form.getValues('name')} uploaded successfully`, - }); - }, - onError: () => { - notifications?.toasts.addDanger({ - title: mountReactNode( - - {form.getValues('name')} artifact - upload failed. - - ), - text: 'The new version was not created.', - }); - }, + onComplete, + onError, }); + } else { + const taskId = await submitModelWithURL(data); + modelTaskManager.query({ taskId, onComplete, onError }); } + // Navigate to model list if form submit successfully + history.push(routerPaths.modelList); + if (latestVersionId) { notifications?.toasts.addSuccess({ title: mountReactNode( @@ -150,7 +156,7 @@ export const RegisterModelForm = () => { } } }, - [submitModel, notifications, form, latestVersionId, history] + [notifications, form, latestVersionId, history] ); useEffect(() => { diff --git a/public/components/register_model/register_model_api.ts b/public/components/register_model/register_model_api.ts new file mode 100644 index 00000000..f4333398 --- /dev/null +++ b/public/components/register_model/register_model_api.ts @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { APIProvider } from '../../apis/api_provider'; +import { MAX_CHUNK_SIZE } from './constants'; +import { getModelContentHashValue } from './get_model_content_hash_value'; +import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; + +export async function submitModelWithFile(model: ModelFileFormData) { + const modelUploadBase = { + name: model.name, + version: model.version, + description: model.description, + // TODO: Need to confirm if we have the model format input + modelFormat: 'TORCH_SCRIPT', + modelConfig: JSON.parse(model.configuration), + }; + const { modelFile } = model; + const totalChunks = Math.ceil(modelFile.size / MAX_CHUNK_SIZE); + const modelContentHashValue = await getModelContentHashValue(modelFile); + + const modelId = ( + await APIProvider.getAPI('model').upload({ + ...modelUploadBase, + totalChunks, + modelContentHashValue, + }) + ).model_id; + + return modelId; +} + +export async function submitModelWithURL(model: ModelUrlFormData) { + const modelUploadBase = { + name: model.name, + version: model.version, + description: model.description, + // TODO: Need to confirm if we have the model format input + modelFormat: 'TORCH_SCRIPT', + modelConfig: JSON.parse(model.configuration), + }; + + const { task_id: taskId } = await APIProvider.getAPI('model').upload({ + ...modelUploadBase, + url: model.modelURL, + }); + + return taskId; +} From ea5d5ecbf7d828c7e5bc0a40aaa99c726a4d8856 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Tue, 7 Mar 2023 16:26:09 +0800 Subject: [PATCH 20/75] Feature/add model name unique verification (#129) * feat: validate model name after input blurred Signed-off-by: Lin Wang * refactor: remove model API mock and unified search API mock Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../__tests__/register_model_details.test.tsx | 16 ++++++++++ .../__tests__/register_model_form.test.tsx | 8 ++--- .../register_model/__tests__/setup.tsx | 7 +++- .../register_model/model_details.tsx | 32 +++++++++++++++++-- 4 files changed, 54 insertions(+), 9 deletions(-) diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index 1eba64bc..0c2a0c22 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -6,6 +6,7 @@ import { setup } from './setup'; import * as formHooks from '../register_model.hooks'; import * as formAPI from '../register_model_api'; +import { Model } from '../../../apis/model'; describe(' Details', () => { const onSubmitMock = jest.fn(); @@ -65,6 +66,21 @@ describe(' Details', () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); + it('should NOT submit the register model form if model name is duplicated', async () => { + const result = await setup(); + jest.spyOn(Model.prototype, 'search').mockResolvedValue({ + data: [], + pagination: { totalPages: 1, totalRecords: 1, currentPage: 1, pageSize: 1 }, + }); + + await result.user.clear(result.nameInput); + await result.user.type(result.nameInput, 'a-duplicated-model-name'); + await result.user.click(result.submitButton); + + expect(result.nameInput).toBeInvalid(); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); + it('should NOT submit the register model form if model description is empty', async () => { const result = await setup(); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 415ec09d..e1ea9cbf 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -8,7 +8,6 @@ import { Route } from 'react-router-dom'; import { render, screen, waitFor } from '../../../../test/test_utils'; import { RegisterModelForm } from '../register_model'; -import { APIProvider } from '../../../apis/api_provider'; import { routerPaths } from '../../../../common/router_paths'; import { setup } from './setup'; import { Model } from '../../../../public/apis/model'; @@ -79,9 +78,8 @@ describe(' Form', () => { }); it('should init form when id param in url route', async () => { - const request = jest.spyOn(APIProvider.getAPI('model'), 'search'); const mockResult = MOCKED_DATA; - request.mockResolvedValue(mockResult); + jest.spyOn(Model.prototype, 'search').mockResolvedValue(mockResult); render( @@ -98,9 +96,7 @@ describe(' Form', () => { }); it('submit button label should be `Register version` when register new version', async () => { - const request = jest.spyOn(APIProvider.getAPI('model'), 'search'); - const mockResult = MOCKED_DATA; - request.mockResolvedValue(mockResult); + jest.spyOn(Model.prototype, 'search').mockResolvedValue(MOCKED_DATA); render( diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index 849b0063..16247a3a 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -7,9 +7,9 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; import { RegisterModelForm } from '../register_model'; +import { Model } from '../../../apis/model'; import { render, RenderWithRouteProps, screen } from '../../../../test/test_utils'; -jest.mock('../../../apis/model'); jest.mock('../../../apis/task'); export async function setup(options?: RenderWithRouteProps) { @@ -24,6 +24,11 @@ export async function setup(options?: RenderWithRouteProps) { const form = screen.getByTestId('mlCommonsPlugin-registerModelForm'); const user = userEvent.setup(); + // Mock model name unique + jest.spyOn(Model.prototype, 'search').mockResolvedValue({ + data: [], + pagination: { totalRecords: 0, currentPage: 1, pageSize: 1, totalPages: 0 }, + }); // fill model name await user.type(nameInput, 'test model name'); // fill model description diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 8d962981..3c3f5ffe 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -3,22 +3,38 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React from 'react'; +import React, { useCallback, useRef } from 'react'; import { EuiFieldText, EuiFormRow, EuiTitle, EuiTextArea, EuiText } from '@elastic/eui'; import { useController, useFormContext } from 'react-hook-form'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { APIProvider } from '../../apis/api_provider'; const NAME_MAX_LENGTH = 80; const DESCRIPTION_MAX_LENGTH = 200; const ANNOTATION_MAX_LENGTH = 200; +const isUniqueModelName = async (name: string) => { + const searchResult = await APIProvider.getAPI('model').search({ + name, + pageSize: 1, + currentPage: 1, + }); + return searchResult.pagination.totalRecords >= 1; +}; + export const ModelDetailsPanel = () => { - const { control } = useFormContext(); + const { control, trigger } = useFormContext(); + const modelNameFocusedRef = useRef(false); const nameFieldController = useController({ name: 'name', control, rules: { required: { value: true, message: 'Name can not be empty' }, + validate: async (name) => { + return !modelNameFocusedRef.current && !!name && (await isUniqueModelName(name)) + ? 'This name is already in use. Use a unique name for the model.' + : undefined; + }, maxLength: { value: NAME_MAX_LENGTH, message: 'Text exceed max length' }, }, }); @@ -42,6 +58,16 @@ export const ModelDetailsPanel = () => { const { ref: descriptionInputRef, ...descriptionField } = descriptionFieldController.field; const { ref: annotationsInputRef, ...annotationsField } = annotationsFieldController.field; + const handleModelNameFocus = useCallback(() => { + modelNameFocusedRef.current = true; + }, []); + + const handleModelNameBlur = useCallback(() => { + nameField.onBlur(); + modelNameFocusedRef.current = false; + trigger('name'); + }, [nameField, trigger]); + return (
@@ -63,6 +89,8 @@ export const ModelDetailsPanel = () => { inputRef={nameInputRef} isInvalid={Boolean(nameFieldController.fieldState.error)} {...nameField} + onFocus={handleModelNameFocus} + onBlur={handleModelNameBlur} /> Date: Tue, 7 Mar 2023 16:33:50 +0800 Subject: [PATCH 21/75] Feature/rename annotation and remove model details (#133) * feat: add version notes to register model form Signed-off-by: Lin Wang * feat: remove annotation field Signed-off-by: Lin Wang * feat: remove model details and update form header description Signed-off-by: Lin Wang * feat: remove annotations field in ModelFormBase Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../__tests__/register_model_details.test.tsx | 16 ------ .../__tests__/register_model_form.test.tsx | 3 +- .../register_model_version_notes.test.tsx | 53 +++++++++++++++++++ .../register_model/__tests__/setup.tsx | 4 +- .../register_model/model_details.tsx | 27 ---------- .../register_model/model_version_notes.tsx | 50 +++++++++++++++++ .../register_model/register_model.tsx | 32 +++++------ .../register_model/register_model.types.ts | 2 +- 8 files changed, 121 insertions(+), 66 deletions(-) create mode 100644 public/components/register_model/__tests__/register_model_version_notes.test.tsx create mode 100644 public/components/register_model/model_version_notes.tsx diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index 0c2a0c22..c5478c91 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -29,7 +29,6 @@ describe(' Details', () => { const result = await setup(); expect(result.nameInput).toBeInTheDocument(); expect(result.descriptionInput).toBeInTheDocument(); - expect(result.annotationsInput).toBeInTheDocument(); }); it('should submit the register model form', async () => { @@ -105,19 +104,4 @@ describe(' Details', () => { await result.user.click(result.submitButton); expect(onSubmitMock).not.toHaveBeenCalled(); }); - - it('annotation text length should not exceed 200', async () => { - const result = await setup(); - - await result.user.clear(result.annotationsInput); - await result.user.type(result.annotationsInput, 'x'.repeat(200)); - expect(result.annotationsInput).toBeValid(); - - await result.user.clear(result.annotationsInput); - await result.user.type(result.annotationsInput, 'x'.repeat(201)); - expect(result.annotationsInput).toBeInvalid(); - - await result.user.click(result.submitButton); - expect(onSubmitMock).not.toHaveBeenCalled(); - }); }); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index e1ea9cbf..abad062c 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -90,8 +90,7 @@ describe(' Form', () => { const { name } = mockResult.data[0]; await waitFor(() => { - const nameInput = screen.getByLabelText(/^name$/i); - expect(nameInput.value).toBe(name); + expect(screen.getByText(name)).toBeInTheDocument(); }); }); diff --git a/public/components/register_model/__tests__/register_model_version_notes.test.tsx b/public/components/register_model/__tests__/register_model_version_notes.test.tsx new file mode 100644 index 00000000..2a65bd67 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_version_notes.test.tsx @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { setup } from './setup'; +import * as formHooks from '../register_model.hooks'; +import * as formAPI from '../register_model_api'; + +describe(' Version notes', () => { + const onSubmitMock = jest.fn(); + + beforeEach(() => { + jest + .spyOn(formHooks, 'useMetricNames') + .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); + jest + .spyOn(formHooks, 'useModelTags') + .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should render a version notes panel', async () => { + const result = await setup(); + expect(result.versionNotesInput).toBeInTheDocument(); + }); + + it('should submit the form without fill version notes', async () => { + const result = await setup(); + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should NOT submit the register model form if model version notes length exceed 200', async () => { + const result = await setup(); + + await result.user.clear(result.versionNotesInput); + await result.user.type(result.versionNotesInput, 'x'.repeat(200)); + expect(result.versionNotesInput).toBeValid(); + + await result.user.clear(result.versionNotesInput); + await result.user.type(result.versionNotesInput, 'x'.repeat(201)); + expect(result.versionNotesInput).toBeInvalid(); + + await result.user.click(result.submitButton); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); +}); diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index 16247a3a..ae9e9755 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -16,13 +16,13 @@ export async function setup(options?: RenderWithRouteProps) { render(, { route: options?.route ?? '/' }); const nameInput = screen.getByLabelText(/^name$/i); const descriptionInput = screen.getByLabelText(/description/i); - const annotationsInput = screen.getByLabelText(/annotation/i); const submitButton = screen.getByRole('button', { name: /register model/i, }); const modelFileInput = screen.queryByLabelText(/file/i); const form = screen.getByTestId('mlCommonsPlugin-registerModelForm'); const user = userEvent.setup(); + const versionNotesInput = screen.getByLabelText(/notes/i); // Mock model name unique jest.spyOn(Model.prototype, 'search').mockResolvedValue({ @@ -44,9 +44,9 @@ export async function setup(options?: RenderWithRouteProps) { return { nameInput, descriptionInput, - annotationsInput, submitButton, form, user, + versionNotesInput, }; } diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 3c3f5ffe..74598005 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -11,7 +11,6 @@ import { APIProvider } from '../../apis/api_provider'; const NAME_MAX_LENGTH = 80; const DESCRIPTION_MAX_LENGTH = 200; -const ANNOTATION_MAX_LENGTH = 200; const isUniqueModelName = async (name: string) => { const searchResult = await APIProvider.getAPI('model').search({ @@ -48,15 +47,8 @@ export const ModelDetailsPanel = () => { }, }); - const annotationsFieldController = useController({ - name: 'annotations', - control, - rules: { maxLength: { value: ANNOTATION_MAX_LENGTH, message: 'Text exceed max length' } }, - }); - const { ref: nameInputRef, ...nameField } = nameFieldController.field; const { ref: descriptionInputRef, ...descriptionField } = descriptionFieldController.field; - const { ref: annotationsInputRef, ...annotationsField } = annotationsFieldController.field; const handleModelNameFocus = useCallback(() => { modelNameFocusedRef.current = true; @@ -108,25 +100,6 @@ export const ModelDetailsPanel = () => { {...descriptionField} /> - - Annotation - Optional - - } - > - -
); }; diff --git a/public/components/register_model/model_version_notes.tsx b/public/components/register_model/model_version_notes.tsx new file mode 100644 index 00000000..839b6340 --- /dev/null +++ b/public/components/register_model/model_version_notes.tsx @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiTitle, EuiSpacer, EuiFormRow, EuiTextArea } from '@elastic/eui'; +import { useFormContext, useController } from 'react-hook-form'; + +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; + +const VERSION_NOTES_MAX_LENGTH = 200; + +export const ModelVersionNotesPanel = () => { + const { control } = useFormContext(); + + const fieldController = useController({ + name: 'versionNotes', + control, + rules: { maxLength: { value: VERSION_NOTES_MAX_LENGTH, message: 'Text exceed max length' } }, + }); + const { ref, ...versionNotesField } = fieldController.field; + + return ( +
+ +

+ Version notes - optional +

+
+ + + + +
+ ); +}; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 6702ca2b..d8fd548a 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useEffect } from 'react'; +import React, { useCallback, useEffect, useState } from 'react'; import { FieldErrors, useForm, FormProvider } from 'react-hook-form'; import { useHistory, useParams } from 'react-router-dom'; import { @@ -39,6 +39,7 @@ import { modelFileUploadManager } from './model_file_upload_manager'; import { MAX_CHUNK_SIZE } from './constants'; import { routerPaths } from '../../../common/router_paths'; import { modelTaskManager } from './model_task_manager'; +import { ModelVersionNotesPanel } from './model_version_notes'; const DEFAULT_VALUES = { name: '', @@ -53,6 +54,7 @@ const FORM_ID = 'mlModelUploadForm'; export const RegisterModelForm = () => { const history = useHistory(); const { id: latestVersionId } = useParams<{ id: string | undefined }>(); + const [modelGroupName, setModelGroupName] = useState(); const typeParams = useSearchParams().get('type'); const { @@ -63,13 +65,14 @@ export const RegisterModelForm = () => { const formType = isValidModelRegisterFormType(typeParams) ? typeParams : 'upload'; const partials = formType === 'import' - ? [ModelDetailsPanel, ModelTagsPanel] + ? [ModelDetailsPanel, ModelTagsPanel, ModelVersionNotesPanel] : [ - ModelDetailsPanel, + ...(latestVersionId ? [] : [ModelDetailsPanel]), ArtifactPanel, ConfigurationPanel, EvaluationMetricsPanel, ModelTagsPanel, + ModelVersionNotesPanel, ]; const form = useForm({ @@ -174,6 +177,7 @@ export const RegisterModelForm = () => { form.setValue('name', name); form.setValue('version', newVersion); form.setValue('configuration', modelConfig?.all_config ?? ''); + setModelGroupName(name); } }; initializeForm(); @@ -197,31 +201,23 @@ export const RegisterModelForm = () => { > - + {latestVersionId && ( <> - Register a new version of Image-classifiar.The version number will be - automatically incremented. For more information on versioning, see{' '} + Register a new version of {modelGroupName}. The version number will be + automatically incremented.  - Model Registry Documentation - - . - - )} - {formType === 'import' && !latestVersionId && ( - <> - Register a pre-trained model. For more information, see{' '} - - OpenSearch model repository documentation + Learn More . )} + {formType === 'import' && !latestVersionId && <>Register a pre-trained model.} {formType === 'upload' && !latestVersionId && ( <> - Register your model to collaboratively manage its life cycle, and facilitate model - discovery across your organization. + Register your model to manage its life cycle, and facilitate model discovery + across your organization. )} diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index 648cae3e..95855303 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -19,10 +19,10 @@ interface ModelFormBase { name: string; version: string; description: string; - annotations?: string; configuration: string; metric?: Metric; tags?: Tag[]; + versionNotes?: string; } /** From 3083a8b6674f5992079bbf1c5fbb91f6a570e868 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Wed, 8 Mar 2023 14:00:37 +0800 Subject: [PATCH 22/75] Feature/fetch pre trained model list (#131) * feat: remove unused DIV wrapper Signed-off-by: Lin Wang * feat: add model repository related API Signed-off-by: Lin Wang * feat: add model repository manager Signed-off-by: Lin Wang * feat: fetch model repository data in register_model_type_modal and register_model form Signed-off-by: Lin Wang * refactor: change switchMap to map Signed-off-by: Lin Wang * feat: update to main branch pre-trained models URL Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- public/apis/__mocks__/model_repository.ts | 49 ++++ public/apis/api_provider.ts | 9 + public/apis/model_repository.ts | 38 +++ .../register_model/register_model.tsx | 21 +- .../__tests__/index.test.tsx | 19 +- .../register_model_type_modal/index.tsx | 275 ++++++++---------- public/utils/model_repository_manager.ts | 62 ++++ .../tests/model_repository_manager.test.ts | 71 +++++ server/plugin.ts | 2 + server/routes/constants.ts | 3 + server/routes/index.ts | 1 + server/routes/model_repository_router.ts | 49 ++++ 12 files changed, 446 insertions(+), 153 deletions(-) create mode 100644 public/apis/__mocks__/model_repository.ts create mode 100644 public/apis/model_repository.ts create mode 100644 public/utils/model_repository_manager.ts create mode 100644 public/utils/tests/model_repository_manager.test.ts create mode 100644 server/routes/model_repository_router.ts diff --git a/public/apis/__mocks__/model_repository.ts b/public/apis/__mocks__/model_repository.ts new file mode 100644 index 00000000..798a0622 --- /dev/null +++ b/public/apis/__mocks__/model_repository.ts @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export class ModelRepository { + public getPreTrainedModels() { + return Promise.resolve({ + 'sentence-transformers/all-distilroberta-v1': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/sentence-transformers_all-distilroberta-v1-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/onnx/sentence-transformers_all-distilroberta-v1-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/onnx/config.json', + }, + }, + }); + } + + public getPreTrainedModelConfig() { + return Promise.resolve({ + name: 'sentence-transformers/msmarco-distilbert-base-tas-b', + version: '1.0.1', + description: + 'This is a port of the DistilBert TAS-B Model to sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and is optimized for the task of semantic search.', + model_task_type: 'TEXT_EMBEDDING', + model_format: 'TORCH_SCRIPT', + model_content_size_in_bytes: 266352827, + model_content_hash_value: 'acdc81b652b83121f914c5912ae27c0fca8fabf270e6f191ace6979a19830413', + model_config: { + model_type: 'distilbert', + embedding_dimension: 768, + framework_type: 'sentence_transformers', + all_config: + '{"_name_or_path":"old_models/msmarco-distilbert-base-tas-b/0_Transformer","activation":"gelu","architectures":["DistilBertModel"],"attention_dropout":0.1,"dim":768,"dropout":0.1,"hidden_dim":3072,"initializer_range":0.02,"max_position_embeddings":512,"model_type":"distilbert","n_heads":12,"n_layers":6,"pad_token_id":0,"qa_dropout":0.1,"seq_classif_dropout":0.2,"sinusoidal_pos_embds":false,"tie_weights_":true,"transformers_version":"4.7.0","vocab_size":30522}', + }, + created_time: 1676073973126, + }); + } +} diff --git a/public/apis/api_provider.ts b/public/apis/api_provider.ts index dfc9bbe6..bcb8b109 100644 --- a/public/apis/api_provider.ts +++ b/public/apis/api_provider.ts @@ -6,6 +6,7 @@ import { Connector } from './connector'; import { Model } from './model'; import { ModelAggregate } from './model_aggregate'; +import { ModelRepository } from './model_repository'; import { Profile } from './profile'; import { Security } from './security'; import { Task } from './task'; @@ -17,6 +18,7 @@ const apiInstanceStore: { connector: Connector | undefined; security: Security | undefined; task: Task | undefined; + modelRepository: ModelRepository | undefined; } = { model: undefined, modelAggregate: undefined, @@ -24,6 +26,7 @@ const apiInstanceStore: { connector: undefined, security: undefined, task: undefined, + modelRepository: undefined, }; export class APIProvider { @@ -33,6 +36,7 @@ export class APIProvider { public static getAPI(type: 'profile'): Profile; public static getAPI(type: 'connector'): Connector; public static getAPI(type: 'security'): Security; + public static getAPI(type: 'modelRepository'): ModelRepository; public static getAPI(type: keyof typeof apiInstanceStore) { if (apiInstanceStore[type]) { return apiInstanceStore[type]!; @@ -68,6 +72,11 @@ export class APIProvider { apiInstanceStore.task = newInstance; return newInstance; } + case 'modelRepository': { + const newInstance = new ModelRepository(); + apiInstanceStore.modelRepository = newInstance; + return newInstance; + } } } public static clear() { diff --git a/public/apis/model_repository.ts b/public/apis/model_repository.ts new file mode 100644 index 00000000..e1dbd123 --- /dev/null +++ b/public/apis/model_repository.ts @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + MODEL_REPOSITORY_API_ENDPOINT, + MODEL_REPOSITORY_CONFIG_URL_API_ENDPOINT, +} from '../../server/routes/constants'; +import { InnerHttpProvider } from './inner_http_provider'; + +interface PreTrainedModelInfo { + model_url: string; + config_url: string; +} + +interface PreTrainedModel { + version: string; + description: string; + torch_script: PreTrainedModelInfo; + onnx: PreTrainedModelInfo; +} + +interface PreTrainedModels { + [key: string]: PreTrainedModel; +} + +export class ModelRepository { + public getPreTrainedModels() { + return InnerHttpProvider.getHttp().get(MODEL_REPOSITORY_API_ENDPOINT); + } + + public getPreTrainedModelConfig(configURL: string) { + return InnerHttpProvider.getHttp().get( + `${MODEL_REPOSITORY_CONFIG_URL_API_ENDPOINT}/${encodeURIComponent(configURL)}` + ); + } +} diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index d8fd548a..0dd03f5b 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -40,6 +40,7 @@ import { MAX_CHUNK_SIZE } from './constants'; import { routerPaths } from '../../../common/router_paths'; import { modelTaskManager } from './model_task_manager'; import { ModelVersionNotesPanel } from './model_version_notes'; +import { modelRepositoryManager } from '../../utils/model_repository_manager'; const DEFAULT_VALUES = { name: '', @@ -55,7 +56,9 @@ export const RegisterModelForm = () => { const history = useHistory(); const { id: latestVersionId } = useParams<{ id: string | undefined }>(); const [modelGroupName, setModelGroupName] = useState(); - const typeParams = useSearchParams().get('type'); + const searchParams = useSearchParams(); + const typeParams = searchParams.get('type'); + const nameParams = searchParams.get('name'); const { services: { chrome, notifications }, @@ -183,6 +186,22 @@ export const RegisterModelForm = () => { initializeForm(); }, [latestVersionId, form]); + useEffect(() => { + if (!nameParams) { + return; + } + const subscriber = modelRepositoryManager + .getPreTrainedModel$(nameParams, 'torch_script') + .subscribe((preTrainedModel) => { + // TODO: store pre-trained model data + // eslint-disable-next-line no-console + console.log(preTrainedModel); + }); + return () => { + subscriber.unsubscribe(); + }; + }, [nameParams]); + const onError = useCallback((errors: FieldErrors) => { // TODO // eslint-disable-next-line no-console diff --git a/public/components/register_model_type_modal/__tests__/index.test.tsx b/public/components/register_model_type_modal/__tests__/index.test.tsx index d9e3fdc1..42cd72da 100644 --- a/public/components/register_model_type_modal/__tests__/index.test.tsx +++ b/public/components/register_model_type_modal/__tests__/index.test.tsx @@ -6,7 +6,9 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; import { RegisterModelTypeModal } from '../index'; -import { render, screen } from '../../../../test/test_utils'; +import { render, screen, waitFor } from '../../../../test/test_utils'; + +jest.mock('../../../apis/model_repository'); const mockOffsetMethods = () => { const originalOffsetHeight = Object.getOwnPropertyDescriptor( @@ -62,20 +64,27 @@ describe('', () => { await userEvent.click(screen.getByLabelText('Opensearch model repository')); expect(screen.getByTestId('findModel')).toBeInTheDocument(); expect(screen.getByTestId('opensearchModelList')).toBeInTheDocument(); - expect(screen.getByText('tapas-tiny')).toBeInTheDocument(); - await userEvent.click(screen.getByText('tapas-tiny')); + await waitFor(() => + expect(screen.getByText('sentence-transformers/all-distilroberta-v1')).toBeInTheDocument() + ); + await userEvent.click(screen.getByText('sentence-transformers/all-distilroberta-v1')); await userEvent.click(screen.getByTestId('continueRegister')); expect(document.URL).toContain( - 'model-registry/register-model/?type=import&name=tapas-tiny&version=tapas-tiny' + 'model-registry/register-model/?type=import&name=sentence-transformers/all-distilroberta-v1&version=sentence-transformers/all-distilroberta-v1' ); mockReset(); }); it('should render no model found when input a invalid text to search model', async () => { + const mockReset = mockOffsetMethods(); render( {}} />); await userEvent.click(screen.getByLabelText('Opensearch model repository')); - await userEvent.type(screen.getByLabelText('OpenSearch model repository models'), '1'); + await waitFor(() => + expect(screen.getByText('sentence-transformers/all-distilroberta-v1')).toBeInTheDocument() + ); + await userEvent.type(screen.getByTestId('findModel'), 'foo'); expect(screen.getByText('No model found')).toBeInTheDocument(); + mockReset(); }); it('should link href after selecting "add your own model" and continue ', async () => { diff --git a/public/components/register_model_type_modal/index.tsx b/public/components/register_model_type_modal/index.tsx index f40c5580..1bfd2d8f 100644 --- a/public/components/register_model_type_modal/index.tsx +++ b/public/components/register_model_type_modal/index.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ import { EuiSpacer } from '@elastic/eui'; -import React, { useState, useCallback, Fragment } from 'react'; +import React, { useState, useCallback, Fragment, useEffect } from 'react'; import { useHistory } from 'react-router-dom'; import { EuiButton, @@ -25,6 +25,8 @@ import { import { htmlIdGenerator } from '@elastic/eui'; import { generatePath } from 'react-router-dom'; import { routerPaths } from '../../../common/router_paths'; +import { modelRepositoryManager } from '../../utils/model_repository_manager'; + enum ModelSource { USER_MODEL = 'UserModel', PRE_TRAINED_MODEL = 'PreTrainedModel', @@ -37,31 +39,6 @@ interface IItem { checked?: 'on' | undefined; description: string; } -const MODEL_LIST = [ - { - name: 'tapas-tiny', - description: - 'TAPAS is a BERT-like transformers model pretrained on a large corpus of English data from Wikipedia in a self-supervised fashion', - checked: undefined, - }, - { - name: 'electra-small-generator', - description: 'ELECTRA is a new method for self-supervised language representation learning', - checked: undefined, - }, - { - name: 'flan-T5-large-grammer-synthesis', - description: - 'A fine-tuned version of google/flan-t5-large for grammer correction on an expanded version of the JFLEG dataset', - checked: undefined, - }, - { - name: 'BEiT', - description: - 'The BEiT model is a version Transformer(ViT),which is a transformer encoder model(BERT-like)', - checked: undefined, - }, -]; const renderModelOption = (option: IItem, searchValue: string) => { return ( <> @@ -77,12 +54,7 @@ const renderModelOption = (option: IItem, searchValue: string) => { }; export function RegisterModelTypeModal({ onCloseModal }: Props) { const [modelRepoSelection, setModelRepoSelection] = useState>>( - () => - MODEL_LIST.map((item) => ({ - checked: item.checked, - label: item.name, - description: item.description, - })) + [] ); const history = useHistory(); const [modelSource, setModelSource] = useState(ModelSource.PRE_TRAINED_MODEL); @@ -112,125 +84,134 @@ export function RegisterModelTypeModal({ onCloseModal }: Props) { }, [history, modelSource, modelRepoSelection, onChange] ); + + useEffect(() => { + const subscribe = modelRepositoryManager.getPreTrainedModels$().subscribe((models) => { + setModelRepoSelection( + Object.keys(models).map((name) => ({ + label: name, + description: models[name].description, + checked: undefined, + })) + ); + }); + return () => { + subscribe.unsubscribe(); + }; + }, []); return ( -
- onCloseModal()} maxWidth="1000px"> - - -

Register model

-
-
- -
- - Model source - - - - - Opensearch model repository - - - - Select from a curated list of pre-trained models for search use cases. - - -
- } - aria-label="Opensearch model repository" - checked={modelSource === ModelSource.USER_MODEL} - onChange={() => setModelSource(ModelSource.USER_MODEL)} - /> - - - - Add your own model - - - - Upload your own model in Torchscript format, as a local file via URL. - - -
- } - aria-label="Add your own model" - checked={modelSource === ModelSource.PRE_TRAINED_MODEL} - onChange={() => setModelSource(ModelSource.PRE_TRAINED_MODEL)} - /> - - -
- - -
+ onCloseModal()} maxWidth="1000px"> + + +

Register model

+
+
+ +
+ + Model source + + + + + Opensearch model repository + + + + Select from a curated list of pre-trained models for search use cases. + + +
+ } + aria-label="Opensearch model repository" + checked={modelSource === ModelSource.USER_MODEL} + onChange={() => setModelSource(ModelSource.USER_MODEL)} + /> + + + + Add your own model + + + + Upload your own model in Torchscript format, as a local file via URL. + + +
+ } + aria-label="Add your own model" + checked={modelSource === ModelSource.PRE_TRAINED_MODEL} + onChange={() => setModelSource(ModelSource.PRE_TRAINED_MODEL)} + /> + + +
+ + +
+ + Model + + +
+ + For more information on each model, see + - Model + + OpenSearch model repository documentation + - -
- - For more information on each model, see - - - - OpenSearch model repository documentation - - -
- - - {(list, search) => ( - - {search} - {list} - - )} -
- - - - Cancel - - + - Continue - - - -
+ {(list, search) => ( + + {search} + {list} + + )} + +
+ + + + Cancel + + + Continue + + + ); } diff --git a/public/utils/model_repository_manager.ts b/public/utils/model_repository_manager.ts new file mode 100644 index 00000000..14feba80 --- /dev/null +++ b/public/utils/model_repository_manager.ts @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Observable, of, from } from 'rxjs'; +import { map, switchMap } from 'rxjs/operators'; +import { APIProvider } from '../apis/api_provider'; + +interface PreTrainedModelInfo { + model_url: string; + config_url: string; +} + +interface PreTrainedModel { + version: string; + description: string; + torch_script: PreTrainedModelInfo; + onnx: PreTrainedModelInfo; +} + +interface PreTrainedModels { + [key: string]: PreTrainedModel; +} + +export class ModelRepositoryManager { + private preTrainedModels: Observable | null = null; + private preTrainedModelConfigs: Map> = new Map(); + + constructor() {} + + getPreTrainedModels$() { + if (!this.preTrainedModels) { + this.preTrainedModels = from(APIProvider.getAPI('modelRepository').getPreTrainedModels()); + } + return this.preTrainedModels; + } + + getPreTrainedModel$(name: string, format: 'torch_script' | 'onnx') { + return this.getPreTrainedModels$().pipe( + switchMap((models) => { + const model = models[name]; + const modelInfo = model[format]; + let modelConfig$ = this.preTrainedModelConfigs.get(modelInfo.config_url); + if (!modelConfig$) { + modelConfig$ = from( + APIProvider.getAPI('modelRepository').getPreTrainedModelConfig(modelInfo.config_url) + ); + this.preTrainedModelConfigs.set(modelInfo.config_url, modelConfig$); + } + return modelConfig$.pipe( + map((config) => ({ + url: modelInfo.model_url, + config, + })) + ); + }) + ); + } +} + +export const modelRepositoryManager = new ModelRepositoryManager(); diff --git a/public/utils/tests/model_repository_manager.test.ts b/public/utils/tests/model_repository_manager.test.ts new file mode 100644 index 00000000..828e0f67 --- /dev/null +++ b/public/utils/tests/model_repository_manager.test.ts @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ModelRepository } from '../../apis/model_repository'; +import { ModelRepositoryManager } from '../model_repository_manager'; + +jest.mock('../../apis/model_repository'); + +describe('ModelRepositoryManager', () => { + beforeEach(() => { + jest.spyOn(ModelRepository.prototype, 'getPreTrainedModelConfig'); + jest.spyOn(ModelRepository.prototype, 'getPreTrainedModels'); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should return consistent pre-trained models', async () => { + const models = await new ModelRepositoryManager().getPreTrainedModels$().toPromise(); + expect(models).toEqual( + expect.objectContaining({ + 'sentence-transformers/all-distilroberta-v1': expect.anything(), + }) + ); + }); + + it('should call getPreTrainedModels once after call getPreTrainedModels$ multi times', async () => { + const manager = new ModelRepositoryManager(); + expect(ModelRepository.prototype.getPreTrainedModels).not.toHaveBeenCalled(); + await manager.getPreTrainedModels$().toPromise(); + expect(ModelRepository.prototype.getPreTrainedModels).toHaveBeenCalledTimes(1); + await manager.getPreTrainedModels$().toPromise(); + expect(ModelRepository.prototype.getPreTrainedModels).toHaveBeenCalledTimes(1); + }); + + it('should call getPreTrainedModelConfig with consistent config URL and return consistent config', async () => { + const manager = new ModelRepositoryManager(); + const result = await manager + .getPreTrainedModel$('sentence-transformers/all-distilroberta-v1', 'torch_script') + .toPromise(); + expect(ModelRepository.prototype.getPreTrainedModelConfig).toHaveBeenCalledWith( + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/config.json' + ); + expect(result).toEqual( + expect.objectContaining({ + url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/sentence-transformers_all-distilroberta-v1-1.0.1-torch_script.zip', + config: expect.objectContaining({ + model_content_hash_value: + 'acdc81b652b83121f914c5912ae27c0fca8fabf270e6f191ace6979a19830413', + }), + }) + ); + }); + + it('should call getPreTrainedModelConfig once after call getPreTrainedModel$ multi times', async () => { + const manager = new ModelRepositoryManager(); + expect(ModelRepository.prototype.getPreTrainedModelConfig).not.toHaveBeenCalled(); + await manager + .getPreTrainedModel$('sentence-transformers/all-distilroberta-v1', 'torch_script') + .toPromise(); + expect(ModelRepository.prototype.getPreTrainedModelConfig).toHaveBeenCalledTimes(1); + await manager + .getPreTrainedModel$('sentence-transformers/all-distilroberta-v1', 'torch_script') + .toPromise(); + expect(ModelRepository.prototype.getPreTrainedModelConfig).toHaveBeenCalledTimes(1); + }); +}); diff --git a/server/plugin.ts b/server/plugin.ts index 352eabde..7c8d4b56 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -20,6 +20,7 @@ import { profileRouter, securityRouter, taskRouter, + modelRepositoryRouter, } from './routes'; import { ModelService } from './services'; @@ -48,6 +49,7 @@ export class MlCommonsPlugin implements Plugin fetch(url).then((response: any) => response.json()); + +export const modelRepositoryRouter = (router: IRouter) => { + router.get({ path: MODEL_REPOSITORY_API_ENDPOINT, validate: false }, async () => { + try { + const data = await fetchURLAsJSONData(PRE_TRAINED_MODELS_URL); + return opensearchDashboardsResponseFactory.ok({ body: data }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ body: error.message }); + } + }); + + router.get( + { + path: `${MODEL_REPOSITORY_CONFIG_URL_API_ENDPOINT}/{configURL}`, + validate: { + params: schema.object({ + configURL: schema.string(), + }), + }, + }, + async (_context, request) => { + try { + const data = await fetchURLAsJSONData(request.params.configURL); + return opensearchDashboardsResponseFactory.ok({ body: data }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ body: error.message }); + } + } + ); +}; From d22c759a8f28529b8bb64b03cb74a97fdf1f1925 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Thu, 9 Mar 2023 17:59:59 +0800 Subject: [PATCH 23/75] fix: revert legacy pagination methods There will be a separate task to refactor the rest pagination code to the new one and then we can safely get rid of the legacy pagination methods Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- server/services/utils/pagination.ts | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 server/services/utils/pagination.ts diff --git a/server/services/utils/pagination.ts b/server/services/utils/pagination.ts new file mode 100644 index 00000000..c2634206 --- /dev/null +++ b/server/services/utils/pagination.ts @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export interface Pagination { + currentPage: number; + pageSize: number; + totalRecords: number; + totalPages: number; +} + +export type RequestPagination = Pick; + +export const getQueryFromSize = (pagination: RequestPagination) => ({ + from: Math.max(0, pagination.currentPage - 1) * pagination.pageSize, + size: pagination.pageSize, +}); + +export const getPagination = (currentPage: number, pageSize: number, totalRecords: number) => ({ + currentPage, + pageSize, + totalRecords, + totalPages: Math.ceil(totalRecords / pageSize), +}); From ed8161061d99b8f88fdc691b8817c8bc0fa02020 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Fri, 10 Mar 2023 11:55:50 +0800 Subject: [PATCH 24/75] feat: init model group page (#134) + Navigate to model group page after model creation succeed + Added a new route `/model-registry/model/:id` for model group fixed a couple of issues due to conflicts of merging 2.x to feature/model-registry: 1. pagination refactor 2. removed unused convertModelSource --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- common/router.ts | 15 +++++ common/router_paths.ts | 1 + public/components/model_group/index.ts | 6 ++ public/components/model_group/model_group.tsx | 47 ++++++++++++++ public/components/nav_panel.tsx | 2 +- .../__tests__/model_task_manager.test.ts | 4 +- .../__tests__/register_model_details.test.tsx | 2 +- .../__tests__/register_model_form.test.tsx | 62 ++++++++++--------- .../__tests__/register_model_tags.test.tsx | 6 +- .../register_model/model_details.tsx | 6 +- .../model_file_upload_manager.ts | 4 +- .../register_model/model_task_manager.ts | 4 +- .../register_model/register_model.tsx | 22 ++++--- public/utils/model_repository_manager.ts | 2 +- server/services/model_service.ts | 2 +- 15 files changed, 130 insertions(+), 55 deletions(-) create mode 100644 public/components/model_group/index.ts create mode 100644 public/components/model_group/model_group.tsx diff --git a/common/router.ts b/common/router.ts index c16ef73a..9f5dd420 100644 --- a/common/router.ts +++ b/common/router.ts @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { ModelGroup } from '../public/components/model_group'; import { ModelList } from '../public/components/model_list'; import { Monitoring } from '../public/components/monitoring'; import { RegisterModelForm } from '../public/components/register_model/register_model'; @@ -13,6 +14,10 @@ interface RouteConfig { Component: React.ComponentType; label: string; exact?: boolean; + /** + * true: display route in nav bar + */ + nav: boolean; } export const ROUTES: RouteConfig[] = [ @@ -20,16 +25,26 @@ export const ROUTES: RouteConfig[] = [ path: routerPaths.overview, Component: Monitoring, label: 'Overview', + nav: true, }, { path: routerPaths.registerModel, label: 'Register Model', Component: RegisterModelForm, + nav: true, }, { path: routerPaths.modelList, label: 'Model List', Component: ModelList, + nav: true, + }, + { + path: routerPaths.modelGroup, + // TODO: refactor label to be dynamic so that we can display group name in breadcrumb + label: 'Model Group', + Component: ModelGroup, + nav: false, }, ]; diff --git a/common/router_paths.ts b/common/router_paths.ts index 3282976f..105cd17f 100644 --- a/common/router_paths.ts +++ b/common/router_paths.ts @@ -9,4 +9,5 @@ export const routerPaths = { monitoring: '/monitoring', registerModel: '/model-registry/register-model/:id?', modelList: '/model-registry/model-list', + modelGroup: '/model-registry/model/:id', }; diff --git a/public/components/model_group/index.ts b/public/components/model_group/index.ts new file mode 100644 index 00000000..8bd34dad --- /dev/null +++ b/public/components/model_group/index.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from './model_group'; diff --git a/public/components/model_group/model_group.tsx b/public/components/model_group/model_group.tsx new file mode 100644 index 00000000..80f30f61 --- /dev/null +++ b/public/components/model_group/model_group.tsx @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EuiButton, EuiLoadingSpinner, EuiPageHeader, EuiPanel, EuiText } from '@elastic/eui'; +import React from 'react'; +import { useParams } from 'react-router-dom'; +import { useFetcher } from '../../hooks'; +import { APIProvider } from '../../apis/api_provider'; + +export const ModelGroup = () => { + const { id: modelId } = useParams<{ id: string }>(); + const { data, loading, error } = useFetcher(APIProvider.getAPI('model').getOne, modelId); + + if (loading) { + // TODO: need to update per design + return ; + } + + if (error) { + // TODO: need to update per design + return 'Error happened while loading the model'; + } + + return ( + <> + +

{data?.name}

+ + } + rightSideItems={[ + Register version, + Edit, + Delete, + ]} + /> + + +

Versions

+
+
+ + ); +}; diff --git a/public/components/nav_panel.tsx b/public/components/nav_panel.tsx index 7a779d8c..db799cb6 100644 --- a/public/components/nav_panel.tsx +++ b/public/components/nav_panel.tsx @@ -13,7 +13,7 @@ export function NavPanel() { const location = useLocation(); const items = useMemo( () => - ROUTES.filter((item) => !!item.label).map((item) => { + ROUTES.filter((item) => !!item.label && item.nav).map((item) => { const href = generatePath(item.path); return { id: href, diff --git a/public/components/register_model/__tests__/model_task_manager.test.ts b/public/components/register_model/__tests__/model_task_manager.test.ts index 7e9b1b4f..e93658f7 100644 --- a/public/components/register_model/__tests__/model_task_manager.test.ts +++ b/public/components/register_model/__tests__/model_task_manager.test.ts @@ -35,7 +35,7 @@ describe('ModelTaskManager', () => { }); await waitFor(() => { - expect(onCompleteMock).toHaveBeenCalledWith(res); + expect(onCompleteMock).toHaveBeenCalledWith('model_id'); expect(onUpdateMock).toHaveBeenCalledWith(res); expect(onErrorMock).not.toHaveBeenCalled(); }); @@ -83,7 +83,7 @@ describe('ModelTaskManager', () => { await waitFor( () => { - expect(onCompleteMock).toHaveBeenCalledWith(res); + expect(onCompleteMock).toHaveBeenCalledWith('model_id'); expect(onUpdateMock).toHaveBeenCalledWith(res); expect(onErrorMock).not.toHaveBeenCalled(); }, diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index c5478c91..a8462b12 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -69,7 +69,7 @@ describe(' Details', () => { const result = await setup(); jest.spyOn(Model.prototype, 'search').mockResolvedValue({ data: [], - pagination: { totalPages: 1, totalRecords: 1, currentPage: 1, pageSize: 1 }, + total_models: 1, }); await result.user.clear(result.nameInput); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index abad062c..7b26301f 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -25,38 +25,34 @@ jest.mock('../../../../../../src/plugins/opensearch_dashboards_react/public', () }); const MOCKED_DATA = { - data: [ - { - id: 'C7jN0YQBjgpeQQ_RmiDE', - model_version: '1.0.7', - created_time: 1669967223491, - model_config: { - all_config: - '{"_name_or_path":"nreimers/MiniLM-L6-H384-uncased","architectures":["BertModel"],"attention_probs_dropout_prob":0.1,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":384,"initializer_range":0.02,"intermediate_size":1536,"layer_norm_eps":1e-12,"max_position_embeddings":512,"model_type":"bert","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":0,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":2,"use_cache":true,"vocab_size":30522}', - model_type: 'bert', - embedding_dimension: 384, - framework_type: 'SENTENCE_TRANSFORMERS', - }, - last_loaded_time: 1672895017422, - model_format: 'TORCH_SCRIPT', - last_uploaded_time: 1669967226531, - name: 'all-MiniLM-L6-v2', - model_state: 'LOADED', - total_chunks: 9, - model_content_size_in_bytes: 83408741, - algorithm: 'TEXT_EMBEDDING', - model_content_hash_value: '9376c2ebd7c83f99ec2526323786c348d2382e6d86576f750c89ea544d6bbb14', - current_worker_node_count: 1, - planning_worker_node_count: 1, - }, - ], - pagination: { currentPage: 1, pageSize: 1, totalRecords: 1, totalPages: 1 }, + id: 'C7jN0YQBjgpeQQ_RmiDE', + model_version: '1.0.7', + created_time: 1669967223491, + model_config: { + all_config: + '{"_name_or_path":"nreimers/MiniLM-L6-H384-uncased","architectures":["BertModel"],"attention_probs_dropout_prob":0.1,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":384,"initializer_range":0.02,"intermediate_size":1536,"layer_norm_eps":1e-12,"max_position_embeddings":512,"model_type":"bert","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":0,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":2,"use_cache":true,"vocab_size":30522}', + model_type: 'bert', + embedding_dimension: 384, + framework_type: 'SENTENCE_TRANSFORMERS', + }, + last_loaded_time: 1672895017422, + model_format: 'TORCH_SCRIPT', + last_uploaded_time: 1669967226531, + name: 'all-MiniLM-L6-v2', + model_state: 'LOADED', + total_chunks: 9, + model_content_size_in_bytes: 83408741, + algorithm: 'TEXT_EMBEDDING', + model_content_hash_value: '9376c2ebd7c83f99ec2526323786c348d2382e6d86576f750c89ea544d6bbb14', + current_worker_node_count: 1, + planning_worker_node_count: 1, }; describe(' Form', () => { + const MOCKED_MODEL_ID = 'model_id'; const addDangerMock = jest.fn(); const addSuccessMock = jest.fn(); - const onSubmitMock = jest.fn(); + const onSubmitMock = jest.fn().mockResolvedValue(MOCKED_MODEL_ID); beforeEach(() => { jest.spyOn(PluginContext, 'useOpenSearchDashboards').mockReturnValue({ @@ -79,7 +75,7 @@ describe(' Form', () => { it('should init form when id param in url route', async () => { const mockResult = MOCKED_DATA; - jest.spyOn(Model.prototype, 'search').mockResolvedValue(mockResult); + jest.spyOn(Model.prototype, 'getOne').mockResolvedValue(mockResult); render( @@ -87,7 +83,7 @@ describe(' Form', () => { { route: '/model-registry/register-model/test_model_id' } ); - const { name } = mockResult.data[0]; + const { name } = mockResult; await waitFor(() => { expect(screen.getByText(name)).toBeInTheDocument(); @@ -95,7 +91,7 @@ describe(' Form', () => { }); it('submit button label should be `Register version` when register new version', async () => { - jest.spyOn(Model.prototype, 'search').mockResolvedValue(MOCKED_DATA); + jest.spyOn(Model.prototype, 'getOne').mockResolvedValue(MOCKED_DATA); render( @@ -144,6 +140,12 @@ describe(' Form', () => { expect(addSuccessMock).toHaveBeenCalled(); }); + it('should navigate to model group page when submit succeed', async () => { + const { user } = await setup(); + await user.click(screen.getByRole('button', { name: /register model/i })); + expect(location.href).toContain(`model-registry/model/${MOCKED_MODEL_ID}`); + }); + it('should call addDanger to display an error toast', async () => { jest.spyOn(formAPI, 'submitModelWithFile').mockRejectedValue(new Error('error')); const { user } = await setup(); diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index e0f14355..8f14cb16 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -171,9 +171,11 @@ describe(' Tags', () => { } // 25 tags are displayed - waitFor(() => expect(screen.queryAllByTestId(/ml-tagKey/i)).toHaveLength(25)); + await waitFor(() => expect(screen.queryAllByTestId(/ml-tagKey/i)).toHaveLength(25)); // add new tag button should not be displayed - waitFor(() => expect(screen.getByRole('button', { name: /add new tag/i })).toBeDisabled()); + await waitFor(() => + expect(screen.getByRole('button', { name: /add new tag/i })).toBeDisabled() + ); }, // The test will fail due to timeout as we interact with the page a lot(24 button click to add new tags) // So we try to increase test running timeout to 60000ms to mitigate the timeout issue diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 74598005..5ff3513b 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -15,10 +15,10 @@ const DESCRIPTION_MAX_LENGTH = 200; const isUniqueModelName = async (name: string) => { const searchResult = await APIProvider.getAPI('model').search({ name, - pageSize: 1, - currentPage: 1, + from: 0, + size: 1, }); - return searchResult.pagination.totalRecords >= 1; + return searchResult.total_models >= 1; }; export const ModelDetailsPanel = () => { diff --git a/public/components/register_model/model_file_upload_manager.ts b/public/components/register_model/model_file_upload_manager.ts index 249c0ac9..26bdc15c 100644 --- a/public/components/register_model/model_file_upload_manager.ts +++ b/public/components/register_model/model_file_upload_manager.ts @@ -19,7 +19,7 @@ interface UploadOptions { modelId: string; onUpdate?: (status: FileUploadStatus) => void; onError?: () => void; - onComplete?: () => void; + onComplete?: (modelId: string) => void; } const MIN_CHUNK_SIZE = 10 * 1000 * 1000; @@ -63,7 +63,7 @@ export class ModelFileUploadManager { this.uploads.next(this.uploads.getValue().filter((obs) => obs !== observable)); if (options.onComplete) { - options.onComplete(); + options.onComplete(options.modelId); } }, }); diff --git a/public/components/register_model/model_task_manager.ts b/public/components/register_model/model_task_manager.ts index 625a4201..96149034 100644 --- a/public/components/register_model/model_task_manager.ts +++ b/public/components/register_model/model_task_manager.ts @@ -12,7 +12,7 @@ interface TaskQueryOptions { taskId: string; onUpdate?: (status: TaskGetOneResponse) => void; onError?: (err: Error) => void; - onComplete?: (status: TaskGetOneResponse) => void; + onComplete?: (modelId: string) => void; } // Model download task is still running if @@ -62,7 +62,7 @@ export class ModelTaskManager { } // Model download task is complete if model id exists if (res.model_id && options.onComplete) { - options.onComplete(res); + options.onComplete(res.model_id); } if (res.error && options.onError) { diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 0dd03f5b..1ce778ef 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -5,7 +5,7 @@ import React, { useCallback, useEffect, useState } from 'react'; import { FieldErrors, useForm, FormProvider } from 'react-hook-form'; -import { useHistory, useParams } from 'react-router-dom'; +import { generatePath, useHistory, useParams } from 'react-router-dom'; import { EuiPageHeader, EuiSpacer, @@ -86,7 +86,10 @@ export const RegisterModelForm = () => { const onSubmit = useCallback( async (data: ModelFileFormData | ModelUrlFormData) => { try { - const onComplete = () => { + const onComplete = (modelId: string) => { + // Navigate to model group page + history.push(generatePath(routerPaths.modelGroup, { id: modelId })); + notifications?.toasts.addSuccess({ title: mountReactNode( @@ -141,7 +144,8 @@ export const RegisterModelForm = () => { notifications?.toasts.addSuccess({ title: mountReactNode( - {form.getValues('name')} was created + {form.getValues('name')} model creation + complete. ), text: @@ -168,19 +172,17 @@ export const RegisterModelForm = () => { useEffect(() => { if (!latestVersionId) return; const initializeForm = async () => { - const { data } = await APIProvider.getAPI('model').search({ - ids: [latestVersionId], - from: 0, - size: 1, - }); - if (data?.[0]) { + try { + const data = await APIProvider.getAPI('model').getOne(latestVersionId); // TODO: clarify which fields to pre-populate - const { model_version: modelVersion, name, model_config: modelConfig } = data?.[0]; + const { model_version: modelVersion, name, model_config: modelConfig } = data; const newVersion = upgradeModelVersion(modelVersion); form.setValue('name', name); form.setValue('version', newVersion); form.setValue('configuration', modelConfig?.all_config ?? ''); setModelGroupName(name); + } catch (e) { + // TODO: handle error here } }; initializeForm(); diff --git a/public/utils/model_repository_manager.ts b/public/utils/model_repository_manager.ts index 14feba80..f53b4f39 100644 --- a/public/utils/model_repository_manager.ts +++ b/public/utils/model_repository_manager.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { Observable, of, from } from 'rxjs'; +import { Observable, from } from 'rxjs'; import { map, switchMap } from 'rxjs/operators'; import { APIProvider } from '../apis/api_provider'; diff --git a/server/services/model_service.ts b/server/services/model_service.ts index f757c3fd..efd80d09 100644 --- a/server/services/model_service.ts +++ b/server/services/model_service.ts @@ -129,7 +129,7 @@ export class ModelService { }); return { id: modelId, - ...convertModelSource(modelSource), + ...modelSource, }; } From 22ebc9c9c08c5cef026cb39510b54ff23dd36781 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 10 Mar 2023 15:12:03 +0800 Subject: [PATCH 25/75] Featuer/fill pre trained model data to register form (#135) * feat: clear observable object after error catched Signed-off-by: Lin Wang * feat: update pre-trained model config data Signed-off-by: Lin Wang * feat: fill pre-trained name description configuration in model import form Signed-off-by: Lin Wang * feat: add loading screen for importing pre-trained model form Signed-off-by: Lin Wang * feat: add todo for failed to load pre-trained model Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- public/apis/__mocks__/model_repository.ts | 14 +-- .../register_model_artifact.test.tsx | 4 +- .../__tests__/register_model_form.test.tsx | 36 ++++++-- .../register_model/__tests__/setup.tsx | 18 +++- .../register_model/register_model.tsx | 90 +++++++++++++------ public/utils/model_repository_manager.ts | 36 +++++--- .../tests/model_repository_manager.test.ts | 30 ++++++- 7 files changed, 170 insertions(+), 58 deletions(-) diff --git a/public/apis/__mocks__/model_repository.ts b/public/apis/__mocks__/model_repository.ts index 798a0622..23d70856 100644 --- a/public/apis/__mocks__/model_repository.ts +++ b/public/apis/__mocks__/model_repository.ts @@ -28,22 +28,22 @@ export class ModelRepository { public getPreTrainedModelConfig() { return Promise.resolve({ - name: 'sentence-transformers/msmarco-distilbert-base-tas-b', + name: 'sentence-transformers/all-distilroberta-v1', version: '1.0.1', description: - 'This is a port of the DistilBert TAS-B Model to sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and is optimized for the task of semantic search.', + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', model_task_type: 'TEXT_EMBEDDING', model_format: 'TORCH_SCRIPT', - model_content_size_in_bytes: 266352827, - model_content_hash_value: 'acdc81b652b83121f914c5912ae27c0fca8fabf270e6f191ace6979a19830413', + model_content_size_in_bytes: 330811571, + model_content_hash_value: '92bc10216c720b57a6bab1d7ca2cc2e559156997212a7f0d8bb70f2edfedc78b', model_config: { - model_type: 'distilbert', + model_type: 'roberta', embedding_dimension: 768, framework_type: 'sentence_transformers', all_config: - '{"_name_or_path":"old_models/msmarco-distilbert-base-tas-b/0_Transformer","activation":"gelu","architectures":["DistilBertModel"],"attention_dropout":0.1,"dim":768,"dropout":0.1,"hidden_dim":3072,"initializer_range":0.02,"max_position_embeddings":512,"model_type":"distilbert","n_heads":12,"n_layers":6,"pad_token_id":0,"qa_dropout":0.1,"seq_classif_dropout":0.2,"sinusoidal_pos_embds":false,"tie_weights_":true,"transformers_version":"4.7.0","vocab_size":30522}', + '{"_name_or_path":"distilroberta-base","architectures":["RobertaForMaskedLM"],"attention_probs_dropout_prob":0.1,"bos_token_id":0,"eos_token_id":2,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":768,"initializer_range":0.02,"intermediate_size":3072,"layer_norm_eps":0.00001,"max_position_embeddings":514,"model_type":"roberta","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":1,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":1,"use_cache":true,"vocab_size":50265}', }, - created_time: 1676073973126, + created_time: 1676072210947, }); } } diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 422136d7..763985a5 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -10,6 +10,8 @@ import { ModelFileUploadManager } from '../model_file_upload_manager'; import * as formAPI from '../register_model_api'; import { ONE_GB } from '../../../../common/constant'; +jest.mock('../../../apis/model_repository'); + describe(' Artifact', () => { const onSubmitWithFileMock = jest.fn(); const onSubmitWithURLMock = jest.fn(); @@ -39,7 +41,7 @@ describe(' Artifact', () => { }); it('should not render an artifact panel if importing an opensearch defined model', async () => { - await setup({ route: '/?type=import' }); + await setup({ route: '/?type=import&name=sentence-transformers/all-distilroberta-v1' }); expect(screen.queryByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeNull(); }); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 7b26301f..72c2eef3 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -23,6 +23,7 @@ jest.mock('../../../../../../src/plugins/opensearch_dashboards_react/public', () ...jest.requireActual('../../../../../../src/plugins/opensearch_dashboards_react/public'), }; }); +jest.mock('../../../apis/model_repository'); const MOCKED_DATA = { id: 'C7jN0YQBjgpeQQ_RmiDE', @@ -104,15 +105,38 @@ describe(' Form', () => { }); it('submit button label should be `Register model` when import a model', async () => { - render( - - - , - { route: '/model-registry/register-model?type=import' } - ); + await setup({ + route: '/?type=import&name=sentence-transformers/all-distilroberta-v1', + ignoreFillFields: ['name', 'description'], + }); expect(screen.getByRole('button', { name: /register model/i })).toBeInTheDocument(); }); + it('should call submitModelWithURL with pre-filled model data after register model button clicked', async () => { + jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitMock); + const { user } = await setup({ + route: '/?type=import&name=sentence-transformers/all-distilroberta-v1', + ignoreFillFields: ['name', 'description'], + }); + await waitFor(() => + expect(screen.getByLabelText(/^name$/i).value).toEqual( + 'sentence-transformers/all-distilroberta-v1' + ) + ); + expect(onSubmitMock).not.toHaveBeenCalled(); + await user.click(screen.getByRole('button', { name: /register model/i })); + expect(onSubmitMock).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'sentence-transformers/all-distilroberta-v1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + modelURL: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/sentence-transformers_all-distilroberta-v1-1.0.1-torch_script.zip', + configuration: expect.stringContaining('sentence_transformers'), + }) + ); + }); + it('submit button label should be `Register model` when register new model', async () => { render( diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index ae9e9755..6d114102 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -8,12 +8,17 @@ import userEvent from '@testing-library/user-event'; import { RegisterModelForm } from '../register_model'; import { Model } from '../../../apis/model'; -import { render, RenderWithRouteProps, screen } from '../../../../test/test_utils'; +import { render, RenderWithRouteProps, screen, waitFor } from '../../../../test/test_utils'; jest.mock('../../../apis/task'); -export async function setup(options?: RenderWithRouteProps) { +interface SetupOptions extends RenderWithRouteProps { + ignoreFillFields?: Array<'name' | 'description'>; +} + +export async function setup(options?: SetupOptions) { render(, { route: options?.route ?? '/' }); + await waitFor(() => expect(screen.queryByLabelText('Model Form Loading')).toBe(null)); const nameInput = screen.getByLabelText(/^name$/i); const descriptionInput = screen.getByLabelText(/description/i); const submitButton = screen.getByRole('button', { @@ -29,10 +34,15 @@ export async function setup(options?: RenderWithRouteProps) { data: [], pagination: { totalRecords: 0, currentPage: 1, pageSize: 1, totalPages: 0 }, }); + // fill model name - await user.type(nameInput, 'test model name'); + if (!options?.ignoreFillFields?.includes('name')) { + await user.type(nameInput, 'test model name'); + } // fill model description - await user.type(descriptionInput, 'test model description'); + if (!options?.ignoreFillFields?.includes('description')) { + await user.type(descriptionInput, 'test model description'); + } // fill model file if (modelFileInput) { await user.upload( diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 1ce778ef..877d6909 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -18,6 +18,7 @@ import { EuiFlexItem, EuiTextColor, EuiLink, + EuiLoadingSpinner, } from '@elastic/eui'; import useObservable from 'react-use/lib/useObservable'; import { from } from 'rxjs'; @@ -66,6 +67,7 @@ export const RegisterModelForm = () => { const isLocked = useObservable(chrome?.getIsNavDrawerLocked$() ?? from([false])); const formType = isValidModelRegisterFormType(typeParams) ? typeParams : 'upload'; + const [preTrainedModelLoading, setPreTrainedModelLoading] = useState(formType === 'import'); const partials = formType === 'import' ? [ModelDetailsPanel, ModelTagsPanel, ModelVersionNotesPanel] @@ -194,15 +196,32 @@ export const RegisterModelForm = () => { } const subscriber = modelRepositoryManager .getPreTrainedModel$(nameParams, 'torch_script') - .subscribe((preTrainedModel) => { - // TODO: store pre-trained model data - // eslint-disable-next-line no-console - console.log(preTrainedModel); - }); + .subscribe( + (preTrainedModel) => { + // TODO: Fill model format here + const { config, url } = preTrainedModel; + form.setValue('modelURL', url); + if (config.name) { + form.setValue('name', config.name); + } + if (config.description) { + form.setValue('description', config.description); + } + if (config.model_config) { + form.setValue('configuration', JSON.stringify(config.model_config)); + } + setPreTrainedModelLoading(false); + }, + (error) => { + // TODO: Should handle loading error here + // eslint-disable-next-line no-console + console.log(error); + } + ); return () => { subscriber.unsubscribe(); }; - }, [nameParams]); + }, [nameParams, form]); const onError = useCallback((errors: FieldErrors) => { // TODO @@ -211,6 +230,42 @@ export const RegisterModelForm = () => { }, []); const errorCount = Object.keys(form.formState.errors).length; + const formHeader = ( + <> + + + + {latestVersionId && ( + <> + Register a new version of {modelGroupName}. The version number will be + automatically incremented.  + + Learn More + + . + + )} + {formType === 'import' && !latestVersionId && <>Register a pre-trained model.} + {formType === 'upload' && !latestVersionId && ( + <> + Register your model to manage its life cycle, and facilitate model discovery across + your organization. + + )} + + + + ); + + if (preTrainedModelLoading) { + return ( + + {formHeader} + + + + ); + } return ( @@ -221,28 +276,7 @@ export const RegisterModelForm = () => { component="form" > - - - - {latestVersionId && ( - <> - Register a new version of {modelGroupName}. The version number will be - automatically incremented.  - - Learn More - - . - - )} - {formType === 'import' && !latestVersionId && <>Register a pre-trained model.} - {formType === 'upload' && !latestVersionId && ( - <> - Register your model to manage its life cycle, and facilitate model discovery - across your organization. - - )} - - + {formHeader} {partials.map((FormPartial, i) => ( diff --git a/public/utils/model_repository_manager.ts b/public/utils/model_repository_manager.ts index f53b4f39..39dcd879 100644 --- a/public/utils/model_repository_manager.ts +++ b/public/utils/model_repository_manager.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { Observable, from } from 'rxjs'; -import { map, switchMap } from 'rxjs/operators'; +import { Observable, from, throwError } from 'rxjs'; +import { catchError, map, switchMap } from 'rxjs/operators'; import { APIProvider } from '../apis/api_provider'; interface PreTrainedModelInfo { @@ -25,13 +25,20 @@ interface PreTrainedModels { export class ModelRepositoryManager { private preTrainedModels: Observable | null = null; - private preTrainedModelConfigs: Map> = new Map(); + private preTrainedModelConfigs: Map> = new Map(); constructor() {} getPreTrainedModels$() { if (!this.preTrainedModels) { - this.preTrainedModels = from(APIProvider.getAPI('modelRepository').getPreTrainedModels()); + this.preTrainedModels = from( + APIProvider.getAPI('modelRepository').getPreTrainedModels() + ).pipe( + catchError((err) => { + this.preTrainedModels = null; + return throwError(err); + }) + ); } return this.preTrainedModels; } @@ -45,15 +52,22 @@ export class ModelRepositoryManager { if (!modelConfig$) { modelConfig$ = from( APIProvider.getAPI('modelRepository').getPreTrainedModelConfig(modelInfo.config_url) - ); + ) + .pipe( + map((config) => ({ + url: modelInfo.model_url, + config, + })) + ) + .pipe( + catchError((err) => { + this.preTrainedModelConfigs.delete(modelInfo.config_url); + return throwError(err); + }) + ); this.preTrainedModelConfigs.set(modelInfo.config_url, modelConfig$); } - return modelConfig$.pipe( - map((config) => ({ - url: modelInfo.model_url, - config, - })) - ); + return modelConfig$; }) ); } diff --git a/public/utils/tests/model_repository_manager.test.ts b/public/utils/tests/model_repository_manager.test.ts index 828e0f67..ba7bc19e 100644 --- a/public/utils/tests/model_repository_manager.test.ts +++ b/public/utils/tests/model_repository_manager.test.ts @@ -36,6 +36,16 @@ describe('ModelRepositoryManager', () => { expect(ModelRepository.prototype.getPreTrainedModels).toHaveBeenCalledTimes(1); }); + it('should call getPreTrainedModels twice after getPreTrainedModels throw error first time and call getPreTrainedModels$ multi times', async () => { + const manager = new ModelRepositoryManager(); + const mockError = new Error(); + jest.spyOn(ModelRepository.prototype, 'getPreTrainedModels').mockRejectedValueOnce(mockError); + await expect(manager.getPreTrainedModels$().toPromise()).rejects.toThrowError(mockError); + expect(ModelRepository.prototype.getPreTrainedModels).toHaveBeenCalledTimes(1); + await manager.getPreTrainedModels$().toPromise(); + expect(ModelRepository.prototype.getPreTrainedModels).toHaveBeenCalledTimes(2); + }); + it('should call getPreTrainedModelConfig with consistent config URL and return consistent config', async () => { const manager = new ModelRepositoryManager(); const result = await manager @@ -50,7 +60,7 @@ describe('ModelRepositoryManager', () => { 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/sentence-transformers_all-distilroberta-v1-1.0.1-torch_script.zip', config: expect.objectContaining({ model_content_hash_value: - 'acdc81b652b83121f914c5912ae27c0fca8fabf270e6f191ace6979a19830413', + '92bc10216c720b57a6bab1d7ca2cc2e559156997212a7f0d8bb70f2edfedc78b', }), }) ); @@ -68,4 +78,22 @@ describe('ModelRepositoryManager', () => { .toPromise(); expect(ModelRepository.prototype.getPreTrainedModelConfig).toHaveBeenCalledTimes(1); }); + + it('should call getPreTrainedModelConfig twice after after getPreTrainedModelConfig throw error first time and call getPreTrainedModel$ multi times', async () => { + const manager = new ModelRepositoryManager(); + const mockError = new Error(); + jest + .spyOn(ModelRepository.prototype, 'getPreTrainedModelConfig') + .mockRejectedValueOnce(mockError); + await expect( + manager + .getPreTrainedModel$('sentence-transformers/all-distilroberta-v1', 'torch_script') + .toPromise() + ).rejects.toThrowError(mockError); + expect(ModelRepository.prototype.getPreTrainedModelConfig).toHaveBeenCalledTimes(1); + await manager + .getPreTrainedModel$('sentence-transformers/all-distilroberta-v1', 'torch_script') + .toPromise(); + expect(ModelRepository.prototype.getPreTrainedModelConfig).toHaveBeenCalledTimes(2); + }); }); From 7654fcd339c097d5b5afe088c54c103f8893851e Mon Sep 17 00:00:00 2001 From: xyinshen Date: Fri, 10 Mar 2023 16:12:38 +0800 Subject: [PATCH 26/75] Feature file version title and configuration description (#130) * fix: update-register-form-hearder-descriptions Signed-off-by: xyinshen * fix: update File & version title Signed-off-by: xyinshen * fix: update File & version title Signed-off-by: xyinshen * fix: update configuration description Signed-off-by: xyinshen * fix: using font-size displayed in figma Signed-off-by: xyinshen * fix: update configuration unit-test Signed-off-by: xyinshen * fix: remove eui-textInheritColor Signed-off-by: xyinshen * fix: remove self-defined css Signed-off-by: xyinshen * fix: add help button Signed-off-by: xyinshen * fix: add help button unit-test Signed-off-by: xyinshen * fix: remove unnecessary change Signed-off-by: xyinshen * fix: restore empty line change Signed-off-by: xyinshen * fix: restore empty line change Signed-off-by: xyinshen * fix: restore unnecessary empty line Signed-off-by: xyinshen --------- Signed-off-by: xyinshen Signed-off-by: Lin Wang --- .../register_model_configuration.test.tsx | 2 +- public/components/register_model/artifact.tsx | 3 ++ .../register_model/model_configuration.tsx | 36 +++++++++++++------ 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/public/components/register_model/__tests__/register_model_configuration.test.tsx b/public/components/register_model/__tests__/register_model_configuration.test.tsx index 8ff8d618..d8a75f96 100644 --- a/public/components/register_model/__tests__/register_model_configuration.test.tsx +++ b/public/components/register_model/__tests__/register_model_configuration.test.tsx @@ -8,7 +8,7 @@ import { setup } from './setup'; describe(' Configuration', () => { it('should render a help flyout when click help button', async () => { - const { user } = await setup({}); + const { user } = await setup(); expect(screen.getByLabelText('Configuration in JSON')).toBeInTheDocument(); await user.click(screen.getByTestId('model-configuration-help-button')); diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index bb7cb4fa..89b63ada 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -17,6 +17,9 @@ export const ArtifactPanel = () => { return (
+ +

File and version information

+

Artifact

diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index 2e47a721..27f2e029 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -9,8 +9,10 @@ import { EuiTitle, EuiCodeEditor, EuiText, - EuiButtonEmpty, + EuiTextColor, + EuiCode, EuiSpacer, + EuiButtonEmpty, } from '@elastic/eui'; import { useController, useFormContext } from 'react-hook-form'; @@ -47,15 +49,19 @@ export const ConfigurationPanel = () => { - The model configuration JSON object.{' '} - setIsHelpVisible(true)} - size="xs" - color="primary" - data-test-subj="model-configuration-help-button" - > - Help. - + The model configuration specifies the{' '} + + model_type + + , + + embedding_dimension + {' '} + , and{' '} + + framework_type + {' '} + of the model. @@ -64,6 +70,16 @@ export const ConfigurationPanel = () => { label="Configuration in JSON" isInvalid={Boolean(configurationFieldController.fieldState.error)} error={configurationFieldController.fieldState.error?.message} + labelAppend={ + setIsHelpVisible(true)} + size="xs" + color="primary" + data-test-subj="model-configuration-help-button" + > + Help + + } > Date: Mon, 13 Mar 2023 09:48:58 +0800 Subject: [PATCH 27/75] fix: tweak test mocks (#137) Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../__tests__/register_model_artifact.test.tsx | 2 +- .../__tests__/register_model_details.test.tsx | 2 +- .../__tests__/register_model_metrics.test.tsx | 2 +- .../register_model/__tests__/register_model_tags.test.tsx | 2 +- .../__tests__/register_model_version_notes.test.tsx | 2 +- public/components/register_model/__tests__/setup.tsx | 6 +----- 6 files changed, 6 insertions(+), 10 deletions(-) diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 763985a5..d2a71f05 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -13,7 +13,7 @@ import { ONE_GB } from '../../../../common/constant'; jest.mock('../../../apis/model_repository'); describe(' Artifact', () => { - const onSubmitWithFileMock = jest.fn(); + const onSubmitWithFileMock = jest.fn().mockResolvedValue('model_id'); const onSubmitWithURLMock = jest.fn(); const uploadMock = jest.fn(); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index a8462b12..55aea11e 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -9,7 +9,7 @@ import * as formAPI from '../register_model_api'; import { Model } from '../../../apis/model'; describe(' Details', () => { - const onSubmitMock = jest.fn(); + const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { jest diff --git a/public/components/register_model/__tests__/register_model_metrics.test.tsx b/public/components/register_model/__tests__/register_model_metrics.test.tsx index 0e9c6653..e5b3a2ca 100644 --- a/public/components/register_model/__tests__/register_model_metrics.test.tsx +++ b/public/components/register_model/__tests__/register_model_metrics.test.tsx @@ -9,7 +9,7 @@ import * as formHooks from '../register_model.hooks'; import * as formAPI from '../register_model_api'; describe(' Evaluation Metrics', () => { - const onSubmitMock = jest.fn(); + const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { jest diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index 8f14cb16..db0a6a42 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -9,7 +9,7 @@ import * as formHooks from '../register_model.hooks'; import * as formAPI from '../register_model_api'; describe(' Tags', () => { - const onSubmitMock = jest.fn(); + const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { jest diff --git a/public/components/register_model/__tests__/register_model_version_notes.test.tsx b/public/components/register_model/__tests__/register_model_version_notes.test.tsx index 2a65bd67..64bd154d 100644 --- a/public/components/register_model/__tests__/register_model_version_notes.test.tsx +++ b/public/components/register_model/__tests__/register_model_version_notes.test.tsx @@ -8,7 +8,7 @@ import * as formHooks from '../register_model.hooks'; import * as formAPI from '../register_model_api'; describe(' Version notes', () => { - const onSubmitMock = jest.fn(); + const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { jest diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index 6d114102..c8634f19 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -30,11 +30,7 @@ export async function setup(options?: SetupOptions) { const versionNotesInput = screen.getByLabelText(/notes/i); // Mock model name unique - jest.spyOn(Model.prototype, 'search').mockResolvedValue({ - data: [], - pagination: { totalRecords: 0, currentPage: 1, pageSize: 1, totalPages: 0 }, - }); - + jest.spyOn(Model.prototype, 'search').mockResolvedValue({ data: [], total_models: 0 }); // fill model name if (!options?.ignoreFillFields?.includes('name')) { await user.type(nameInput, 'test model name'); From 2ba8108f2c2350d720881cc617e0ab21b3c231fa Mon Sep 17 00:00:00 2001 From: xyinshen Date: Mon, 13 Mar 2023 17:07:45 +0800 Subject: [PATCH 28/75] Feature update description max width 725 (#140) * fix: update-register-form-hearder-descriptions Signed-off-by: xyinshen * fix: update File & version title Signed-off-by: xyinshen * fix: update File & version title Signed-off-by: xyinshen * fix: using font-size displayed in figma Signed-off-by: xyinshen * fix: remove self-defined css Signed-off-by: xyinshen * fix: add help button Signed-off-by: xyinshen * fix: add help button unit-test Signed-off-by: xyinshen * fix: remove unnecessary change Signed-off-by: xyinshen * fix: restore unnecessary empty line Signed-off-by: xyinshen * fix: update form_partial_description_max-width-to-725px Signed-off-by: xyinshen * fix: fix test_error Signed-off-by: xyinshen --------- Signed-off-by: xyinshen Signed-off-by: Lin Wang --- public/components/register_model/artifact.tsx | 2 +- public/components/register_model/evaluation_metrics.tsx | 2 +- public/components/register_model/model_configuration.tsx | 2 +- public/components/register_model/model_tags.tsx | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index 89b63ada..6cfa413a 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -23,7 +23,7 @@ export const ArtifactPanel = () => {

Artifact

- + The zipped artifact must include a model file and a tokenizer file. If uploading with a local file, keep this browser open util the upload completes.{' '} diff --git a/public/components/register_model/evaluation_metrics.tsx b/public/components/register_model/evaluation_metrics.tsx index a46c26ab..16378c1a 100644 --- a/public/components/register_model/evaluation_metrics.tsx +++ b/public/components/register_model/evaluation_metrics.tsx @@ -146,7 +146,7 @@ export const EvaluationMetricsPanel = () => { Evaluation Metrics - optional - + Track training, validation, and testing metrics to compare across versions and models. diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index 27f2e029..e43773f9 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -47,7 +47,7 @@ export const ConfigurationPanel = () => {

Configuration

- + The model configuration specifies the{' '} diff --git a/public/components/register_model/model_tags.tsx b/public/components/register_model/model_tags.tsx index 8a26c994..89d088a2 100644 --- a/public/components/register_model/model_tags.tsx +++ b/public/components/register_model/model_tags.tsx @@ -32,7 +32,7 @@ export const ModelTagsPanel = () => { Tags - optional - + Add tags to facilitate model discovery and tracking across your organization. From c24d55dca7a90e1fc61cf90d42e46083c223622a Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Tue, 14 Mar 2023 15:03:37 +0800 Subject: [PATCH 29/75] feat: disallow user to type if text exceed max length (#138) The rule is applied to: 1. Model name input 2. Model description input 3. Model note input + removed validation rules on name, description and note input + renamed isUniqueModelName to isDuplicateModelName --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../__tests__/register_model_details.test.tsx | 22 ++++--------------- .../register_model_version_notes.test.tsx | 11 ++-------- .../register_model/model_details.tsx | 13 ++++++----- .../register_model/model_version_notes.tsx | 4 ++-- 4 files changed, 15 insertions(+), 35 deletions(-) diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index 55aea11e..9fd6b7cd 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -50,19 +50,12 @@ describe(' Details', () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); - it('should NOT submit the register model form if model name length exceeded 80', async () => { + it('should NOT allow to input a model name which exceed max length: 80', async () => { const result = await setup(); - await result.user.clear(result.nameInput); - await result.user.type(result.nameInput, 'x'.repeat(80)); - expect(result.nameInput).toBeValid(); - await result.user.clear(result.nameInput); await result.user.type(result.nameInput, 'x'.repeat(81)); - expect(result.nameInput).toBeInvalid(); - - await result.user.click(result.submitButton); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(result.nameInput.value).toHaveLength(80); }); it('should NOT submit the register model form if model name is duplicated', async () => { @@ -90,18 +83,11 @@ describe(' Details', () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); - it('should NOT submit the register model form if model description length exceed 200', async () => { + it('should NOT allow to input a model description which exceed max length: 200', async () => { const result = await setup(); - await result.user.clear(result.descriptionInput); - await result.user.type(result.descriptionInput, 'x'.repeat(200)); - expect(result.descriptionInput).toBeValid(); - await result.user.clear(result.descriptionInput); await result.user.type(result.descriptionInput, 'x'.repeat(201)); - expect(result.descriptionInput).toBeInvalid(); - - await result.user.click(result.submitButton); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(result.descriptionInput.value).toHaveLength(200); }); }); diff --git a/public/components/register_model/__tests__/register_model_version_notes.test.tsx b/public/components/register_model/__tests__/register_model_version_notes.test.tsx index 64bd154d..560ba174 100644 --- a/public/components/register_model/__tests__/register_model_version_notes.test.tsx +++ b/public/components/register_model/__tests__/register_model_version_notes.test.tsx @@ -36,18 +36,11 @@ describe(' Version notes', () => { expect(onSubmitMock).toHaveBeenCalled(); }); - it('should NOT submit the register model form if model version notes length exceed 200', async () => { + it('should NOT allow to input a model note which exceed max length: 200', async () => { const result = await setup(); - await result.user.clear(result.versionNotesInput); - await result.user.type(result.versionNotesInput, 'x'.repeat(200)); - expect(result.versionNotesInput).toBeValid(); - await result.user.clear(result.versionNotesInput); await result.user.type(result.versionNotesInput, 'x'.repeat(201)); - expect(result.versionNotesInput).toBeInvalid(); - - await result.user.click(result.submitButton); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(result.versionNotesInput.value).toHaveLength(200); }); }); diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 5ff3513b..337f60f8 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -12,7 +12,7 @@ import { APIProvider } from '../../apis/api_provider'; const NAME_MAX_LENGTH = 80; const DESCRIPTION_MAX_LENGTH = 200; -const isUniqueModelName = async (name: string) => { +const isDuplicateModelName = async (name: string) => { const searchResult = await APIProvider.getAPI('model').search({ name, from: 0, @@ -30,11 +30,10 @@ export const ModelDetailsPanel = () => { rules: { required: { value: true, message: 'Name can not be empty' }, validate: async (name) => { - return !modelNameFocusedRef.current && !!name && (await isUniqueModelName(name)) + return !modelNameFocusedRef.current && !!name && (await isDuplicateModelName(name)) ? 'This name is already in use. Use a unique name for the model.' : undefined; }, - maxLength: { value: NAME_MAX_LENGTH, message: 'Text exceed max length' }, }, }); @@ -43,7 +42,6 @@ export const ModelDetailsPanel = () => { control, rules: { required: { value: true, message: 'Description can not be empty' }, - maxLength: { value: DESCRIPTION_MAX_LENGTH, message: 'Text exceed max length' }, }, }); @@ -71,7 +69,8 @@ export const ModelDetailsPanel = () => { error={nameFieldController.fieldState.error?.message} helpText={ - {Math.max(NAME_MAX_LENGTH - nameField.value.length, 0)} characters allowed. + {Math.max(NAME_MAX_LENGTH - nameField.value.length, 0)} characters{' '} + {nameField.value.length ? 'left' : 'allowed'}.
Use a unique for the model.
@@ -80,6 +79,7 @@ export const ModelDetailsPanel = () => { { helpText={`${Math.max( DESCRIPTION_MAX_LENGTH - descriptionField.value.length, 0 - )} characters allowed.`} + )} characters ${descriptionField.value.length ? 'left' : 'allowed'}.`} > diff --git a/public/components/register_model/model_version_notes.tsx b/public/components/register_model/model_version_notes.tsx index 839b6340..e0983c5b 100644 --- a/public/components/register_model/model_version_notes.tsx +++ b/public/components/register_model/model_version_notes.tsx @@ -17,7 +17,6 @@ export const ModelVersionNotesPanel = () => { const fieldController = useController({ name: 'versionNotes', control, - rules: { maxLength: { value: VERSION_NOTES_MAX_LENGTH, message: 'Text exceed max length' } }, }); const { ref, ...versionNotesField } = fieldController.field; @@ -33,7 +32,7 @@ export const ModelVersionNotesPanel = () => { helpText={`${Math.max( VERSION_NOTES_MAX_LENGTH - (versionNotesField.value?.length ?? 0), 0 - )} characters allowed.`} + )} characters ${versionNotesField.value?.length ? 'left' : 'allowed'}.`} isInvalid={Boolean(fieldController.fieldState.error)} error={fieldController.fieldState.error?.message} label="Notes" @@ -41,6 +40,7 @@ export const ModelVersionNotesPanel = () => { From 8a7ed889468ef40d049e71816a6bf896d690ab88 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Thu, 16 Mar 2023 09:33:22 +0800 Subject: [PATCH 30/75] Feature/update model register tags logic (#142) * refactor: use mode to finish different register form setup Signed-off-by: Lin Wang * feat: update tag logic and UI according the new design Signed-off-by: Lin Wang * feat: remove metric input Signed-off-by: Lin Wang * feat: update tags panel order for register own model Signed-off-by: Lin Wang * refactor: change to use setup init register form Signed-off-by: Lin Wang * test: add case for hidden selected item in option list Signed-off-by: Lin Wang * fix: add back selected value in value option list Signed-off-by: Lin Wang * feat: update validate to avoid duplicate tag keys Signed-off-by: Lin Wang * feat: update tags panel description for register new version Signed-off-by: Lin Wang * test: correct error message in register model setup Signed-off-by: Lin Wang * chore: update the annotation for duplicate tag key Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../register_model_artifact.test.tsx | 3 - .../__tests__/register_model_details.test.tsx | 3 - .../__tests__/register_model_form.test.tsx | 28 +-- .../__tests__/register_model_metrics.test.tsx | 131 ----------- .../__tests__/register_model_tags.test.tsx | 109 +++++++-- .../register_model_version_notes.test.tsx | 3 - .../register_model/__tests__/setup.tsx | 76 +++++-- .../register_model/evaluation_metrics.tsx | 210 ------------------ .../components/register_model/model_tags.tsx | 32 ++- .../register_model/register_model.hooks.ts | 21 -- .../register_model/register_model.tsx | 5 +- .../register_model/register_model.types.ts | 8 - .../components/register_model/tag_field.tsx | 149 +++++++------ 13 files changed, 274 insertions(+), 504 deletions(-) delete mode 100644 public/components/register_model/__tests__/register_model_metrics.test.tsx delete mode 100644 public/components/register_model/evaluation_metrics.tsx diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index d2a71f05..848bef71 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -18,9 +18,6 @@ describe(' Artifact', () => { const uploadMock = jest.fn(); beforeEach(() => { - jest - .spyOn(formHooks, 'useMetricNames') - .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index 9fd6b7cd..5f10860b 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -12,9 +12,6 @@ describe(' Details', () => { const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { - jest - .spyOn(formHooks, 'useMetricNames') - .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 72c2eef3..cbf20b97 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -4,11 +4,8 @@ */ import React from 'react'; -import { Route } from 'react-router-dom'; import { render, screen, waitFor } from '../../../../test/test_utils'; -import { RegisterModelForm } from '../register_model'; -import { routerPaths } from '../../../../common/router_paths'; import { setup } from './setup'; import { Model } from '../../../../public/apis/model'; import * as PluginContext from '../../../../../../src/plugins/opensearch_dashboards_react/public'; @@ -77,12 +74,7 @@ describe(' Form', () => { it('should init form when id param in url route', async () => { const mockResult = MOCKED_DATA; jest.spyOn(Model.prototype, 'getOne').mockResolvedValue(mockResult); - render( - - - , - { route: '/model-registry/register-model/test_model_id' } - ); + await setup({ route: '/test_model_id', mode: 'version' }); const { name } = mockResult; @@ -94,12 +86,7 @@ describe(' Form', () => { it('submit button label should be `Register version` when register new version', async () => { jest.spyOn(Model.prototype, 'getOne').mockResolvedValue(MOCKED_DATA); - render( - - - , - { route: '/model-registry/register-model/test_model_id' } - ); + await setup({ route: '/test_model_id', mode: 'version' }); expect(screen.getByRole('button', { name: /register version/i })).toBeInTheDocument(); }); @@ -107,7 +94,7 @@ describe(' Form', () => { it('submit button label should be `Register model` when import a model', async () => { await setup({ route: '/?type=import&name=sentence-transformers/all-distilroberta-v1', - ignoreFillFields: ['name', 'description'], + mode: 'import', }); expect(screen.getByRole('button', { name: /register model/i })).toBeInTheDocument(); }); @@ -116,7 +103,7 @@ describe(' Form', () => { jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitMock); const { user } = await setup({ route: '/?type=import&name=sentence-transformers/all-distilroberta-v1', - ignoreFillFields: ['name', 'description'], + mode: 'import', }); await waitFor(() => expect(screen.getByLabelText(/^name$/i).value).toEqual( @@ -138,12 +125,7 @@ describe(' Form', () => { }); it('submit button label should be `Register model` when register new model', async () => { - render( - - - , - { route: '/model-registry/register-model' } - ); + await setup(); expect(screen.getByRole('button', { name: /register model/i })).toBeInTheDocument(); }); diff --git a/public/components/register_model/__tests__/register_model_metrics.test.tsx b/public/components/register_model/__tests__/register_model_metrics.test.tsx deleted file mode 100644 index e5b3a2ca..00000000 --- a/public/components/register_model/__tests__/register_model_metrics.test.tsx +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { screen } from '../../../../test/test_utils'; -import { setup } from './setup'; -import * as formHooks from '../register_model.hooks'; -import * as formAPI from '../register_model_api'; - -describe(' Evaluation Metrics', () => { - const onSubmitMock = jest.fn().mockResolvedValue('model_id'); - - beforeEach(() => { - jest - .spyOn(formHooks, 'useMetricNames') - .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); - jest - .spyOn(formHooks, 'useModelTags') - .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); - jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); - }); - - afterEach(() => { - jest.clearAllMocks(); - }); - - it('should render a evaluation metrics panel', async () => { - await setup(); - - expect(screen.getByLabelText(/^metric$/i)).toBeInTheDocument(); - expect(screen.getByLabelText(/training value/i)).toBeInTheDocument(); - expect(screen.getByLabelText(/validation value/i)).toBeInTheDocument(); - expect(screen.getByLabelText(/testing value/i)).toBeInTheDocument(); - }); - - it('should render metric value input as disabled by default', async () => { - await setup(); - - expect(screen.getByLabelText(/training value/i)).toBeDisabled(); - expect(screen.getByLabelText(/validation value/i)).toBeDisabled(); - expect(screen.getByLabelText(/testing value/i)).toBeDisabled(); - }); - - it('should render metric value input as enabled after selecting a metric name', async () => { - const result = await setup(); - - await result.user.click(screen.getByLabelText(/^metric$/i)); - await result.user.click(screen.getByText('Metric 1')); - - expect(screen.getByLabelText(/training value/i)).toBeEnabled(); - expect(screen.getByLabelText(/validation value/i)).toBeEnabled(); - expect(screen.getByLabelText(/testing value/i)).toBeEnabled(); - }); - - it('should submit the form without selecting metric name', async () => { - const result = await setup(); - await result.user.click(result.submitButton); - - expect(onSubmitMock).toHaveBeenCalled(); - }); - - it('should NOT submit the form if metric name is selected but metric value are empty and error message in screen', async () => { - const result = await setup(); - await result.user.click(screen.getByLabelText(/^metric$/i)); - await result.user.click(screen.getByText('Metric 1')); - await result.user.click(result.submitButton); - - expect(onSubmitMock).not.toHaveBeenCalled(); - expect(screen.getByText('At least one value is required. Enter a value')).toBeInTheDocument(); - }); - - it('should submit the form if metric name and all metric value are selected', async () => { - const result = await setup(); - await result.user.click(screen.getByLabelText(/^metric$/i)); - await result.user.click(screen.getByText('Metric 1')); - - await result.user.type(screen.getByLabelText(/training value/i), '1'); - await result.user.type(screen.getByLabelText(/validation value/i), '1'); - await result.user.type(screen.getByLabelText(/testing value/i), '1'); - - await result.user.click(result.submitButton); - - expect(onSubmitMock).toHaveBeenCalled(); - }); - - it('should submit the form if metric name is selected but metric value are partially selected', async () => { - const result = await setup(); - await result.user.click(screen.getByLabelText(/^metric$/i)); - await result.user.click(screen.getByText('Metric 1')); - - // Only input Training metric value - await result.user.type(screen.getByLabelText(/training value/i), '1'); - await result.user.click(result.submitButton); - - expect(onSubmitMock).toHaveBeenCalled(); - }); - - it('should NOT submit the form if metric value < 0', async () => { - const result = await setup(); - await result.user.click(screen.getByLabelText(/^metric$/i)); - await result.user.click(screen.getByText('Metric 1')); - - // Type an invalid value - await result.user.type(screen.getByLabelText(/training value/i), '-.1'); - await result.user.click(result.submitButton); - - expect(onSubmitMock).not.toHaveBeenCalled(); - }); - - it('should NOT submit the form if metric value > 1', async () => { - const result = await setup(); - await result.user.click(screen.getByLabelText(/^metric$/i)); - await result.user.click(screen.getByText('Metric 1')); - - // Type an invalid value - await result.user.type(screen.getByLabelText(/training value/i), '1.1'); - await result.user.click(result.submitButton); - - expect(onSubmitMock).not.toHaveBeenCalled(); - }); - - it('should keep metric value not more than 2 decimal point', async () => { - const result = await setup(); - await result.user.click(screen.getByLabelText(/^metric$/i)); - await result.user.click(screen.getByText('Metric 1')); - - await result.user.type(screen.getByLabelText(/training value/i), '1.111'); - expect(screen.getByLabelText(/training value/i)).toHaveValue(1.11); - }); -}); diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index db0a6a42..b54f3158 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -12,9 +12,6 @@ describe(' Tags', () => { const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { - jest - .spyOn(formHooks, 'useMetricNames') - .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); @@ -121,7 +118,7 @@ describe(' Tags', () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); - it('should NOT allow to submit if it has duplicate tags', async () => { + it('should NOT allow to submit if it has duplicate tags key', async () => { const result = await setup(); // input tag key: 'Key 1' @@ -141,28 +138,26 @@ describe(' Tags', () => { const keyInput2 = within(keyContainer2).getByRole('textbox'); await result.user.type(keyInput2, 'Key 1'); - // input tag key: 'Value 1' + // input tag key: 'Value 2' const valueContainer2 = screen.getByTestId('ml-tagValue2'); const valueInput2 = within(valueContainer2).getByRole('textbox'); - await result.user.type(valueInput2, 'Value 1'); + await result.user.type(valueInput2, 'Value 2'); await result.user.click(result.submitButton); // Display error message expect( - within(keyContainer2).queryByText( - 'This tag has already been added. Remove the duplicate tag.' - ) + within(keyContainer2).queryByText('Tag keys must be unique. Use a unique key.') ).toBeInTheDocument(); // it should not submit the form expect(onSubmitMock).not.toHaveBeenCalled(); }); it( - 'should only allow to add maximum 25 tags', + 'should only allow to add maximum 10 tags', async () => { const result = await setup(); - const MAX_TAG_NUM = 25; + const MAX_TAG_NUM = 10; // It has one tag by default, we can add 24 more tags const addNewTagButton = screen.getByText(/add new tag/i); @@ -170,8 +165,8 @@ describe(' Tags', () => { await result.user.click(addNewTagButton); } - // 25 tags are displayed - await waitFor(() => expect(screen.queryAllByTestId(/ml-tagKey/i)).toHaveLength(25)); + // 10 tags are displayed + await waitFor(() => expect(screen.queryAllByTestId(/ml-tagKey/i)).toHaveLength(10)); // add new tag button should not be displayed await waitFor(() => expect(screen.getByRole('button', { name: /add new tag/i })).toBeDisabled() @@ -208,4 +203,92 @@ describe(' Tags', () => { }) ); }); + + it('should allow adding one more tag when registering new version if model group has only two tags', async () => { + const result = await setup({ + route: '/foo', + mode: 'version', + }); + + await result.user.click(screen.getByText(/add new tag/i)); + + await waitFor(() => + expect(screen.getByRole('button', { name: /add new tag/i })).toBeDisabled() + ); + }); + + it('should prevent creating new tag key when registering new version', async () => { + const result = await setup({ + route: '/foo', + mode: 'version', + }); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await result.user.type(keyInput, 'foo{enter}'); + expect( + screen.getByText((content, element) => { + return ( + element?.tagName.toLowerCase() === 'strong' && + content === 'foo' && + element?.nextSibling?.textContent?.trim() === "doesn't match any options" + ); + }) + ).toBeInTheDocument(); + }); + + it('should display error when creating new tag key with more than 80 characters', async () => { + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await result.user.type(keyInput, `${'x'.repeat(81)}{enter}`); + expect( + within(keyContainer).queryByText('80 characters allowed. Use 80 characters or less.') + ).toBeInTheDocument(); + }); + + it('should display error when creating new tag value with more than 80 characters', async () => { + const result = await setup(); + + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + await result.user.type(valueInput, `${'x'.repeat(81)}{enter}`); + expect( + within(valueContainer).queryByText('80 characters allowed. Use 80 characters or less.') + ).toBeInTheDocument(); + }); + + it('should display "No keys found" and "No values found" if no tag keys and no tag values are provided', async () => { + jest.spyOn(formHooks, 'useModelTags').mockReturnValue([false, { keys: [], values: [] }]); + + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await result.user.click(keyInput); + expect(screen.getByText('No keys found. Add a key.')).toBeInTheDocument(); + + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + await result.user.click(valueInput); + expect(screen.getByText('No values found. Add a value.')).toBeInTheDocument(); + }); + + it('should only display "Key2" in the option list after "Key1" selected', async () => { + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await result.user.click(keyInput); + const optionListContainer = screen.getByTestId('comboBoxOptionsList'); + + expect(within(optionListContainer).getByTitle('Key2')).toBeInTheDocument(); + expect(within(optionListContainer).getByTitle('Key1')).toBeInTheDocument(); + + await result.user.click(within(optionListContainer).getByTitle('Key1')); + + expect(within(optionListContainer).getByTitle('Key2')).toBeInTheDocument(); + expect(within(optionListContainer).queryByTitle('Key1')).toBe(null); + }); }); diff --git a/public/components/register_model/__tests__/register_model_version_notes.test.tsx b/public/components/register_model/__tests__/register_model_version_notes.test.tsx index 560ba174..304e1570 100644 --- a/public/components/register_model/__tests__/register_model_version_notes.test.tsx +++ b/public/components/register_model/__tests__/register_model_version_notes.test.tsx @@ -11,9 +11,6 @@ describe(' Version notes', () => { const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { - jest - .spyOn(formHooks, 'useMetricNames') - .mockReturnValue([false, ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']]); jest .spyOn(formHooks, 'useModelTags') .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index c8634f19..db2bfd4a 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -5,6 +5,8 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; +import { Route } from 'react-router-dom'; +import { UserEvent } from '@testing-library/user-event/dist/types/setup/setup'; import { RegisterModelForm } from '../register_model'; import { Model } from '../../../apis/model'; @@ -13,39 +15,83 @@ import { render, RenderWithRouteProps, screen, waitFor } from '../../../../test/ jest.mock('../../../apis/task'); interface SetupOptions extends RenderWithRouteProps { - ignoreFillFields?: Array<'name' | 'description'>; + mode?: 'model' | 'version' | 'import'; } -export async function setup(options?: SetupOptions) { - render(, { route: options?.route ?? '/' }); +interface SetupReturn { + nameInput: HTMLInputElement; + descriptionInput: HTMLTextAreaElement; + submitButton: HTMLButtonElement; + form: HTMLElement; + user: UserEvent; + versionNotesInput: HTMLTextAreaElement; +} + +export async function setup(options: { + route: string; + mode: 'version'; +}): Promise>; +export async function setup(options?: { + route: string; + mode: 'model' | 'import'; +}): Promise; +export async function setup( + { route, mode }: SetupOptions = { + route: '/', + mode: 'model', + } +) { + render( + + + , + { route } + ); await waitFor(() => expect(screen.queryByLabelText('Model Form Loading')).toBe(null)); - const nameInput = screen.getByLabelText(/^name$/i); - const descriptionInput = screen.getByLabelText(/description/i); + const nameInput = screen.queryByLabelText(/^name$/i); + const descriptionInput = screen.queryByLabelText(/description/i); const submitButton = screen.getByRole('button', { - name: /register model/i, + name: mode === 'version' ? /register version/i : /register model/i, }); const modelFileInput = screen.queryByLabelText(/file/i); const form = screen.getByTestId('mlCommonsPlugin-registerModelForm'); const user = userEvent.setup(); const versionNotesInput = screen.getByLabelText(/notes/i); + // fill model file + if (modelFileInput) { + await user.upload( + modelFileInput, + new File(['test model file'], 'model.zip', { type: 'application/zip' }) + ); + } + + if (mode === 'version') { + return { + submitButton, + form, + user, + versionNotesInput, + }; + } + + if (!nameInput) { + throw new Error('Name input not found'); + } + if (!descriptionInput) { + throw new Error('Description input not found'); + } + // Mock model name unique jest.spyOn(Model.prototype, 'search').mockResolvedValue({ data: [], total_models: 0 }); // fill model name - if (!options?.ignoreFillFields?.includes('name')) { + if (mode === 'model') { await user.type(nameInput, 'test model name'); } // fill model description - if (!options?.ignoreFillFields?.includes('description')) { + if (mode === 'model') { await user.type(descriptionInput, 'test model description'); } - // fill model file - if (modelFileInput) { - await user.upload( - modelFileInput, - new File(['test model file'], 'model.zip', { type: 'application/zip' }) - ); - } return { nameInput, diff --git a/public/components/register_model/evaluation_metrics.tsx b/public/components/register_model/evaluation_metrics.tsx deleted file mode 100644 index 16378c1a..00000000 --- a/public/components/register_model/evaluation_metrics.tsx +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React, { useCallback, useMemo, useState } from 'react'; -import { - EuiFormRow, - EuiTitle, - EuiComboBox, - EuiComboBoxOptionOption, - EuiFlexItem, - EuiFlexGroup, - EuiFieldNumber, - EuiSpacer, - EuiText, -} from '@elastic/eui'; -import { useController, useFormContext, useWatch } from 'react-hook-form'; - -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -import { useMetricNames } from './register_model.hooks'; -import { fixTwoDecimalPoint } from '../../utils'; - -const METRIC_VALUE_STEP = 0.01; -const MAX_METRIC_NAME_LENGTH = 50; - -export const EvaluationMetricsPanel = () => { - const { trigger, control } = useFormContext(); - const [isRequiredValueText, setIsRequiredValueText] = useState(false); - const [metricNamesLoading, metricNames] = useMetricNames(); - - // TODO: this has to be hooked with data from BE API - const options = useMemo(() => { - return metricNames.map((n) => ({ label: n })); - }, [metricNames]); - - const metricKeyController = useController({ - name: 'metric.key', - control, - }); - - const metric = useWatch({ - control, - name: 'metric', - }); - - const valueValidateFn = () => { - if (metric) { - const { trainingValue, validationValue, testingValue, key } = metric; - if (key && !trainingValue && !validationValue && !testingValue) { - setIsRequiredValueText(true); - return false; - } else { - setIsRequiredValueText(false); - return true; - } - } - return true; - }; - const trainingMetricFieldController = useController({ - name: 'metric.trainingValue', - control, - rules: { - max: 1, - min: 0, - validate: valueValidateFn, - }, - }); - - const validationMetricFieldController = useController({ - name: 'metric.validationValue', - control, - rules: { - max: 1, - min: 0, - validate: valueValidateFn, - }, - }); - - const testingMetricFieldController = useController({ - name: 'metric.testingValue', - control, - rules: { - max: 1, - min: 0, - validate: valueValidateFn, - }, - }); - - const onMetricNameChange = useCallback( - (data: EuiComboBoxOptionOption[]) => { - if (data.length === 0) { - trainingMetricFieldController.field.onChange(''); - validationMetricFieldController.field.onChange(''); - testingMetricFieldController.field.onChange(''); - metricKeyController.field.onChange(''); - } else { - metricKeyController.field.onChange(data[0].label); - } - }, - [ - metricKeyController, - trainingMetricFieldController, - validationMetricFieldController, - testingMetricFieldController, - ] - ); - - const onCreateMetricName = useCallback( - (metricName: string) => { - if (metricName.length > MAX_METRIC_NAME_LENGTH) { - return; - } - metricKeyController.field.onChange(metricName); - }, - [metricKeyController.field] - ); - - const metricValueFields = [ - { label: 'Training value', controller: trainingMetricFieldController }, - { label: 'Validation value', controller: validationMetricFieldController }, - { label: 'Testing value', controller: testingMetricFieldController }, - ]; - - const onBlur = useCallback( - (e: React.FocusEvent) => { - // The blur could happen when selecting combo box dropdown - // But we don't want to trigger form validation in this case - if ( - (e.relatedTarget?.getAttribute('role') === 'option' && - e.relatedTarget?.tagName === 'BUTTON') || - e.relatedTarget?.getAttribute('role') === 'textbox' - ) { - return; - } - // Validate the form only when the current tag row blurred - trigger('metric'); - }, - [trigger] - ); - - return ( -
- -

- Evaluation Metrics - optional -

-
- - - Track training, validation, and testing metrics to compare across versions and models. - - - - - - - - - {metricValueFields.map(({ label, controller }) => ( - - - - controller.field.onChange(fixTwoDecimalPoint(value.target.value)) - } - onBlur={controller.field.onBlur} - inputRef={controller.field.ref} - /> - - - ))} - - {isRequiredValueText && ( - - At least one value is required. Enter a value - - )} -
- ); -}; diff --git a/public/components/register_model/model_tags.tsx b/public/components/register_model/model_tags.tsx index 89d088a2..ace1befb 100644 --- a/public/components/register_model/model_tags.tsx +++ b/public/components/register_model/model_tags.tsx @@ -4,22 +4,26 @@ */ import React, { useCallback } from 'react'; -import { EuiButton, EuiTitle, EuiSpacer, EuiText } from '@elastic/eui'; +import { EuiButton, EuiTitle, EuiSpacer, EuiText, EuiLink } from '@elastic/eui'; import { useFieldArray, useFormContext } from 'react-hook-form'; +import { useParams } from 'react-router-dom'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { ModelTagField } from './tag_field'; import { useModelTags } from './register_model.hooks'; -const MAX_TAG_NUM = 25; +const MAX_TAG_NUM = 10; export const ModelTagsPanel = () => { const { control } = useFormContext(); + const { id: latestVersionId } = useParams<{ id: string | undefined }>(); const [, { keys, values }] = useModelTags(); const { fields, append, remove } = useFieldArray({ name: 'tags', control, }); + const isRegisterNewVersion = !!latestVersionId; + const maxTagNum = isRegisterNewVersion ? keys.length : MAX_TAG_NUM; const addNewTag = useCallback(() => { append({ key: '', value: '' }); @@ -33,7 +37,24 @@ export const ModelTagsPanel = () => { - Add tags to facilitate model discovery and tracking across your organization. + + {isRegisterNewVersion ? ( + <> + Add tags to help your organization discover and compare models, and track information + related to new versions of this model, such as evaluation metrics. + + ) : ( + <> + Add tags to help your organization discover, compare, and track information related to + your model. The tag keys used here will define the available keys for all versions of + this model.{' '} + + Learn more + + . + + )} + {fields.map((field, index) => { @@ -44,16 +65,17 @@ export const ModelTagsPanel = () => { tagKeys={keys} tagValues={values} onDelete={remove} + allowKeyCreate={!latestVersionId} /> ); })} - = MAX_TAG_NUM} onClick={addNewTag}> + = maxTagNum} onClick={addNewTag}> Add new tag - {`You can add up to ${MAX_TAG_NUM - fields.length} more tags.`} + {`You can add up to ${maxTagNum - fields.length} more tags.`}
); diff --git a/public/components/register_model/register_model.hooks.ts b/public/components/register_model/register_model.hooks.ts index 487b2510..0cc31c56 100644 --- a/public/components/register_model/register_model.hooks.ts +++ b/public/components/register_model/register_model.hooks.ts @@ -5,27 +5,6 @@ import { useEffect, useState } from 'react'; -const metricNames = ['Metric 1', 'Metric 2', 'Metric 3', 'Metric 4']; - -/** - * TODO: implement this function so that it retrieve metric names from BE - */ -export const useMetricNames = () => { - const [loading, setLoading] = useState(true); - - useEffect(() => { - const timeoutId = window.setTimeout(() => { - setLoading(false); - }, 1000); - - return () => { - window.clearTimeout(timeoutId); - }; - }, []); - - return [loading, metricNames] as const; -}; - const keys = ['tag1', 'tag2']; const values = ['value1', 'value2']; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 877d6909..bb00baae 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -27,7 +27,6 @@ import { ModelDetailsPanel } from './model_details'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { ArtifactPanel } from './artifact'; import { ConfigurationPanel } from './model_configuration'; -import { EvaluationMetricsPanel } from './evaluation_metrics'; import { ModelTagsPanel } from './model_tags'; import { submitModelWithFile, submitModelWithURL } from './register_model_api'; import { APIProvider } from '../../apis/api_provider'; @@ -73,10 +72,10 @@ export const RegisterModelForm = () => { ? [ModelDetailsPanel, ModelTagsPanel, ModelVersionNotesPanel] : [ ...(latestVersionId ? [] : [ModelDetailsPanel]), + ...(latestVersionId ? [] : [ModelTagsPanel]), ArtifactPanel, ConfigurationPanel, - EvaluationMetricsPanel, - ModelTagsPanel, + ...(latestVersionId ? [ModelTagsPanel] : []), ModelVersionNotesPanel, ]; diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index 95855303..b7a2baec 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -8,19 +8,11 @@ export interface Tag { value: string; } -interface Metric { - key: string; - trainingValue: string; - validationValue: string; - testingValue: string; -} - interface ModelFormBase { name: string; version: string; description: string; configuration: string; - metric?: Metric; tags?: Tag[]; versionNotes?: string; } diff --git a/public/components/register_model/tag_field.tsx b/public/components/register_model/tag_field.tsx index ed47ca03..e076a955 100644 --- a/public/components/register_model/tag_field.tsx +++ b/public/components/register_model/tag_field.tsx @@ -10,6 +10,7 @@ import { EuiFlexGroup, EuiFlexItem, EuiFormRow, + EuiContext, } from '@elastic/eui'; import React, { useCallback, useMemo, useRef } from 'react'; import { useController, useWatch, useFormContext } from 'react-hook-form'; @@ -21,10 +22,23 @@ interface ModelTagFieldProps { onDelete: (index: number) => void; tagKeys: string[]; tagValues: string[]; + allowKeyCreate?: boolean; } const MAX_TAG_LENGTH = 80; +const KEY_COMBOBOX_I18N = { + mapping: { + 'euiComboBoxOptionsList.noAvailableOptions': 'No keys found. Add a key.', + }, +}; + +const VALUE_COMBOBOX_I18N = { + mapping: { + 'euiComboBoxOptionsList.noAvailableOptions': 'No values found. Add a value.', + }, +}; + function getComboBoxValue(data: EuiComboBoxOptionOption[]) { if (data.length === 0) { return ''; @@ -33,7 +47,13 @@ function getComboBoxValue(data: EuiComboBoxOptionOption[]) { } } -export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagFieldProps) => { +export const ModelTagField = ({ + index, + tagKeys, + tagValues, + allowKeyCreate, + onDelete, +}: ModelTagFieldProps) => { const rowEleRef = useRef(null); const { trigger, control } = useFormContext(); const tags = useWatch({ @@ -45,6 +65,10 @@ export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagF name: `tags.${index}.key` as const, control, rules: { + maxLength: { + value: MAX_TAG_LENGTH, + message: '80 characters allowed. Use 80 characters or less.', + }, validate: (tagKey) => { if (tags) { const tag = tags[index]; @@ -52,13 +76,13 @@ export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagF if (!tagKey && tag.value) { return 'A key is required. Enter a key.'; } - // If a tag has both key and value, validate if the same tag was added before - if (tagKey && tag.value) { - // Find if the same tag appears before the current tag + // If a tag has key, validate if the same tag key was added before + if (tagKey) { + // Find if the same tag key appears before the current tag key for (let i = 0; i < index; i++) { - // If found the same tag, then the current tag is invalid - if (tags[i].key === tagKey && tags[i].value === tag.value) { - return 'This tag has already been added. Remove the duplicate tag.'; + // If found the same tag key, then the current tag key is invalid + if (tags[i].key === tagKey) { + return 'Tag keys must be unique. Use a unique key.'; } } } @@ -72,6 +96,10 @@ export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagF name: `tags.${index}.value` as const, control, rules: { + maxLength: { + value: MAX_TAG_LENGTH, + message: '80 characters allowed. Use 80 characters or less.', + }, validate: (tagValue) => { if (tags) { const tag = tags[index]; @@ -79,17 +107,6 @@ export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagF if (!tagValue && tag.key) { return 'A value is required. Enter a value.'; } - // If a tag has both key and value, validate if the same tag was added before - if (tag.key && tagValue) { - // Find if the same tag appears before the current tag - for (let i = 0; i < index; i++) { - // If found the same tag, then the current tag is invalid - if (tags[i].key === tag.key && tags[i].value === tagValue) { - // return `false` instead of error message because we don't show error message on value field - return false; - } - } - } } return true; }, @@ -112,9 +129,6 @@ export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagF const onKeyCreate = useCallback( (value: string) => { - if (value.length > MAX_TAG_LENGTH) { - return; - } tagKeyController.field.onChange(value); }, [tagKeyController.field] @@ -122,17 +136,16 @@ export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagF const onValueCreate = useCallback( (value: string) => { - if (value.length > MAX_TAG_LENGTH) { - return; - } tagValueController.field.onChange(value); }, [tagValueController.field] ); const keyOptions = useMemo(() => { - return tagKeys.map((key) => ({ label: key })); - }, [tagKeys]); + return tagKeys + .filter((key) => !tags?.find((tag) => tag.key === key)) + .map((key) => ({ label: key })); + }, [tagKeys, tags]); const valueOptions = useMemo(() => { return tagValues.map((value) => ({ label: value })); @@ -162,50 +175,54 @@ export const ModelTagField = ({ index, tagKeys, tagValues, onDelete }: ModelTagF return ( - - + - + error={tagKeyController.fieldState.error?.message} + > + + + - - + - + error={tagValueController.fieldState.error?.message} + > + + + onDelete(index)}> From 8787bd4ac32c36ece67766c5fbfa1ad9e229572f Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Thu, 16 Mar 2023 15:13:13 +0800 Subject: [PATCH 31/75] feat: add form error call-out (#141) 1. Display error call-out after the first time trying to submit the form but with errors 2. Changed the form behavior to display multiple errors of one field --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../__tests__/error_call_out.test.tsx | 49 ++++++++ .../register_model_configuration.test.tsx | 114 +++++++++++++++++- .../__tests__/register_model_details.test.tsx | 5 +- .../__tests__/register_model_form.test.tsx | 7 +- .../register_model/__tests__/setup.tsx | 32 +++-- public/components/register_model/artifact.tsx | 3 +- .../register_model/artifact_file.tsx | 16 +-- public/components/register_model/constants.ts | 80 ++++++++++++ .../register_model/error_call_out.tsx | 66 ++++++++++ .../register_model/error_message.tsx | 29 +++++ .../register_model/model_configuration.tsx | 63 +++++++++- .../register_model/model_details.tsx | 20 +-- .../register_model/register_model.tsx | 20 ++- 13 files changed, 464 insertions(+), 40 deletions(-) create mode 100644 public/components/register_model/__tests__/error_call_out.test.tsx create mode 100644 public/components/register_model/error_call_out.tsx create mode 100644 public/components/register_model/error_message.tsx diff --git a/public/components/register_model/__tests__/error_call_out.test.tsx b/public/components/register_model/__tests__/error_call_out.test.tsx new file mode 100644 index 00000000..4821cc6f --- /dev/null +++ b/public/components/register_model/__tests__/error_call_out.test.tsx @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { screen } from '../../../../test/test_utils'; +import { setup } from './setup'; +import * as formHooks from '../register_model.hooks'; +import * as formAPI from '../register_model_api'; +import { ModelFileUploadManager } from '../model_file_upload_manager'; + +describe(' ErrorCallOut', () => { + const onSubmitWithFileMock = jest.fn().mockResolvedValue('model_id'); + const onSubmitWithURLMock = jest.fn(); + const uploadMock = jest.fn(); + + beforeEach(() => { + jest + .spyOn(formHooks, 'useModelTags') + .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitWithFileMock); + jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitWithURLMock); + jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should display error call-out', async () => { + const { user, nameInput } = await setup(); + await user.clear(nameInput); + await user.click(screen.getByRole('button', { name: /register model/i })); + + expect(screen.queryByLabelText(/Address errors in the form/i)).toBeInTheDocument(); + }); + + it('should not display error call-out after errors been fixed', async () => { + const { user, nameInput } = await setup(); + await user.clear(nameInput); + await user.click(screen.getByRole('button', { name: /register model/i })); + + expect(screen.queryByLabelText(/Address errors in the form/i)).toBeInTheDocument(); + + // fix the form errors + await user.type(nameInput, 'test model name'); + expect(screen.queryByLabelText(/Address errors in the form/i)).not.toBeInTheDocument(); + }); +}); diff --git a/public/components/register_model/__tests__/register_model_configuration.test.tsx b/public/components/register_model/__tests__/register_model_configuration.test.tsx index d8a75f96..25b62bfa 100644 --- a/public/components/register_model/__tests__/register_model_configuration.test.tsx +++ b/public/components/register_model/__tests__/register_model_configuration.test.tsx @@ -3,10 +3,30 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { screen } from '../../../../test/test_utils'; +import { screen, within } from '../../../../test/test_utils'; import { setup } from './setup'; +import * as formAPI from '../register_model_api'; +import * as formHooks from '../register_model.hooks'; +import { ModelFileUploadManager } from '../model_file_upload_manager'; describe(' Configuration', () => { + const onSubmitWithFileMock = jest.fn().mockResolvedValue('model_id'); + const onSubmitWithURLMock = jest.fn(); + const uploadMock = jest.fn(); + + beforeEach(() => { + jest + .spyOn(formHooks, 'useModelTags') + .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitWithFileMock); + jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitWithURLMock); + jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + it('should render a help flyout when click help button', async () => { const { user } = await setup(); @@ -14,4 +34,96 @@ describe(' Configuration', () => { await user.click(screen.getByTestId('model-configuration-help-button')); expect(screen.getByRole('dialog')).toBeInTheDocument(); }); + + it('should not allow to submit form if model_type is missing', async () => { + // Missing model_type + const invalidConfiguration = `{}`; + const result = await setup({ defaultValues: { configuration: invalidConfiguration } }); + await result.user.click(result.submitButton); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); + + // Field error messages + const configurationContainer = screen.getByTestId('ml-registerModelConfiguration'); + expect( + within(configurationContainer).queryByText(/model_type is required/i) + ).toBeInTheDocument(); + + // Error Callout + const errorCallOutContainer = screen.getByLabelText(/Address errors in the form/i); + expect( + within(errorCallOutContainer).queryByText(/JSON configuration: specify the model_type/i) + ).toBeInTheDocument(); + }); + + it('should not allow to submit form if model_type is invalid', async () => { + // Incorrect model_type type, model_type must be a string + const invalidConfiguration = `{ + "model_type": false + }`; + const result = await setup({ defaultValues: { configuration: invalidConfiguration } }); + await result.user.click(result.submitButton); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); + + // Field error messages + const configurationContainer = screen.getByTestId('ml-registerModelConfiguration'); + expect( + within(configurationContainer).queryByText(/model_type must be a string/i) + ).toBeInTheDocument(); + + // Error Callout + const errorCallOutContainer = screen.getByLabelText(/Address errors in the form/i); + expect( + within(errorCallOutContainer).queryByText(/JSON configuration: model_type must be a string./i) + ).toBeInTheDocument(); + }); + + it('should not allow to submit form if embedding_dimension is invalid', async () => { + // Incorrect embedding_dimension type, embedding_dimension must be a number + const invalidConfiguration = `{ + "model_type": "bert", + "embedding_dimension": "must_be_a_number" + }`; + const result = await setup({ defaultValues: { configuration: invalidConfiguration } }); + await result.user.click(result.submitButton); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); + + // Field error messages + const configurationContainer = screen.getByTestId('ml-registerModelConfiguration'); + expect( + within(configurationContainer).queryByText(/embedding_dimension must be a number/i) + ).toBeInTheDocument(); + + // Error Callout + const errorCallOutContainer = screen.getByLabelText(/Address errors in the form/i); + expect( + within(errorCallOutContainer).queryByText( + /JSON configuration: embedding_dimension must be a number/i + ) + ).toBeInTheDocument(); + }); + + it('should not allow to submit form if framework_type is invalid', async () => { + // Incorrect framework_type, framework_type must be a string + const invalidConfiguration = `{ + "model_type": "bert", + "framework_type": 384 + }`; + const result = await setup({ defaultValues: { configuration: invalidConfiguration } }); + await result.user.click(result.submitButton); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); + + // Field error messages + const configurationContainer = screen.getByTestId('ml-registerModelConfiguration'); + expect( + within(configurationContainer).queryByText(/framework_type must be a string/i) + ).toBeInTheDocument(); + + // Error Callout + const errorCallOutContainer = screen.getByLabelText(/Address errors in the form/i); + expect( + within(errorCallOutContainer).queryByText( + /JSON configuration: framework_type must be a string./i + ) + ).toBeInTheDocument(); + }); }); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index 5f10860b..64beaf72 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -70,14 +70,13 @@ describe(' Details', () => { expect(onSubmitMock).not.toHaveBeenCalled(); }); - it('should NOT submit the register model form if model description is empty', async () => { + it('should submit the register model form if model description is empty', async () => { const result = await setup(); await result.user.clear(result.descriptionInput); await result.user.click(result.submitButton); - expect(result.descriptionInput).toBeInvalid(); - expect(onSubmitMock).not.toHaveBeenCalled(); + expect(onSubmitMock).toHaveBeenCalled(); }); it('should NOT allow to input a model description which exceed max length: 200', async () => { diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index cbf20b97..3da411ff 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -130,14 +130,13 @@ describe(' Form', () => { }); it('should display number of form errors in form footer', async () => { - const { user, nameInput, descriptionInput } = await setup(); + const { user, nameInput } = await setup(); await user.clear(nameInput); - await user.clear(descriptionInput); await user.click(screen.getByRole('button', { name: /register model/i })); - expect(screen.queryByText(/2 form errors/i)).toBeInTheDocument(); + expect(screen.queryByText(/1 form error/i)).toBeInTheDocument(); await user.type(nameInput, 'test model name'); - expect(screen.queryByText(/1 form error/i)).toBeInTheDocument(); + expect(screen.queryByText(/1 form error/i)).not.toBeInTheDocument(); }); it('should call addSuccess to display a success toast', async () => { diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index db2bfd4a..5bbc3493 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -11,11 +11,13 @@ import { UserEvent } from '@testing-library/user-event/dist/types/setup/setup'; import { RegisterModelForm } from '../register_model'; import { Model } from '../../../apis/model'; import { render, RenderWithRouteProps, screen, waitFor } from '../../../../test/test_utils'; +import { ModelFileFormData, ModelUrlFormData } from '../register_model.types'; jest.mock('../../../apis/task'); -interface SetupOptions extends RenderWithRouteProps { +interface SetupOptions extends Partial { mode?: 'model' | 'version' | 'import'; + defaultValues?: Partial | Partial; } interface SetupReturn { @@ -27,25 +29,41 @@ interface SetupReturn { versionNotesInput: HTMLTextAreaElement; } +const CONFIGURATION = `{ + "model_type": "bert", + "embedding_dimension": 384, + "framework_type": "sentence_transformers" +}`; + +const DEFAULT_VALUES = { + name: '', + description: '', + version: '1', + configuration: CONFIGURATION, + tags: [{ key: '', value: '' }], +}; + export async function setup(options: { - route: string; + route?: string; mode: 'version'; + defaultValues?: Partial | Partial; }): Promise>; export async function setup(options?: { - route: string; - mode: 'model' | 'import'; + route?: string; + mode?: 'model' | 'import'; + defaultValues?: Partial | Partial; }): Promise; export async function setup( - { route, mode }: SetupOptions = { + { route, mode, defaultValues }: SetupOptions = { route: '/', mode: 'model', } ) { render( - + , - { route } + { route: route ?? '/' } ); await waitFor(() => expect(screen.queryByLabelText('Model Form Loading')).toBe(null)); const nameInput = screen.queryByLabelText(/^name$/i); diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index 6cfa413a..54319de2 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -6,9 +6,10 @@ import React, { useState } from 'react'; import { EuiTitle, htmlIdGenerator, EuiSpacer, EuiText, EuiRadio, EuiLink } from '@elastic/eui'; -import { MAX_MODEL_FILE_SIZE, ModelFileUploader } from './artifact_file'; +import { ModelFileUploader } from './artifact_file'; import { ArtifactUrl } from './artifact_url'; import { ONE_GB } from '../../../common/constant'; +import { MAX_MODEL_FILE_SIZE } from './constants'; export const ArtifactPanel = () => { const [selectedSource, setSelectedSource] = useState<'source_from_computer' | 'source_from_url'>( diff --git a/public/components/register_model/artifact_file.tsx b/public/components/register_model/artifact_file.tsx index 3d4c9e9d..b5a12b2d 100644 --- a/public/components/register_model/artifact_file.tsx +++ b/public/components/register_model/artifact_file.tsx @@ -8,18 +8,12 @@ import { EuiFormRow, EuiFilePicker } from '@elastic/eui'; import { useController, useFormContext } from 'react-hook-form'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -import { ONE_GB } from '../../../common/constant'; +import { CUSTOM_FORM_ERROR_TYPES, MAX_MODEL_FILE_SIZE } from './constants'; -// 4GB -export const MAX_MODEL_FILE_SIZE = 4 * ONE_GB; - -function validateFile(file: File) { - if (file.size > MAX_MODEL_FILE_SIZE) { +function validateFileSize(file?: File) { + if (file && file.size > MAX_MODEL_FILE_SIZE) { return 'Maximum file size exceeded. Add a smaller file.'; } - if (!file.name.endsWith('.zip')) { - return 'Invalid file format. Add a ZIP(.zip) file.'; - } return true; } @@ -30,7 +24,9 @@ export const ModelFileUploader = () => { control, rules: { required: { value: true, message: 'A file is required. Add a file.' }, - validate: validateFile, + validate: { + [CUSTOM_FORM_ERROR_TYPES.FILE_SIZE_EXCEED_LIMIT]: validateFileSize, + }, }, shouldUnregister: true, }); diff --git a/public/components/register_model/constants.ts b/public/components/register_model/constants.ts index 9ec2d9d6..107ecd27 100644 --- a/public/components/register_model/constants.ts +++ b/public/components/register_model/constants.ts @@ -2,5 +2,85 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ +import { ONE_GB } from '../../../common/constant'; export const MAX_CHUNK_SIZE = 10 * 1000 * 1000; +export const MAX_MODEL_FILE_SIZE = 4 * ONE_GB; + +export enum CUSTOM_FORM_ERROR_TYPES { + DUPLICATE_NAME = 'duplicateName', + FILE_SIZE_EXCEED_LIMIT = 'fileSizeExceedLimit', + INVALID_CONFIGURATION = 'invalidConfiguration', + CONFIGURATION_MISSING_MODEL_TYPE = 'configurationMissingModelType', + INVALID_MODEL_TYPE_VALUE = 'invalidModelTypeValue', + INVALID_EMBEDDING_DIMENSION_VALUE = 'invalidEmbeddingDimensionValue', + INVALID_FRAMEWORK_TYPE_VALUE = 'invalidFrameworkTypeValue', +} + +export const FORM_ERRORS = [ + { + field: 'name', + type: 'required', + message: 'Name: Enter a name.', + }, + { + field: 'name', + type: CUSTOM_FORM_ERROR_TYPES.DUPLICATE_NAME, + message: 'Name: Use a unique name.', + }, + { + field: 'modelFile', + type: 'required', + message: 'File: Add a file.', + }, + { + field: 'modelFile', + type: CUSTOM_FORM_ERROR_TYPES.FILE_SIZE_EXCEED_LIMIT, + message: `File: Add a file below ${MAX_MODEL_FILE_SIZE / ONE_GB} GB.`, + }, + { + field: 'modelURL', + type: 'required', + message: 'URL: Enter a URL.', + }, + { + field: 'modelURL', + type: 'pattern', + message: 'URL: Enter a valid URL.', + }, + { + field: 'modelFormat', + type: 'required', + message: 'Model file format: Select a model file format.', + }, + { + field: 'configuration', + type: 'required', + message: 'JSON configuration: Add a JSON configuration.', + }, + { + field: 'configuration', + type: CUSTOM_FORM_ERROR_TYPES.INVALID_CONFIGURATION, + message: 'JSON configuration: Add valid JSON.', + }, + { + field: 'configuration', + type: CUSTOM_FORM_ERROR_TYPES.CONFIGURATION_MISSING_MODEL_TYPE, + message: 'JSON configuration: specify the model_type.', + }, + { + field: 'configuration', + type: CUSTOM_FORM_ERROR_TYPES.INVALID_MODEL_TYPE_VALUE, + message: 'JSON configuration: model_type must be a string.', + }, + { + field: 'configuration', + type: CUSTOM_FORM_ERROR_TYPES.INVALID_EMBEDDING_DIMENSION_VALUE, + message: 'JSON configuration: embedding_dimension must be a number.', + }, + { + field: 'configuration', + type: CUSTOM_FORM_ERROR_TYPES.INVALID_FRAMEWORK_TYPE_VALUE, + message: 'JSON configuration: framework_type must be a string.', + }, +]; diff --git a/public/components/register_model/error_call_out.tsx b/public/components/register_model/error_call_out.tsx new file mode 100644 index 00000000..84607aaa --- /dev/null +++ b/public/components/register_model/error_call_out.tsx @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EuiCallOut, EuiText } from '@elastic/eui'; +import React, { useMemo } from 'react'; +import { useFormContext } from 'react-hook-form'; +import { FORM_ERRORS } from './constants'; + +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; + +export const ErrorCallOut = () => { + const form = useFormContext(); + + const errors = useMemo(() => { + const messages: string[] = []; + Object.keys(form.formState.errors).forEach((errorField) => { + const error = form.formState.errors[errorField as keyof typeof form.formState.errors]; + // If form have: criteriaMode: 'all', error.types will be set a value + // error.types will contain all the errors of each field + // In this case, we will display all the errors in the callout + if (error?.types) { + Object.keys(error.types).forEach((k) => { + const errorMessage = FORM_ERRORS.find((e) => e.field === errorField && e.type === k); + if (errorMessage) { + messages.push(errorMessage.message); + } + }); + } else { + // If form didn't have: criteriaMode: 'all', the default behavior of react-hook-form is + // to only produce the first error, even if a field has multiple errors. + // In this case, error.types won't be set, and error.type and error.field represent the + // first error + const errorMessage = FORM_ERRORS.find( + (e) => e.field === errorField && e.type === error?.type + ); + if (errorMessage) { + messages.push(errorMessage.message); + } + } + }); + return messages; + }, [form]); + + if (errors.length === 0) { + return null; + } + + return ( + + +
    + {errors.map((e) => ( +
  • - {e}
  • + ))} +
+
+
+ ); +}; diff --git a/public/components/register_model/error_message.tsx b/public/components/register_model/error_message.tsx new file mode 100644 index 00000000..6db6f386 --- /dev/null +++ b/public/components/register_model/error_message.tsx @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { FieldError } from 'react-hook-form'; + +interface ErrorMessageProps { + error?: FieldError; +} + +export const ErrorMessage = ({ error }: ErrorMessageProps) => { + if (!error) { + return null; + } + + if (error.types) { + return ( +
    + {Object.keys(error.types).map((k) => ( +
  • {error.types?.[k]}
  • + ))} +
+ ); + } + + return {error.message}; +}; diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index e43773f9..0d8bc423 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -20,6 +20,8 @@ import '../../ace-themes/sql_console.js'; import { FORM_ITEM_WIDTH } from './form_constants'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { HelpFlyout } from './help_flyout'; +import { CUSTOM_FORM_ERROR_TYPES } from './constants'; +import { ErrorMessage } from './error_message'; function validateConfigurationObject(value: string) { try { @@ -30,6 +32,55 @@ function validateConfigurationObject(value: string) { return true; } +function validateModelType(value: string) { + try { + const config = JSON.parse(value.trim()); + if (!('model_type' in config)) { + return 'model_type is required. Specify the model_type'; + } + } catch { + return true; + } + return true; +} + +function validateModelTypeValue(value: string) { + try { + const config = JSON.parse(value.trim()); + if ('model_type' in config && typeof config.model_type !== 'string') { + return 'model_type must be a string'; + } + } catch { + return true; + } + return true; +} + +function validateEmbeddingDimensionValue(value: string) { + try { + const config = JSON.parse(value.trim()); + if ('embedding_dimension' in config && typeof config.embedding_dimension !== 'number') { + return 'embedding_dimension must be a number'; + } + } catch { + return true; + } + + return true; +} + +function validateFrameworkTypeValue(value: string) { + try { + const config = JSON.parse(value.trim()); + if ('framework_type' in config && typeof config.framework_type !== 'string') { + return 'framework_type must be a string'; + } + } catch { + return true; + } + return true; +} + export const ConfigurationPanel = () => { const { control } = useFormContext(); const [isHelpVisible, setIsHelpVisible] = useState(false); @@ -38,12 +89,18 @@ export const ConfigurationPanel = () => { control, rules: { required: { value: true, message: 'Configuration is required.' }, - validate: validateConfigurationObject, + validate: { + [CUSTOM_FORM_ERROR_TYPES.INVALID_CONFIGURATION]: validateConfigurationObject, + [CUSTOM_FORM_ERROR_TYPES.CONFIGURATION_MISSING_MODEL_TYPE]: validateModelType, + [CUSTOM_FORM_ERROR_TYPES.INVALID_MODEL_TYPE_VALUE]: validateModelTypeValue, + [CUSTOM_FORM_ERROR_TYPES.INVALID_EMBEDDING_DIMENSION_VALUE]: validateEmbeddingDimensionValue, + [CUSTOM_FORM_ERROR_TYPES.INVALID_FRAMEWORK_TYPE_VALUE]: validateFrameworkTypeValue, + }, }, }); return ( -
+

Configuration

@@ -69,7 +126,7 @@ export const ConfigurationPanel = () => { style={{ maxWidth: FORM_ITEM_WIDTH * 2 }} label="Configuration in JSON" isInvalid={Boolean(configurationFieldController.fieldState.error)} - error={configurationFieldController.fieldState.error?.message} + error={} labelAppend={ setIsHelpVisible(true)} diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 337f60f8..8586ff1b 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -8,6 +8,7 @@ import { EuiFieldText, EuiFormRow, EuiTitle, EuiTextArea, EuiText } from '@elast import { useController, useFormContext } from 'react-hook-form'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { APIProvider } from '../../apis/api_provider'; +import { CUSTOM_FORM_ERROR_TYPES } from './constants'; const NAME_MAX_LENGTH = 80; const DESCRIPTION_MAX_LENGTH = 200; @@ -29,10 +30,12 @@ export const ModelDetailsPanel = () => { control, rules: { required: { value: true, message: 'Name can not be empty' }, - validate: async (name) => { - return !modelNameFocusedRef.current && !!name && (await isDuplicateModelName(name)) - ? 'This name is already in use. Use a unique name for the model.' - : undefined; + validate: { + [CUSTOM_FORM_ERROR_TYPES.DUPLICATE_NAME]: async (name) => { + return !modelNameFocusedRef.current && !!name && (await isDuplicateModelName(name)) + ? 'This name is already in use. Use a unique name for the model.' + : undefined; + }, }, }, }); @@ -40,9 +43,6 @@ export const ModelDetailsPanel = () => { const descriptionFieldController = useController({ name: 'description', control, - rules: { - required: { value: true, message: 'Description can not be empty' }, - }, }); const { ref: nameInputRef, ...nameField } = nameFieldController.field; @@ -86,7 +86,11 @@ export const ModelDetailsPanel = () => { /> + Description - optional + + } isInvalid={Boolean(descriptionFieldController.fieldState.error)} error={descriptionFieldController.fieldState.error?.message} helpText={`${Math.max( diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index bb00baae..3b600c37 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -41,19 +41,25 @@ import { routerPaths } from '../../../common/router_paths'; import { modelTaskManager } from './model_task_manager'; import { ModelVersionNotesPanel } from './model_version_notes'; import { modelRepositoryManager } from '../../utils/model_repository_manager'; +import { ErrorCallOut } from './error_call_out'; const DEFAULT_VALUES = { name: '', description: '', version: '1', - configuration: '{}', + configuration: '', tags: [{ key: '', value: '' }], }; const FORM_ID = 'mlModelUploadForm'; -export const RegisterModelForm = () => { +interface RegisterModelFormProps { + defaultValues?: Partial | Partial; +} + +export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterModelFormProps) => { const history = useHistory(); + const [isSubmitted, setIsSubmitted] = useState(false); const { id: latestVersionId } = useParams<{ id: string | undefined }>(); const [modelGroupName, setModelGroupName] = useState(); const searchParams = useSearchParams(); @@ -81,7 +87,8 @@ export const RegisterModelForm = () => { const form = useForm({ mode: 'onChange', - defaultValues: DEFAULT_VALUES, + defaultValues, + criteriaMode: 'all', }); const onSubmit = useCallback( @@ -277,6 +284,12 @@ export const RegisterModelForm = () => { {formHeader} + {isSubmitted && !form.formState.isValid && ( + <> + + + + )} {partials.map((FormPartial, i) => ( @@ -307,6 +320,7 @@ export const RegisterModelForm = () => { disabled={form.formState.isSubmitting} isLoading={form.formState.isSubmitting} type="submit" + onClick={() => setIsSubmitted(true)} fill > {latestVersionId ? 'Register version' : 'Register model'} From 88e983c48555b1d2ae26033b5928489710577e28 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Fri, 17 Mar 2023 13:32:28 +0800 Subject: [PATCH 32/75] feat: add model file format select (#143) Added a select to allow user to select model file format from the dropdown menu, user can select from ONNX or Torchscript --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../register_model_artifact.test.tsx | 8 +-- .../register_model_file_format.test.tsx | 67 ++++++++++++++++++ .../register_model/__tests__/setup.tsx | 10 ++- public/components/register_model/artifact.tsx | 70 ++++++++++++++++++- public/components/register_model/constants.ts | 5 ++ .../register_model/register_model.tsx | 1 + .../register_model/register_model.types.ts | 1 + 7 files changed, 154 insertions(+), 8 deletions(-) create mode 100644 public/components/register_model/__tests__/register_model_file_format.test.tsx diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 848bef71..683318ec 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -83,13 +83,13 @@ describe(' Artifact', () => { // Empty model file selection by clicking the `Remove` button on EuiFilePicker await result.user.click(screen.getByLabelText(/clear selected files/i)); - const modelFileInput = screen.getByLabelText(/file/i); + const modelFileInput = screen.getByLabelText(/^file$/i); // User select a file with maximum accepted size const validFile = new File(['test model file'], 'model.zip', { type: 'application/zip' }); Object.defineProperty(validFile, 'size', { value: 4 * ONE_GB }); await result.user.upload(modelFileInput, validFile); - expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeValid(); + expect(screen.getByLabelText(/^file$/i, { selector: 'input[type="file"]' })).toBeValid(); await result.user.click(result.submitButton); expect(onSubmitWithFileMock).toHaveBeenCalled(); }); @@ -100,7 +100,7 @@ describe(' Artifact', () => { // Empty model file selection by clicking the `Remove` button on EuiFilePicker await result.user.click(screen.getByLabelText(/clear selected files/i)); - const modelFileInput = screen.getByLabelText(/file/i); + const modelFileInput = screen.getByLabelText(/^file$/i); // File size can not exceed 4GB const invalidFile = new File(['test model file'], 'model.zip', { type: 'application/zip' }); Object.defineProperty(invalidFile, 'size', { value: 4 * ONE_GB + 1 }); @@ -117,7 +117,7 @@ describe(' Artifact', () => { // Empty model file selection by clicking the `Remove` button on EuiFilePicker await result.user.click(screen.getByLabelText(/clear selected files/i)); - const modelFileInput = screen.getByLabelText(/file/i); + const modelFileInput = screen.getByLabelText(/^file$/i); // Only ZIP(.zip) file is allowed const invalidFile = new File(['test model file'], 'model.json', { type: 'application/json' }); await result.user.upload(modelFileInput, invalidFile); diff --git a/public/components/register_model/__tests__/register_model_file_format.test.tsx b/public/components/register_model/__tests__/register_model_file_format.test.tsx new file mode 100644 index 00000000..b3e014d4 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_file_format.test.tsx @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { screen } from '../../../../test/test_utils'; +import { setup } from './setup'; +import * as formHooks from '../register_model.hooks'; +import { ModelFileUploadManager } from '../model_file_upload_manager'; +import * as formAPI from '../register_model_api'; + +jest.mock('../../../apis/model_repository'); + +describe(' Artifact', () => { + const onSubmitWithFileMock = jest.fn().mockResolvedValue('model_id'); + const onSubmitWithURLMock = jest.fn(); + const uploadMock = jest.fn(); + + beforeEach(() => { + jest + .spyOn(formHooks, 'useModelTags') + .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitWithFileMock); + jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitWithURLMock); + jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should not submit the form if model format is not selected', async () => { + const result = await setup(); + result.user.click(screen.getByLabelText('Clear input')); + + await result.user.click(result.submitButton); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); + }); + + it('should display error messages if model format is not selected', async () => { + const result = await setup(); + result.user.click(screen.getByLabelText('Clear input')); + + await result.user.click(result.submitButton); + expect(onSubmitWithFileMock).not.toHaveBeenCalled(); + + // Field error message + expect( + screen.queryByText(/Model file format is required. Select a model file format/i) + ).toBeInTheDocument(); + + // Error callout + expect(screen.queryByText(/Model file format: Select a model format/i)).toBeInTheDocument(); + }); + + it('should submit the form with selected model file format', async () => { + const result = await setup(); + const fileFormatInput = screen.getByLabelText(/^Model file format$/i); + await result.user.click(fileFormatInput); + await result.user.click(screen.getByText('Torchscript(.pt)')); + + await result.user.click(result.submitButton); + expect(onSubmitWithFileMock).toHaveBeenCalledWith( + expect.objectContaining({ modelFileFormat: 'TORCH_SCRIPT' }) + ); + }); +}); diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index 5bbc3493..650cd938 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -10,7 +10,7 @@ import { UserEvent } from '@testing-library/user-event/dist/types/setup/setup'; import { RegisterModelForm } from '../register_model'; import { Model } from '../../../apis/model'; -import { render, RenderWithRouteProps, screen, waitFor } from '../../../../test/test_utils'; +import { render, RenderWithRouteProps, screen, waitFor, within } from '../../../../test/test_utils'; import { ModelFileFormData, ModelUrlFormData } from '../register_model.types'; jest.mock('../../../apis/task'); @@ -71,7 +71,8 @@ export async function setup( const submitButton = screen.getByRole('button', { name: mode === 'version' ? /register version/i : /register model/i, }); - const modelFileInput = screen.queryByLabelText(/file/i); + const modelFileInput = screen.queryByLabelText(/^file$/i); + const fileFormatInput = screen.queryByLabelText(/^Model file format$/i); const form = screen.getByTestId('mlCommonsPlugin-registerModelForm'); const user = userEvent.setup(); const versionNotesInput = screen.getByLabelText(/notes/i); @@ -84,6 +85,11 @@ export async function setup( ); } + if (fileFormatInput) { + await user.click(fileFormatInput); + await user.click(screen.getByText('ONNX(.onnx)')); + } + if (mode === 'version') { return { submitButton, diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index 54319de2..0b3891e2 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -3,19 +3,70 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useState } from 'react'; -import { EuiTitle, htmlIdGenerator, EuiSpacer, EuiText, EuiRadio, EuiLink } from '@elastic/eui'; +import React, { useCallback, useMemo, useState } from 'react'; +import { + EuiTitle, + htmlIdGenerator, + EuiSpacer, + EuiText, + EuiRadio, + EuiLink, + EuiFormRow, + EuiComboBox, + EuiComboBoxOptionOption, +} from '@elastic/eui'; +import { useController, useFormContext } from 'react-hook-form'; import { ModelFileUploader } from './artifact_file'; import { ArtifactUrl } from './artifact_url'; import { ONE_GB } from '../../../common/constant'; import { MAX_MODEL_FILE_SIZE } from './constants'; +import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; + +const FILE_FORMAT_OPTIONS = [ + { + label: 'ONNX(.onnx)', + value: 'ONNX', + }, + { + label: 'Torchscript(.pt)', + value: 'TORCH_SCRIPT', + }, +]; export const ArtifactPanel = () => { + const { control } = useFormContext(); const [selectedSource, setSelectedSource] = useState<'source_from_computer' | 'source_from_url'>( 'source_from_computer' ); + const modelFileFormatController = useController({ + name: 'modelFileFormat', + control, + rules: { + required: { + value: true, + message: 'Model file format is required. Select a model file format.', + }, + }, + }); + + const { ref: fileFormatInputRef, ...fileFormatField } = modelFileFormatController.field; + + const selectedFileFormatOption = useMemo(() => { + if (fileFormatField.value) { + return FILE_FORMAT_OPTIONS.find((fmt) => fmt.value === fileFormatField.value); + } + }, [fileFormatField]); + + const onFileFormatChange = useCallback( + (options: Array>) => { + const value = options[0]?.value; + fileFormatField.onChange(value); + }, + [fileFormatField] + ); + return (
@@ -64,6 +115,21 @@ export const ArtifactPanel = () => {
  • Tokenizer file, accepted format: JSON(.json)
  • + + + +
    ); }; diff --git a/public/components/register_model/constants.ts b/public/components/register_model/constants.ts index 107ecd27..6c653d98 100644 --- a/public/components/register_model/constants.ts +++ b/public/components/register_model/constants.ts @@ -53,6 +53,11 @@ export const FORM_ERRORS = [ type: 'required', message: 'Model file format: Select a model file format.', }, + { + field: 'modelFileFormat', + type: 'required', + message: 'Model file format: Select a model format.', + }, { field: 'configuration', type: 'required', diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 3b600c37..941abd56 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -49,6 +49,7 @@ const DEFAULT_VALUES = { version: '1', configuration: '', tags: [{ key: '', value: '' }], + modelFileFormat: '', }; const FORM_ID = 'mlModelUploadForm'; diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index b7a2baec..b359eb8d 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -13,6 +13,7 @@ interface ModelFormBase { version: string; description: string; configuration: string; + modelFileFormat: string; tags?: Tag[]; versionNotes?: string; } From 69a6823491818631f57d3eeb1951e27d0e4c4931 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Mon, 20 Mar 2023 18:07:28 +0800 Subject: [PATCH 33/75] feat: tweaks form section titles per new design (#144) Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- public/components/register_model/artifact.tsx | 8 ++----- .../register_model/model_configuration.tsx | 5 ++-- .../register_model/model_details.tsx | 8 +++---- .../components/register_model/model_tags.tsx | 6 ++--- .../register_model/model_version_notes.tsx | 6 ++--- .../register_model/register_model.tsx | 24 ++++++++++++++++++- 6 files changed, 37 insertions(+), 20 deletions(-) diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index 0b3891e2..9498f7fe 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -5,7 +5,6 @@ import React, { useCallback, useMemo, useState } from 'react'; import { - EuiTitle, htmlIdGenerator, EuiSpacer, EuiText, @@ -69,12 +68,9 @@ export const ArtifactPanel = () => { return (
    - -

    File and version information

    -
    - +

    Artifact

    -
    + The zipped artifact must include a model file and a tokenizer file. If uploading with a diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index 0d8bc423..73436742 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -6,7 +6,6 @@ import React, { useState } from 'react'; import { EuiFormRow, - EuiTitle, EuiCodeEditor, EuiText, EuiTextColor, @@ -101,9 +100,9 @@ export const ConfigurationPanel = () => { return (
    - +

    Configuration

    -
    + The model configuration specifies the{' '} diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 8586ff1b..01594f93 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -4,7 +4,7 @@ */ import React, { useCallback, useRef } from 'react'; -import { EuiFieldText, EuiFormRow, EuiTitle, EuiTextArea, EuiText } from '@elastic/eui'; +import { EuiFieldText, EuiFormRow, EuiTextArea, EuiText } from '@elastic/eui'; import { useController, useFormContext } from 'react-hook-form'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { APIProvider } from '../../apis/api_provider'; @@ -60,9 +60,9 @@ export const ModelDetailsPanel = () => { return (
    - -

    Model Details

    -
    + +

    Details

    +
    { return (
    - +

    Tags - optional

    -
    + {isRegisterNewVersion ? ( diff --git a/public/components/register_model/model_version_notes.tsx b/public/components/register_model/model_version_notes.tsx index e0983c5b..23b2fd7c 100644 --- a/public/components/register_model/model_version_notes.tsx +++ b/public/components/register_model/model_version_notes.tsx @@ -4,7 +4,7 @@ */ import React from 'react'; -import { EuiTitle, EuiSpacer, EuiFormRow, EuiTextArea } from '@elastic/eui'; +import { EuiText, EuiSpacer, EuiFormRow, EuiTextArea } from '@elastic/eui'; import { useFormContext, useController } from 'react-hook-form'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; @@ -22,11 +22,11 @@ export const ModelVersionNotesPanel = () => { return (
    - +

    Version notes - optional

    -
    + | Partial; } +const ModelOverviewTitle = () => { + return ( + +

    Model overview

    +
    + ); +}; + +const FileAndVersionTitle = () => { + return ( + +

    File and version information

    +
    + ); +}; + export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterModelFormProps) => { const history = useHistory(); const [isSubmitted, setIsSubmitted] = useState(false); @@ -78,8 +94,10 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo formType === 'import' ? [ModelDetailsPanel, ModelTagsPanel, ModelVersionNotesPanel] : [ + ...(latestVersionId ? [] : [ModelOverviewTitle]), ...(latestVersionId ? [] : [ModelDetailsPanel]), ...(latestVersionId ? [] : [ModelTagsPanel]), + ...(latestVersionId ? [] : [FileAndVersionTitle]), ArtifactPanel, ConfigurationPanel, ...(latestVersionId ? [ModelTagsPanel] : []), @@ -294,7 +312,11 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo {partials.map((FormPartial, i) => ( - + {FormPartial === ModelOverviewTitle || FormPartial === FileAndVersionTitle ? ( + + ) : ( + + )} ))} From 95918d1a46afcf0295ff49bcebaa31ecd8eddf84 Mon Sep 17 00:00:00 2001 From: xyinshen Date: Wed, 22 Mar 2023 09:29:16 +0800 Subject: [PATCH 34/75] Feature change register form max width to 1000px and make it centered if container growth (#148) * fix: update-register-form-hearder-descriptions Signed-off-by: xyinshen * fix: Change register form max-width to 1000px and make it centered if container growth Signed-off-by: xyinshen * fix: Change register form max-width to 1000px and make it centered if container growth Signed-off-by: xyinshen --------- Signed-off-by: xyinshen Signed-off-by: Lin Wang --- .../register_model/register_model.tsx | 124 ++++++++++-------- 1 file changed, 66 insertions(+), 58 deletions(-) diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 4db0c8fa..d3ec94f1 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -19,6 +19,7 @@ import { EuiTextColor, EuiLink, EuiLoadingSpinner, + EuiPageContent, } from '@elastic/eui'; import useObservable from 'react-use/lib/useObservable'; import { from } from 'rxjs'; @@ -293,65 +294,72 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo } return ( - - - - {formHeader} - - {isSubmitted && !form.formState.isValid && ( - <> - - - - )} - {partials.map((FormPartial, i) => ( - - - {FormPartial === ModelOverviewTitle || FormPartial === FileAndVersionTitle ? ( - - ) : ( - + + + + + {formHeader} + + {isSubmitted && !form.formState.isValid && ( + <> + + + + )} + {partials.map((FormPartial, i) => ( + + + {FormPartial === ModelOverviewTitle || FormPartial === FileAndVersionTitle ? ( + + ) : ( + + )} + + ))} + + + + + + {errorCount > 0 && ( + + + {errorCount} form {errorCount > 1 ? 'errors' : 'error'} + + )} - - ))} - - - - - - {errorCount > 0 && ( - - - {errorCount} form {errorCount > 1 ? 'errors' : 'error'} - + + setIsSubmitted(true)} + fill + > + {latestVersionId ? 'Register version' : 'Register model'} + - )} - - setIsSubmitted(true)} - fill - > - {latestVersionId ? 'Register version' : 'Register model'} - - - - - - + + + + + ); }; From 33705dcf83683ac25f59222051945ec50dc63972 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Fri, 24 Mar 2023 14:35:32 +0800 Subject: [PATCH 35/75] test: initiate the use of MSW for API mocking (#147) This commit initiate the use of MSW for API mocking, model repository APIs are now mocked by MSW and the usages of jest.mock are removed accordingly --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- package.json | 3 +- public/apis/__mocks__/model_repository.ts | 49 -- .../register_model_artifact.test.tsx | 2 - .../register_model_file_format.test.tsx | 2 - .../__tests__/register_model_form.test.tsx | 1 - .../__tests__/index.test.tsx | 2 - .../tests/model_repository_manager.test.ts | 2 - test/fetch_polyfill.ts | 13 + test/mocks/data/model_config.ts | 23 + test/mocks/data/model_repository.ts | 160 +++++ test/mocks/handlers.ts | 17 + test/mocks/server.ts | 9 + test/setup.jest.ts | 10 + test/setupTests.ts | 5 + test/setup_dashboard.ts | 49 ++ yarn.lock | 647 +++++++++++++++++- 16 files changed, 927 insertions(+), 67 deletions(-) delete mode 100644 public/apis/__mocks__/model_repository.ts create mode 100644 test/fetch_polyfill.ts create mode 100644 test/mocks/data/model_config.ts create mode 100644 test/mocks/data/model_repository.ts create mode 100644 test/mocks/handlers.ts create mode 100644 test/mocks/server.ts create mode 100644 test/setup_dashboard.ts diff --git a/package.json b/package.json index 50ef67f0..9475571a 100644 --- a/package.json +++ b/package.json @@ -29,6 +29,7 @@ "@testing-library/user-event": "^14.4.3", "husky": "^8.0.0", "@types/papaparse": "^5.3.5", - "lint-staged": "^10.0.0" + "lint-staged": "^10.0.0", + "msw": "^1.1.0" } } diff --git a/public/apis/__mocks__/model_repository.ts b/public/apis/__mocks__/model_repository.ts deleted file mode 100644 index 23d70856..00000000 --- a/public/apis/__mocks__/model_repository.ts +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -export class ModelRepository { - public getPreTrainedModels() { - return Promise.resolve({ - 'sentence-transformers/all-distilroberta-v1': { - version: '1.0.1', - description: - 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', - torch_script: { - model_url: - 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/sentence-transformers_all-distilroberta-v1-1.0.1-torch_script.zip', - config_url: - 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/config.json', - }, - onnx: { - model_url: - 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/onnx/sentence-transformers_all-distilroberta-v1-1.0.1-onnx.zip', - config_url: - 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/onnx/config.json', - }, - }, - }); - } - - public getPreTrainedModelConfig() { - return Promise.resolve({ - name: 'sentence-transformers/all-distilroberta-v1', - version: '1.0.1', - description: - 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', - model_task_type: 'TEXT_EMBEDDING', - model_format: 'TORCH_SCRIPT', - model_content_size_in_bytes: 330811571, - model_content_hash_value: '92bc10216c720b57a6bab1d7ca2cc2e559156997212a7f0d8bb70f2edfedc78b', - model_config: { - model_type: 'roberta', - embedding_dimension: 768, - framework_type: 'sentence_transformers', - all_config: - '{"_name_or_path":"distilroberta-base","architectures":["RobertaForMaskedLM"],"attention_probs_dropout_prob":0.1,"bos_token_id":0,"eos_token_id":2,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":768,"initializer_range":0.02,"intermediate_size":3072,"layer_norm_eps":0.00001,"max_position_embeddings":514,"model_type":"roberta","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":1,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":1,"use_cache":true,"vocab_size":50265}', - }, - created_time: 1676072210947, - }); - } -} diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index 683318ec..c22f6a2b 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -10,8 +10,6 @@ import { ModelFileUploadManager } from '../model_file_upload_manager'; import * as formAPI from '../register_model_api'; import { ONE_GB } from '../../../../common/constant'; -jest.mock('../../../apis/model_repository'); - describe(' Artifact', () => { const onSubmitWithFileMock = jest.fn().mockResolvedValue('model_id'); const onSubmitWithURLMock = jest.fn(); diff --git a/public/components/register_model/__tests__/register_model_file_format.test.tsx b/public/components/register_model/__tests__/register_model_file_format.test.tsx index b3e014d4..e7b2e865 100644 --- a/public/components/register_model/__tests__/register_model_file_format.test.tsx +++ b/public/components/register_model/__tests__/register_model_file_format.test.tsx @@ -9,8 +9,6 @@ import * as formHooks from '../register_model.hooks'; import { ModelFileUploadManager } from '../model_file_upload_manager'; import * as formAPI from '../register_model_api'; -jest.mock('../../../apis/model_repository'); - describe(' Artifact', () => { const onSubmitWithFileMock = jest.fn().mockResolvedValue('model_id'); const onSubmitWithURLMock = jest.fn(); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 3da411ff..fb350bc6 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -20,7 +20,6 @@ jest.mock('../../../../../../src/plugins/opensearch_dashboards_react/public', () ...jest.requireActual('../../../../../../src/plugins/opensearch_dashboards_react/public'), }; }); -jest.mock('../../../apis/model_repository'); const MOCKED_DATA = { id: 'C7jN0YQBjgpeQQ_RmiDE', diff --git a/public/components/register_model_type_modal/__tests__/index.test.tsx b/public/components/register_model_type_modal/__tests__/index.test.tsx index 42cd72da..4252cb86 100644 --- a/public/components/register_model_type_modal/__tests__/index.test.tsx +++ b/public/components/register_model_type_modal/__tests__/index.test.tsx @@ -8,8 +8,6 @@ import userEvent from '@testing-library/user-event'; import { RegisterModelTypeModal } from '../index'; import { render, screen, waitFor } from '../../../../test/test_utils'; -jest.mock('../../../apis/model_repository'); - const mockOffsetMethods = () => { const originalOffsetHeight = Object.getOwnPropertyDescriptor( HTMLElement.prototype, diff --git a/public/utils/tests/model_repository_manager.test.ts b/public/utils/tests/model_repository_manager.test.ts index ba7bc19e..259205e0 100644 --- a/public/utils/tests/model_repository_manager.test.ts +++ b/public/utils/tests/model_repository_manager.test.ts @@ -6,8 +6,6 @@ import { ModelRepository } from '../../apis/model_repository'; import { ModelRepositoryManager } from '../model_repository_manager'; -jest.mock('../../apis/model_repository'); - describe('ModelRepositoryManager', () => { beforeEach(() => { jest.spyOn(ModelRepository.prototype, 'getPreTrainedModelConfig'); diff --git a/test/fetch_polyfill.ts b/test/fetch_polyfill.ts new file mode 100644 index 00000000..ccb375e9 --- /dev/null +++ b/test/fetch_polyfill.ts @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import fetch, { Headers, Request, Response } from 'node-fetch'; + +if (!globalThis.fetch) { + globalThis.fetch = fetch; + globalThis.Headers = Headers; + globalThis.Request = Request; + globalThis.Response = Response; +} diff --git a/test/mocks/data/model_config.ts b/test/mocks/data/model_config.ts new file mode 100644 index 00000000..b32d9560 --- /dev/null +++ b/test/mocks/data/model_config.ts @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const modelConfig = { + name: 'sentence-transformers/all-distilroberta-v1', + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + model_task_type: 'TEXT_EMBEDDING', + model_format: 'TORCH_SCRIPT', + model_content_size_in_bytes: 330811571, + model_content_hash_value: '92bc10216c720b57a6bab1d7ca2cc2e559156997212a7f0d8bb70f2edfedc78b', + model_config: { + model_type: 'roberta', + embedding_dimension: 768, + framework_type: 'sentence_transformers', + all_config: + '{"_name_or_path":"distilroberta-base","architectures":["RobertaForMaskedLM"],"attention_probs_dropout_prob":0.1,"bos_token_id":0,"eos_token_id":2,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":768,"initializer_range":0.02,"intermediate_size":3072,"layer_norm_eps":0.00001,"max_position_embeddings":514,"model_type":"roberta","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":1,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":1,"use_cache":true,"vocab_size":50265}', + }, + created_time: 1676072210947, +}; diff --git a/test/mocks/data/model_repository.ts b/test/mocks/data/model_repository.ts new file mode 100644 index 00000000..4e1bae56 --- /dev/null +++ b/test/mocks/data/model_repository.ts @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const modelRepositoryResponse = { + 'sentence-transformers/all-distilroberta-v1': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/sentence-transformers_all-distilroberta-v1-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/onnx/sentence-transformers_all-distilroberta-v1-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/onnx/config.json', + }, + }, + 'sentence-transformers/all-MiniLM-L12-v2': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L12-v2/1.0.1/torch_script/sentence-transformers_all-MiniLM-L12-v2-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L12-v2/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L12-v2/1.0.1/onnx/sentence-transformers_all-MiniLM-L12-v2-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L12-v2/1.0.1/onnx/config.json', + }, + }, + 'sentence-transformers/all-MiniLM-L6-v2': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/torch_script/sentence-transformers_all-MiniLM-L6-v2-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/onnx/sentence-transformers_all-MiniLM-L6-v2-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/onnx/config.json', + }, + }, + 'sentence-transformers/all-mpnet-base-v2': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-mpnet-base-v2/1.0.1/torch_script/sentence-transformers_all-mpnet-base-v2-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-mpnet-base-v2/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-mpnet-base-v2/1.0.1/onnx/sentence-transformers_all-mpnet-base-v2-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-mpnet-base-v2/1.0.1/onnx/config.json', + }, + }, + 'sentence-transformers/msmarco-distilbert-base-tas-b': { + version: '1.0.1', + description: + 'This is a port of the DistilBert TAS-B Model to sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and is optimized for the task of semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/onnx/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/onnx/config.json', + }, + }, + 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and was designed for semantic search. It has been trained on 215M (question, answer) pairs from diverse sources.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/multi-qa-MiniLM-L6-cos-v1/1.0.1/torch_script/sentence-transformers_multi-qa-MiniLM-L6-cos-v1-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/multi-qa-MiniLM-L6-cos-v1/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/multi-qa-MiniLM-L6-cos-v1/1.0.1/onnx/sentence-transformers_multi-qa-MiniLM-L6-cos-v1-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/multi-qa-MiniLM-L6-cos-v1/1.0.1/onnx/config.json', + }, + }, + 'sentence-transformers/multi-qa-mpnet-base-dot-v1': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/multi-qa-mpnet-base-dot-v1/1.0.1/torch_script/sentence-transformers_multi-qa-mpnet-base-dot-v1-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/multi-qa-mpnet-base-dot-v1/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/multi-qa-mpnet-base-dot-v1/1.0.1/onnx/sentence-transformers_multi-qa-mpnet-base-dot-v1-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/multi-qa-mpnet-base-dot-v1/1.0.1/onnx/config.json', + }, + }, + 'sentence-transformers/paraphrase-MiniLM-L3-v2': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/paraphrase-MiniLM-L3-v2/1.0.1/torch_script/sentence-transformers_paraphrase-MiniLM-L3-v2-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/paraphrase-MiniLM-L3-v2/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/paraphrase-MiniLM-L3-v2/1.0.1/onnx/sentence-transformers_paraphrase-MiniLM-L3-v2-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/paraphrase-MiniLM-L3-v2/1.0.1/onnx/config.json', + }, + }, + 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2': { + version: '1.0.1', + description: + 'This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.', + torch_script: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/1.0.1/torch_script/sentence-transformers_paraphrase-multilingual-MiniLM-L12-v2-1.0.1-torch_script.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/1.0.1/torch_script/config.json', + }, + onnx: { + model_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/1.0.1/onnx/sentence-transformers_paraphrase-multilingual-MiniLM-L12-v2-1.0.1-onnx.zip', + config_url: + 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/1.0.1/onnx/config.json', + }, + }, +}; diff --git a/test/mocks/handlers.ts b/test/mocks/handlers.ts new file mode 100644 index 00000000..ec2ed385 --- /dev/null +++ b/test/mocks/handlers.ts @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { rest } from 'msw'; +import { modelConfig } from './data/model_config'; +import { modelRepositoryResponse } from './data/model_repository'; + +export const handlers = [ + rest.get('/api/ml-commons/model-repository', (req, res, ctx) => { + return res(ctx.status(200), ctx.json(modelRepositoryResponse)); + }), + rest.get('/api/ml-commons/model-repository/config-url/:config_url', (req, res, ctx) => { + return res(ctx.status(200), ctx.json(modelConfig)); + }), +]; diff --git a/test/mocks/server.ts b/test/mocks/server.ts new file mode 100644 index 00000000..2c1057b1 --- /dev/null +++ b/test/mocks/server.ts @@ -0,0 +1,9 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { setupServer } from 'msw/node'; +import { handlers } from './handlers'; + +export const server = setupServer(...handlers); diff --git a/test/setup.jest.ts b/test/setup.jest.ts index 9c7266ba..65b15487 100644 --- a/test/setup.jest.ts +++ b/test/setup.jest.ts @@ -6,4 +6,14 @@ import { configure } from '@testing-library/react'; import '@testing-library/jest-dom'; +import { server } from './mocks/server'; + configure({ testIdAttribute: 'data-test-subj' }); + +// Establish API mocking before all tests. +beforeAll(() => server.listen()); +// Reset any request handlers that we may add during the tests, +// so they don't affect other tests. +afterEach(() => server.resetHandlers()); +// Clean up after the tests are finished. +afterAll(() => server.close()); diff --git a/test/setupTests.ts b/test/setupTests.ts index 6ea8ab39..fab3b451 100644 --- a/test/setupTests.ts +++ b/test/setupTests.ts @@ -4,3 +4,8 @@ */ import 'babel-polyfill'; + +import './fetch_polyfill'; +import { setupDashboard } from './setup_dashboard'; + +setupDashboard(); diff --git a/test/setup_dashboard.ts b/test/setup_dashboard.ts new file mode 100644 index 00000000..219b9a56 --- /dev/null +++ b/test/setup_dashboard.ts @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { InjectedMetadataService } from '../../../src/core/public/injected_metadata'; +import { HttpService } from '../../../src/core/public/http'; +import { FatalErrorsService } from '../../../src/core/public/fatal_errors'; +import { I18nService } from '../../../src/core/public/i18n'; +import { InnerHttpProvider } from '../public/apis/inner_http_provider'; + +export function setupDashboard() { + const injectedMetadataService = new InjectedMetadataService({ + injectedMetadata: { + version: 'x.x', + buildNumber: 0, + branch: 'x', + basePath: '', + serverBasePath: '', + csp: { warnLegacyBrowsers: false }, + vars: {}, + env: { + mode: { name: 'development', dev: true, prod: false }, + packageInfo: { + version: 'x.x', + branch: '', + buildNum: 0, + buildSha: '', + dist: false, + }, + }, + uiPlugins: [], + anonymousStatusPage: false, + legacyMetadata: { uiSettings: { defaults: {} } }, + branding: {}, + }, + }); + + const fatalErrorsService = new FatalErrorsService(document.body, () => {}); + const injectedMetadata = injectedMetadataService.setup(); + const i18n = new I18nService().getContext(); + + const http = new HttpService().setup({ + injectedMetadata, + fatalErrors: fatalErrorsService.setup({ injectedMetadata, i18n }), + }); + + InnerHttpProvider.setHttp(http); +} diff --git a/yarn.lock b/yarn.lock index 27e1f2ed..fe5661a9 100644 --- a/yarn.lock +++ b/yarn.lock @@ -23,11 +23,60 @@ chalk "^2.0.0" js-tokens "^4.0.0" +"@mswjs/cookies@^0.2.2": + version "0.2.2" + resolved "https://registry.yarnpkg.com/@mswjs/cookies/-/cookies-0.2.2.tgz#b4e207bf6989e5d5427539c2443380a33ebb922b" + integrity sha512-mlN83YSrcFgk7Dm1Mys40DLssI1KdJji2CMKN8eOlBqsTADYzj2+jWzsANsUTFbxDMWPD5e9bfA1RGqBpS3O1g== + dependencies: + "@types/set-cookie-parser" "^2.4.0" + set-cookie-parser "^2.4.6" + +"@mswjs/interceptors@^0.17.5": + version "0.17.9" + resolved "https://registry.yarnpkg.com/@mswjs/interceptors/-/interceptors-0.17.9.tgz#0096fc88fea63ee42e36836acae8f4ae33651c04" + integrity sha512-4LVGt03RobMH/7ZrbHqRxQrS9cc2uh+iNKSj8UWr8M26A2i793ju+csaB5zaqYltqJmA2jUq4VeYfKmVqvsXQg== + dependencies: + "@open-draft/until" "^1.0.3" + "@types/debug" "^4.1.7" + "@xmldom/xmldom" "^0.8.3" + debug "^4.3.3" + headers-polyfill "^3.1.0" + outvariant "^1.2.1" + strict-event-emitter "^0.2.4" + web-encoding "^1.1.5" + +"@open-draft/until@^1.0.3": + version "1.0.3" + resolved "https://registry.yarnpkg.com/@open-draft/until/-/until-1.0.3.tgz#db9cc719191a62e7d9200f6e7bab21c5b848adca" + integrity sha512-Aq58f5HiWdyDlFffbbSjAlv596h/cOnt2DO1w3DOC7OJ5EHs0hd/nycJfiu9RJbT6Yk6F1knnRRXNSpxoIVZ9Q== + "@testing-library/user-event@^14.4.3": version "14.4.3" resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.4.3.tgz#af975e367743fa91989cd666666aec31a8f50591" integrity sha512-kCUc5MEwaEMakkO5x7aoD+DLi02ehmEM2QCGWvNqAS1dV/fAvORWEjnjsEIvml59M7Y5kCkWN6fCCyPOe8OL6Q== +"@types/cookie@^0.4.1": + version "0.4.1" + resolved "https://registry.yarnpkg.com/@types/cookie/-/cookie-0.4.1.tgz#bfd02c1f2224567676c1545199f87c3a861d878d" + integrity sha512-XW/Aa8APYr6jSVVA1y/DEIZX0/GMKLEVekNG727R8cs56ahETkRAy/3DR7+fJyh7oUgGwNQaRfXCun0+KbWY7Q== + +"@types/debug@^4.1.7": + version "4.1.7" + resolved "https://registry.yarnpkg.com/@types/debug/-/debug-4.1.7.tgz#7cc0ea761509124709b8b2d1090d8f6c17aadb82" + integrity sha512-9AonUzyTjXXhEOa0DnqpzZi6VHlqKMswga9EXjpXnnqxwLtdvPPtlO8evrI5D9S6asFRCQ6v+wpiUKbw+vKqyg== + dependencies: + "@types/ms" "*" + +"@types/js-levenshtein@^1.1.1": + version "1.1.1" + resolved "https://registry.yarnpkg.com/@types/js-levenshtein/-/js-levenshtein-1.1.1.tgz#ba05426a43f9e4e30b631941e0aa17bf0c890ed5" + integrity sha512-qC4bCqYGy1y/NP7dDVr7KJarn+PbX1nSpwA7JXdu0HxT3QYjO8MJ+cntENtHFVy2dRAyBV23OZ6MxsW1AM1L8g== + +"@types/ms@*": + version "0.7.31" + resolved "https://registry.yarnpkg.com/@types/ms/-/ms-0.7.31.tgz#31b7ca6407128a3d2bbc27fe2d21b345397f6197" + integrity sha512-iiUgKzV9AuaEkZqkOLDIvlQiL6ltuZd9tGcW3gwpnX8JbuiuhFlEGmmFXEXkN50Cvq7Os88IY2v0dkDqXYWVgA== + "@types/node@*": version "18.7.14" resolved "https://registry.yarnpkg.com/@types/node/-/node-18.7.14.tgz#0fe081752a3333392d00586d815485a17c2cf3c9" @@ -45,6 +94,23 @@ resolved "https://registry.yarnpkg.com/@types/parse-json/-/parse-json-4.0.0.tgz#2f8bb441434d163b35fb8ffdccd7138927ffb8c0" integrity sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA== +"@types/set-cookie-parser@^2.4.0": + version "2.4.2" + resolved "https://registry.yarnpkg.com/@types/set-cookie-parser/-/set-cookie-parser-2.4.2.tgz#b6a955219b54151bfebd4521170723df5e13caad" + integrity sha512-fBZgytwhYAUkj/jC/FAV4RQ5EerRup1YQsXQCh8rZfiHkc4UahC192oH0smGwsXol3cL3A5oETuAHeQHmhXM4w== + dependencies: + "@types/node" "*" + +"@xmldom/xmldom@^0.8.3": + version "0.8.6" + resolved "https://registry.yarnpkg.com/@xmldom/xmldom/-/xmldom-0.8.6.tgz#8a1524eb5bd5e965c1e3735476f0262469f71440" + integrity sha512-uRjjusqpoqfmRkTaNuLJ2VohVr67Q5YwDATW3VU7PfzTj6IRaihGrYI7zckGZjxQPBIp63nfvJbM+Yu5ICh0Bg== + +"@zxing/text-encoding@0.9.0": + version "0.9.0" + resolved "https://registry.yarnpkg.com/@zxing/text-encoding/-/text-encoding-0.9.0.tgz#fb50ffabc6c7c66a0c96b4c03e3d9be74864b70b" + integrity sha512-U/4aVJ2mxI0aDNI8Uq0wEhMgY+u4CNtEb0om3+y3+niDAsoTCOB33UF0sxpzqzdqXLqmvc+vZyAt4O8pPdfkwA== + aggregate-error@^3.0.0: version "3.1.0" resolved "https://registry.yarnpkg.com/aggregate-error/-/aggregate-error-3.1.0.tgz#92670ff50f5359bdb7a3e0d40d0ec30c5737687a" @@ -58,7 +124,7 @@ ansi-colors@^4.1.1: resolved "https://registry.yarnpkg.com/ansi-colors/-/ansi-colors-4.1.3.tgz#37611340eb2243e70cc604cad35d63270d48781b" integrity sha512-/6w/C21Pm1A7aZitlI5Ni/2J6FFQN8i1Cvz3kHABAAbw93v/NlvKdVOqz7CCWz/3iv/JplRSEEZ83XION15ovw== -ansi-escapes@^4.3.0: +ansi-escapes@^4.2.1, ansi-escapes@^4.3.0: version "4.3.2" resolved "https://registry.yarnpkg.com/ansi-escapes/-/ansi-escapes-4.3.2.tgz#6b2291d1db7d98b6521d5f1efa42d0f3a9feb65e" integrity sha512-gKXj5ALrKWQLsYG9jlTRmR/xKluxHV+Z9QEwNIgCfM1/uwPMCuzVVnh5mwTd+OuBZcwSIMbqssNWRm1lE51QaQ== @@ -84,11 +150,43 @@ ansi-styles@^4.0.0, ansi-styles@^4.1.0: dependencies: color-convert "^2.0.1" +anymatch@~3.1.2: + version "3.1.3" + resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.3.tgz#790c58b19ba1720a84205b57c618d5ad8524973e" + integrity sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw== + dependencies: + normalize-path "^3.0.0" + picomatch "^2.0.4" + astral-regex@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/astral-regex/-/astral-regex-2.0.0.tgz#483143c567aeed4785759c0865786dc77d7d2e31" integrity sha512-Z7tMw1ytTXt5jqMcOP+OQteU1VuNK9Y02uuJtKQ1Sv69jXQKKg5cibLwGJow8yzZP+eAc18EmLGPal0bp36rvQ== +available-typed-arrays@^1.0.5: + version "1.0.5" + resolved "https://registry.yarnpkg.com/available-typed-arrays/-/available-typed-arrays-1.0.5.tgz#92f95616501069d07d10edb2fc37d3e1c65123b7" + integrity sha512-DMD0KiN46eipeziST1LPP/STfDU0sufISXmjSgvVsoU2tqxctQeASejWcfNtxYKqETM1UxQ8sp2OrSBWpHY6sw== + +base64-js@^1.3.1: + version "1.5.1" + resolved "https://registry.yarnpkg.com/base64-js/-/base64-js-1.5.1.tgz#1b1b440160a5bf7ad40b650f095963481903930a" + integrity sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA== + +binary-extensions@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.2.0.tgz#75f502eeaf9ffde42fc98829645be4ea76bd9e2d" + integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== + +bl@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/bl/-/bl-4.1.0.tgz#451535264182bec2fbbc83a62ab98cf11d9f7b3a" + integrity sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w== + dependencies: + buffer "^5.5.0" + inherits "^2.0.4" + readable-stream "^3.4.0" + braces@^3.0.3: version "3.0.3" resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.3.tgz#490332f40919452272d55a8480adc0c441358789" @@ -96,11 +194,42 @@ braces@^3.0.3: dependencies: fill-range "^7.1.1" +braces@~3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" + integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== + dependencies: + fill-range "^7.0.1" + +buffer@^5.5.0: + version "5.7.1" + resolved "https://registry.yarnpkg.com/buffer/-/buffer-5.7.1.tgz#ba62e7c13133053582197160851a8f648e99eed0" + integrity sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ== + dependencies: + base64-js "^1.3.1" + ieee754 "^1.1.13" + +call-bind@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/call-bind/-/call-bind-1.0.2.tgz#b1d4e89e688119c3c9a903ad30abb2f6a919be3c" + integrity sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA== + dependencies: + function-bind "^1.1.1" + get-intrinsic "^1.0.2" + callsites@^3.0.0: version "3.1.0" resolved "https://registry.yarnpkg.com/callsites/-/callsites-3.1.0.tgz#b3630abd8943432f54b3f0519238e33cd7df2f73" integrity sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ== +chalk@4.1.1: + version "4.1.1" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.1.tgz#c80b3fab28bf6371e6863325eee67e618b77e6ad" + integrity sha512-diHzdDKxcU+bAsUboHLPEDQiw0qEe0qd7SYUn3HgcFlWgbDcfLGswOHYeGrHKzG9z6UYf01d9VFMfZxPM1xZSg== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + chalk@^2.0.0: version "2.4.2" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424" @@ -110,7 +239,7 @@ chalk@^2.0.0: escape-string-regexp "^1.0.5" supports-color "^5.3.0" -chalk@^4.1.0: +chalk@^4.1.0, chalk@^4.1.1: version "4.1.2" resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.2.tgz#aac4e2b7734a740867aeb16bf02aad556a1e7a01" integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA== @@ -118,6 +247,26 @@ chalk@^4.1.0: ansi-styles "^4.1.0" supports-color "^7.1.0" +chardet@^0.7.0: + version "0.7.0" + resolved "https://registry.yarnpkg.com/chardet/-/chardet-0.7.0.tgz#90094849f0937f2eedc2425d0d28a9e5f0cbad9e" + integrity sha512-mT8iDcrh03qDGRRmoA2hmBJnxpllMR+0/0qlzjqZES6NdiWDcZkCNAk4rPFZ9Q85r27unkiNNg8ZOiwZXBHwcA== + +chokidar@^3.4.2: + version "3.5.3" + resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.5.3.tgz#1cf37c8707b932bd1af1ae22c0432e2acd1903bd" + integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== + dependencies: + anymatch "~3.1.2" + braces "~3.0.2" + glob-parent "~5.1.2" + is-binary-path "~2.1.0" + is-glob "~4.0.1" + normalize-path "~3.0.0" + readdirp "~3.6.0" + optionalDependencies: + fsevents "~2.3.2" + clean-stack@^2.0.0: version "2.2.0" resolved "https://registry.yarnpkg.com/clean-stack/-/clean-stack-2.2.0.tgz#ee8472dbb129e727b31e8a10a427dee9dfe4008b" @@ -130,6 +279,11 @@ cli-cursor@^3.1.0: dependencies: restore-cursor "^3.1.0" +cli-spinners@^2.5.0: + version "2.7.0" + resolved "https://registry.yarnpkg.com/cli-spinners/-/cli-spinners-2.7.0.tgz#f815fd30b5f9eaac02db604c7a231ed7cb2f797a" + integrity sha512-qu3pN8Y3qHNgE2AFweciB1IfMnmZ/fsNTEE+NOFjmGB2F/7rLhnhzppvpCnN4FovtP26k8lHyy9ptEbNwWFLzw== + cli-truncate@^2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/cli-truncate/-/cli-truncate-2.1.0.tgz#c39e28bf05edcde5be3b98992a22deed5a2b93c7" @@ -138,6 +292,25 @@ cli-truncate@^2.1.0: slice-ansi "^3.0.0" string-width "^4.2.0" +cli-width@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/cli-width/-/cli-width-3.0.0.tgz#a2f48437a2caa9a22436e794bf071ec9e61cedf6" + integrity sha512-FxqpkPPwu1HjuN93Omfm4h8uIanXofW0RxVEW3k5RKx+mJJYSthzNhp32Kzxxy3YAEZ/Dc/EWN1vZRY0+kOhbw== + +cliui@^8.0.1: + version "8.0.1" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-8.0.1.tgz#0c04b075db02cbfe60dc8e6cf2f5486b1a3608aa" + integrity sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.1" + wrap-ansi "^7.0.0" + +clone@^1.0.2: + version "1.0.4" + resolved "https://registry.yarnpkg.com/clone/-/clone-1.0.4.tgz#da309cc263df15994c688ca902179ca3c7cd7c7e" + integrity sha512-JQHZ2QMW6l3aH/j6xCqQThY/9OH4D/9ls34cgkUBiEeocRTU04tHfKPBsUK1PqZCUQM7GiA0IIXJSuXHI64Kbg== + color-convert@^1.9.0: version "1.9.3" resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-1.9.3.tgz#bb71850690e1f136567de629d2d5471deda4c1e8" @@ -172,6 +345,11 @@ commander@^6.2.0: resolved "https://registry.yarnpkg.com/commander/-/commander-6.2.1.tgz#0792eb682dfbc325999bb2b84fddddba110ac73c" integrity sha512-U7VdrJFnJgo4xjrHpTzu0yrHPGImdsmD95ZlgYSEajAn2JKzDhDTPG9kBTefmObL2w/ngeZnilk+OV9CG3d7UA== +cookie@^0.4.2: + version "0.4.2" + resolved "https://registry.yarnpkg.com/cookie/-/cookie-0.4.2.tgz#0e41f24de5ecf317947c82fc789e06a884824432" + integrity sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA== + cosmiconfig@^7.0.0: version "7.0.1" resolved "https://registry.yarnpkg.com/cosmiconfig/-/cosmiconfig-7.0.1.tgz#714d756522cace867867ccb4474c5d01bbae5d6d" @@ -192,7 +370,7 @@ cross-spawn@^7.0.0: shebang-command "^2.0.0" which "^2.0.1" -debug@^4.2.0: +debug@^4.2.0, debug@^4.3.3: version "4.3.4" resolved "https://registry.yarnpkg.com/debug/-/debug-4.3.4.tgz#1319f6579357f2338d3337d2cdd4914bb5dcc865" integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ== @@ -204,6 +382,13 @@ dedent@^0.7.0: resolved "https://registry.yarnpkg.com/dedent/-/dedent-0.7.0.tgz#2495ddbaf6eb874abb0e1be9df22d2e5a544326c" integrity sha512-Q6fKUPqnAHAyhiUgFU7BUzLiv0kd8saH9al7tnu5Q/okj6dnupxyTgFIBjVzJATdfIAm9NAsvXNzjaKa+bxVyA== +defaults@^1.0.3: + version "1.0.4" + resolved "https://registry.yarnpkg.com/defaults/-/defaults-1.0.4.tgz#b0b02062c1e2aa62ff5d9528f0f98baa90978d7a" + integrity sha512-eFuaLoy/Rxalv2kr+lqMlUnrDWV+3j4pljOIJgLIhI058IQfWJ7vXhyEIHu+HtC738klGALYxOKDO0bQP3tg8A== + dependencies: + clone "^1.0.2" + emoji-regex@^8.0.0: version "8.0.0" resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" @@ -230,11 +415,21 @@ error-ex@^1.3.1: dependencies: is-arrayish "^0.2.1" +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + escape-string-regexp@^1.0.5: version "1.0.5" resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz#1b61c0562190a8dff6ae3bb2cf0200ca130b86d4" integrity sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg== +events@^3.3.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/events/-/events-3.3.0.tgz#31a95ad0a924e2d2c419a813aeb2c4e878ea7400" + integrity sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q== + execa@^4.1.0: version "4.1.0" resolved "https://registry.yarnpkg.com/execa/-/execa-4.1.0.tgz#4e5491ad1572f2f17a77d388c6c857135b22847a" @@ -250,6 +445,29 @@ execa@^4.1.0: signal-exit "^3.0.2" strip-final-newline "^2.0.0" +external-editor@^3.0.3: + version "3.1.0" + resolved "https://registry.yarnpkg.com/external-editor/-/external-editor-3.1.0.tgz#cb03f740befae03ea4d283caed2741a83f335495" + integrity sha512-hMQ4CX1p1izmuLYyZqLMO/qGNw10wSv9QDCPfzXfyFrOaCSSoRfqE1Kf1s5an66J5JZC62NewG+mK49jOCtQew== + dependencies: + chardet "^0.7.0" + iconv-lite "^0.4.24" + tmp "^0.0.33" + +figures@^3.0.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/figures/-/figures-3.2.0.tgz#625c18bd293c604dc4a8ddb2febf0c88341746af" + integrity sha512-yaduQFRKLXYOGgEn6AZau90j3ggSOyiqXU0F9JZfeXYhNa+Jk4X+s45A2zg5jns87GAFa34BBm2kXw4XpNcbdg== + dependencies: + escape-string-regexp "^1.0.5" + +fill-range@^7.0.1: + version "7.0.1" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" + integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== + dependencies: + to-regex-range "^5.0.1" + fill-range@^7.1.1: version "7.1.1" resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.1.1.tgz#44265d3cac07e3ea7dc247516380643754a05292" @@ -257,6 +475,37 @@ fill-range@^7.1.1: dependencies: to-regex-range "^5.0.1" +for-each@^0.3.3: + version "0.3.3" + resolved "https://registry.yarnpkg.com/for-each/-/for-each-0.3.3.tgz#69b447e88a0a5d32c3e7084f3f1710034b21376e" + integrity sha512-jqYfLp7mo9vIyQf8ykW2v7A+2N4QjeCeI5+Dz9XraiO1ign81wjiH7Fb9vSOWvQfNtmSa4H2RoQTrrXivdUZmw== + dependencies: + is-callable "^1.1.3" + +fsevents@~2.3.2: + version "2.3.2" + resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" + integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== + +function-bind@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" + integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== + +get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +get-intrinsic@^1.0.2, get-intrinsic@^1.1.3: + version "1.2.0" + resolved "https://registry.yarnpkg.com/get-intrinsic/-/get-intrinsic-1.2.0.tgz#7ad1dc0535f3a2904bba075772763e5051f6d05f" + integrity sha512-L049y6nFOuom5wGyRc3/gdTLO94dySVKRACj1RmJZBQXlbTMhtNIgkWkUHq+jYmZvKf14EW1EoJnnjbmoHij0Q== + dependencies: + function-bind "^1.1.1" + has "^1.0.3" + has-symbols "^1.0.3" + get-own-enumerable-property-symbols@^3.0.0: version "3.0.2" resolved "https://registry.yarnpkg.com/get-own-enumerable-property-symbols/-/get-own-enumerable-property-symbols-3.0.2.tgz#b5fde77f22cbe35f390b4e089922c50bce6ef664" @@ -269,6 +518,25 @@ get-stream@^5.0.0: dependencies: pump "^3.0.0" +glob-parent@~5.1.2: + version "5.1.2" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" + integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== + dependencies: + is-glob "^4.0.1" + +gopd@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/gopd/-/gopd-1.0.1.tgz#29ff76de69dac7489b7c0918a5788e56477c332c" + integrity sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA== + dependencies: + get-intrinsic "^1.1.3" + +"graphql@^15.0.0 || ^16.0.0": + version "16.6.0" + resolved "https://registry.yarnpkg.com/graphql/-/graphql-16.6.0.tgz#c2dcffa4649db149f6282af726c8c83f1c7c5fdb" + integrity sha512-KPIBPDlW7NxrbT/eh4qPXz5FiFdL5UbaA0XUNz2Rp3Z3hqBSkbj0GVjwFDztsWVauZUWsbKHgMg++sk8UX0bkw== + has-flag@^3.0.0: version "3.0.0" resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd" @@ -279,11 +547,35 @@ has-flag@^4.0.0: resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== +has-symbols@^1.0.2, has-symbols@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has-symbols/-/has-symbols-1.0.3.tgz#bb7b2c4349251dce87b125f7bdf874aa7c8b39f8" + integrity sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A== + +has-tostringtag@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/has-tostringtag/-/has-tostringtag-1.0.0.tgz#7e133818a7d394734f941e73c3d3f9291e658b25" + integrity sha512-kFjcSNhnlGV1kyoGk7OXKSawH5JOb/LzUc5w9B02hOTO0dfFRjbHQKvg1d6cf3HbeUmtU9VbbV3qzZ2Teh97WQ== + dependencies: + has-symbols "^1.0.2" + +has@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" + integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== + dependencies: + function-bind "^1.1.1" + hash-wasm@^4.9.0: version "4.9.0" resolved "https://registry.yarnpkg.com/hash-wasm/-/hash-wasm-4.9.0.tgz#7e9dcc9f7d6bd0cc802f2a58f24edce999744206" integrity sha512-7SW7ejyfnRxuOc7ptQHSf4LDoZaWOivfzqw+5rpcQku0nHfmicPKE51ra9BiRLAmT8+gGLestr1XroUkqdjL6w== +headers-polyfill@^3.1.0: + version "3.1.2" + resolved "https://registry.yarnpkg.com/headers-polyfill/-/headers-polyfill-3.1.2.tgz#9a4dcb545c5b95d9569592ef7ec0708aab763fbe" + integrity sha512-tWCK4biJ6hcLqTviLXVR9DTRfYGQMXEIUj3gwJ2rZ5wO/at3XtkI4g8mCvFdUF9l1KMBNCfmNAdnahm1cgavQA== + human-signals@^1.1.1: version "1.1.1" resolved "https://registry.yarnpkg.com/human-signals/-/human-signals-1.1.1.tgz#c5b1cd14f50aeae09ab6c59fe63ba3395fe4dfa3" @@ -294,6 +586,18 @@ husky@^8.0.0: resolved "https://registry.yarnpkg.com/husky/-/husky-8.0.3.tgz#4936d7212e46d1dea28fef29bb3a108872cd9184" integrity sha512-+dQSyqPh4x1hlO1swXBiNb2HzTDN1I2IGLQx1GrBuiqFJfoMrnZWwVmatvSiO+Iz8fBUnf+lekwNo4c2LlXItg== +iconv-lite@^0.4.24: + version "0.4.24" + resolved "https://registry.yarnpkg.com/iconv-lite/-/iconv-lite-0.4.24.tgz#2022b4b25fbddc21d2f524974a474aafe733908b" + integrity sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA== + dependencies: + safer-buffer ">= 2.1.2 < 3" + +ieee754@^1.1.13: + version "1.2.1" + resolved "https://registry.yarnpkg.com/ieee754/-/ieee754-1.2.1.tgz#8eb7a10a63fff25d15a57b001586d177d1b0d352" + integrity sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA== + import-fresh@^3.2.1: version "3.3.0" resolved "https://registry.yarnpkg.com/import-fresh/-/import-fresh-3.3.0.tgz#37162c25fcb9ebaa2e6e53d5b4d88ce17d9e0c2b" @@ -307,16 +611,91 @@ indent-string@^4.0.0: resolved "https://registry.yarnpkg.com/indent-string/-/indent-string-4.0.0.tgz#624f8f4497d619b2d9768531d58f4122854d7251" integrity sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg== +inherits@^2.0.3, inherits@^2.0.4: + version "2.0.4" + resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c" + integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== + +inquirer@^8.2.0: + version "8.2.5" + resolved "https://registry.yarnpkg.com/inquirer/-/inquirer-8.2.5.tgz#d8654a7542c35a9b9e069d27e2df4858784d54f8" + integrity sha512-QAgPDQMEgrDssk1XiwwHoOGYF9BAbUcc1+j+FhEvaOt8/cKRqyLn0U5qA6F74fGhTMGxf92pOvPBeh29jQJDTQ== + dependencies: + ansi-escapes "^4.2.1" + chalk "^4.1.1" + cli-cursor "^3.1.0" + cli-width "^3.0.0" + external-editor "^3.0.3" + figures "^3.0.0" + lodash "^4.17.21" + mute-stream "0.0.8" + ora "^5.4.1" + run-async "^2.4.0" + rxjs "^7.5.5" + string-width "^4.1.0" + strip-ansi "^6.0.0" + through "^2.3.6" + wrap-ansi "^7.0.0" + +is-arguments@^1.0.4: + version "1.1.1" + resolved "https://registry.yarnpkg.com/is-arguments/-/is-arguments-1.1.1.tgz#15b3f88fda01f2a97fec84ca761a560f123efa9b" + integrity sha512-8Q7EARjzEnKpt/PCD7e1cgUS0a6X8u5tdSiMqXhojOdoV9TsMsiO+9VLC5vAmO8N7/GmXn7yjR8qnA6bVAEzfA== + dependencies: + call-bind "^1.0.2" + has-tostringtag "^1.0.0" + is-arrayish@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/is-arrayish/-/is-arrayish-0.2.1.tgz#77c99840527aa8ecb1a8ba697b80645a7a926a9d" integrity sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg== +is-binary-path@~2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" + integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== + dependencies: + binary-extensions "^2.0.0" + +is-callable@^1.1.3: + version "1.2.7" + resolved "https://registry.yarnpkg.com/is-callable/-/is-callable-1.2.7.tgz#3bc2a85ea742d9e36205dcacdd72ca1fdc51b055" + integrity sha512-1BC0BVFhS/p0qtw6enp8e+8OD0UrK0oFLztSjNzhcKA3WDuJxxAPXzPuPtKkjEY9UUoEWlX/8fgKeu2S8i9JTA== + +is-extglob@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" + integrity sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ== + is-fullwidth-code-point@^3.0.0: version "3.0.0" resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== +is-generator-function@^1.0.7: + version "1.0.10" + resolved "https://registry.yarnpkg.com/is-generator-function/-/is-generator-function-1.0.10.tgz#f1558baf1ac17e0deea7c0415c438351ff2b3c72" + integrity sha512-jsEjy9l3yiXEQ+PsXdmBwEPcOxaXWLspKdplFUVI9vq1iZgIekeC0L167qeu86czQaxed3q/Uzuw0swL0irL8A== + dependencies: + has-tostringtag "^1.0.0" + +is-glob@^4.0.1, is-glob@~4.0.1: + version "4.0.3" + resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.3.tgz#64f61e42cbbb2eec2071a9dac0b28ba1e65d5084" + integrity sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg== + dependencies: + is-extglob "^2.1.1" + +is-interactive@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-interactive/-/is-interactive-1.0.0.tgz#cea6e6ae5c870a7b0a0004070b7b587e0252912e" + integrity sha512-2HvIEKRoqS62guEC+qBjpvRubdX910WCMuJTZ+I9yvqKU2/12eSL549HMwtabb4oupdj2sMP50k+XJfB/8JE6w== + +is-node-process@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/is-node-process/-/is-node-process-1.0.1.tgz#4fc7ac3a91e8aac58175fe0578abbc56f2831b23" + integrity sha512-5IcdXuf++TTNt3oGl9EBdkvndXA8gmc4bz/Y+mdEpWh3Mcn/+kOw6hI7LD5CocqJWMzeb0I0ClndRVNdEPuJXQ== + is-number@^7.0.0: version "7.0.0" resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" @@ -337,6 +716,17 @@ is-stream@^2.0.0: resolved "https://registry.yarnpkg.com/is-stream/-/is-stream-2.0.1.tgz#fac1e3d53b97ad5a9d0ae9cef2389f5810a5c077" integrity sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg== +is-typed-array@^1.1.10, is-typed-array@^1.1.3: + version "1.1.10" + resolved "https://registry.yarnpkg.com/is-typed-array/-/is-typed-array-1.1.10.tgz#36a5b5cb4189b575d1a3e4b08536bfb485801e3f" + integrity sha512-PJqgEHiWZvMpaFZ3uTc8kHPM4+4ADTlDniuQL7cU/UDA0Ql7F70yGfHph3cLNe+c9toaigv+DFzTJKhc2CtO6A== + dependencies: + available-typed-arrays "^1.0.5" + call-bind "^1.0.2" + for-each "^0.3.3" + gopd "^1.0.1" + has-tostringtag "^1.0.0" + is-unicode-supported@^0.1.0: version "0.1.0" resolved "https://registry.yarnpkg.com/is-unicode-supported/-/is-unicode-supported-0.1.0.tgz#3f26c76a809593b52bfa2ecb5710ed2779b522a7" @@ -347,6 +737,11 @@ isexe@^2.0.0: resolved "https://registry.yarnpkg.com/isexe/-/isexe-2.0.0.tgz#e8fbf374dc556ff8947a10dcb0572d633f2cfa10" integrity sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw== +js-levenshtein@^1.1.6: + version "1.1.6" + resolved "https://registry.yarnpkg.com/js-levenshtein/-/js-levenshtein-1.1.6.tgz#c6cee58eb3550372df8deb85fad5ce66ce01d59d" + integrity sha512-X2BB11YZtrRqY4EnQcLX5Rh373zbK4alC1FW7D7MBhL2gtcC17cTnr6DmfHZeS0s2rTHjUTMMHfG7gO8SSdw+g== + js-tokens@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499" @@ -397,7 +792,12 @@ listr2@^3.2.2: through "^2.3.8" wrap-ansi "^7.0.0" -log-symbols@^4.0.0: +lodash@^4.17.21: + version "4.17.21" + resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" + integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== + +log-symbols@^4.0.0, log-symbols@^4.1.0: version "4.1.0" resolved "https://registry.yarnpkg.com/log-symbols/-/log-symbols-4.1.0.tgz#3fbdbb95b4683ac9fc785111e792e558d4abd503" integrity sha512-8XPvpAA8uyhfteu8pIvQxpJZ7SYYdpUivZpGy6sFsBuKRY/7rQGavedeB8aK+Zkyq6upMFVL/9AW6vOYzfRyLg== @@ -438,7 +838,44 @@ ms@2.1.2: resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009" integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w== -normalize-path@^3.0.0: +msw@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/msw/-/msw-1.1.0.tgz#f88806b7ce4cade89b5bf629fa98c17218a4f036" + integrity sha512-oqMvUXm1bMbwvGpoXAQVz8vXXQyQyx52HBDg3EDOK+dFXkQHssgkXEG4LfMwwZyr2Qt18I/w04XPaY4BkFTkzA== + dependencies: + "@mswjs/cookies" "^0.2.2" + "@mswjs/interceptors" "^0.17.5" + "@open-draft/until" "^1.0.3" + "@types/cookie" "^0.4.1" + "@types/js-levenshtein" "^1.1.1" + chalk "4.1.1" + chokidar "^3.4.2" + cookie "^0.4.2" + graphql "^15.0.0 || ^16.0.0" + headers-polyfill "^3.1.0" + inquirer "^8.2.0" + is-node-process "^1.0.1" + js-levenshtein "^1.1.6" + node-fetch "^2.6.7" + outvariant "^1.3.0" + path-to-regexp "^6.2.0" + strict-event-emitter "^0.4.3" + type-fest "^2.19.0" + yargs "^17.3.1" + +mute-stream@0.0.8: + version "0.0.8" + resolved "https://registry.yarnpkg.com/mute-stream/-/mute-stream-0.0.8.tgz#1630c42b2251ff81e2a283de96a5497ea92e5e0d" + integrity sha512-nnbWWOkoWyUsTjKrhgD0dcz22mdkSnpYqbEjIm2nhwhuxlSkpywJmBo8h0ZqJdkp73mb90SssHkN4rsRaBAfAA== + +node-fetch@^2.6.7: + version "2.6.9" + resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.9.tgz#7c7f744b5cc6eb5fd404e0c7a9fec630a55657e6" + integrity sha512-DJm/CJkZkRjKKj4Zi4BsKVZh3ValV5IR5s7LVZnW+6YMh0W1BfNA8XSs6DLMGYlId5F3KnA70uu2qepcR08Qqg== + dependencies: + whatwg-url "^5.0.0" + +normalize-path@^3.0.0, normalize-path@~3.0.0: version "3.0.0" resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== @@ -464,6 +901,31 @@ onetime@^5.1.0: dependencies: mimic-fn "^2.1.0" +ora@^5.4.1: + version "5.4.1" + resolved "https://registry.yarnpkg.com/ora/-/ora-5.4.1.tgz#1b2678426af4ac4a509008e5e4ac9e9959db9e18" + integrity sha512-5b6Y85tPxZZ7QytO+BQzysW31HJku27cRIlkbAXaNx+BdcVi+LlRFmVXzeF6a7JCwJpyw5c4b+YSVImQIrBpuQ== + dependencies: + bl "^4.1.0" + chalk "^4.1.0" + cli-cursor "^3.1.0" + cli-spinners "^2.5.0" + is-interactive "^1.0.0" + is-unicode-supported "^0.1.0" + log-symbols "^4.1.0" + strip-ansi "^6.0.0" + wcwidth "^1.0.1" + +os-tmpdir@~1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/os-tmpdir/-/os-tmpdir-1.0.2.tgz#bbe67406c79aa85c5cfec766fe5734555dfa1274" + integrity sha512-D2FR03Vir7FIu45XBY20mTb+/ZSWB00sjU9jdQXt83gDrI4Ztz5Fs7/yy74g2N5SVQY4xY1qDr4rNddwYRVX0g== + +outvariant@^1.2.1, outvariant@^1.3.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/outvariant/-/outvariant-1.3.0.tgz#c39723b1d2cba729c930b74bf962317a81b9b1c9" + integrity sha512-yeWM9k6UPfG/nzxdaPlJkB2p08hCg4xP6Lx99F+vP8YF7xyZVfTmJjrrNalkmzudD4WFvNLVudQikqUmF8zhVQ== + p-map@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/p-map/-/p-map-4.0.0.tgz#bb2f95a5eda2ec168ec9274e06a747c3e2904d2b" @@ -498,12 +960,17 @@ path-key@^3.0.0, path-key@^3.1.0: resolved "https://registry.yarnpkg.com/path-key/-/path-key-3.1.1.tgz#581f6ade658cbba65a0d3380de7753295054f375" integrity sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q== +path-to-regexp@^6.2.0: + version "6.2.1" + resolved "https://registry.yarnpkg.com/path-to-regexp/-/path-to-regexp-6.2.1.tgz#d54934d6798eb9e5ef14e7af7962c945906918e5" + integrity sha512-JLyh7xT1kizaEvcaXOQwOc2/Yhw6KZOvPf1S8401UyLk86CU79LN3vl7ztXGm/pZ+YjoyAJ4rxmHwbkBXJX+yw== + path-type@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/path-type/-/path-type-4.0.0.tgz#84ed01c0a7ba380afe09d90a8c180dcd9d03043b" integrity sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw== -picomatch@^2.3.1: +picomatch@^2.0.4, picomatch@^2.2.1, picomatch@^2.3.1: version "2.3.1" resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== @@ -528,6 +995,27 @@ react-hook-form@^7.39.4: resolved "https://registry.yarnpkg.com/react-hook-form/-/react-hook-form-7.39.4.tgz#7d9edf4e778a0cec4383f0119cd0699e3826a14a" integrity sha512-B0e78r9kR9L2M4A4AXGbHoA/vyv34sB/n8QWJAw33TFz8f5t9helBbYAeqnbvcQf1EYzJxKX/bGQQh9K+evCyQ== +readable-stream@^3.4.0: + version "3.6.2" + resolved "https://registry.yarnpkg.com/readable-stream/-/readable-stream-3.6.2.tgz#56a9b36ea965c00c5a93ef31eb111a0f11056967" + integrity sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA== + dependencies: + inherits "^2.0.3" + string_decoder "^1.1.1" + util-deprecate "^1.0.1" + +readdirp@~3.6.0: + version "3.6.0" + resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.6.0.tgz#74a370bd857116e245b29cc97340cd431a02a6c7" + integrity sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA== + dependencies: + picomatch "^2.2.1" + +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== + resolve-from@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/resolve-from/-/resolve-from-4.0.0.tgz#4abcd852ad32dd7baabfe9b40e00a36db5f392e6" @@ -546,6 +1034,11 @@ rfdc@^1.3.0: resolved "https://registry.yarnpkg.com/rfdc/-/rfdc-1.3.0.tgz#d0b7c441ab2720d05dc4cf26e01c89631d9da08b" integrity sha512-V2hovdzFbOi77/WajaSMXk2OLm+xNIeQdMMuB7icj7bk6zi2F8GGAxigcnDFpJHbNyNcgyJDiP+8nOrY5cZGrA== +run-async@^2.4.0: + version "2.4.1" + resolved "https://registry.yarnpkg.com/run-async/-/run-async-2.4.1.tgz#8440eccf99ea3e70bd409d49aab88e10c189a455" + integrity sha512-tvVnVv01b8c1RrA6Ep7JkStj85Guv/YrMcwqYQnwjsAS2cTmmPGBBjAjpCW7RrSodNSoE2/qg9O4bceNvUuDgQ== + rxjs@^6.5.5: version "6.6.7" resolved "https://registry.yarnpkg.com/rxjs/-/rxjs-6.6.7.tgz#90ac018acabf491bf65044235d5863c4dab804c9" @@ -560,11 +1053,33 @@ rxjs@^7.5.1: dependencies: tslib "^2.1.0" +rxjs@^7.5.5: + version "7.8.0" + resolved "https://registry.yarnpkg.com/rxjs/-/rxjs-7.8.0.tgz#90a938862a82888ff4c7359811a595e14e1e09a4" + integrity sha512-F2+gxDshqmIub1KdvZkaEfGDwLNpPvk9Fs6LD/MyQxNgMds/WH9OdDDXOmxUZpME+iSK3rQCctkL0DYyytUqMg== + dependencies: + tslib "^2.1.0" + +safe-buffer@~5.2.0: + version "5.2.1" + resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.2.1.tgz#1eaf9fa9bdb1fdd4ec75f58f9cdb4e6b7827eec6" + integrity sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ== + +"safer-buffer@>= 2.1.2 < 3": + version "2.1.2" + resolved "https://registry.yarnpkg.com/safer-buffer/-/safer-buffer-2.1.2.tgz#44fa161b0187b9549dd84bb91802f9bd8385cd6a" + integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg== + semver-compare@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/semver-compare/-/semver-compare-1.0.0.tgz#0dee216a1c941ab37e9efb1788f6afc5ff5537fc" integrity sha512-YM3/ITh2MJ5MtzaM429anh+x2jiLVjqILF4m4oyQB18W7Ggea7BfqdH/wGMK7dDiMghv/6WG7znWMwUDzJiXow== +set-cookie-parser@^2.4.6: + version "2.6.0" + resolved "https://registry.yarnpkg.com/set-cookie-parser/-/set-cookie-parser-2.6.0.tgz#131921e50f62ff1a66a461d7d62d7b21d5d15a51" + integrity sha512-RVnVQxTXuerk653XfuliOxBP81Sf0+qfQE73LIYKcyMYHG94AuH0kgrQpRDuTZnSmjpysHmzxJXKNfa6PjFhyQ== + shebang-command@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-2.0.0.tgz#ccd0af4f8835fbdc265b82461aaf0c36663f34ea" @@ -600,12 +1115,24 @@ slice-ansi@^4.0.0: astral-regex "^2.0.0" is-fullwidth-code-point "^3.0.0" +strict-event-emitter@^0.2.4: + version "0.2.8" + resolved "https://registry.yarnpkg.com/strict-event-emitter/-/strict-event-emitter-0.2.8.tgz#b4e768927c67273c14c13d20e19d5e6c934b47ca" + integrity sha512-KDf/ujU8Zud3YaLtMCcTI4xkZlZVIYxTLr+XIULexP+77EEVWixeXroLUXQXiVtH4XH2W7jr/3PT1v3zBuvc3A== + dependencies: + events "^3.3.0" + +strict-event-emitter@^0.4.3: + version "0.4.6" + resolved "https://registry.yarnpkg.com/strict-event-emitter/-/strict-event-emitter-0.4.6.tgz#ff347c8162b3e931e3ff5f02cfce6772c3b07eb3" + integrity sha512-12KWeb+wixJohmnwNFerbyiBrAlq5qJLwIt38etRtKtmmHyDSoGlIqFE9wx+4IwG0aDjI7GV8tc8ZccjWZZtTg== + string-argv@0.3.1: version "0.3.1" resolved "https://registry.yarnpkg.com/string-argv/-/string-argv-0.3.1.tgz#95e2fbec0427ae19184935f816d74aaa4c5c19da" integrity sha512-a1uQGz7IyVy9YwhqjZIZu1c8JO8dNIe20xBmSS6qu9kv++k3JGzCVmprbNN5Kn+BgzD5E7YYwg1CcjuJMRNsvg== -string-width@^4.1.0, string-width@^4.2.0: +string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: version "4.2.3" resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== @@ -614,6 +1141,13 @@ string-width@^4.1.0, string-width@^4.2.0: is-fullwidth-code-point "^3.0.0" strip-ansi "^6.0.1" +string_decoder@^1.1.1: + version "1.3.0" + resolved "https://registry.yarnpkg.com/string_decoder/-/string_decoder-1.3.0.tgz#42f114594a46cf1a8e30b0a84f56c78c3edac21e" + integrity sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA== + dependencies: + safe-buffer "~5.2.0" + stringify-object@^3.3.0: version "3.3.0" resolved "https://registry.yarnpkg.com/stringify-object/-/stringify-object-3.3.0.tgz#703065aefca19300d3ce88af4f5b3956d7556629" @@ -649,11 +1183,18 @@ supports-color@^7.1.0: dependencies: has-flag "^4.0.0" -through@^2.3.8: +through@^2.3.6, through@^2.3.8: version "2.3.8" resolved "https://registry.yarnpkg.com/through/-/through-2.3.8.tgz#0dd4c9ffaabc357960b1b724115d7e0e86a2e1f5" integrity sha512-w89qg7PI8wAdvX60bMDP+bFoD5Dvhm9oLheFp5O4a2QF0cSBGsBX4qZmadPMvVqlLJBBci+WqGGOAPvcDeNSVg== +tmp@^0.0.33: + version "0.0.33" + resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.0.33.tgz#6d34335889768d21b2bcda0aa277ced3b1bfadf9" + integrity sha512-jRCJlojKnZ3addtTOjdIqoRuPEKBvNXcGYqzO6zWZX8KfKEpnGY5jfggJQ3EjKuu8D4bJRr0y+cYJFmYbImXGw== + dependencies: + os-tmpdir "~1.0.2" + to-regex-range@^5.0.1: version "5.0.1" resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" @@ -661,6 +1202,11 @@ to-regex-range@^5.0.1: dependencies: is-number "^7.0.0" +tr46@~0.0.3: + version "0.0.3" + resolved "https://registry.yarnpkg.com/tr46/-/tr46-0.0.3.tgz#8184fd347dac9cdc185992f3a6622e14b9d9ab6a" + integrity sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw== + tslib@^1.9.0: version "1.14.1" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.14.1.tgz#cf2d38bdc34a134bcaf1091c41f6619e2f672d00" @@ -676,6 +1222,68 @@ type-fest@^0.21.3: resolved "https://registry.yarnpkg.com/type-fest/-/type-fest-0.21.3.tgz#d260a24b0198436e133fa26a524a6d65fa3b2e37" integrity sha512-t0rzBq87m3fVcduHDUFhKmyyX+9eo6WQjZvf51Ea/M0Q7+T374Jp1aUiyUl0GKxp8M/OETVHSDvmkyPgvX+X2w== +type-fest@^2.19.0: + version "2.19.0" + resolved "https://registry.yarnpkg.com/type-fest/-/type-fest-2.19.0.tgz#88068015bb33036a598b952e55e9311a60fd3a9b" + integrity sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA== + +util-deprecate@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" + integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== + +util@^0.12.3: + version "0.12.5" + resolved "https://registry.yarnpkg.com/util/-/util-0.12.5.tgz#5f17a6059b73db61a875668781a1c2b136bd6fbc" + integrity sha512-kZf/K6hEIrWHI6XqOFUiiMa+79wE/D8Q+NCNAWclkyg3b4d2k7s0QGepNjiABc+aR3N1PAyHL7p6UcLY6LmrnA== + dependencies: + inherits "^2.0.3" + is-arguments "^1.0.4" + is-generator-function "^1.0.7" + is-typed-array "^1.1.3" + which-typed-array "^1.1.2" + +wcwidth@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/wcwidth/-/wcwidth-1.0.1.tgz#f0b0dcf915bc5ff1528afadb2c0e17b532da2fe8" + integrity sha512-XHPEwS0q6TaxcvG85+8EYkbiCux2XtWG2mkc47Ng2A77BQu9+DqIOJldST4HgPkuea7dvKSj5VgX3P1d4rW8Tg== + dependencies: + defaults "^1.0.3" + +web-encoding@^1.1.5: + version "1.1.5" + resolved "https://registry.yarnpkg.com/web-encoding/-/web-encoding-1.1.5.tgz#fc810cf7667364a6335c939913f5051d3e0c4864" + integrity sha512-HYLeVCdJ0+lBYV2FvNZmv3HJ2Nt0QYXqZojk3d9FJOLkwnuhzM9tmamh8d7HPM8QqjKH8DeHkFTx+CFlWpZZDA== + dependencies: + util "^0.12.3" + optionalDependencies: + "@zxing/text-encoding" "0.9.0" + +webidl-conversions@^3.0.0: + version "3.0.1" + resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871" + integrity sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ== + +whatwg-url@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/whatwg-url/-/whatwg-url-5.0.0.tgz#966454e8765462e37644d3626f6742ce8b70965d" + integrity sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw== + dependencies: + tr46 "~0.0.3" + webidl-conversions "^3.0.0" + +which-typed-array@^1.1.2: + version "1.1.9" + resolved "https://registry.yarnpkg.com/which-typed-array/-/which-typed-array-1.1.9.tgz#307cf898025848cf995e795e8423c7f337efbde6" + integrity sha512-w9c4xkx6mPidwp7180ckYWfMmvxpjlZuIudNtDf4N/tTAUB8VJbX25qZoAsrtGuYNnGw3pa0AXgbGKRB8/EceA== + dependencies: + available-typed-arrays "^1.0.5" + call-bind "^1.0.2" + for-each "^0.3.3" + gopd "^1.0.1" + has-tostringtag "^1.0.0" + is-typed-array "^1.1.10" + which@^2.0.1: version "2.0.2" resolved "https://registry.yarnpkg.com/which/-/which-2.0.2.tgz#7c6a8dd0a636a0327e10b59c9286eee93f3f51b1" @@ -706,7 +1314,30 @@ wrappy@1: resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" integrity sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ== +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + yaml@^1.10.0: version "1.10.2" resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.2.tgz#2301c5ffbf12b467de8da2333a459e29e7920e4b" integrity sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg== + +yargs-parser@^21.1.1: + version "21.1.1" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-21.1.1.tgz#9096bceebf990d21bb31fa9516e0ede294a77d35" + integrity sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw== + +yargs@^17.3.1: + version "17.7.1" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-17.7.1.tgz#34a77645201d1a8fc5213ace787c220eabbd0967" + integrity sha512-cwiTb08Xuv5fqF4AovYacTFNxk62th7LKJ6BL9IGUpTJrWoU7/7WdQGTP2SjKf1dUNBGzDd28p/Yfs/GI6JrLw== + dependencies: + cliui "^8.0.1" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.3" + y18n "^5.0.5" + yargs-parser "^21.1.1" From 7b1fea3b692b0f3eabeffca882eb98c4cb198760 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Thu, 6 Apr 2023 15:33:16 +0800 Subject: [PATCH 36/75] Feature/add model version detail mock page (#153) * test: add model msw handlers Signed-off-by: Lin Wang * feat: add model VersionToggler Signed-off-by: Lin Wang * feat: add model version page component Signed-off-by: Lin Wang * feat: add model version page to router Signed-off-by: Lin Wang * test: remove fake timer for model version loading indicator test Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- common/router.ts | 7 ++ common/router_paths.ts | 1 + .../__tests__/model_version.test.tsx | 65 +++++++++++++ .../__tests__/version_toggler.test.tsx | 45 +++++++++ public/components/model_version/index.ts | 6 ++ .../model_version/model_version.tsx | 84 +++++++++++++++++ .../model_version/version_toggler.tsx | 92 +++++++++++++++++++ test/mocks/handlers.ts | 2 + test/mocks/model_handlers.ts | 44 +++++++++ test/test_utils.tsx | 28 ++++++ 10 files changed, 374 insertions(+) create mode 100644 public/components/model_version/__tests__/model_version.test.tsx create mode 100644 public/components/model_version/__tests__/version_toggler.test.tsx create mode 100644 public/components/model_version/index.ts create mode 100644 public/components/model_version/model_version.tsx create mode 100644 public/components/model_version/version_toggler.tsx create mode 100644 test/mocks/model_handlers.ts diff --git a/common/router.ts b/common/router.ts index 9f5dd420..1b3c5243 100644 --- a/common/router.ts +++ b/common/router.ts @@ -7,6 +7,7 @@ import { ModelGroup } from '../public/components/model_group'; import { ModelList } from '../public/components/model_list'; import { Monitoring } from '../public/components/monitoring'; import { RegisterModelForm } from '../public/components/register_model/register_model'; +import { ModelVersion } from '../public/components/model_version'; import { routerPaths } from './router_paths'; interface RouteConfig { @@ -46,6 +47,12 @@ export const ROUTES: RouteConfig[] = [ Component: ModelGroup, nav: false, }, + { + path: routerPaths.modelVersion, + label: 'Model Version', + Component: ModelVersion, + nav: false, + }, ]; /* export const ROUTES1 = [ diff --git a/common/router_paths.ts b/common/router_paths.ts index 105cd17f..168806a0 100644 --- a/common/router_paths.ts +++ b/common/router_paths.ts @@ -10,4 +10,5 @@ export const routerPaths = { registerModel: '/model-registry/register-model/:id?', modelList: '/model-registry/model-list', modelGroup: '/model-registry/model/:id', + modelVersion: '/model-registry/model-version/:id', }; diff --git a/public/components/model_version/__tests__/model_version.test.tsx b/public/components/model_version/__tests__/model_version.test.tsx new file mode 100644 index 00000000..91d403ad --- /dev/null +++ b/public/components/model_version/__tests__/model_version.test.tsx @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { generatePath, Route } from 'react-router-dom'; +import userEvent from '@testing-library/user-event'; + +import { ModelVersion } from '../model_version'; +import { render, screen, waitFor, mockOffsetMethods } from '../../../../test/test_utils'; +import { routerPaths } from '../../../../common/router_paths'; + +const setup = () => + render( + + + , + { route: generatePath(routerPaths.modelVersion, { id: '1' }) } + ); + +describe('', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should display consistent model name and version', async () => { + setup(); + + await waitFor(() => { + expect(screen.getByText('model1')).toBeInTheDocument(); + expect(screen.getByText('v1.0.0')).toBeInTheDocument(); + }); + }); + + it('should display loading screen during Model.getOne calling', async () => { + setup(); + + expect(screen.getByTestId('modelVersionLoadingSpinner')).toBeInTheDocument(); + await waitFor(() => { + expect(screen.queryByTestId('modelVersionLoadingSpinner')).not.toBeInTheDocument(); + }); + }); + + it('should display v1.0.1 and update location.pathname after version selected', async () => { + const mockRest = mockOffsetMethods(); + const user = userEvent.setup(); + + setup(); + + await waitFor(() => { + expect(screen.getByText('v1.0.0')).toBeInTheDocument(); + }); + await user.click(screen.getByText('v1.0.0')); + + await user.click(screen.getByText('1.0.1')); + + await waitFor(() => { + expect(screen.getByText('v1.0.1')).toBeInTheDocument(); + }); + expect(location.pathname).toBe('/model-registry/model-version/2'); + + mockRest(); + }); +}); diff --git a/public/components/model_version/__tests__/version_toggler.test.tsx b/public/components/model_version/__tests__/version_toggler.test.tsx new file mode 100644 index 00000000..e0526c3e --- /dev/null +++ b/public/components/model_version/__tests__/version_toggler.test.tsx @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +import React from 'react'; +import userEvent from '@testing-library/user-event'; +import { render, screen, within, mockOffsetMethods } from '../../../../test/test_utils'; +import { VersionToggler } from '../version_toggler'; + +describe('', () => { + let mockReset: Function; + beforeEach(() => { + mockReset = mockOffsetMethods(); + }); + + afterEach(() => { + mockReset(); + }); + + it('should show currentVersion and version list', async () => { + const user = userEvent.setup(); + render( + + ); + + expect(screen.getByText('v1.0.0')).toBeInTheDocument(); + await user.click(screen.getByText('v1.0.0')); + expect(within(screen.getAllByRole('option')[0]).getByText('1.0.0')).toBeInTheDocument(); + }); + + it('should call onVersionChange with consistent params', async () => { + const user = userEvent.setup(); + const onVersionChange = jest.fn(); + render( + + ); + + await user.click(screen.getByText('v1.0.0')); + await user.click(screen.getByText('1.0.1')); + expect(onVersionChange).toHaveBeenCalledWith({ + newVersion: '1.0.1', + newId: '2', + }); + }); +}); diff --git a/public/components/model_version/index.ts b/public/components/model_version/index.ts new file mode 100644 index 00000000..aa0eb126 --- /dev/null +++ b/public/components/model_version/index.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export { ModelVersion } from './model_version'; diff --git a/public/components/model_version/model_version.tsx b/public/components/model_version/model_version.tsx new file mode 100644 index 00000000..7b0c92df --- /dev/null +++ b/public/components/model_version/model_version.tsx @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useEffect, useCallback } from 'react'; +import { + EuiButton, + EuiPageHeader, + EuiFlexGroup, + EuiFlexItem, + EuiLoadingSpinner, +} from '@elastic/eui'; +import { generatePath, useHistory, useParams } from 'react-router-dom'; + +import { useFetcher } from '../../hooks'; +import { APIProvider } from '../../apis/api_provider'; +import { routerPaths } from '../../../common/router_paths'; +import { VersionToggler } from './version_toggler'; + +export const ModelVersion = () => { + const { id: modelId } = useParams<{ id: string }>(); + const { data: model } = useFetcher(APIProvider.getAPI('model').getOne, modelId); + const [modelInfo, setModelInfo] = useState<{ version: string; name: string }>(); + const history = useHistory(); + const modelName = model?.name; + const modelVersion = model?.model_version; + + const onVersionChange = useCallback( + ({ newVersion, newId }: { newVersion: string; newId: string }) => { + setModelInfo((prevModelInfo) => + prevModelInfo ? { ...prevModelInfo, version: newVersion } : prevModelInfo + ); + history.push(generatePath(routerPaths.modelVersion, { id: newId })); + }, + [history] + ); + + useEffect(() => { + if (!modelName || !modelVersion) { + return; + } + setModelInfo((prevModelInfo) => { + if (prevModelInfo?.name === modelName && prevModelInfo?.version === modelVersion) { + return prevModelInfo; + } + return { + name: modelName, + version: modelVersion, + }; + }); + }, [modelName, modelVersion]); + + if (!modelInfo) { + return ; + } + return ( + <> + + {modelInfo.name} + + + + + } + rightSideGroupProps={{ + gutterSize: 'm', + }} + rightSideItems={[ + Register version, + Edit, + Deploy, + Delete, + ]} + /> + + ); +}; diff --git a/public/components/model_version/version_toggler.tsx b/public/components/model_version/version_toggler.tsx new file mode 100644 index 00000000..7bab0c9f --- /dev/null +++ b/public/components/model_version/version_toggler.tsx @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useCallback, useMemo } from 'react'; +import { + EuiIcon, + EuiText, + EuiFlexGroup, + EuiFlexItem, + EuiPopover, + EuiSelectable, + EuiSelectableOption, +} from '@elastic/eui'; + +import { useFetcher } from '../../hooks'; +import { APIProvider } from '../../apis/api_provider'; + +interface VersionTogglerProps { + modelName: string; + currentVersion: string; + onVersionChange: (version: { newVersion: string; newId: string }) => void; +} + +export const VersionToggler = ({ + modelName, + currentVersion, + onVersionChange, +}: VersionTogglerProps) => { + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const { data: versions } = useFetcher(APIProvider.getAPI('model').search, { + name: modelName, + from: 0, + // TODO: Implement scroll bottom load more once version toggler UX confirmed + size: 50, + }); + + const options = useMemo(() => { + return versions?.data.map(({ model_version: modelVersion, id }) => ({ + label: modelVersion, + checked: modelVersion === currentVersion ? ('on' as const) : undefined, + key: id, + })); + }, [versions, currentVersion]); + + const openPopover = useCallback(() => { + setIsPopoverOpen((isOpen) => !isOpen); + }, []); + + const closePopover = useCallback(() => { + setIsPopoverOpen(false); + }, []); + + const onVersionSelectableChange = useCallback( + (newOptions: Array>) => { + const checkedOption = newOptions.find(({ checked }) => checked === 'on'); + if (!checkedOption || !checkedOption.key || !checkedOption.label) { + return; + } + onVersionChange({ newVersion: checkedOption.label, newId: checkedOption.key }); + }, + [onVersionChange] + ); + + return ( + + + v{currentVersion} + + + + + + } + closePopover={closePopover} + > + + {(list) =>
    {list}
    } +
    +
    + ); +}; diff --git a/test/mocks/handlers.ts b/test/mocks/handlers.ts index ec2ed385..dcdf9c38 100644 --- a/test/mocks/handlers.ts +++ b/test/mocks/handlers.ts @@ -6,6 +6,7 @@ import { rest } from 'msw'; import { modelConfig } from './data/model_config'; import { modelRepositoryResponse } from './data/model_repository'; +import { modelHandlers } from './model_handlers'; export const handlers = [ rest.get('/api/ml-commons/model-repository', (req, res, ctx) => { @@ -14,4 +15,5 @@ export const handlers = [ rest.get('/api/ml-commons/model-repository/config-url/:config_url', (req, res, ctx) => { return res(ctx.status(200), ctx.json(modelConfig)); }), + ...modelHandlers, ]; diff --git a/test/mocks/model_handlers.ts b/test/mocks/model_handlers.ts new file mode 100644 index 00000000..f81f3f5e --- /dev/null +++ b/test/mocks/model_handlers.ts @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { rest } from 'msw'; + +import { MODEL_API_ENDPOINT } from '../../server/routes/constants'; + +const models = [ + { + id: '1', + name: 'model1', + model_version: '1.0.0', + }, + { + id: '2', + model: 'model1', + model_version: '1.0.1', + }, + { + id: '3', + model: 'model2', + model_version: '1.0.0', + }, +]; + +export const modelHandlers = [ + rest.get(MODEL_API_ENDPOINT, (req, res, ctx) => { + const data = models.filter((model) => !req.params.name || model.name === req.params.name); + return res( + ctx.status(200), + ctx.json({ + data, + total_models: data.length, + }) + ); + }), + + rest.get(`${MODEL_API_ENDPOINT}/:modelId`, (req, res, ctx) => { + const [modelId, ..._restParts] = req.url.pathname.split('/').reverse(); + return res(ctx.status(200), ctx.json(models.find((model) => model.id === modelId))); + }), +]; diff --git a/test/test_utils.tsx b/test/test_utils.tsx index e8aa475d..95c08ff4 100644 --- a/test/test_utils.tsx +++ b/test/test_utils.tsx @@ -49,3 +49,31 @@ const customRender = ( export * from '@testing-library/react'; export { customRender as render }; + +export const mockOffsetMethods = () => { + const originalOffsetHeight = Object.getOwnPropertyDescriptor( + HTMLElement.prototype, + 'offsetHeight' + ); + const originalOffsetWidth = Object.getOwnPropertyDescriptor(HTMLElement.prototype, 'offsetWidth'); + Object.defineProperty(HTMLElement.prototype, 'offsetHeight', { + configurable: true, + value: 600, + }); + Object.defineProperty(HTMLElement.prototype, 'offsetWidth', { + configurable: true, + value: 600, + }); + return () => { + Object.defineProperty( + HTMLElement.prototype, + 'offsetHeight', + originalOffsetHeight as PropertyDescriptor + ); + Object.defineProperty( + HTMLElement.prototype, + 'offsetWidth', + originalOffsetWidth as PropertyDescriptor + ); + }; +}; From 60b11296886efe35d74f7a9c8e6ab334b0dbe99a Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Wed, 19 Apr 2023 17:38:47 +0800 Subject: [PATCH 37/75] feat: support for adding tag types (#161) User can select tag type from `number` or `string` --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../__tests__/error_call_out.test.tsx | 4 - .../register_model_artifact.test.tsx | 4 - .../register_model_configuration.test.tsx | 4 - .../__tests__/register_model_details.test.tsx | 4 - .../register_model_file_format.test.tsx | 4 - .../__tests__/register_model_tags.test.tsx | 127 +++++++++++-- .../register_model_version_notes.test.tsx | 4 - .../register_model/__tests__/setup.tsx | 4 +- .../__tests__/tag_type_popover.test.tsx | 43 +++++ .../components/register_model/model_tags.tsx | 9 +- .../register_model/register_model.hooks.ts | 35 +++- .../register_model/register_model.tsx | 2 +- .../register_model/register_model.types.ts | 1 + .../components/register_model/tag_field.tsx | 179 ++++++++++++++---- .../register_model/tag_type_popover.tsx | 77 ++++++++ 15 files changed, 412 insertions(+), 89 deletions(-) create mode 100644 public/components/register_model/__tests__/tag_type_popover.test.tsx create mode 100644 public/components/register_model/tag_type_popover.tsx diff --git a/public/components/register_model/__tests__/error_call_out.test.tsx b/public/components/register_model/__tests__/error_call_out.test.tsx index 4821cc6f..726351d8 100644 --- a/public/components/register_model/__tests__/error_call_out.test.tsx +++ b/public/components/register_model/__tests__/error_call_out.test.tsx @@ -5,7 +5,6 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; -import * as formHooks from '../register_model.hooks'; import * as formAPI from '../register_model_api'; import { ModelFileUploadManager } from '../model_file_upload_manager'; @@ -15,9 +14,6 @@ describe(' ErrorCallOut', () => { const uploadMock = jest.fn(); beforeEach(() => { - jest - .spyOn(formHooks, 'useModelTags') - .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitWithFileMock); jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitWithURLMock); jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); diff --git a/public/components/register_model/__tests__/register_model_artifact.test.tsx b/public/components/register_model/__tests__/register_model_artifact.test.tsx index c22f6a2b..932f71e4 100644 --- a/public/components/register_model/__tests__/register_model_artifact.test.tsx +++ b/public/components/register_model/__tests__/register_model_artifact.test.tsx @@ -5,7 +5,6 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; -import * as formHooks from '../register_model.hooks'; import { ModelFileUploadManager } from '../model_file_upload_manager'; import * as formAPI from '../register_model_api'; import { ONE_GB } from '../../../../common/constant'; @@ -16,9 +15,6 @@ describe(' Artifact', () => { const uploadMock = jest.fn(); beforeEach(() => { - jest - .spyOn(formHooks, 'useModelTags') - .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitWithFileMock); jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitWithURLMock); jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); diff --git a/public/components/register_model/__tests__/register_model_configuration.test.tsx b/public/components/register_model/__tests__/register_model_configuration.test.tsx index 25b62bfa..dad7f926 100644 --- a/public/components/register_model/__tests__/register_model_configuration.test.tsx +++ b/public/components/register_model/__tests__/register_model_configuration.test.tsx @@ -6,7 +6,6 @@ import { screen, within } from '../../../../test/test_utils'; import { setup } from './setup'; import * as formAPI from '../register_model_api'; -import * as formHooks from '../register_model.hooks'; import { ModelFileUploadManager } from '../model_file_upload_manager'; describe(' Configuration', () => { @@ -15,9 +14,6 @@ describe(' Configuration', () => { const uploadMock = jest.fn(); beforeEach(() => { - jest - .spyOn(formHooks, 'useModelTags') - .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitWithFileMock); jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitWithURLMock); jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index 64beaf72..c441eb00 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -4,7 +4,6 @@ */ import { setup } from './setup'; -import * as formHooks from '../register_model.hooks'; import * as formAPI from '../register_model_api'; import { Model } from '../../../apis/model'; @@ -12,9 +11,6 @@ describe(' Details', () => { const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { - jest - .spyOn(formHooks, 'useModelTags') - .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); }); diff --git a/public/components/register_model/__tests__/register_model_file_format.test.tsx b/public/components/register_model/__tests__/register_model_file_format.test.tsx index e7b2e865..0ecd53f5 100644 --- a/public/components/register_model/__tests__/register_model_file_format.test.tsx +++ b/public/components/register_model/__tests__/register_model_file_format.test.tsx @@ -5,7 +5,6 @@ import { screen } from '../../../../test/test_utils'; import { setup } from './setup'; -import * as formHooks from '../register_model.hooks'; import { ModelFileUploadManager } from '../model_file_upload_manager'; import * as formAPI from '../register_model_api'; @@ -15,9 +14,6 @@ describe(' Artifact', () => { const uploadMock = jest.fn(); beforeEach(() => { - jest - .spyOn(formHooks, 'useModelTags') - .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitWithFileMock); jest.spyOn(formAPI, 'submitModelWithURL').mockImplementation(onSubmitWithURLMock); jest.spyOn(ModelFileUploadManager.prototype, 'upload').mockImplementation(uploadMock); diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index b54f3158..1d1caff2 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -12,9 +12,13 @@ describe(' Tags', () => { const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { - jest - .spyOn(formHooks, 'useModelTags') - .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); + jest.spyOn(formHooks, 'useModelTags').mockReturnValue([ + false, + [ + { name: 'Key1', type: 'string', values: ['Value1'] }, + { name: 'Key2', type: 'number', values: [0.95] }, + ], + ]); jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); }); @@ -39,6 +43,20 @@ describe(' Tags', () => { expect(onSubmitMock).toHaveBeenCalled(); }); + it('tag value input should be disabled if tag key is empty', async () => { + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + expect(valueInput).toBeDisabled(); + + await result.user.type(keyInput, 'Key1{enter}'); + expect(valueInput).toBeEnabled(); + }); + it('should submit the form with selected tags', async () => { const result = await setup(); @@ -48,13 +66,13 @@ describe(' Tags', () => { const valueContainer = screen.getByTestId('ml-tagValue1'); const valueInput = within(valueContainer).getByRole('textbox'); - await result.user.type(keyInput, 'Key1'); - await result.user.type(valueInput, 'Value1'); + await result.user.type(keyInput, 'Key1{enter}'); + await result.user.type(valueInput, 'Value1{enter}'); await result.user.click(result.submitButton); expect(onSubmitMock).toHaveBeenCalledWith( - expect.objectContaining({ tags: [{ key: 'Key1', value: 'Value1' }] }) + expect.objectContaining({ tags: [{ key: 'Key1', value: 'Value1', type: 'string' }] }) ); }); @@ -74,9 +92,9 @@ describe(' Tags', () => { expect(onSubmitMock).toHaveBeenCalledWith( expect.objectContaining({ tags: [ - { key: '', value: '' }, - { key: '', value: '' }, - { key: '', value: '' }, + { key: '', value: '', type: 'string' }, + { key: '', value: '', type: 'string' }, + { key: '', value: '', type: 'string' }, ], }) ); @@ -103,16 +121,20 @@ describe(' Tags', () => { it('should NOT allow to submit tag which does NOT have key', async () => { const result = await setup(); + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await result.user.type(keyInput, 'Key1{enter}'); + const valueContainer = screen.getByTestId('ml-tagValue1'); const valueInput = within(valueContainer).getByRole('textbox'); - // only input value, but NOT key + // Input value, then clear key await result.user.type(valueInput, 'Value 1'); + await result.user.click(within(keyContainer).getByLabelText('Clear input')); + await result.user.click(result.submitButton); // tag key input should be invalid - const keyContainer = screen.getByTestId('ml-tagKey1'); - const keyInput = within(keyContainer).queryByText('A key is required. Enter a key.'); - expect(keyInput).toBeInTheDocument(); + expect(within(keyContainer).queryByText('A key is required. Enter a key.')).toBeInTheDocument(); // it should not submit the form expect(onSubmitMock).not.toHaveBeenCalled(); @@ -199,12 +221,20 @@ describe(' Tags', () => { expect(onSubmitMock).toHaveBeenCalledWith( expect.objectContaining({ - tags: [{ key: '', value: '' }], + tags: [{ key: '', value: '', type: 'string' }], }) ); }); it('should allow adding one more tag when registering new version if model group has only two tags', async () => { + jest.spyOn(formHooks, 'useModelTags').mockReturnValue([ + false, + [ + { name: 'Key1', type: 'string', values: ['Value1'] }, + { name: 'Key2', type: 'number', values: [0.95] }, + ], + ]); + const result = await setup({ route: '/foo', mode: 'version', @@ -251,6 +281,10 @@ describe(' Tags', () => { it('should display error when creating new tag value with more than 80 characters', async () => { const result = await setup(); + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await result.user.type(keyInput, 'dummy key{enter}'); + const valueContainer = screen.getByTestId('ml-tagValue1'); const valueInput = within(valueContainer).getByRole('textbox'); await result.user.type(valueInput, `${'x'.repeat(81)}{enter}`); @@ -260,7 +294,7 @@ describe(' Tags', () => { }); it('should display "No keys found" and "No values found" if no tag keys and no tag values are provided', async () => { - jest.spyOn(formHooks, 'useModelTags').mockReturnValue([false, { keys: [], values: [] }]); + jest.spyOn(formHooks, 'useModelTags').mockReturnValue([false, []]); const result = await setup(); @@ -269,26 +303,79 @@ describe(' Tags', () => { await result.user.click(keyInput); expect(screen.getByText('No keys found. Add a key.')).toBeInTheDocument(); + await result.user.type(keyInput, 'dummy key{enter}'); + const valueContainer = screen.getByTestId('ml-tagValue1'); const valueInput = within(valueContainer).getByRole('textbox'); await result.user.click(valueInput); expect(screen.getByText('No values found. Add a value.')).toBeInTheDocument(); }); - it('should only display "Key2" in the option list after "Key1" selected', async () => { + it('should NOT display "Key1" in the option list after "Key1" selected', async () => { const result = await setup(); const keyContainer = screen.getByTestId('ml-tagKey1'); const keyInput = within(keyContainer).getByRole('textbox'); await result.user.click(keyInput); const optionListContainer = screen.getByTestId('comboBoxOptionsList'); - - expect(within(optionListContainer).getByTitle('Key2')).toBeInTheDocument(); expect(within(optionListContainer).getByTitle('Key1')).toBeInTheDocument(); await result.user.click(within(optionListContainer).getByTitle('Key1')); - - expect(within(optionListContainer).getByTitle('Key2')).toBeInTheDocument(); expect(within(optionListContainer).queryByTitle('Key1')).toBe(null); }); + + it('should not allow to select tag type if selected an existed tag', async () => { + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const valueContainer = screen.getByTestId('ml-tagValue1'); + const keyInput = within(keyContainer).getByRole('textbox'); + + await result.user.click(keyInput); + // selected an existed tag + await result.user.click(within(screen.getByTestId('comboBoxOptionsList')).getByTitle('Key1')); + + expect(within(valueContainer).queryByLabelText('select tag type')).not.toBeInTheDocument(); + }); + + it('should display a list of tag value for selection after selecting a tag key', async () => { + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const valueContainer = screen.getByTestId('ml-tagValue1'); + const keyInput = within(keyContainer).getByRole('textbox'); + const valueInput = within(valueContainer).getByRole('textbox'); + + await result.user.click(keyInput); + // selected an existed tag + await result.user.click(within(screen.getByTestId('comboBoxOptionsList')).getByTitle('Key1')); + + await result.user.click(valueInput); + expect( + within(screen.getByTestId('comboBoxOptionsList')).queryByTitle('Value1') + ).toBeInTheDocument(); + }); + + it('should clear the tag input when click remove button if there is only one tag', async () => { + const result = await setup(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const valueContainer = screen.getByTestId('ml-tagValue1'); + const keyInput = within(keyContainer).getByRole('textbox'); + const valueInput = within(valueContainer).getByRole('textbox'); + + await result.user.click(keyInput); + // selected an existed tag + await result.user.click(within(screen.getByTestId('comboBoxOptionsList')).getByTitle('Key1')); + + await result.user.click(valueInput); + await result.user.click(within(screen.getByTestId('comboBoxOptionsList')).getByTitle('Value1')); + + expect(screen.queryByText('Key1', { selector: '.euiComboBoxPill' })).toBeInTheDocument(); + expect(screen.queryByText('Value1', { selector: '.euiComboBoxPill' })).toBeInTheDocument(); + + await result.user.click(screen.getByLabelText(/remove tag at row 1/i)); + expect(screen.queryByText('Key1', { selector: '.euiComboBoxPill' })).not.toBeInTheDocument(); + expect(screen.queryByText('Value1', { selector: '.euiComboBoxPill' })).not.toBeInTheDocument(); + }); }); diff --git a/public/components/register_model/__tests__/register_model_version_notes.test.tsx b/public/components/register_model/__tests__/register_model_version_notes.test.tsx index 304e1570..27977e1c 100644 --- a/public/components/register_model/__tests__/register_model_version_notes.test.tsx +++ b/public/components/register_model/__tests__/register_model_version_notes.test.tsx @@ -4,16 +4,12 @@ */ import { setup } from './setup'; -import * as formHooks from '../register_model.hooks'; import * as formAPI from '../register_model_api'; describe(' Version notes', () => { const onSubmitMock = jest.fn().mockResolvedValue('model_id'); beforeEach(() => { - jest - .spyOn(formHooks, 'useModelTags') - .mockReturnValue([false, { keys: ['Key1', 'Key2'], values: ['Value1', 'Value2'] }]); jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); }); diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index 650cd938..c35b78aa 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -10,7 +10,7 @@ import { UserEvent } from '@testing-library/user-event/dist/types/setup/setup'; import { RegisterModelForm } from '../register_model'; import { Model } from '../../../apis/model'; -import { render, RenderWithRouteProps, screen, waitFor, within } from '../../../../test/test_utils'; +import { render, RenderWithRouteProps, screen, waitFor } from '../../../../test/test_utils'; import { ModelFileFormData, ModelUrlFormData } from '../register_model.types'; jest.mock('../../../apis/task'); @@ -40,7 +40,7 @@ const DEFAULT_VALUES = { description: '', version: '1', configuration: CONFIGURATION, - tags: [{ key: '', value: '' }], + tags: [{ key: '', value: '', type: 'string' as const }], }; export async function setup(options: { diff --git a/public/components/register_model/__tests__/tag_type_popover.test.tsx b/public/components/register_model/__tests__/tag_type_popover.test.tsx new file mode 100644 index 00000000..87030bf2 --- /dev/null +++ b/public/components/register_model/__tests__/tag_type_popover.test.tsx @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../test/test_utils'; +import { TagTypePopover } from '../tag_type_popover'; + +describe('', () => { + it('should display tag type popover when clicking', async () => { + const user = userEvent.setup(); + + render(); + expect(screen.queryByText('String')).toBeInTheDocument(); + + // click button to show the popover + await user.click(screen.getByText('String')); + + expect(screen.getByLabelText('String')).toBeChecked(); + expect(screen.getByLabelText('Number')).not.toBeChecked(); + expect(screen.queryByText('Apply')).toBeInTheDocument(); + }); + + it('should call onApply', async () => { + const user = userEvent.setup(); + const onApplyMock = jest.fn(); + + render(); + expect(screen.queryByText('String')).toBeInTheDocument(); + + // click button to show the popover + await user.click(screen.getByText('String')); + + // select Number + await user.click(screen.getByLabelText('Number')); + await user.click(screen.getByText('Apply')); + + expect(onApplyMock).toHaveBeenCalledWith('number'); + }); +}); diff --git a/public/components/register_model/model_tags.tsx b/public/components/register_model/model_tags.tsx index df22fb28..391ae0a9 100644 --- a/public/components/register_model/model_tags.tsx +++ b/public/components/register_model/model_tags.tsx @@ -17,16 +17,16 @@ const MAX_TAG_NUM = 10; export const ModelTagsPanel = () => { const { control } = useFormContext(); const { id: latestVersionId } = useParams<{ id: string | undefined }>(); - const [, { keys, values }] = useModelTags(); + const [, tags] = useModelTags(); const { fields, append, remove } = useFieldArray({ name: 'tags', control, }); const isRegisterNewVersion = !!latestVersionId; - const maxTagNum = isRegisterNewVersion ? keys.length : MAX_TAG_NUM; + const maxTagNum = isRegisterNewVersion ? tags.length : MAX_TAG_NUM; const addNewTag = useCallback(() => { - append({ key: '', value: '' }); + append({ key: '', value: '', type: 'string' }); }, [append]); return ( @@ -62,8 +62,7 @@ export const ModelTagsPanel = () => { diff --git a/public/components/register_model/register_model.hooks.ts b/public/components/register_model/register_model.hooks.ts index 0cc31c56..7a56fe37 100644 --- a/public/components/register_model/register_model.hooks.ts +++ b/public/components/register_model/register_model.hooks.ts @@ -8,6 +8,39 @@ import { useEffect, useState } from 'react'; const keys = ['tag1', 'tag2']; const values = ['value1', 'value2']; +const results = [ + { + name: 'Accuracy: test', + type: 'number' as const, + values: [0.9, 0.8, 0.75], + }, + { + name: 'Accuracy: training', + type: 'number' as const, + values: [0.9, 0.8, 0.75], + }, + { + name: 'Accuracy: validation', + type: 'number' as const, + values: [0.9, 0.8, 0.75], + }, + { + name: 'Task', + type: 'string' as const, + values: [ + 'Computer vision', + 'Image classification', + 'Image-to-image', + 'Natural language processing', + ], + }, + { + name: 'Team', + type: 'string' as const, + values: ['IT', 'Finance', 'HR'], + }, +]; + /** * TODO: implement this function so that it retrieve tags from BE */ @@ -24,5 +57,5 @@ export const useModelTags = () => { }; }, []); - return [loading, { keys, values }] as const; + return [loading, results] as const; }; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index d3ec94f1..619f3925 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -49,7 +49,7 @@ const DEFAULT_VALUES = { description: '', version: '1', configuration: '', - tags: [{ key: '', value: '' }], + tags: [{ key: '', value: '', type: 'string' as const }], modelFileFormat: '', }; diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index b359eb8d..27064f9d 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -6,6 +6,7 @@ export interface Tag { key: string; value: string; + type: 'number' | 'string'; } interface ModelFormBase { diff --git a/public/components/register_model/tag_field.tsx b/public/components/register_model/tag_field.tsx index e076a955..484a5d85 100644 --- a/public/components/register_model/tag_field.tsx +++ b/public/components/register_model/tag_field.tsx @@ -4,25 +4,35 @@ */ import { - EuiButton, EuiComboBox, EuiComboBoxOptionOption, EuiFlexGroup, EuiFlexItem, EuiFormRow, EuiContext, + EuiButtonIcon, + EuiFieldNumber, + EuiText, + EuiToken, + EuiToolTip, } from '@elastic/eui'; import React, { useCallback, useMemo, useRef } from 'react'; import { useController, useWatch, useFormContext } from 'react-hook-form'; import { FORM_ITEM_WIDTH } from './form_constants'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { TagTypePopover } from './tag_type_popover'; + +interface TagGroup { + name: string; + type: 'string' | 'number'; + values: string[] | number[]; +} interface ModelTagFieldProps { index: number; onDelete: (index: number) => void; - tagKeys: string[]; - tagValues: string[]; allowKeyCreate?: boolean; + tagGroups: TagGroup[]; } const MAX_TAG_LENGTH = 80; @@ -39,7 +49,7 @@ const VALUE_COMBOBOX_I18N = { }, }; -function getComboBoxValue(data: EuiComboBoxOptionOption[]) { +function getComboBoxValue(data: Array>) { if (data.length === 0) { return ''; } else { @@ -49,8 +59,7 @@ function getComboBoxValue(data: EuiComboBoxOptionOption[]) { export const ModelTagField = ({ index, - tagKeys, - tagValues, + tagGroups, allowKeyCreate, onDelete, }: ModelTagFieldProps) => { @@ -113,20 +122,51 @@ export const ModelTagField = ({ }, }); + const selectedTagGroup = useMemo( + () => tagGroups.find((t) => t.name === tagKeyController.field.value), + [tagGroups, tagKeyController] + ); + + const tagTypeController = useController({ + name: `tags.${index}.type` as const, + control, + }); + const onKeyChange = useCallback( - (data: EuiComboBoxOptionOption[]) => { - tagKeyController.field.onChange(getComboBoxValue(data)); + (data: Array>) => { + const tagKey = getComboBoxValue(data); + tagKeyController.field.onChange(tagKey); + + // update tag type if selected an existed tag + const tagGroup = tagGroups.find((t) => t.name === tagKey); + if (tagGroup) { + tagTypeController.field.onChange(tagGroup.type); + } }, - [tagKeyController.field] + [tagKeyController.field, tagTypeController.field, tagGroups] ); - const onValueChange = useCallback( - (data: EuiComboBoxOptionOption[]) => { + const onStringValueChange = useCallback( + (data: Array>) => { tagValueController.field.onChange(getComboBoxValue(data)); }, [tagValueController.field] ); + const onNumberValueChange = useCallback( + (e: React.ChangeEvent) => { + tagValueController.field.onChange(e.target.value); + }, + [tagValueController.field] + ); + + const onApplyType = useCallback( + (type: 'number' | 'string') => { + tagTypeController.field.onChange(type); + }, + [tagTypeController.field] + ); + const onKeyCreate = useCallback( (value: string) => { tagKeyController.field.onChange(value); @@ -142,14 +182,17 @@ export const ModelTagField = ({ ); const keyOptions = useMemo(() => { - return tagKeys - .filter((key) => !tags?.find((tag) => tag.key === key)) - .map((key) => ({ label: key })); - }, [tagKeys, tags]); + return tagGroups + .filter((group) => !tags?.find((tag) => tag.key === group.name)) + .map((group) => ({ label: group.name, value: group })); + }, [tagGroups, tags]); const valueOptions = useMemo(() => { - return tagValues.map((value) => ({ label: value })); - }, [tagValues]); + if (selectedTagGroup) { + return selectedTagGroup.values.map((v) => ({ label: `${v}` })); + } + return []; + }, [selectedTagGroup]); const onBlur = useCallback( (e: React.FocusEvent) => { @@ -172,8 +215,38 @@ export const ModelTagField = ({ [trigger] ); + const renderOption = useCallback( + (option: EuiComboBoxOptionOption, searchValue: string, contentClassName: string) => { + return ( +
    + + {option.label} + + {option.value?.type} + +
    + ); + }, + [] + ); + + const onRemove = useCallback( + (idx: number) => { + if (tags?.length && tags.length > 1) { + onDelete(idx); + } else { + tagValueController.field.onChange(''); + tagKeyController.field.onChange(''); + } + }, + [tags, onDelete, tagKeyController.field, tagValueController.field] + ); + return ( - + - placeholder="Select or add a key" isInvalid={Boolean(tagKeyController.fieldState.error)} singleSelection={{ asPlainText: true }} options={keyOptions} + renderOption={renderOption} selectedOptions={ tagKeyController.field.value ? [{ label: tagKeyController.field.value }] : [] } @@ -207,27 +281,60 @@ export const ModelTagField = ({ isInvalid={Boolean(tagValueController.fieldState.error)} error={tagValueController.fieldState.error?.message} > - + {tagTypeController.field.value === 'string' ? ( + + ) + } + placeholder="Select or add a value" + isInvalid={Boolean(tagValueController.fieldState.error)} + singleSelection={{ asPlainText: true }} + options={valueOptions} + selectedOptions={ + tagValueController.field.value ? [{ label: tagValueController.field.value }] : [] + } + onChange={onStringValueChange} + onCreateOption={onValueCreate} + customOptionText="Add {searchValue} as a value." + onBlur={tagValueController.field.onBlur} + inputRef={tagValueController.field.ref} + isDisabled={!Boolean(tagKeyController.field.value)} + /> + ) : ( + + ) + } + placeholder="Add a value" + value={tagValueController.field.value} + isInvalid={Boolean(tagValueController.fieldState.error)} + onChange={onNumberValueChange} + onBlur={tagValueController.field.onBlur} + inputRef={tagValueController.field.ref} + disabled={!Boolean(tagKeyController.field.value)} + /> + )} - onDelete(index)}> - Remove - + 1 ? 'Remove' : 'Clear'}> + onRemove(index)} + /> + ); diff --git a/public/components/register_model/tag_type_popover.tsx b/public/components/register_model/tag_type_popover.tsx new file mode 100644 index 00000000..5bb5bbbd --- /dev/null +++ b/public/components/register_model/tag_type_popover.tsx @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useCallback } from 'react'; +import { + EuiButton, + EuiButtonEmpty, + EuiPopover, + EuiPopoverFooter, + EuiPopoverTitle, + EuiRadio, + htmlIdGenerator, +} from '@elastic/eui'; + +type TagValueType = 'number' | 'string'; + +interface TagTypePopoverProps { + value: TagValueType; + onApply: (type: TagValueType) => void; + disabled?: boolean; + className?: string; +} + +export const TagTypePopover = ({ value, onApply, disabled, className }: TagTypePopoverProps) => { + const [tagType, setTagType] = useState(value); + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + + const onApplyType = useCallback(() => { + onApply(tagType); + setIsPopoverOpen(false); + }, [tagType, onApply]); + + return ( + setIsPopoverOpen(!isPopoverOpen)} + size="xs" + iconType="arrowDown" + iconSide="right" + disabled={disabled} + > + {value === 'number' ? 'Number' : 'String'} + + } + closePopover={() => setIsPopoverOpen(false)} + isOpen={isPopoverOpen} + panelPaddingSize="s" + > + TAG TYPE + setTagType('string')} + /> + setTagType('number')} + /> + + + Apply + + + + ); +}; From 9121c4d92f69b8d1b2cd51b989908c179ae0a4fa Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Thu, 20 Apr 2023 14:33:11 +0800 Subject: [PATCH 38/75] Feature/replace model list stage filter with deployment toggle (#163) * feat: replace stage_filter with deployment toggle Signed-off-by: Lin Wang * chore: remove stage_filter in model list Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../__tests__/model_list_filter.test.tsx | 37 +++++++++++++++--- .../__tests__/stage_filter.test.tsx | 21 ---------- .../model_list/model_list_filter.tsx | 39 ++++++++++++++++--- public/components/model_list/stage_filter.tsx | 19 --------- 4 files changed, 65 insertions(+), 51 deletions(-) delete mode 100644 public/components/model_list/__tests__/stage_filter.test.tsx delete mode 100644 public/components/model_list/stage_filter.tsx diff --git a/public/components/model_list/__tests__/model_list_filter.test.tsx b/public/components/model_list/__tests__/model_list_filter.test.tsx index 020c5cf7..4cd7e0c9 100644 --- a/public/components/model_list/__tests__/model_list_filter.test.tsx +++ b/public/components/model_list/__tests__/model_list_filter.test.tsx @@ -4,28 +4,55 @@ */ import React from 'react'; +import userEvent from '@testing-library/user-event'; import { ModelListFilter } from '../model_list_filter'; import { render, screen } from '../../../../test/test_utils'; describe('', () => { - it('should render default search bar with tag, stage and owner filter', () => { - render( {}} />); + it('should render default search bar with tag, deployed and owner filter', () => { + render( {}} />); expect(screen.getByPlaceholderText('Search by name, person, or keyword')).toBeInTheDocument(); expect(screen.getByText('Tags')).toBeInTheDocument(); - expect(screen.getByText('Stage')).toBeInTheDocument(); expect(screen.getByText('Owner')).toBeInTheDocument(); + expect(screen.getByText('Deployed')).toBeInTheDocument(); + expect(screen.getByText('Undeployed')).toBeInTheDocument(); }); it('should render default search value and filter value', () => { render( {}} /> ); expect(screen.getByDisplayValue('foo')).toBeInTheDocument(); - expect(screen.queryAllByText('1')).toHaveLength(3); + expect(screen.queryAllByText('1')).toHaveLength(2); + expect(screen.getByText('Deployed')).not.toHaveClass('euiFilterButton-hasActiveFilters'); + expect(screen.getByText('Undeployed')).not.toHaveClass('euiFilterButton-hasActiveFilters'); + }); + + it('should call onChange with consistent deployed value', async () => { + const onChangeMock = jest.fn(); + const user = userEvent.setup(); + const { rerender } = render( + + ); + + await user.click(screen.getByText('Deployed')); + expect(onChangeMock).toHaveBeenCalledWith(expect.objectContaining({ deployed: true })); + + rerender( + + ); + await user.click(screen.getByText('Undeployed')); + expect(onChangeMock).toHaveBeenCalledWith(expect.objectContaining({ deployed: false })); + + rerender( + + ); + await user.click(screen.getByText('Undeployed')); + expect(onChangeMock).toHaveBeenCalledWith(expect.objectContaining({ deployed: undefined })); }); }); diff --git a/public/components/model_list/__tests__/stage_filter.test.tsx b/public/components/model_list/__tests__/stage_filter.test.tsx deleted file mode 100644 index 9992b02c..00000000 --- a/public/components/model_list/__tests__/stage_filter.test.tsx +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React from 'react'; - -import { render, screen } from '../../../../test/test_utils'; -import { StageFilter } from '../stage_filter'; - -describe('', () => { - afterEach(() => { - jest.resetAllMocks(); - }); - - it('should render "Stage" with 0 active filter for normal', () => { - render( {}} />); - expect(screen.getByText('Stage')).toBeInTheDocument(); - expect(screen.getByText('0')).toBeInTheDocument(); - }); -}); diff --git a/public/components/model_list/model_list_filter.tsx b/public/components/model_list/model_list_filter.tsx index fb621b03..4e92bb65 100644 --- a/public/components/model_list/model_list_filter.tsx +++ b/public/components/model_list/model_list_filter.tsx @@ -3,18 +3,23 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { EuiFlexItem, EuiFlexGroup, EuiFieldSearch, EuiFilterGroup } from '@elastic/eui'; +import { + EuiFlexItem, + EuiFlexGroup, + EuiFieldSearch, + EuiFilterGroup, + EuiFilterButton, +} from '@elastic/eui'; import React, { useCallback, useRef } from 'react'; import { TagFilter } from './tag_filter'; import { OwnerFilter } from './owner_filter'; -import { StageFilter } from './stage_filter'; export interface ModelListFilterFilterValue { search?: string; tag: string[]; owner: string[]; - stage: string[]; + deployed?: boolean; } export const ModelListFilter = ({ @@ -43,8 +48,18 @@ export const ModelListFilter = ({ onChangeRef.current({ ...valueRef.current, owner }); }, []); - const handleStageChange = useCallback((stage: string[]) => { - onChangeRef.current({ ...valueRef.current, stage }); + const handleDeployedClick = useCallback(() => { + onChangeRef.current({ + ...valueRef.current, + deployed: valueRef.current.deployed ? undefined : true, + }); + }, []); + + const handleUnDeployedClick = useCallback(() => { + onChangeRef.current({ + ...valueRef.current, + deployed: valueRef.current.deployed === false ? undefined : false, + }); }, []); return ( @@ -62,7 +77,19 @@ export const ModelListFilter = ({ - + + Deployed + + + Undeployed + diff --git a/public/components/model_list/stage_filter.tsx b/public/components/model_list/stage_filter.tsx deleted file mode 100644 index 9acb92e4..00000000 --- a/public/components/model_list/stage_filter.tsx +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React from 'react'; -import { ModelFilter, ModelFilterProps } from './model_filter'; - -export const StageFilter = ({ value, onChange }: Pick) => { - return ( - - ); -}; From cb0c0e834635ac9de3a27a61cbcb3d95a51ed886 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Sun, 23 Apr 2023 17:58:18 +0800 Subject: [PATCH 39/75] Feature/update tag filter (#162) * feat: add tag_filter_popover_content Signed-off-by: Lin Wang * feat: update tag_filter use tag_filter_popover_content Signed-off-by: Lin Wang * feat: remove resetAfterSaveOrCancel property in tag_filter_popover_content Signed-off-by: Lin Wang * feat: move tagKeys and tagKeysLoading to model_list_filter level Signed-off-by: Lin Wang * feat: add selected_tag_filter_panel Signed-off-by: Lin Wang * feat: add selected_tag_filter_panel to model_list_filter Signed-off-by: Lin Wang * test: increase timeout avoid model_list_filter test fail Signed-off-by: Lin Wang * test: increase timeout to fix tag_filter and selected_tag_filter_panel test failed Signed-off-by: Lin Wang * chore: address PR comments Signed-off-by: Lin Wang * test: add test case for tag_filter_popover_content when tagFilter provided Signed-off-by: Lin Wang * feat: add type to tagFilter Signed-off-by: Lin Wang * feat: update tag filter popover button UI and move right Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../selected_tag_filter_panel.test.tsx | 155 ++++++++++++ public/components/common/index.ts | 2 + .../common/selected_tag_filter_panel.tsx | 159 ++++++++++++ .../tag_filter_popover_content.test.tsx | 184 ++++++++++++++ .../__tests__/tag_value_selector.test.tsx | 59 +++++ .../tag_filter_popover_content/index.ts | 11 + .../tag_filter_popover_content.tsx | 233 ++++++++++++++++++ .../tag_value_selector.tsx | 46 ++++ .../__tests__/model_list_filter.test.tsx | 73 +++++- .../model_list/__tests__/tag_filter.test.tsx | 121 ++++++++- public/components/model_list/model_filter.tsx | 4 +- .../components/model_list/model_list.hooks.ts | 31 +++ .../model_list/model_list_filter.tsx | 44 +++- public/components/model_list/tag_filter.tsx | 71 +++++- 14 files changed, 1165 insertions(+), 28 deletions(-) create mode 100644 public/components/common/__tests__/selected_tag_filter_panel.test.tsx create mode 100644 public/components/common/selected_tag_filter_panel.tsx create mode 100644 public/components/common/tag_filter_popover_content/__tests__/tag_filter_popover_content.test.tsx create mode 100644 public/components/common/tag_filter_popover_content/__tests__/tag_value_selector.test.tsx create mode 100644 public/components/common/tag_filter_popover_content/index.ts create mode 100644 public/components/common/tag_filter_popover_content/tag_filter_popover_content.tsx create mode 100644 public/components/common/tag_filter_popover_content/tag_value_selector.tsx create mode 100644 public/components/model_list/model_list.hooks.ts diff --git a/public/components/common/__tests__/selected_tag_filter_panel.test.tsx b/public/components/common/__tests__/selected_tag_filter_panel.test.tsx new file mode 100644 index 00000000..2fd2c64b --- /dev/null +++ b/public/components/common/__tests__/selected_tag_filter_panel.test.tsx @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../test/test_utils'; +import { SelectedTagFiltersPanel } from '../selected_tag_filter_panel'; +import { TagFilterOperator } from '../tag_filter_popover_content'; + +describe('', () => { + it('should render regular filter items', () => { + render( + + ); + + expect(screen.getByText('Task: Computer vision')); + expect(screen.getByText('Task is one of Computer vision, Image classification')); + expect(screen.getByText('F1: 0.98')); + expect(screen.getByText('F1 weighted > 0.99')); + expect(screen.getByText('F2 < 0.97')); + }); + + it('should render "NOT" filter items', () => { + render( + + ); + + expect(screen.getByTitle('NOT Task: Computer vision')); + expect(screen.getByTitle('NOT Task is one of Computer vision, Image classification')); + expect(screen.getByTitle('NOT F1: 0.98')); + }); + + it('should call onTagFiltersChange after filter item removed', async () => { + const user = userEvent.setup(); + const onTagFiltersChangeMock = jest.fn(); + render( + + ); + expect(onTagFiltersChangeMock).not.toHaveBeenCalled(); + + await user.click(screen.getByLabelText('Remove filter')); + expect(onTagFiltersChangeMock).toHaveBeenCalledWith([]); + }); + + it( + 'should onTagFiltersChange after filter item updated', + async () => { + const user = userEvent.setup(); + const onTagFiltersChangeMock = jest.fn(); + render( + + ); + expect(onTagFiltersChangeMock).not.toHaveBeenCalled(); + + await user.click(screen.getByTitle('NOT Task: Computer vision')); + await user.click(screen.getByText('Computer vision')); + await user.click(screen.getByRole('option', { name: 'Image classification' })); + await user.click(screen.getByText('Save')); + + expect(onTagFiltersChangeMock).toHaveBeenCalledWith([ + { + name: 'Task', + operator: TagFilterOperator.IsNot, + value: 'Image classification', + type: 'string', + }, + ]); + }, + // There are too many operations, need to increase timeout + 10 * 1000 + ); +}); diff --git a/public/components/common/index.ts b/public/components/common/index.ts index 8ece445d..5040dd86 100644 --- a/public/components/common/index.ts +++ b/public/components/common/index.ts @@ -5,3 +5,5 @@ export * from './custom'; export * from './copyable_text'; +export * from './tag_filter_popover_content'; +export * from './selected_tag_filter_panel'; diff --git a/public/components/common/selected_tag_filter_panel.tsx b/public/components/common/selected_tag_filter_panel.tsx new file mode 100644 index 00000000..4d233303 --- /dev/null +++ b/public/components/common/selected_tag_filter_panel.tsx @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useState } from 'react'; +import { EuiFlexGroup, EuiFlexItem, EuiBadge, EuiTextColor, EuiPopover } from '@elastic/eui'; + +import { + TagFilterValue, + TagFilterOperator, + TagFilterPopoverContent, + TagFilterPopoverContentProps, +} from './tag_filter_popover_content'; + +const generateFilterLabel = (tag: TagFilterValue) => { + const texts = [tag.name]; + texts.push( + ({ + [TagFilterOperator.IsLessThan]: ' < ', + [TagFilterOperator.IsGreaterThan]: ' > ', + [TagFilterOperator.IsOneOf]: ' is one of ', + [TagFilterOperator.IsNotOneOf]: ' is one of ', + } as Record)[tag.operator] || ': ' + ); + const text = `${texts.join('')}${ + Array.isArray(tag.value) ? `${tag.value.join(', ')}` : tag.value + }`; + + return tag.operator === TagFilterOperator.IsNot || + tag.operator === TagFilterOperator.IsNotOneOf ? ( + <> + {'NOT '} + {text} + + ) : ( + text + ); +}; + +interface SelectedTagFilterItemProps extends Pick { + filter: TagFilterValue; + index: number; + onRemove: (index: number) => void; + onChange: (index: number, newValue: TagFilterValue) => void; +} + +const SelectedTagFilterItem = ({ + filter, + tagKeys, + index, + onRemove, + onChange, +}: SelectedTagFilterItemProps) => { + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + + const handleClose = useCallback(() => { + onRemove(index); + }, [onRemove, index]); + + const closePopover = useCallback(() => { + setIsPopoverOpen(false); + }, []); + + const handleClick = useCallback(() => { + setIsPopoverOpen((prev) => !prev); + }, []); + + const handleSave = useCallback( + (newValue: TagFilterValue) => { + closePopover(); + onChange(index, newValue); + }, + [closePopover, onChange, index] + ); + + return ( + + + {generateFilterLabel(filter)} + + } + isOpen={isPopoverOpen} + closePopover={closePopover} + initialFocus={false} + > + + + + ); +}; + +interface SelectedTagFiltersPanelProps extends Pick { + tagFilters: TagFilterValue[]; + onTagFiltersChange: (newFilters: TagFilterValue[]) => void; +} + +export const SelectedTagFiltersPanel = ({ + tagKeys, + tagFilters, + onTagFiltersChange, +}: SelectedTagFiltersPanelProps) => { + const handleChange = useCallback( + (index: number, newFilterValue: TagFilterValue) => { + onTagFiltersChange([ + ...tagFilters.slice(0, index), + newFilterValue, + ...tagFilters.slice(index + 1), + ]); + }, + [tagFilters, onTagFiltersChange] + ); + const handleRemove = useCallback( + (index: number) => { + onTagFiltersChange([...tagFilters.slice(0, index), ...tagFilters.slice(index + 1)]); + }, + [tagFilters, onTagFiltersChange] + ); + return ( + + {tagFilters.map((tagFilter, index) => ( + + ))} + + ); +}; diff --git a/public/components/common/tag_filter_popover_content/__tests__/tag_filter_popover_content.test.tsx b/public/components/common/tag_filter_popover_content/__tests__/tag_filter_popover_content.test.tsx new file mode 100644 index 00000000..e7d1efbe --- /dev/null +++ b/public/components/common/tag_filter_popover_content/__tests__/tag_filter_popover_content.test.tsx @@ -0,0 +1,184 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { + TagFilterOperator, + TagFilterPopoverContent, + TagFilterPopoverContentProps, +} from '../tag_filter_popover_content'; +import { render, screen } from '../../../../../test/test_utils'; + +const setup = (options?: Partial) => { + const user = userEvent.setup(); + const onCancelMock = jest.fn(); + const onSaveMock = jest.fn(); + const renderResult = render( + + ); + const tagKeySelector = screen.getByText('Select a tag key'); + const operatorSelector = screen.getByText('Select operator'); + const cancelButton = screen.getByRole('button', { name: 'Cancel' }); + const saveButton = screen.getByRole('button', { name: 'Save' }); + + return { + user, + renderResult, + tagKeySelector, + operatorSelector, + cancelButton, + saveButton, + onCancelMock, + onSaveMock, + }; +}; + +describe('', () => { + it('should show tag key selector, operator selector, cancel button and save button by default', async () => { + const { tagKeySelector, operatorSelector, cancelButton, saveButton } = setup(); + + expect(tagKeySelector).toBeInTheDocument(); + expect(operatorSelector).toBeInTheDocument(); + expect(cancelButton).toBeInTheDocument(); + expect(saveButton).toBeInTheDocument(); + expect(saveButton).toBeDisabled(); + }); + + it('should display string tag key operators after string tag key selected', async () => { + const { user, tagKeySelector, operatorSelector } = setup(); + + await user.click(tagKeySelector); + await user.click(screen.getByRole('option', { name: 'foo' })); + + await user.click(operatorSelector); + expect(screen.getByRole('option', { name: 'is' })); + expect(screen.getByRole('option', { name: 'is not' })); + expect(screen.getByRole('option', { name: 'is one of' })); + expect(screen.getByRole('option', { name: 'is not one of' })); + }); + + it('should display number tag key operators after number tag key selected', async () => { + const { user, tagKeySelector, operatorSelector } = setup(); + + await user.click(tagKeySelector); + await user.click(screen.getByRole('option', { name: 'bar' })); + + await user.click(operatorSelector); + expect(screen.getByRole('option', { name: 'is' })); + expect(screen.getByRole('option', { name: 'is not' })); + expect(screen.getByRole('option', { name: 'is greater than' })); + expect(screen.getByRole('option', { name: 'is less than' })); + }); + + it('should display number input and call onSave with number value', async () => { + const { user, tagKeySelector, operatorSelector, saveButton, onSaveMock } = setup(); + await user.click(tagKeySelector); + await user.click(screen.getByRole('option', { name: 'bar' })); + + await user.click(operatorSelector); + await user.click(screen.getByRole('option', { name: 'is greater than' })); + + const valueInput = screen.getByPlaceholderText('Add a value'); + + expect(valueInput).toBeInTheDocument(); + await user.type(valueInput, '0.98'); + await user.click(saveButton); + + expect(onSaveMock).toHaveBeenCalledWith({ + name: 'bar', + operator: 'is greater than', + value: 0.98, + type: 'number', + }); + }); + + it('should display value selector and call onSave with string value', async () => { + const { user, tagKeySelector, operatorSelector, saveButton, onSaveMock } = setup(); + await user.click(tagKeySelector); + await user.click(screen.getByRole('option', { name: 'foo' })); + + await user.click(operatorSelector); + await user.click(screen.getByRole('option', { name: 'is not' })); + + const valueSelector = screen.getByText('Select a value'); + + expect(valueSelector).toBeInTheDocument(); + await user.click(valueSelector); + await user.click(screen.getByRole('option', { name: 'Computer vision' })); + await user.click(saveButton); + + expect(onSaveMock).toHaveBeenCalledWith({ + name: 'foo', + operator: 'is not', + value: 'Computer vision', + type: 'string', + }); + }); + + it('should display value selector and call onSave with string array value', async () => { + const { user, tagKeySelector, operatorSelector, saveButton, onSaveMock } = setup(); + await user.click(tagKeySelector); + await user.click(screen.getByRole('option', { name: 'foo' })); + + await user.click(operatorSelector); + await user.click(screen.getByRole('option', { name: 'is one of' })); + + const valueSelector = screen.getByText('Select a value'); + + expect(valueSelector).toBeInTheDocument(); + await user.click(valueSelector); + await user.click(screen.getByRole('option', { name: 'Computer vision' })); + await user.click(screen.getByRole('option', { name: 'Image classification' })); + await user.click(saveButton); + + expect(onSaveMock).toHaveBeenCalledWith({ + name: 'foo', + operator: 'is one of', + value: ['Computer vision', 'Image classification'], + type: 'string', + }); + }); + + it('should call onCancel after cancel button clicked', async () => { + const { user, cancelButton, onCancelMock } = setup(); + + await user.click(cancelButton); + expect(onCancelMock).toHaveBeenCalled(); + }); + + it('should render edit title and value if tagFilter provided', async () => { + render( + + ); + + expect(screen.getByText('EDIT TAG FILTER')).toBeInTheDocument(); + expect(screen.getByText('bar')).toBeInTheDocument(); + expect(screen.getByText('is greater than')).toBeInTheDocument(); + expect(screen.getByPlaceholderText('Add a value')).toHaveValue(0.98); + }); +}); diff --git a/public/components/common/tag_filter_popover_content/__tests__/tag_value_selector.test.tsx b/public/components/common/tag_filter_popover_content/__tests__/tag_value_selector.test.tsx new file mode 100644 index 00000000..67a60c62 --- /dev/null +++ b/public/components/common/tag_filter_popover_content/__tests__/tag_value_selector.test.tsx @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { TagValueSelector } from '../tag_value_selector'; +import { render, screen } from '../../../../../test/test_utils'; + +describe('', () => { + it('should display value selector and tag value options', async () => { + const user = userEvent.setup(); + render(); + const selector = screen.getByText('Select a value'); + + expect(selector).toBeInTheDocument(); + + await user.click(selector); + expect(screen.getByRole('option', { name: 'Computer vision' })).toBeInTheDocument(); + expect(screen.getByRole('option', { name: 'Image classification' })).toBeInTheDocument(); + }); + + it('should remove selected value in options list and display in the selector', async () => { + const user = userEvent.setup(); + render(); + + expect(screen.getByText('Computer vision')).toBeInTheDocument(); + + await user.click(screen.getByTestId('comboBoxToggleListButton')); + expect(screen.queryByRole('option', { name: 'Computer vision' })).toBeNull(); + expect(screen.getByRole('option', { name: 'Image classification' })).toBeInTheDocument(); + }); + + it('should call onChange with selected values', async () => { + const user = userEvent.setup(); + const onChangeMock = jest.fn(); + render(); + + await user.click(screen.getByTestId('comboBoxToggleListButton')); + await user.click(screen.getByRole('option', { name: 'Computer vision' })); + + expect(onChangeMock).toHaveBeenCalledWith(['Image classification', 'Computer vision']); + }); + + it('should call onChange with string value', async () => { + const user = userEvent.setup(); + const onChangeMock = jest.fn(); + render( + + ); + + await user.click(screen.getByTestId('comboBoxToggleListButton')); + await user.click(screen.getByRole('option', { name: 'Computer vision' })); + + expect(onChangeMock).toHaveBeenCalledWith('Computer vision'); + }); +}); diff --git a/public/components/common/tag_filter_popover_content/index.ts b/public/components/common/tag_filter_popover_content/index.ts new file mode 100644 index 00000000..42baa474 --- /dev/null +++ b/public/components/common/tag_filter_popover_content/index.ts @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export { + TagFilterPopoverContent, + TagFilterValue, + TagFilterOperator, + TagFilterPopoverContentProps, +} from './tag_filter_popover_content'; diff --git a/public/components/common/tag_filter_popover_content/tag_filter_popover_content.tsx b/public/components/common/tag_filter_popover_content/tag_filter_popover_content.tsx new file mode 100644 index 00000000..2319b845 --- /dev/null +++ b/public/components/common/tag_filter_popover_content/tag_filter_popover_content.tsx @@ -0,0 +1,233 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useMemo, useCallback } from 'react'; +import { + EuiPopoverTitle, + EuiFlexGroup, + EuiFlexItem, + EuiFormRow, + EuiComboBox, + EuiComboBoxOptionOption, + EuiToken, + EuiText, + EuiSpacer, + EuiButtonEmpty, + EuiButton, + EuiFieldNumber, +} from '@elastic/eui'; +import { TagValueSelector } from './tag_value_selector'; + +interface TagKey { + name: string; + type: 'string' | 'number'; +} + +export enum TagFilterOperator { + Is = 'is', + IsNot = 'is not', + IsGreaterThan = 'is greater than', + IsLessThan = 'is less than', + IsOneOf = 'is one of', + IsNotOneOf = 'is not one of', +} + +export interface TagFilterValue { + name: string; + operator: TagFilterOperator; + type: 'string' | 'number'; + value: string | string[] | number; +} + +export interface TagFilterPopoverContentProps { + tagFilter?: TagFilterValue; + tagKeys: TagKey[]; + onCancel: () => void; + onSave: (tagFilter: TagFilterValue) => void; +} + +const getValueInput = ( + operator: TagFilterOperator, + valueType: 'string' | 'number', + value: string | string[] | number | undefined, + onChange: (value: string | string[] | number | undefined) => void +) => { + if (valueType === 'string') { + return ( + + ); + } + return ( + { + const newValue = e.currentTarget.value; + onChange(newValue === '' ? undefined : parseFloat(newValue)); + }} + fullWidth + placeholder="Add a value" + /> + ); +}; + +export const TagFilterPopoverContent = ({ + onSave, + tagKeys, + onCancel, + tagFilter, +}: TagFilterPopoverContentProps) => { + const [value, setValue] = useState(tagFilter?.value); + const [selectedTagOptions, setSelectedTagOptions] = useState< + Array> + >(() => { + if (!tagFilter) { + return []; + } + return [ + { + label: tagFilter.name, + value: { + name: tagFilter.name, + type: tagFilter.type, + }, + }, + ]; + }); + const [selectedOperatorOptions, setSelectedOperatorOptions] = useState< + Array> + >(() => { + if (!tagFilter) { + return []; + } + return [ + { + label: tagFilter.operator, + }, + ]; + }); + const selectedTag = selectedTagOptions[0]?.value; + const selectedTagType = selectedTag?.type; + + const tagOptions = useMemo(() => tagKeys.map((item) => ({ label: item.name, value: item })), [ + tagKeys, + ]); + + const operatorsOptions = useMemo(() => { + if (!selectedTagType) { + return []; + } + return [ + TagFilterOperator.Is, + TagFilterOperator.IsNot, + ...(selectedTagType === 'string' + ? [TagFilterOperator.IsOneOf, TagFilterOperator.IsNotOneOf] + : []), + ...(selectedTagType === 'number' + ? [TagFilterOperator.IsGreaterThan, TagFilterOperator.IsLessThan] + : []), + ].map((label) => ({ + label, + })); + }, [selectedTagType]); + const operator = selectedOperatorOptions[0]?.label as TagFilterOperator; + + const tagKeyOptionRenderer = useCallback( + (option: EuiComboBoxOptionOption, _searchValue: string, contentClassName: string) => { + return ( +
    + + {option.label} + + {option.value?.type} + +
    + ); + }, + [] + ); + + const handleSave = useCallback(() => { + if (!selectedTag || !operator || !value) { + return; + } + onSave({ name: selectedTag.name, value, operator, type: selectedTag.type }); + }, [selectedTag, value, operator, onSave]); + + return ( + <> + {tagFilter ? 'EDIT' : 'ADD'} TAG FILTER +
    + + + + { + setSelectedTagOptions(e); + setSelectedOperatorOptions([]); + setValue(undefined); + }} + singleSelection={{ asPlainText: true }} + renderOption={tagKeyOptionRenderer} + /> + + + + + { + setSelectedOperatorOptions(e); + setValue(undefined); + }} + singleSelection={{ asPlainText: true }} + /> + + + + {operator && selectedTagType && ( + <> + + + {getValueInput(operator, selectedTagType, value, setValue)} + + + )} + + + + + Cancel + + + + + Save + + + +
    + + ); +}; diff --git a/public/components/common/tag_filter_popover_content/tag_value_selector.tsx b/public/components/common/tag_filter_popover_content/tag_value_selector.tsx new file mode 100644 index 00000000..db5ca01f --- /dev/null +++ b/public/components/common/tag_filter_popover_content/tag_value_selector.tsx @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useMemo, useCallback } from 'react'; +import { EuiComboBox } from '@elastic/eui'; + +interface TagValueSelectorProps { + value: string | string[] | undefined; + onChange: (value: string | string[] | undefined) => void; + singleSelection?: boolean; +} + +export const TagValueSelector = ({ value, onChange, singleSelection }: TagValueSelectorProps) => { + // TODO: Change to fetch value options via API + const [valueOptions] = useState([ + { + label: 'Computer vision', + }, + { + label: 'Image classification', + }, + ]); + const selectedValueOptions = useMemo( + () => valueOptions.filter((item) => item.label === value || value?.includes(item.label)), + [value, valueOptions] + ); + const handleChange = useCallback( + (options) => { + onChange(singleSelection ? options[0].label : options.map((item) => item.label)); + }, + [onChange, singleSelection] + ); + return ( + + ); +}; diff --git a/public/components/model_list/__tests__/model_list_filter.test.tsx b/public/components/model_list/__tests__/model_list_filter.test.tsx index 4cd7e0c9..2ab87eb8 100644 --- a/public/components/model_list/__tests__/model_list_filter.test.tsx +++ b/public/components/model_list/__tests__/model_list_filter.test.tsx @@ -7,30 +7,35 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; import { ModelListFilter } from '../model_list_filter'; -import { render, screen } from '../../../../test/test_utils'; +import { act, render, screen } from '../../../../test/test_utils'; +import { TagFilterOperator } from '../../common'; describe('', () => { it('should render default search bar with tag, deployed and owner filter', () => { render( {}} />); expect(screen.getByPlaceholderText('Search by name, person, or keyword')).toBeInTheDocument(); - expect(screen.getByText('Tags')).toBeInTheDocument(); + expect(screen.getByText('Add tag filter')).toBeInTheDocument(); expect(screen.getByText('Owner')).toBeInTheDocument(); expect(screen.getByText('Deployed')).toBeInTheDocument(); expect(screen.getByText('Undeployed')).toBeInTheDocument(); }); - it('should render default search value and filter value', () => { + it('should render default search value, filter value and selected tags panel', () => { render( {}} /> ); expect(screen.getByDisplayValue('foo')).toBeInTheDocument(); - expect(screen.queryAllByText('1')).toHaveLength(2); + expect(screen.queryAllByText('1')).toHaveLength(1); expect(screen.getByText('Deployed')).not.toHaveClass('euiFilterButton-hasActiveFilters'); expect(screen.getByText('Undeployed')).not.toHaveClass('euiFilterButton-hasActiveFilters'); + expect(screen.getByTitle('NOT tag1: 123')).toBeInTheDocument(); }); it('should call onChange with consistent deployed value', async () => { @@ -55,4 +60,62 @@ describe('', () => { await user.click(screen.getByText('Undeployed')); expect(onChangeMock).toHaveBeenCalledWith(expect.objectContaining({ deployed: undefined })); }); + + it( + 'should call onChange with unique tags after tag filter updated', + async () => { + jest.useFakeTimers(); + const onChangeMock = jest.fn(); + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + + render( + + ); + + act(() => { + jest.advanceTimersByTime(1000); + }); + + await user.click(screen.getByTitle('Accuracy: test: Image classification')); + + await user.click(screen.getByText('Image classification')); + await user.click(screen.getByRole('option', { name: 'Computer vision' })); + await user.click(screen.getByText('Save')); + + expect(onChangeMock).toHaveBeenCalledWith({ + tag: [ + { + name: 'Accuracy: test', + operator: TagFilterOperator.Is, + value: 'Computer vision', + type: 'string', + }, + ], + owner: ['owner1'], + }); + + jest.useRealTimers(); + }, + 10 * 1000 + ); }); diff --git a/public/components/model_list/__tests__/tag_filter.test.tsx b/public/components/model_list/__tests__/tag_filter.test.tsx index 9280b727..5b693ea5 100644 --- a/public/components/model_list/__tests__/tag_filter.test.tsx +++ b/public/components/model_list/__tests__/tag_filter.test.tsx @@ -4,18 +4,125 @@ */ import React from 'react'; +import userEvent from '@testing-library/user-event'; -import { render, screen } from '../../../../test/test_utils'; +import { render, screen, waitFor } from '../../../../test/test_utils'; import { TagFilter } from '../tag_filter'; describe('', () => { - afterEach(() => { - jest.resetAllMocks(); + it('should render "Add tag filter" button by default', () => { + render(); + expect(screen.queryByText('Add tag filter')).toBeInTheDocument(); }); - it('should render "Tags" with 0 active filter for normal', () => { - render( {}} />); - expect(screen.queryByText('Tags')).toBeInTheDocument(); - expect(screen.queryByText('0')).toBeInTheDocument(); + it( + 'should call onChange when applying tag filter', + async () => { + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + const onChangeMock = jest.fn(); + render( + + ); + + await user.click(screen.getByText('Add tag filter')); + await waitFor(() => { + expect(screen.getByText('Select a tag key')).toBeInTheDocument(); + }); + await user.click(screen.getByText('Select a tag key')); + await user.click(screen.getByRole('option', { name: 'F1' })); + await user.click(screen.getByText('Select operator')); + await user.click(screen.getByRole('option', { name: 'is' })); + await user.type(screen.getByPlaceholderText('Add a value'), '0.92', {}); + await user.click(screen.getByText('Save')); + + expect(onChangeMock).toHaveBeenCalledWith([ + { + name: 'F1', + operator: 'is', + value: 0.92, + type: 'number', + }, + ]); + }, + // There are too many operations, need to increase timeout + 10 * 1000 + ); + + it('should render an empty tag list if no tags', async () => { + const user = userEvent.setup(); + render(); + await user.click(screen.getByText('Add tag filter')); + + expect(screen.getByText('No options found')).toBeInTheDocument(); }); + + it('should render loading screen when tags are loading', async () => { + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + + const { rerender } = render( + + ); + await user.click(screen.getByText('Add tag filter')); + rerender( + + ); + + expect(screen.getByText('Loading filters')).toBeInTheDocument(); + }); + + it( + 'should reset input after popover re-open', + async () => { + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + const onChangeMock = jest.fn(); + render( + + ); + + await user.click(screen.getByText('Add tag filter')); + + await waitFor(() => { + expect(screen.getByText('Select a tag key')).toBeInTheDocument(); + }); + await user.click(screen.getByText('Select a tag key')); + await user.click(screen.getByRole('option', { name: 'F1' })); + await user.click(screen.getByRole('button', { name: 'Cancel' })); + + await waitFor(() => { + expect(screen.queryByRole('dialog')).toBeNull(); + }); + + await user.click(screen.getByText('Add tag filter')); + expect(screen.getByText('Select a tag key')).toBeInTheDocument(); + expect(screen.getByText('Select operator').closest('[role="combobox"]')).toHaveClass( + 'euiComboBox-isDisabled' + ); + }, + // There are too many operations, need to increase timeout + 10 * 1000 + ); }); diff --git a/public/components/model_list/model_filter.tsx b/public/components/model_list/model_filter.tsx index d8506891..c2ef3187 100644 --- a/public/components/model_list/model_filter.tsx +++ b/public/components/model_list/model_filter.tsx @@ -49,7 +49,7 @@ export const ModelFilter = ({ [searchText, options] ); - const hadleButtonClick = useCallback(() => { + const handleButtonClick = useCallback(() => { setIsPopoverOpen((prevState) => !prevState); }, []); @@ -71,7 +71,7 @@ export const ModelFilter = ({ button={ 0} diff --git a/public/components/model_list/model_list.hooks.ts b/public/components/model_list/model_list.hooks.ts new file mode 100644 index 00000000..d05891ed --- /dev/null +++ b/public/components/model_list/model_list.hooks.ts @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useEffect, useState } from 'react'; + +/** + * TODO: implement this function so that it retrieve tags from BE + */ +export const useModelTagKeys = () => { + const [loading, setLoading] = useState(true); + + useEffect(() => { + const timeoutId = window.setTimeout(() => { + setLoading(false); + }, 1000); + + return () => { + window.clearTimeout(timeoutId); + }; + }, []); + + return [ + loading, + [ + { name: 'Accuracy: test', type: 'string' as const }, + { name: 'F1', type: 'number' as const }, + ] as Array<{ name: string; type: 'string' | 'number' }>, + ] as const; +}; diff --git a/public/components/model_list/model_list_filter.tsx b/public/components/model_list/model_list_filter.tsx index 4e92bb65..2769777b 100644 --- a/public/components/model_list/model_list_filter.tsx +++ b/public/components/model_list/model_list_filter.tsx @@ -9,15 +9,34 @@ import { EuiFieldSearch, EuiFilterGroup, EuiFilterButton, + EuiSpacer, } from '@elastic/eui'; import React, { useCallback, useRef } from 'react'; +import { TagFilterValue, SelectedTagFiltersPanel } from '../common'; + import { TagFilter } from './tag_filter'; import { OwnerFilter } from './owner_filter'; +import { useModelTagKeys } from './model_list.hooks'; + +const removeDuplicateTag = (tagFilters: TagFilterValue[]) => { + const generateTagKey = (tagFilter: TagFilterValue) => + `${tagFilter.name}${tagFilter.operator}${tagFilter.value.toString()}`; + const existsTagMap: { [key: string]: boolean } = {}; + return tagFilters.filter((tagFilter) => { + const key = generateTagKey(tagFilter); + if (!existsTagMap[key]) { + existsTagMap[key] = true; + return true; + } + + return false; + }); +}; export interface ModelListFilterFilterValue { search?: string; - tag: string[]; + tag: TagFilterValue[]; owner: string[]; deployed?: boolean; } @@ -31,6 +50,8 @@ export const ModelListFilter = ({ value: Omit; onChange: (value: ModelListFilterFilterValue) => void; }) => { + // TODO: Change to model tags API + const [tagKeysLoading, tagKeys] = useModelTagKeys(); const valueRef = useRef(value); valueRef.current = value; const onChangeRef = useRef(onChange); @@ -40,8 +61,8 @@ export const ModelListFilter = ({ onChangeRef.current({ ...valueRef.current, search }); }, []); - const handleTagChange = useCallback((tag: string[]) => { - onChangeRef.current({ ...valueRef.current, tag }); + const handleTagChange = useCallback((tag: TagFilterValue[]) => { + onChangeRef.current({ ...valueRef.current, tag: removeDuplicateTag(tag) }); }, []); const handleOwnerChange = useCallback((owner: string[]) => { @@ -75,7 +96,6 @@ export const ModelListFilter = ({ - Undeployed + + {value.tag.length > 0 && ( + <> + + + + )} ); }; diff --git a/public/components/model_list/tag_filter.tsx b/public/components/model_list/tag_filter.tsx index 7f7d679d..7d27fefa 100644 --- a/public/components/model_list/tag_filter.tsx +++ b/public/components/model_list/tag_filter.tsx @@ -3,17 +3,68 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React from 'react'; -import { ModelFilter, ModelFilterProps } from './model_filter'; +import React, { useState, useCallback } from 'react'; +import { EuiFilterButton, EuiIcon, EuiLoadingChart, EuiPopover, EuiSpacer } from '@elastic/eui'; + +import { TagFilterPopoverContent, TagFilterValue, TagFilterPopoverContentProps } from '../common'; + +interface TagFilterProps extends Pick { + tagKeysLoading: boolean; + value: TagFilterValue[]; + onChange: (value: TagFilterValue[]) => void; +} + +export const TagFilter = ({ value, onChange, tagKeys, tagKeysLoading }: TagFilterProps) => { + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + + const closePopover = useCallback(() => { + setIsPopoverOpen(false); + }, []); + + const handleFilterButtonClick = useCallback(() => { + setIsPopoverOpen((prevOpenState) => !prevOpenState); + }, []); + + const handleSave = useCallback( + (tagFilter) => { + onChange([...value, tagFilter]); + closePopover(); + }, + [value, onChange, closePopover] + ); -export const TagFilter = ({ value, onChange }: Pick) => { return ( - + + Add tag filter + + } + isOpen={isPopoverOpen} + closePopover={closePopover} + initialFocus={false} + > + {!tagKeysLoading && tagKeys.length > 0 && ( + + )} + {tagKeysLoading && ( +
    +
    + + +

    Loading filters

    +
    +
    + )} + {tagKeys.length === 0 && ( +
    +
    + + +

    No options found

    +
    +
    + )} + ); }; From 03e554bed6c122a4e73d24ad27076ab8d6241ed6 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Wed, 26 Apr 2023 10:59:22 +0800 Subject: [PATCH 40/75] Feature/update model detail page layout (#166) * feat: add model_group related panels and cards Signed-off-by: Lin Wang * feat: add cards and panels to model group detail Signed-off-by: Lin Wang * test: add TZ=UTC to avoid timezone error in runner Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- package.json | 4 +- public/apis/model.ts | 2 + .../__tests__/model_group.test.tsx | 61 +++++++++++++ .../model_group_overview_card.test.tsx | 40 +++++++++ public/components/model_group/model_group.tsx | 76 +++++++++++++--- .../model_group/model_group_details_panel.tsx | 19 ++++ .../model_group/model_group_overview_card.tsx | 86 +++++++++++++++++++ .../model_group/model_group_tags_panel.tsx | 19 ++++ .../model_group_versions_panel.tsx | 17 ++++ 9 files changed, 310 insertions(+), 14 deletions(-) create mode 100644 public/components/model_group/__tests__/model_group.test.tsx create mode 100644 public/components/model_group/__tests__/model_group_overview_card.test.tsx create mode 100644 public/components/model_group/model_group_details_panel.tsx create mode 100644 public/components/model_group/model_group_overview_card.tsx create mode 100644 public/components/model_group/model_group_tags_panel.tsx create mode 100644 public/components/model_group/model_group_versions_panel.tsx diff --git a/package.json b/package.json index 9475571a..32e3ba23 100644 --- a/package.json +++ b/package.json @@ -10,8 +10,8 @@ "plugin-helpers": "node ../../scripts/plugin_helpers", "osd": "node ../../scripts/osd", "lint:es": "node ../../scripts/eslint", - "test:jest": "../../node_modules/.bin/jest --config ./test/jest.config.js", - "test:watch": "../../node_modules/.bin/jest --config ./test/jest.config.js --watch", + "test:jest": "TZ=UTC ../../node_modules/.bin/jest --config ./test/jest.config.js", + "test:watch": "TZ=UTC ../../node_modules/.bin/jest --config ./test/jest.config.js --watch", "prepare": "husky install" }, "husky": { diff --git a/public/apis/model.ts b/public/apis/model.ts index 6ea9be65..4d63b238 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -37,6 +37,8 @@ export interface ModelSearchItem { export interface ModelDetail extends ModelSearchItem { content: string; + last_updated_time: number; + created_time: number; } export interface ModelSearchResponse { diff --git a/public/components/model_group/__tests__/model_group.test.tsx b/public/components/model_group/__tests__/model_group.test.tsx new file mode 100644 index 00000000..2df4b7e7 --- /dev/null +++ b/public/components/model_group/__tests__/model_group.test.tsx @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen, waitFor, within } from '../../../../test/test_utils'; +import { ModelGroup } from '../model_group'; +import { routerPaths } from '../../../../common/router_paths'; +import { Route, generatePath } from 'react-router-dom'; + +const setup = () => { + const renderResult = render( + + + , + { route: generatePath(routerPaths.modelGroup, { id: '1' }) } + ); + + return { + renderResult, + }; +}; + +describe('', () => { + it('should display model name, action buttons, overview-card, tabs and tabpanel after data loaded', async () => { + setup(); + + await waitFor(() => { + expect(screen.queryByTestId('model-group-loading-indicator')).toBeNull(); + }); + expect(screen.getByText('model1')).toBeInTheDocument(); + expect(screen.getByText('Delete')).toBeInTheDocument(); + expect(screen.getByText('Register version')).toBeInTheDocument(); + expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); + expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); + expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); + expect(screen.getByRole('tab', { name: 'Versions' })).toHaveClass('euiTab-isSelected'); + expect(within(screen.getByRole('tabpanel')).getByText('Versions')).toBeInTheDocument(); + }); + + it('should display consistent tabs content after tab clicked', async () => { + setup(); + + await waitFor(() => { + expect(screen.queryByTestId('model-group-loading-indicator')).toBeNull(); + }); + expect(screen.getByRole('tab', { name: 'Versions' })).toHaveClass('euiTab-isSelected'); + expect(within(screen.getByRole('tabpanel')).getByText('Versions')).toBeInTheDocument(); + + await userEvent.click(screen.getByRole('tab', { name: 'Details' })); + expect(screen.getByRole('tab', { name: 'Details' })).toHaveClass('euiTab-isSelected'); + expect(within(screen.getByRole('tabpanel')).getByText('Details')).toBeInTheDocument(); + + await userEvent.click(screen.getByRole('tab', { name: 'Tags' })); + expect(screen.getByRole('tab', { name: 'Tags' })).toHaveClass('euiTab-isSelected'); + expect(within(screen.getByRole('tabpanel')).getByText('Tags')).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_group/__tests__/model_group_overview_card.test.tsx b/public/components/model_group/__tests__/model_group_overview_card.test.tsx new file mode 100644 index 00000000..addc35ab --- /dev/null +++ b/public/components/model_group/__tests__/model_group_overview_card.test.tsx @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { render, screen, within } from '../../../../test/test_utils'; +import { ModelGroupOverviewCard } from '../model_group_overview_card'; + +describe('', () => { + it('should model group overview information according passed data', () => { + render( + + ); + + expect(screen.getByText('Model description of model 1')).toBeInTheDocument(); + expect(screen.getByText('Foo (you)')).toBeInTheDocument(); + expect(screen.getByText('Created')).toBeInTheDocument(); + expect( + within(screen.getByText('Created').closest('dl')!).getByText('Apr 24, 2023 8:18 AM') + ).toBeInTheDocument(); + expect(screen.getByText('Last updated')).toBeInTheDocument(); + expect( + within(screen.getByText('Last updated').closest('dl')!).getByText('Apr 24, 2023 1:18 PM') + ).toBeInTheDocument(); + + expect(screen.getByText('model-1-id')).toBeInTheDocument(); + expect( + within(screen.getByText('model-1-id')).getByTestId('copy-id-button') + ).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_group/model_group.tsx b/public/components/model_group/model_group.tsx index 80f30f61..b1a4505f 100644 --- a/public/components/model_group/model_group.tsx +++ b/public/components/model_group/model_group.tsx @@ -3,24 +3,72 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { EuiButton, EuiLoadingSpinner, EuiPageHeader, EuiPanel, EuiText } from '@elastic/eui'; -import React from 'react'; +import { + EuiButton, + EuiLoadingSpinner, + EuiPageHeader, + EuiSpacer, + EuiTabbedContent, + EuiTabbedContentTab, + EuiText, +} from '@elastic/eui'; +import React, { useState, useMemo } from 'react'; import { useParams } from 'react-router-dom'; import { useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; +import { ModelGroupOverviewCard } from './model_group_overview_card'; +import { ModelGroupVersionsPanel } from './model_group_versions_panel'; +import { ModelGroupDetailsPanel } from './model_group_details_panel'; +import { ModelGroupTagsPanel } from './model_group_tags_panel'; export const ModelGroup = () => { const { id: modelId } = useParams<{ id: string }>(); const { data, loading, error } = useFetcher(APIProvider.getAPI('model').getOne, modelId); + const tabs = useMemo( + () => [ + { + name: 'Versions', + id: 'versions', + content: ( + <> + + + + ), + }, + { + name: 'Details', + id: 'details', + content: ( + <> + + + + ), + }, + { + name: 'Tags', + id: 'tags', + content: ( + <> + + + + ), + }, + ], + [] + ); + const [selectedTab, setSelectedTab] = useState(tabs[0]); if (loading) { // TODO: need to update per design - return ; + return ; } - if (error) { + if (error || !data) { // TODO: need to update per design - return 'Error happened while loading the model'; + return <>Error happened while loading the model; } return ( @@ -28,20 +76,24 @@ export const ModelGroup = () => { -

    {data?.name}

    +

    {data.name}

    } rightSideItems={[ Register version, - Edit, Delete, ]} /> - - -

    Versions

    -
    -
    + + + ); }; diff --git a/public/components/model_group/model_group_details_panel.tsx b/public/components/model_group/model_group_details_panel.tsx new file mode 100644 index 00000000..6484f034 --- /dev/null +++ b/public/components/model_group/model_group_details_panel.tsx @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiHorizontalRule, EuiPanel, EuiSpacer, EuiTitle } from '@elastic/eui'; + +export const ModelGroupDetailsPanel = () => { + return ( + + +

    Details

    +
    + + +
    + ); +}; diff --git a/public/components/model_group/model_group_overview_card.tsx b/public/components/model_group/model_group_overview_card.tsx new file mode 100644 index 00000000..48b054e8 --- /dev/null +++ b/public/components/model_group/model_group_overview_card.tsx @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EuiDescriptionList, EuiFlexGroup, EuiFlexItem, EuiPanel, EuiSpacer } from '@elastic/eui'; +import React from 'react'; +import { CopyableText } from '../common'; +import { renderTime } from '../../utils'; + +interface ModelGroupOverviewCardProps { + id: string; + description?: string; + owner: string; + isModelOwner: boolean; + createdTime: number; + updatedTime: number; +} + +export const ModelGroupOverviewCard = ({ + id, + owner, + createdTime, + updatedTime, + description, + isModelOwner, +}: ModelGroupOverviewCardProps) => { + return ( + + + + + + + + + + + + + + + + , + }, + ]} + /> + + + + + ); +}; diff --git a/public/components/model_group/model_group_tags_panel.tsx b/public/components/model_group/model_group_tags_panel.tsx new file mode 100644 index 00000000..735558ad --- /dev/null +++ b/public/components/model_group/model_group_tags_panel.tsx @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiHorizontalRule, EuiPanel, EuiSpacer, EuiTitle } from '@elastic/eui'; + +export const ModelGroupTagsPanel = () => { + return ( + + +

    Tags

    +
    + + +
    + ); +}; diff --git a/public/components/model_group/model_group_versions_panel.tsx b/public/components/model_group/model_group_versions_panel.tsx new file mode 100644 index 00000000..1cb3f0f4 --- /dev/null +++ b/public/components/model_group/model_group_versions_panel.tsx @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiPanel, EuiTitle } from '@elastic/eui'; + +export const ModelGroupVersionsPanel = () => { + return ( + + +

    Versions

    +
    +
    + ); +}; From a9c1fd63331e9ac1b11d697b2b01c9e3c7a1f750 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Thu, 27 Apr 2023 22:30:10 +0800 Subject: [PATCH 41/75] Feature/update global breadcrumbs (#164) * feat: change property to onBreadcrumbsChange in GlobalBreadcrumbs Signed-off-by: Lin Wang * feat: add async load breadcrumbs for model-registry Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- common/index.ts | 1 + common/router.ts | 2 +- .../__tests__/global_breadcrumbs.test.tsx | 155 ++++++++++++++++ public/components/app.tsx | 4 +- public/components/global_breadcrumbs.tsx | 166 ++++++++++++++++-- test/test_utils.tsx | 2 +- 6 files changed, 315 insertions(+), 15 deletions(-) create mode 100644 public/components/__tests__/global_breadcrumbs.test.tsx diff --git a/common/index.ts b/common/index.ts index 40f8005c..c06353fe 100644 --- a/common/index.ts +++ b/common/index.ts @@ -10,3 +10,4 @@ export const PLUGIN_DESC = `ML Commons for OpenSearch eases the development of m export * from './constant'; export * from './status'; export * from './model'; +export * from './router_paths'; diff --git a/common/router.ts b/common/router.ts index 1b3c5243..0022f445 100644 --- a/common/router.ts +++ b/common/router.ts @@ -36,7 +36,7 @@ export const ROUTES: RouteConfig[] = [ }, { path: routerPaths.modelList, - label: 'Model List', + label: 'Model Registry', Component: ModelList, nav: true, }, diff --git a/public/components/__tests__/global_breadcrumbs.test.tsx b/public/components/__tests__/global_breadcrumbs.test.tsx new file mode 100644 index 00000000..ef5cfff9 --- /dev/null +++ b/public/components/__tests__/global_breadcrumbs.test.tsx @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { GlobalBreadcrumbs } from '../global_breadcrumbs'; +import { history, render, waitFor, act } from '../../../test/test_utils'; +import { Model, ModelDetail } from '../../apis/model'; + +describe('', () => { + it('should call onBreadcrumbsChange with overview title', () => { + const onBreadcrumbsChange = jest.fn(); + render(, { + route: '/overview', + }); + + expect(onBreadcrumbsChange).toHaveBeenCalledWith([ + { text: 'Machine Learning', href: '/foo' }, + { text: 'Overview' }, + ]); + }); + + it('should call onBreadcrumbsChange with register model breadcrumbs', () => { + const onBreadcrumbsChange = jest.fn(); + render(, { + route: '/model-registry/register-model', + }); + + expect(onBreadcrumbsChange).toHaveBeenCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + { text: 'Register model' }, + ]); + }); + + it('should call onBreadcrumbsChange with register version breadcrumbs', async () => { + const onBreadcrumbsChange = jest.fn(); + render(, { + route: '/model-registry/register-model/1', + }); + + expect(onBreadcrumbsChange).toHaveBeenCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + ]); + + await waitFor(() => { + expect(onBreadcrumbsChange).toBeCalledTimes(2); + expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + { text: 'model1', href: '/model-registry/model/1' }, + { text: 'Register version' }, + ]); + }); + }); + + it('should call onBreadcrumbsChange with model group breadcrumbs', async () => { + const onBreadcrumbsChange = jest.fn(); + render(, { + route: '/model-registry/model/1', + }); + + expect(onBreadcrumbsChange).toHaveBeenCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + ]); + + await waitFor(() => { + expect(onBreadcrumbsChange).toBeCalledTimes(2); + expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + { text: 'model1' }, + ]); + }); + }); + + it('should call onBreadcrumbsChange with model version breadcrumbs', async () => { + const onBreadcrumbsChange = jest.fn(); + render(, { + route: '/model-registry/model-version/1', + }); + + expect(onBreadcrumbsChange).toHaveBeenCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + ]); + + await waitFor(() => { + expect(onBreadcrumbsChange).toBeCalledTimes(2); + expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + { text: 'model1', href: '/model-registry/model/1' }, + { text: 'Version 1.0.0' }, + ]); + }); + }); + + it('should NOT call onBreadcrumbs with steal breadcrumbs after pathname changed', async () => { + jest.useFakeTimers(); + const onBreadcrumbsChange = jest.fn(); + const modelGetOneMock = jest.spyOn(Model.prototype, 'getOne').mockImplementation( + (id) => + new Promise((resolve) => { + setTimeout( + () => { + resolve({ + id, + name: `model${id}`, + model_version: `1.0.${id}`, + } as ModelDetail); + }, + id === '2' ? 1000 : 0 + ); + }) + ); + render(, { + route: '/model-registry/model-version/2', + }); + + expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + ]); + + history.current.push('/model-registry/model/1'); + + await act(async () => { + jest.advanceTimersByTime(200); + }); + + expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + { text: 'model1' }, + ]); + + await act(async () => { + jest.advanceTimersByTime(1000); + }); + + expect(onBreadcrumbsChange).not.toHaveBeenLastCalledWith([ + { text: 'Machine Learning', href: '/' }, + { text: 'Model Registry', href: '/model-registry/model-list' }, + { text: 'model2', href: '/model-registry/model/2' }, + { text: 'Version 1.0.2' }, + ]); + + modelGetOneMock.mockRestore(); + jest.useRealTimers(); + }); +}); diff --git a/public/components/app.tsx b/public/components/app.tsx index dc6df48b..5e8e4f76 100644 --- a/public/components/app.tsx +++ b/public/components/app.tsx @@ -103,7 +103,9 @@ export const MlCommonsPluginApp = ({ {/* Breadcrumbs will contains dynamic content in new page header, should be provided by each page self*/} - {!useNewPageHeader && } + {!useNewPageHeader && ( + + )} {dataSourceEnabled && ( { - const breadcrumbs: ChromeBreadcrumb[] = [{ text: PLUGIN_NAME, href: basename }]; - const matchedRoute = ROUTES.find((route) => - matchPath(pathname, { path: route.path, exact: route.exact }) - ); +type RouteConfig = typeof ROUTES[number]; + +const joinUrl = (basename: string, pathname: string) => + `${basename.endsWith('/') ? basename.slice(0, -1) : basename}${pathname}`; + +const getBasicBreadcrumbs = (basename: string): ChromeBreadcrumb[] => { + return [{ text: PLUGIN_NAME, href: basename }]; +}; + +const getRouteMatchedBreadcrumbs = (basename: string, matchedRoute: RouteConfig | undefined) => { + const breadcrumbs: ChromeBreadcrumb[] = getBasicBreadcrumbs(basename); if (!matchedRoute?.label) { return breadcrumbs; } @@ -23,18 +30,153 @@ const getBreadcrumbs = (pathname: string, basename: string) => { }); }; +const getBasicModelRegistryBreadcrumbs = (basename: string) => { + const breadcrumbs = getRouteMatchedBreadcrumbs( + basename, + ROUTES.find((item) => item.path === routerPaths.modelList) + ); + breadcrumbs[breadcrumbs.length - 1].href = joinUrl(basename, routerPaths.modelList); + return breadcrumbs; +}; + +const getModelRegisterBreadcrumbs = (basename: string, matchedParams: {}) => { + const baseModelRegistryBreadcrumbs = getBasicModelRegistryBreadcrumbs(basename); + if ('id' in matchedParams && typeof matchedParams.id === 'string') { + const modelId = matchedParams.id; + return { + staticBreadcrumbs: baseModelRegistryBreadcrumbs, + // TODO: Change to model group API + asyncBreadcrumbsLoader: () => + APIProvider.getAPI('model') + .getOne(modelId) + .then( + (model) => + [ + { + text: model.name, + href: joinUrl(basename, generatePath(routerPaths.modelGroup, { id: modelId })), + }, + { + text: 'Register version', + }, + ] as ChromeBreadcrumb[] + ), + }; + } + return { + staticBreadcrumbs: [ + ...baseModelRegistryBreadcrumbs, + { + text: 'Register model', + }, + ], + }; +}; + +const getModelGroupBreadcrumbs = (basename: string, matchedParams: {}) => { + const baseModelRegistryBreadcrumbs = getBasicModelRegistryBreadcrumbs(basename); + if ('id' in matchedParams && typeof matchedParams.id === 'string') { + const modelId = matchedParams.id; + return { + staticBreadcrumbs: baseModelRegistryBreadcrumbs, + // TODO: Change to model group API + asyncBreadcrumbsLoader: () => { + return APIProvider.getAPI('model') + .getOne(modelId) + .then( + (model) => + [ + { + text: model.name, + }, + ] as ChromeBreadcrumb[] + ); + }, + }; + } + return { + staticBreadcrumbs: baseModelRegistryBreadcrumbs, + }; +}; + +const getModelVersionBreadcrumbs = (basename: string, matchedParams: {}) => { + const baseModelRegistryBreadcrumbs = getBasicModelRegistryBreadcrumbs(basename); + if ('id' in matchedParams && typeof matchedParams.id === 'string') { + const modelId = matchedParams.id; + return { + staticBreadcrumbs: baseModelRegistryBreadcrumbs, + // TODO: Change to model group API + asyncBreadcrumbsLoader: () => + APIProvider.getAPI('model') + .getOne(modelId) + .then( + (model) => + [ + { + text: model.name, + // TODO: Change to use model group id + href: joinUrl(basename, generatePath(routerPaths.modelGroup, { id: modelId })), + }, + { + text: `Version ${model.model_version}`, + }, + ] as ChromeBreadcrumb[] + ), + }; + } + return { + staticBreadcrumbs: baseModelRegistryBreadcrumbs, + }; +}; + +const routerPathBreadcrumbsMap = { + [routerPaths.registerModel]: getModelRegisterBreadcrumbs, + [routerPaths.modelGroup]: getModelGroupBreadcrumbs, + [routerPaths.modelVersion]: getModelVersionBreadcrumbs, +}; + export const GlobalBreadcrumbs = ({ - chrome, + onBreadcrumbsChange, basename, }: { - chrome: CoreStart['chrome']; + onBreadcrumbsChange: CoreStart['chrome']['setBreadcrumbs']; basename: string; }) => { const location = useLocation(); - const { setBreadcrumbs } = chrome; useEffect(() => { - setBreadcrumbs(getBreadcrumbs(location.pathname, basename)); - }, [location.pathname, setBreadcrumbs, basename]); + let matchedRoute: typeof ROUTES[number] | undefined; + let matchedParams = {}; + for (let i = 0; i < ROUTES.length; i++) { + const route = ROUTES[i]; + const matchedResult = matchPath(location.pathname, { path: route.path, exact: route.exact }); + if (matchedResult) { + matchedParams = matchedResult.params; + matchedRoute = route; + break; + } + } + + if (!matchedRoute || !(matchedRoute.path in routerPathBreadcrumbsMap)) { + onBreadcrumbsChange(getRouteMatchedBreadcrumbs(basename, matchedRoute)); + return; + } + + let changed = false; + const { staticBreadcrumbs, asyncBreadcrumbsLoader } = routerPathBreadcrumbsMap[ + matchedRoute.path + ](basename, matchedParams); + + onBreadcrumbsChange(staticBreadcrumbs); + if (asyncBreadcrumbsLoader) + asyncBreadcrumbsLoader().then((asyncBreadcrumbs) => { + if (!changed) { + onBreadcrumbsChange([...staticBreadcrumbs, ...asyncBreadcrumbs]); + } + }); + return () => { + changed = true; + }; + }, [location.pathname, onBreadcrumbsChange, basename]); return null; }; diff --git a/test/test_utils.tsx b/test/test_utils.tsx index 95c08ff4..7950ae60 100644 --- a/test/test_utils.tsx +++ b/test/test_utils.tsx @@ -14,7 +14,7 @@ export interface RenderWithRouteProps { route: string; } -const history = { +export const history = { current: createBrowserHistory(), }; From 067a7beb9bca74ef6d0ac9428588f1b2ac2cd52a Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Sat, 6 May 2023 09:58:22 +0800 Subject: [PATCH 42/75] Feature/add model loading empty failed screen for model list (#165) * feat: add model list description Signed-off-by: Lin Wang * feat: update register model button in model list Signed-off-by: Lin Wang * feat: remove uploading banner in model list Signed-off-by: Lin Wang * feat: set loading true after stringify params changed in use fetcher hook Signed-off-by: Lin Wang * feat: add model list empty panel Signed-off-by: Lin Wang * feat: avoid update loading status after params changed Signed-off-by: Lin Wang * feat: add empty, loading, error and no result screen to model list Signed-off-by: Lin Wang * feat: update button props to register_new_model_button Signed-off-by: Lin Wang * feat: change to use register_new_model_button button in model_list_empty Signed-off-by: Lin Wang * refactor: move monitoring search_bar to common debounced_search_bar Signed-off-by: Lin Wang * feat: change to debounced search bar in model list Signed-off-by: Lin Wang * chore: address PR comments Signed-off-by: Lin Wang * fix: loading and error screen not show if models provided Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../__tests__/doubounced_search_bar.test.tsx} | 12 +- .../debounced_search_bar.tsx} | 19 +++- public/components/common/index.ts | 1 + .../model_list/__tests__/model_list.test.tsx | 80 +++++++++++++ .../__tests__/model_list_empty.test.tsx | 21 ++++ .../__tests__/model_list_filter.test.tsx | 4 +- .../model_list/__tests__/model_table.test.tsx | 77 ++++++++++++- public/components/model_list/index.tsx | 107 ++++++++++++------ .../model_list/model_list_empty.tsx | 32 ++++++ .../model_list/model_list_filter.tsx | 12 +- public/components/model_list/model_table.tsx | 85 ++++++++++++-- ...tton.tsx => register_new_model_button.tsx} | 13 ++- public/components/monitoring/index.tsx | 9 +- ...e_fetcher.test.ts => use_fetcher.test.tsx} | 81 ++++++++++++- public/hooks/use_fetcher.ts | 17 ++- test/mocks/data/model_aggregate.ts | 18 +++ test/mocks/handlers.ts | 5 + 17 files changed, 512 insertions(+), 81 deletions(-) rename public/components/{monitoring/__tests__/search_bar.test.tsx => common/__tests__/doubounced_search_bar.test.tsx} (58%) rename public/components/{monitoring/search_bar.tsx => common/debounced_search_bar.tsx} (56%) create mode 100644 public/components/model_list/__tests__/model_list.test.tsx create mode 100644 public/components/model_list/__tests__/model_list_empty.test.tsx create mode 100644 public/components/model_list/model_list_empty.tsx rename public/components/model_list/{regsister_new_model_button.tsx => register_new_model_button.tsx} (59%) rename public/hooks/tests/{use_fetcher.test.ts => use_fetcher.test.tsx} (66%) create mode 100644 test/mocks/data/model_aggregate.ts diff --git a/public/components/monitoring/__tests__/search_bar.test.tsx b/public/components/common/__tests__/doubounced_search_bar.test.tsx similarity index 58% rename from public/components/monitoring/__tests__/search_bar.test.tsx rename to public/components/common/__tests__/doubounced_search_bar.test.tsx index a7549a56..7b76711e 100644 --- a/public/components/monitoring/__tests__/search_bar.test.tsx +++ b/public/components/common/__tests__/doubounced_search_bar.test.tsx @@ -5,13 +5,13 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; -import { SearchBar } from '../search_bar'; +import { DebouncedSearchBar } from '../debounced_search_bar'; import { render, screen } from '../../../../test/test_utils'; -describe('', () => { +describe('', () => { it('should render default search bar', () => { - render(); - expect(screen.getByPlaceholderText('Search by model name or ID')).toBeInTheDocument(); + render(); + expect(screen.getByPlaceholderText('Search by name or ID')).toBeInTheDocument(); }); it('should call onSearch with 400ms debounce', async () => { @@ -19,9 +19,9 @@ describe('', () => { jest.useFakeTimers(); const onSearch = jest.fn(); - render(); + render(); - await user.type(screen.getByPlaceholderText('Search by model name or ID'), 'foo'); + await user.type(screen.getByPlaceholderText('Search by name or ID'), 'foo'); expect(onSearch).not.toHaveBeenCalled(); jest.advanceTimersByTime(400); expect(onSearch).toHaveBeenCalled(); diff --git a/public/components/monitoring/search_bar.tsx b/public/components/common/debounced_search_bar.tsx similarity index 56% rename from public/components/monitoring/search_bar.tsx rename to public/components/common/debounced_search_bar.tsx index 39eec9d6..113c901f 100644 --- a/public/components/monitoring/search_bar.tsx +++ b/public/components/common/debounced_search_bar.tsx @@ -3,14 +3,22 @@ * SPDX-License-Identifier: Apache-2.0 */ import React, { useMemo, useCallback } from 'react'; -import { EuiCompressedFieldSearch } from '@elastic/eui'; +import { EuiCompressedFieldSearch, EuiCompressedFieldSearchProps } from '@elastic/eui'; import { debounce } from 'lodash'; -interface SearchBarProps { + +interface DebouncedSearchBarProps + extends Pick { onSearch: (value: string) => void; + debounceMs?: number; inputRef?: (node: HTMLInputElement | null) => void; } -export const SearchBar = ({ onSearch, inputRef }: SearchBarProps) => { - const onSearchDebounce = useMemo(() => debounce(onSearch, 400), [onSearch]); +export const DebouncedSearchBar = ({ + onSearch, + inputRef, + debounceMs = 400, + ...resetProps +}: DebouncedSearchBarProps) => { + const onSearchDebounce = useMemo(() => debounce(onSearch, debounceMs), [onSearch, debounceMs]); const onChange = useCallback( (e: React.ChangeEvent) => { @@ -23,12 +31,11 @@ export const SearchBar = ({ onSearch, inputRef }: SearchBarProps) => { ); }; diff --git a/public/components/common/index.ts b/public/components/common/index.ts index 5040dd86..b9493677 100644 --- a/public/components/common/index.ts +++ b/public/components/common/index.ts @@ -7,3 +7,4 @@ export * from './custom'; export * from './copyable_text'; export * from './tag_filter_popover_content'; export * from './selected_tag_filter_panel'; +export * from './debounced_search_bar'; diff --git a/public/components/model_list/__tests__/model_list.test.tsx b/public/components/model_list/__tests__/model_list.test.tsx new file mode 100644 index 00000000..bb1858fe --- /dev/null +++ b/public/components/model_list/__tests__/model_list.test.tsx @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { ModelAggregate } from '../../../apis/model_aggregate'; +import { render, screen, waitFor, within } from '../../../../test/test_utils'; + +import { ModelList } from '../index'; + +const setup = () => { + const notificationsMock = { + toasts: { + get$: jest.fn(), + add: jest.fn(), + remove: jest.fn(), + addSuccess: jest.fn(), + addWarning: jest.fn(), + addDanger: jest.fn(), + addError: jest.fn(), + addInfo: jest.fn(), + }, + }; + const renderResult = render(); + return { + renderResult, + notificationsMock, + }; +}; + +describe('', () => { + it('should show empty screen if no models in system', async () => { + const modelAggregateMock = jest + .spyOn(ModelAggregate.prototype, 'search') + .mockImplementation(() => + Promise.resolve({ + data: [], + pagination: { + currentPage: 1, + pageSize: 15, + totalPages: 0, + totalRecords: 0, + }, + }) + ); + + setup(); + + await waitFor(() => { + expect(screen.getByText('Registered models will appear here.')); + }); + + modelAggregateMock.mockRestore(); + }); + + it('should show model total count and model table after model data loaded', async () => { + setup(); + + await waitFor(() => { + expect(within(screen.getByTestId('modelTotalCount')).getByText('(1)')).toBeInTheDocument(); + expect( + screen.getByText('traced_small_model').closest('.euiTableRowCell') + ).toBeInTheDocument(); + expect(screen.getByText('1.0.5').closest('.euiTableRowCell')).toBeInTheDocument(); + }); + }); + + it('should render model list filter by default', () => { + setup(); + + expect(screen.getByPlaceholderText('Search by name, person, or keyword')).toBeInTheDocument(); + expect( + screen.getByText( + (text, node) => text === 'Owner' && !!node?.className.includes('euiFilterButton__textShift') + ) + ).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_list/__tests__/model_list_empty.test.tsx b/public/components/model_list/__tests__/model_list_empty.test.tsx new file mode 100644 index 00000000..0673451d --- /dev/null +++ b/public/components/model_list/__tests__/model_list_empty.test.tsx @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { render, screen } from '../../../../test/test_utils'; + +import { ModelListEmpty } from '../model_list_empty'; + +describe('', () => { + it('should show tips, register model and Read documentation by default', () => { + render(); + + expect(screen.getByText('Registered models will appear here.')); + expect(screen.getByText('Register model')); + expect(screen.getByText('Read documentation')); + expect(screen.getByText('Read documentation')).toHaveProperty('href', 'http://localhost/todo'); + }); +}); diff --git a/public/components/model_list/__tests__/model_list_filter.test.tsx b/public/components/model_list/__tests__/model_list_filter.test.tsx index 2ab87eb8..537757a2 100644 --- a/public/components/model_list/__tests__/model_list_filter.test.tsx +++ b/public/components/model_list/__tests__/model_list_filter.test.tsx @@ -20,10 +20,9 @@ describe('', () => { expect(screen.getByText('Undeployed')).toBeInTheDocument(); }); - it('should render default search value, filter value and selected tags panel', () => { + it('should render filter value and selected tags panel', () => { render( ', () => { onChange={() => {}} /> ); - expect(screen.getByDisplayValue('foo')).toBeInTheDocument(); expect(screen.queryAllByText('1')).toHaveLength(1); expect(screen.getByText('Deployed')).not.toHaveClass('euiFilterButton-hasActiveFilters'); expect(screen.getByText('Undeployed')).not.toHaveClass('euiFilterButton-hasActiveFilters'); diff --git a/public/components/model_list/__tests__/model_table.test.tsx b/public/components/model_list/__tests__/model_table.test.tsx index c2bb7e80..f0afb550 100644 --- a/public/components/model_list/__tests__/model_table.test.tsx +++ b/public/components/model_list/__tests__/model_table.test.tsx @@ -7,7 +7,7 @@ import React from 'react'; import moment from 'moment'; import userEvent from '@testing-library/user-event'; -import { ModelTable } from '../model_table'; +import { ModelTable, ModelTableProps } from '../model_table'; import { render, screen, within } from '../../../../test/test_utils'; import { MODEL_STATE } from '../../../../common/model'; @@ -32,9 +32,10 @@ const tableData = [ }, ]; -const setup = () => { +const setup = (options?: Partial) => { const onChangeMock = jest.fn(); const onModelNameClickMock = jest.fn(); + const onResetClickMock = jest.fn(); const renderResult = render( { onChange={onChangeMock} onModelNameClick={onModelNameClickMock} pagination={{ currentPage: 1, pageSize: 15, totalRecords: 300 }} + loading={false} + error={false} + onResetClick={onResetClickMock} + {...options} /> ); return { renderResult, onChangeMock, onModelNameClickMock, + onResetClickMock, }; }; @@ -152,4 +158,71 @@ describe('', () => { await userEvent.click(renderResult.getByText('model1')); expect(onModelNameClickMock).toHaveBeenCalledWith('model1'); }); + + it('should show loading screen if property loading equal true', () => { + setup({ + loading: true, + error: false, + models: [], + }); + + expect(screen.getByText('Loading models')).toBeInTheDocument(); + }); + + it('should show error screen if property error equal true', () => { + setup({ + loading: false, + error: true, + models: [], + }); + + expect(screen.getByText('Failed to load models')).toBeInTheDocument(); + }); + + it('should show no result screen if load empty data', () => { + setup({ + loading: false, + error: false, + models: [], + }); + + expect(screen.getByText('Reset search and filters')).toBeInTheDocument(); + expect( + screen.getByText( + 'There are no results for your search. Reset the search criteria to view registered models.' + ) + ).toBeInTheDocument(); + }); + + it('should show loading screen even models provided', () => { + setup({ + loading: true, + error: false, + models: tableData, + }); + + expect(screen.getByText('Loading models')).toBeInTheDocument(); + }); + + it('should show error screen even models provided', () => { + setup({ + loading: false, + error: true, + models: tableData, + }); + + expect(screen.getByText('Failed to load models')).toBeInTheDocument(); + }); + + it('should call onRestClick after reset button clicked', async () => { + const { onResetClickMock } = setup({ + loading: false, + error: false, + models: [], + }); + + expect(onResetClickMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Reset search and filters')); + expect(onResetClickMock).toHaveBeenCalled(); + }); }); diff --git a/public/components/model_list/index.tsx b/public/components/model_list/index.tsx index f8d0d17f..5ed97120 100644 --- a/public/components/model_list/index.tsx +++ b/public/components/model_list/index.tsx @@ -3,19 +3,19 @@ * SPDX-License-Identifier: Apache-2.0 */ import React, { useState, useCallback, useMemo, useRef } from 'react'; -import { EuiPageHeader, EuiSpacer, EuiPanel } from '@elastic/eui'; +import { EuiPageHeader, EuiSpacer, EuiPanel, EuiTextColor } from '@elastic/eui'; import { CoreStart } from '../../../../../src/core/public'; import { APIProvider } from '../../apis/api_provider'; import { useFetcher } from '../../hooks/use_fetcher'; import { ModelDrawer } from '../model_drawer'; -import { ModelTable, ModelTableSort } from './model_table'; +import { ModelTable, ModelTableCriteria, ModelTableSort } from './model_table'; import { ModelListFilter, ModelListFilterFilterValue } from './model_list_filter'; -import { RegisterNewModelButton } from './regsister_new_model_button'; +import { RegisterNewModelButton } from './register_new_model_button'; import { ModelConfirmDeleteModal, ModelConfirmDeleteModalInstance, } from './model_confirm_delete_modal'; -import { UploadCallout } from './upload_callout'; +import { ModelListEmpty } from './model_list_empty'; export const ModelList = ({ notifications }: { notifications: CoreStart['notifications'] }) => { const confirmModelDeleteRef = useRef(null); @@ -27,12 +27,17 @@ export const ModelList = ({ notifications }: { notifications: CoreStart['notific }>({ currentPage: 1, pageSize: 15, - filterValue: { tag: [], owner: [], stage: [] }, + filterValue: { tag: [], owner: [] }, sort: { field: 'created_time', direction: 'desc' }, }); const [drawerModelName, setDrawerModelName] = useState(''); + const searchInputRef = useRef(); - const { data, reload } = useFetcher(APIProvider.getAPI('modelAggregate').search, { + const setSearchInputRef = useCallback((node: HTMLInputElement | null) => { + searchInputRef.current = node; + }, []); + + const { data, reload, loading, error } = useFetcher(APIProvider.getAPI('modelAggregate').search, { from: Math.max(0, (params.currentPage - 1) * params.pageSize), size: params.pageSize, sort: params.sort?.field, @@ -46,10 +51,11 @@ export const ModelList = ({ notifications }: { notifications: CoreStart['notific () => ({ currentPage: params.currentPage, pageSize: params.pageSize, - totalRecords: totalModelCounts, + totalRecords: totalModelCounts || 0, }), [totalModelCounts, params.currentPage, params.pageSize] ); + const showEmptyScreen = !loading && totalModelCounts === 0 && !params.filterValue.search; const handleModelDeleted = useCallback(async () => { reload(); @@ -64,50 +70,83 @@ export const ModelList = ({ notifications }: { notifications: CoreStart['notific setDrawerModelName(name); }, []); - const handleTableChange = useCallback((criteria) => { - const { - pagination: { currentPage, pageSize }, - sort, - } = criteria; + const handleTableChange = useCallback((criteria: ModelTableCriteria) => { + const { pagination: newPagination, sort } = criteria; setParams((previousValue) => { - if ( - currentPage === previousValue.currentPage && - pageSize === previousValue.pageSize && - (!sort || sort === previousValue.sort) - ) { + const criteriaConsistent = + newPagination?.currentPage === previousValue.currentPage && + newPagination?.pageSize === previousValue.pageSize && + (!sort || sort === previousValue.sort); + + if (criteriaConsistent) { return previousValue; } return { ...previousValue, - currentPage, - pageSize, + ...(newPagination + ? { currentPage: newPagination.currentPage, pageSize: newPagination.pageSize } + : {}), ...(sort ? { sort } : {}), }; }); }, []); + const handleReset = useCallback(() => { + setParams((prevParams) => ({ + ...prevParams, + filterValue: { tag: [], owner: [] }, + })); + if (searchInputRef.current) { + searchInputRef.current.value = ''; + } + }, [setParams]); + const handleFilterChange = useCallback((filterValue: ModelListFilterFilterValue) => { setParams((prevValue) => ({ ...prevValue, filterValue, currentPage: 1 })); }, []); + return ( - Models} rightSideItems={[]} /> - - - - - - + Models  + {typeof totalModelCounts === 'number' && ( + + ({totalModelCounts}) + + )} + + } + description="Discover, manage, and track machine learning models across your organization." + rightSideItems={[]} /> - - {drawerModelName && ( - setDrawerModelName('')} name={drawerModelName} /> + + {!showEmptyScreen && ( + <> + + + + + {drawerModelName && ( + setDrawerModelName('')} name={drawerModelName} /> + )} + )} + {showEmptyScreen && } ); }; diff --git a/public/components/model_list/model_list_empty.tsx b/public/components/model_list/model_list_empty.tsx new file mode 100644 index 00000000..e0512d4b --- /dev/null +++ b/public/components/model_list/model_list_empty.tsx @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EuiEmptyPrompt, EuiLink, EuiSpacer, EuiButtonEmpty } from '@elastic/eui'; +import React from 'react'; + +import { RegisterNewModelButton } from './register_new_model_button'; + +export const ModelListEmpty = () => { + return ( +
    + } + body="Registered models will appear here." + actions={ + <> + + + + + + Read documentation + + + + } + /> +
    + ); +}; diff --git a/public/components/model_list/model_list_filter.tsx b/public/components/model_list/model_list_filter.tsx index 2769777b..0e8616d6 100644 --- a/public/components/model_list/model_list_filter.tsx +++ b/public/components/model_list/model_list_filter.tsx @@ -6,14 +6,13 @@ import { EuiFlexItem, EuiFlexGroup, - EuiFieldSearch, EuiFilterGroup, EuiFilterButton, EuiSpacer, } from '@elastic/eui'; import React, { useCallback, useRef } from 'react'; -import { TagFilterValue, SelectedTagFiltersPanel } from '../common'; +import { TagFilterValue, SelectedTagFiltersPanel, DebouncedSearchBar } from '../common'; import { TagFilter } from './tag_filter'; import { OwnerFilter } from './owner_filter'; @@ -44,9 +43,9 @@ export interface ModelListFilterFilterValue { export const ModelListFilter = ({ value, onChange, - defaultSearch, + searchInputRef, }: { - defaultSearch?: string; + searchInputRef?: (node: HTMLInputElement | null) => void; value: Omit; onChange: (value: ModelListFilterFilterValue) => void; }) => { @@ -87,11 +86,10 @@ export const ModelListFilter = ({ <> - diff --git a/public/components/model_list/model_table.tsx b/public/components/model_list/model_table.tsx index 16fee0bc..3023214e 100644 --- a/public/components/model_list/model_table.tsx +++ b/public/components/model_list/model_table.tsx @@ -5,13 +5,19 @@ import React, { useMemo, useCallback, useRef } from 'react'; import { - CriteriaWithPagination, EuiBasicTable, EuiBasicTableColumn, EuiText, Direction, + EuiLoadingSpinner, + EuiTitle, + EuiEmptyPrompt, + EuiButton, + EuiSpacer, + EuiIcon, } from '@elastic/eui'; +import { Criteria } from '@elastic/eui'; import { renderTime } from '../../utils'; import { ModelOwner } from './model_owner'; import { ModelDeployedVersions } from './model_deployed_versions'; @@ -24,7 +30,7 @@ export interface ModelTableSort { } export interface ModelTableCriteria { - pagination: { currentPage: number; pageSize: number }; + pagination?: { currentPage: number; pageSize: number }; sort?: ModelTableSort; } @@ -38,10 +44,13 @@ export interface ModelTableProps { sort: ModelTableSort; onChange: (criteria: ModelTableCriteria) => void; onModelNameClick: (name: string) => void; + loading: boolean; + error: boolean; + onResetClick: () => void; } export function ModelTable(props: ModelTableProps) { - const { models, sort, onChange, onModelNameClick } = props; + const { models, sort, onChange, onModelNameClick, loading, onResetClick, error } = props; const onChangeRef = useRef(onChange); onChangeRef.current = onChange; @@ -160,11 +169,70 @@ export function ModelTable(props: ModelTableProps) { const sorting = useMemo(() => ({ sort }), [sort]); - const handleChange = useCallback((criteria: CriteriaWithPagination) => { - const newPagination = { currentPage: criteria.page.index + 1, pageSize: criteria.page.size }; + const noItemsMessage = useMemo( + () => ( +
    + {loading && ( + + + + + +

    + Loading models +

    +
    + + } + /> + )} + {!loading && error && ( + + + + + +

    Failed to load models

    +
    + + Check your internet connection + + } + /> + )} + {!loading && !error && ( + + + + There are no results for your search. Reset the search criteria to view registered + models. + + + } + actions={ + <> + Reset search and filters + + + } + /> + )} +
    + ), + [onResetClick, loading, error] + ); + const handleChange = useCallback((criteria: Criteria) => { onChangeRef.current({ - pagination: newPagination, + ...(criteria.page + ? { pagination: { currentPage: criteria.page.index + 1, pageSize: criteria.page.size } } + : {}), ...(criteria.sort ? { sort: criteria.sort as ModelTableSort } : {}), }); }, []); @@ -172,11 +240,12 @@ export function ModelTable(props: ModelTableProps) { return ( columns={columns} - items={models} - pagination={pagination} + items={loading || error ? [] : models} + pagination={models.length > 0 ? pagination : undefined} onChange={handleChange} sorting={sorting} hasActions + noItemsMessage={noItemsMessage} /> ); } diff --git a/public/components/model_list/regsister_new_model_button.tsx b/public/components/model_list/register_new_model_button.tsx similarity index 59% rename from public/components/model_list/regsister_new_model_button.tsx rename to public/components/model_list/register_new_model_button.tsx index ad0fa9d2..10d831a1 100644 --- a/public/components/model_list/regsister_new_model_button.tsx +++ b/public/components/model_list/register_new_model_button.tsx @@ -3,9 +3,14 @@ * SPDX-License-Identifier: Apache-2.0 */ import React, { useState, useCallback } from 'react'; -import { EuiButton } from '@elastic/eui'; +import { EuiButton, EuiButtonProps } from '@elastic/eui'; import { RegisterModelTypeModal } from '../register_model_type_modal'; -export function RegisterNewModelButton() { + +interface RegisterNewModelButtonProps { + buttonProps?: Partial; +} + +export function RegisterNewModelButton({ buttonProps }: RegisterNewModelButtonProps) { const [isModalVisible, setIsModalVisible] = useState(false); const showModal = useCallback(() => { setIsModalVisible(true); @@ -15,7 +20,9 @@ export function RegisterNewModelButton() { }, []); return ( <> - Register new model + + Register model + {isModalVisible && } ); diff --git a/public/components/monitoring/index.tsx b/public/components/monitoring/index.tsx index 798ab021..42995b11 100644 --- a/public/components/monitoring/index.tsx +++ b/public/components/monitoring/index.tsx @@ -16,6 +16,7 @@ import React, { useState, useRef, useCallback } from 'react'; import { FormattedMessage } from '@osd/i18n/react'; import { ModelDeploymentProfile } from '../../apis/profile'; +import { DebouncedSearchBar } from '../common'; import { PreviewPanel } from '../preview_panel'; import { ApplicationStart, ChromeStart } from '../../../../../src/core/public'; import { NavigationPublicPluginStart } from '../../../../../src/plugins/navigation/public'; @@ -23,7 +24,6 @@ import { NavigationPublicPluginStart } from '../../../../../src/plugins/navigati import { ModelDeploymentItem, ModelDeploymentTable } from './model_deployment_table'; import { useMonitoring } from './use_monitoring'; import { ModelStatusFilter } from './model_status_filter'; -import { SearchBar } from './search_bar'; import { ModelSourceFilter } from './model_source_filter'; import { ModelConnectorFilter } from './model_connector_filter'; import { MonitoringPageHeader } from './monitoring_page_header'; @@ -128,7 +128,12 @@ export const Monitoring = (props: MonitoringProps) => { <> - + diff --git a/public/hooks/tests/use_fetcher.test.ts b/public/hooks/tests/use_fetcher.test.tsx similarity index 66% rename from public/hooks/tests/use_fetcher.test.ts rename to public/hooks/tests/use_fetcher.test.tsx index 64e20d30..32c5a7a1 100644 --- a/public/hooks/tests/use_fetcher.test.ts +++ b/public/hooks/tests/use_fetcher.test.tsx @@ -2,15 +2,17 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ +import React from 'react'; import { act, renderHook } from '@testing-library/react-hooks'; import { DO_NOT_FETCH, useFetcher } from '../use_fetcher'; +import { render, waitFor } from '../../../test/test_utils'; describe('useFetcher', () => { it('should call fetcher with consistent params and return consistent result', async () => { const data = { foo: 'bar' }; const fetcher = jest.fn((_arg1: string) => Promise.resolve(data)); - const { result, waitFor } = renderHook(() => useFetcher(fetcher, 'foo')); + const { result } = renderHook(() => useFetcher(fetcher, 'foo')); await waitFor(() => result.current.data !== null); expect(result.current.data).toBe(data); @@ -20,7 +22,7 @@ describe('useFetcher', () => { it('should call fetcher only once if params content not change', async () => { const fetcher = jest.fn((_arg1: any) => Promise.resolve()); - const { result, waitFor, rerender } = renderHook(({ params }) => useFetcher(fetcher, params), { + const { result, rerender } = renderHook(({ params }) => useFetcher(fetcher, params), { initialProps: { params: { foo: 'bar' } }, }); @@ -107,7 +109,7 @@ describe('useFetcher', () => { it('should return consistent updated data', async () => { const fetcher = () => Promise.resolve('foo'); - const { result, waitFor } = renderHook(() => useFetcher(fetcher)); + const { result } = renderHook(() => useFetcher(fetcher)); await waitFor(() => result.current.data === 'foo'); await act(async () => { @@ -120,7 +122,7 @@ describe('useFetcher', () => { it('should return consistent mutated data', async () => { const fetcher = () => Promise.resolve('foo'); - const { result, waitFor } = renderHook(() => useFetcher(fetcher)); + const { result } = renderHook(() => useFetcher(fetcher)); await waitFor(() => result.current.data === 'foo'); @@ -150,7 +152,7 @@ describe('useFetcher', () => { it('should call fetcher after first parameter changed from DO_NOT_FETCH', async () => { const fetcher = jest.fn(async (...params) => params); - const { result, rerender, waitFor } = renderHook( + const { result, rerender, waitFor: hookWaitFor } = renderHook( ({ params }) => useFetcher(fetcher, ...params), { initialProps: { @@ -163,9 +165,76 @@ describe('useFetcher', () => { expect(result.current.loading).toBe(true); expect(fetcher).toHaveBeenCalled(); - await waitFor(() => { + await hookWaitFor(() => { expect(result.current.loading).toBe(false); expect(result.current.data).toEqual([]); }); }); + + it('should return loading true immediately after params change', async () => { + const testLoadingFetcher = (payload: string) => Promise.resolve(payload); + const loadingAndParams: Array<[boolean, string]> = []; + const collectLoadingAndParams = (loading: boolean, params: string) => { + loadingAndParams.push([loading, params]); + }; + + const TestLoading = ({ + params, + onRender, + }: { + params: string; + onRender: (loading: boolean, params: string) => void; + }) => { + const { loading } = useFetcher(testLoadingFetcher, params); + onRender(loading, params); + return <>{loading.toString()}; + }; + + const { getByText, rerender } = render( + + ); + await waitFor(() => { + expect(getByText('false')).toBeInTheDocument(); + }); + rerender(); + await waitFor(() => { + expect(getByText('false')).toBeInTheDocument(); + }); + + expect(loadingAndParams).toEqual([ + [true, 'foo'], // For first rendering + [false, 'foo'], // For first data load complete + [true, 'bar'], // For params modified rendering + [false, 'bar'], // For params modified data load complete + ]); + }); + + it('should return loading true after params changed and not response', async () => { + jest.useFakeTimers(); + + const fetcher = jest.fn( + (params: string) => + new Promise((resolve) => { + setTimeout(() => { + resolve(params); + }, 1000); + }) + ); + const { result, rerender } = renderHook(({ params }) => useFetcher(fetcher, params), { + initialProps: { params: 'foo' }, + }); + expect(result.current.loading).toBe(true); + + await act(async () => { + jest.advanceTimersByTime(500); + rerender({ params: 'bar' }); + }); + + await act(async () => { + jest.advanceTimersByTime(600); + }); + expect(result.current.loading).toBe(true); + + jest.useRealTimers(); + }); }); diff --git a/public/hooks/use_fetcher.ts b/public/hooks/use_fetcher.ts index f3e5d684..55cfb25d 100644 --- a/public/hooks/use_fetcher.ts +++ b/public/hooks/use_fetcher.ts @@ -34,6 +34,7 @@ export const useFetcher = ( const paramsRef = useRef(params); paramsRef.current = params; const paramsKey = isDoNotFetch(params) ? params : JSON.stringify(params); + const lastLoadStringifyParamsRef = useRef(); const forceUpdate = useCallback(() => { setCount((prevCount) => (prevCount === Number.MAX_SAFE_INTEGER ? 0 : prevCount + 1)); @@ -41,8 +42,9 @@ export const useFetcher = ( const loadData = useCallback( async (fetcherParams: TParams, shouldUpdateResult: () => boolean = () => true) => { + const shouldUpdateLoading = loadingRef.current !== true; loadingRef.current = true; - if (usedRef.current.loading) { + if (shouldUpdateLoading && usedRef.current.loading) { forceUpdate(); } let shouldUpdate = false; @@ -61,10 +63,12 @@ export const useFetcher = ( dataRef.current = null; } } finally { - loadingRef.current = false; + if (shouldUpdate) { + loadingRef.current = false; + } if ( - usedRef.current.loading || - (shouldUpdate && (usedRef.current.data || usedRef.current.error)) + shouldUpdate && + (usedRef.current.loading || usedRef.current.data || usedRef.current.error) ) { forceUpdate(); } @@ -99,12 +103,17 @@ export const useFetcher = ( return; } let changed = false; + lastLoadStringifyParamsRef.current = paramsKey; loadData(JSON.parse(paramsKey), () => !changed); return () => { changed = true; }; }, [paramsKey, loadData]); + if (!isDoNotFetch(paramsKey) && paramsKey !== lastLoadStringifyParamsRef.current) { + loadingRef.current = true; + } + return Object.defineProperties( { data: dataRef.current, diff --git a/test/mocks/data/model_aggregate.ts b/test/mocks/data/model_aggregate.ts new file mode 100644 index 00000000..74f617f3 --- /dev/null +++ b/test/mocks/data/model_aggregate.ts @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export const modelAggregateResponse = { + data: [ + { + name: 'traced_small_model', + deployed_versions: ['1.0.1'], + owner: 'traced_small_model', + latest_version: '1.0.5', + latest_version_state: 'DEPLOYED', + created_time: 1681887678282, + }, + ], + pagination: { currentPage: 1, pageSize: 15, totalRecords: 1, totalPages: 1 }, +}; diff --git a/test/mocks/handlers.ts b/test/mocks/handlers.ts index dcdf9c38..dc292be0 100644 --- a/test/mocks/handlers.ts +++ b/test/mocks/handlers.ts @@ -7,6 +7,8 @@ import { rest } from 'msw'; import { modelConfig } from './data/model_config'; import { modelRepositoryResponse } from './data/model_repository'; import { modelHandlers } from './model_handlers'; +import { modelAggregateResponse } from './data/model_aggregate'; +import { MODEL_AGGREGATE_API_ENDPOINT } from '../../server/routes/constants'; export const handlers = [ rest.get('/api/ml-commons/model-repository', (req, res, ctx) => { @@ -16,4 +18,7 @@ export const handlers = [ return res(ctx.status(200), ctx.json(modelConfig)); }), ...modelHandlers, + rest.get(MODEL_AGGREGATE_API_ENDPOINT, (_req, res, ctx) => { + return res(ctx.status(200), ctx.json(modelAggregateResponse)); + }), ]; From 4fab3799d58fb79e97dc5b42ecb5b3f144d36706 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Tue, 9 May 2023 17:05:07 +0800 Subject: [PATCH 43/75] Feature/add versions table in model group detail (#170) * feat: move model list tag filter to common folder Signed-off-by: Lin Wang * feat: move options filter to common folder Signed-off-by: Lin Wang * feat: add searchWidth and value generic to options filter Signed-off-by: Lin Wang * feat: add task search API Signed-off-by: Lin Wang * test: add api mocks for task Signed-off-by: Lin Wang * feat: add registerFailed state and time field for model version Signed-off-by: Lin Wang * test: add global._isJest to avoid data-grid failed in test casees Signed-off-by: Lin Wang * feat: add filter and data-grid to versions detail panel Signed-off-by: Lin Wang * test: increase timeout to 10*1000 ms for complicated render Signed-off-by: Lin Wang * test: increase timeout for test failed in github runner Signed-off-by: Lin Wang * test: increase test case timeout to 40s Signed-off-by: Lin Wang * test: update case description and add status details test Signed-off-by: Lin Wang * feat: rename model group to model Signed-off-by: Lin Wang * chore: address PR comments Signed-off-by: Lin Wang * test: increase timeout for table sort waitFor Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- common/model.ts | 1 + common/router.ts | 8 +- common/router_paths.ts | 2 +- public/apis/model.ts | 4 +- public/apis/task.ts | 25 +++ .../__tests__/global_breadcrumbs.test.tsx | 2 +- .../__tests__/tag_filter.test.tsx | 8 +- public/components/common/index.ts | 2 + .../{model_list => common}/tag_filter.tsx | 0 public/components/global_breadcrumbs.tsx | 8 +- .../components/model/__tests__/model.test.tsx | 69 ++++++ .../__tests__/model_overview_card.test.tsx} | 8 +- .../{model_group => model}/index.ts | 2 +- .../model_group.tsx => model/model.tsx} | 20 +- .../model_details_panel.tsx} | 2 +- .../model_overview_card.tsx} | 6 +- .../model_tags_panel.tsx} | 2 +- .../__tests__/model_version_cell.test.tsx | 121 +++++++++++ ...model_version_error_details_modal.test.tsx | 75 +++++++ .../model_version_list_filter.test.tsx | 127 +++++++++++ .../model_version_status_cell.test.tsx | 51 +++++ .../model_version_status_detail.test.tsx | 86 ++++++++ .../__tests__/model_version_table.test.tsx | 160 ++++++++++++++ .../model_version_table_row_actions.test.tsx | 92 ++++++++ .../__tests__/model_versions_panel.test.tsx | 96 +++++++++ .../model/model_versions_panel/index.ts | 6 + .../model_version_cell.tsx | 55 +++++ .../model_version_error_details_modal.tsx | 72 +++++++ .../model_version_list_filter.tsx | 128 ++++++++++++ .../model_version_status_cell.tsx | 29 +++ .../model_version_status_detail.tsx | 197 ++++++++++++++++++ .../model_version_table.tsx | 112 ++++++++++ .../model_version_table_row_actions.tsx | 96 +++++++++ .../model_versions_panel.tsx | 171 +++++++++++++++ public/components/model/types.ts | 16 ++ .../__tests__/model_group.test.tsx | 61 ------ .../model_group_versions_panel.tsx | 17 -- .../__tests__/model_filter.test.tsx | 148 ------------- .../__tests__/model_filter_item.test.tsx | 34 --- .../__tests__/owner_filter.test.tsx | 5 +- public/components/model_list/model_filter.tsx | 107 ---------- .../model_list/model_filter_item.tsx | 24 --- .../model_list/model_list_filter.tsx | 3 +- public/components/model_list/owner_filter.tsx | 9 +- .../__tests__/register_model_form.test.tsx | 2 +- .../__tests__/register_model_tags.test.tsx | 2 +- .../register_model/register_model.tsx | 2 +- server/routes/model_router.ts | 9 +- server/routes/task_router.ts | 43 ++++ server/services/task_service.ts | 53 +++++ test/jest.config.js | 4 + test/mocks/handlers.ts | 6 +- test/mocks/task_handlers.ts | 44 ++++ 53 files changed, 1990 insertions(+), 442 deletions(-) rename public/components/{model_list => common}/__tests__/tag_filter.test.tsx (93%) rename public/components/{model_list => common}/tag_filter.tsx (100%) create mode 100644 public/components/model/__tests__/model.test.tsx rename public/components/{model_group/__tests__/model_group_overview_card.test.tsx => model/__tests__/model_overview_card.test.tsx} (83%) rename public/components/{model_group => model}/index.ts (73%) rename public/components/{model_group/model_group.tsx => model/model.tsx} (82%) rename public/components/{model_group/model_group_details_panel.tsx => model/model_details_panel.tsx} (89%) rename public/components/{model_group/model_group_overview_card.tsx => model/model_overview_card.tsx} (94%) rename public/components/{model_group/model_group_tags_panel.tsx => model/model_tags_panel.tsx} (90%) create mode 100644 public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx create mode 100644 public/components/model/model_versions_panel/__tests__/model_version_error_details_modal.test.tsx create mode 100644 public/components/model/model_versions_panel/__tests__/model_version_list_filter.test.tsx create mode 100644 public/components/model/model_versions_panel/__tests__/model_version_status_cell.test.tsx create mode 100644 public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx create mode 100644 public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx create mode 100644 public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx create mode 100644 public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx create mode 100644 public/components/model/model_versions_panel/index.ts create mode 100644 public/components/model/model_versions_panel/model_version_cell.tsx create mode 100644 public/components/model/model_versions_panel/model_version_error_details_modal.tsx create mode 100644 public/components/model/model_versions_panel/model_version_list_filter.tsx create mode 100644 public/components/model/model_versions_panel/model_version_status_cell.tsx create mode 100644 public/components/model/model_versions_panel/model_version_status_detail.tsx create mode 100644 public/components/model/model_versions_panel/model_version_table.tsx create mode 100644 public/components/model/model_versions_panel/model_version_table_row_actions.tsx create mode 100644 public/components/model/model_versions_panel/model_versions_panel.tsx create mode 100644 public/components/model/types.ts delete mode 100644 public/components/model_group/__tests__/model_group.test.tsx delete mode 100644 public/components/model_group/model_group_versions_panel.tsx delete mode 100644 public/components/model_list/__tests__/model_filter.test.tsx delete mode 100644 public/components/model_list/__tests__/model_filter_item.test.tsx delete mode 100644 public/components/model_list/model_filter.tsx delete mode 100644 public/components/model_list/model_filter_item.tsx create mode 100644 test/mocks/task_handlers.ts diff --git a/common/model.ts b/common/model.ts index 590dd497..f81f1e95 100644 --- a/common/model.ts +++ b/common/model.ts @@ -13,6 +13,7 @@ export enum MODEL_STATE { loading = 'DEPLOYING', partiallyLoaded = 'PARTIALLY_DEPLOYED', loadFailed = 'DEPLOY_FAILED', + registerFailed = 'REGISTER_FAILED', } export interface OpenSearchModelBase { diff --git a/common/router.ts b/common/router.ts index 0022f445..e3bc2c8f 100644 --- a/common/router.ts +++ b/common/router.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ModelGroup } from '../public/components/model_group'; +import { Model } from '../public/components/model'; import { ModelList } from '../public/components/model_list'; import { Monitoring } from '../public/components/monitoring'; import { RegisterModelForm } from '../public/components/register_model/register_model'; @@ -41,10 +41,10 @@ export const ROUTES: RouteConfig[] = [ nav: true, }, { - path: routerPaths.modelGroup, + path: routerPaths.model, // TODO: refactor label to be dynamic so that we can display group name in breadcrumb - label: 'Model Group', - Component: ModelGroup, + label: 'Model', + Component: Model, nav: false, }, { diff --git a/common/router_paths.ts b/common/router_paths.ts index 168806a0..706377d5 100644 --- a/common/router_paths.ts +++ b/common/router_paths.ts @@ -9,6 +9,6 @@ export const routerPaths = { monitoring: '/monitoring', registerModel: '/model-registry/register-model/:id?', modelList: '/model-registry/model-list', - modelGroup: '/model-registry/model/:id', + model: '/model-registry/model/:id', modelVersion: '/model-registry/model-version/:id', }; diff --git a/public/apis/model.ts b/public/apis/model.ts index 4d63b238..7a6742a2 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -17,7 +17,7 @@ export interface ModelSearchItem { id: string; name: string; algorithm: string; - model_state: string; + model_state: MODEL_STATE; model_version: string; current_worker_node_count: number; planning_worker_node_count: number; @@ -33,6 +33,8 @@ export interface ModelSearchItem { framework_type: string; model_type: string; }; + last_updated_time: number; + created_time: number; } export interface ModelDetail extends ModelSearchItem { diff --git a/public/apis/task.ts b/public/apis/task.ts index e0a42014..05dea325 100644 --- a/public/apis/task.ts +++ b/public/apis/task.ts @@ -6,6 +6,8 @@ import { TASK_API_ENDPOINT } from '../../server/routes/constants'; import { InnerHttpProvider } from './inner_http_provider'; +type TaskSearchSortItem = 'last_update_time-desc' | 'last_update_time-asc'; + export interface TaskGetOneResponse { error?: string; last_update_time: number; @@ -18,8 +20,31 @@ export interface TaskGetOneResponse { worker_node: string[]; } +export interface TaskSearchResponse { + data: TaskGetOneResponse[]; + total_tasks: number; +} + export class Task { public getOne(taskId: string) { return InnerHttpProvider.getHttp().get(`${TASK_API_ENDPOINT}/${taskId}`); } + + public search(query: { + from: number; + size: number; + modelId?: string; + taskType?: string; + state?: string; + sort?: TaskSearchSortItem | [TaskSearchSortItem]; + }) { + const { modelId, taskType, ...restQuery } = query; + return InnerHttpProvider.getHttp().get(TASK_API_ENDPOINT, { + query: { + ...restQuery, + model_id: modelId, + task_type: taskType, + }, + }); + } } diff --git a/public/components/__tests__/global_breadcrumbs.test.tsx b/public/components/__tests__/global_breadcrumbs.test.tsx index ef5cfff9..e2fe5cb8 100644 --- a/public/components/__tests__/global_breadcrumbs.test.tsx +++ b/public/components/__tests__/global_breadcrumbs.test.tsx @@ -56,7 +56,7 @@ describe('', () => { }); }); - it('should call onBreadcrumbsChange with model group breadcrumbs', async () => { + it('should call onBreadcrumbsChange with model breadcrumbs', async () => { const onBreadcrumbsChange = jest.fn(); render(, { route: '/model-registry/model/1', diff --git a/public/components/model_list/__tests__/tag_filter.test.tsx b/public/components/common/__tests__/tag_filter.test.tsx similarity index 93% rename from public/components/model_list/__tests__/tag_filter.test.tsx rename to public/components/common/__tests__/tag_filter.test.tsx index 5b693ea5..bdfafe13 100644 --- a/public/components/model_list/__tests__/tag_filter.test.tsx +++ b/public/components/common/__tests__/tag_filter.test.tsx @@ -18,7 +18,7 @@ describe('', () => { it( 'should call onChange when applying tag filter', async () => { - const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + const user = userEvent.setup(); const onChangeMock = jest.fn(); render( ', () => { ]); }, // There are too many operations, need to increase timeout - 10 * 1000 + 20 * 1000 ); it('should render an empty tag list if no tags', async () => { @@ -65,7 +65,7 @@ describe('', () => { }); it('should render loading screen when tags are loading', async () => { - const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + const user = userEvent.setup(); const { rerender } = render( @@ -89,7 +89,7 @@ describe('', () => { it( 'should reset input after popover re-open', async () => { - const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + const user = userEvent.setup(); const onChangeMock = jest.fn(); render( { [ { text: model.name, - href: joinUrl(basename, generatePath(routerPaths.modelGroup, { id: modelId })), + href: joinUrl(basename, generatePath(routerPaths.model, { id: modelId })), }, { text: 'Register version', @@ -73,7 +73,7 @@ const getModelRegisterBreadcrumbs = (basename: string, matchedParams: {}) => { }; }; -const getModelGroupBreadcrumbs = (basename: string, matchedParams: {}) => { +const getModelBreadcrumbs = (basename: string, matchedParams: {}) => { const baseModelRegistryBreadcrumbs = getBasicModelRegistryBreadcrumbs(basename); if ('id' in matchedParams && typeof matchedParams.id === 'string') { const modelId = matchedParams.id; @@ -115,7 +115,7 @@ const getModelVersionBreadcrumbs = (basename: string, matchedParams: {}) => { { text: model.name, // TODO: Change to use model group id - href: joinUrl(basename, generatePath(routerPaths.modelGroup, { id: modelId })), + href: joinUrl(basename, generatePath(routerPaths.model, { id: modelId })), }, { text: `Version ${model.model_version}`, @@ -131,7 +131,7 @@ const getModelVersionBreadcrumbs = (basename: string, matchedParams: {}) => { const routerPathBreadcrumbsMap = { [routerPaths.registerModel]: getModelRegisterBreadcrumbs, - [routerPaths.modelGroup]: getModelGroupBreadcrumbs, + [routerPaths.model]: getModelBreadcrumbs, [routerPaths.modelVersion]: getModelVersionBreadcrumbs, }; diff --git a/public/components/model/__tests__/model.test.tsx b/public/components/model/__tests__/model.test.tsx new file mode 100644 index 00000000..ab743271 --- /dev/null +++ b/public/components/model/__tests__/model.test.tsx @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen, waitFor, within } from '../../../../test/test_utils'; +import { Model } from '../model'; +import { routerPaths } from '../../../../common/router_paths'; +import { Route, generatePath } from 'react-router-dom'; + +const setup = () => { + const renderResult = render( + + + , + { route: generatePath(routerPaths.model, { id: '1' }) } + ); + + return { + renderResult, + }; +}; + +describe('', () => { + it( + 'should display model name, action buttons, overview-card, tabs and tabpanel after data loaded', + async () => { + setup(); + + await waitFor(() => { + expect(screen.queryByTestId('model-group-loading-indicator')).toBeNull(); + }); + expect(screen.getByText('model1')).toBeInTheDocument(); + expect(screen.getByText('Delete')).toBeInTheDocument(); + expect(screen.getByText('Register version')).toBeInTheDocument(); + expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); + expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); + expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); + expect(screen.getByRole('tab', { name: 'Versions' })).toHaveClass('euiTab-isSelected'); + expect(within(screen.getByRole('tabpanel')).getByText('Versions')).toBeInTheDocument(); + }, + 10 * 1000 + ); + + it( + 'should display consistent tabs content after tab clicked', + async () => { + setup(); + + await waitFor(() => { + expect(screen.queryByTestId('model-group-loading-indicator')).toBeNull(); + }); + expect(screen.getByRole('tab', { name: 'Versions' })).toHaveClass('euiTab-isSelected'); + expect(within(screen.getByRole('tabpanel')).getByText('Versions')).toBeInTheDocument(); + + await userEvent.click(screen.getByRole('tab', { name: 'Details' })); + expect(screen.getByRole('tab', { name: 'Details' })).toHaveClass('euiTab-isSelected'); + expect(within(screen.getByRole('tabpanel')).getByText('Details')).toBeInTheDocument(); + + await userEvent.click(screen.getByRole('tab', { name: 'Tags' })); + expect(screen.getByRole('tab', { name: 'Tags' })).toHaveClass('euiTab-isSelected'); + expect(within(screen.getByRole('tabpanel')).getByText('Tags')).toBeInTheDocument(); + }, + 10 * 1000 + ); +}); diff --git a/public/components/model_group/__tests__/model_group_overview_card.test.tsx b/public/components/model/__tests__/model_overview_card.test.tsx similarity index 83% rename from public/components/model_group/__tests__/model_group_overview_card.test.tsx rename to public/components/model/__tests__/model_overview_card.test.tsx index addc35ab..56809433 100644 --- a/public/components/model_group/__tests__/model_group_overview_card.test.tsx +++ b/public/components/model/__tests__/model_overview_card.test.tsx @@ -6,12 +6,12 @@ import React from 'react'; import { render, screen, within } from '../../../../test/test_utils'; -import { ModelGroupOverviewCard } from '../model_group_overview_card'; +import { ModelOverviewCard } from '../model_overview_card'; -describe('', () => { - it('should model group overview information according passed data', () => { +describe('', () => { + it('should model overview information according passed data', () => { render( - { +export const Model = () => { const { id: modelId } = useParams<{ id: string }>(); const { data, loading, error } = useFetcher(APIProvider.getAPI('model').getOne, modelId); const tabs = useMemo( @@ -32,7 +32,7 @@ export const ModelGroup = () => { content: ( <> - + ), }, @@ -42,7 +42,7 @@ export const ModelGroup = () => { content: ( <> - + ), }, @@ -52,12 +52,12 @@ export const ModelGroup = () => { content: ( <> - + ), }, ], - [] + [modelId] ); const [selectedTab, setSelectedTab] = useState(tabs[0]); @@ -84,7 +84,7 @@ export const ModelGroup = () => { Delete, ]} /> - { +export const ModelDetailsPanel = () => { return ( diff --git a/public/components/model_group/model_group_overview_card.tsx b/public/components/model/model_overview_card.tsx similarity index 94% rename from public/components/model_group/model_group_overview_card.tsx rename to public/components/model/model_overview_card.tsx index 48b054e8..ddb76a7f 100644 --- a/public/components/model_group/model_group_overview_card.tsx +++ b/public/components/model/model_overview_card.tsx @@ -8,7 +8,7 @@ import React from 'react'; import { CopyableText } from '../common'; import { renderTime } from '../../utils'; -interface ModelGroupOverviewCardProps { +interface ModelOverviewCardProps { id: string; description?: string; owner: string; @@ -17,14 +17,14 @@ interface ModelGroupOverviewCardProps { updatedTime: number; } -export const ModelGroupOverviewCard = ({ +export const ModelOverviewCard = ({ id, owner, createdTime, updatedTime, description, isModelOwner, -}: ModelGroupOverviewCardProps) => { +}: ModelOverviewCardProps) => { return ( diff --git a/public/components/model_group/model_group_tags_panel.tsx b/public/components/model/model_tags_panel.tsx similarity index 90% rename from public/components/model_group/model_group_tags_panel.tsx rename to public/components/model/model_tags_panel.tsx index 735558ad..e4d080a4 100644 --- a/public/components/model_group/model_group_tags_panel.tsx +++ b/public/components/model/model_tags_panel.tsx @@ -6,7 +6,7 @@ import React from 'react'; import { EuiHorizontalRule, EuiPanel, EuiSpacer, EuiTitle } from '@elastic/eui'; -export const ModelGroupTagsPanel = () => { +export const ModelTagsPanel = () => { return ( diff --git a/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx new file mode 100644 index 00000000..4c069b32 --- /dev/null +++ b/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { render, screen } from '../../../../../test/test_utils'; +import { ModelVersionCell } from '../model_version_cell'; +import { MODEL_STATE } from '../../../../../common'; + +const setup = (options: { columnId: string; isDetails?: boolean }) => + render( + + ); + +describe('', () => { + it('should render consistent version', () => { + setup({ + columnId: 'version', + }); + + expect(screen.getByText('1.0.0')).toBeInTheDocument(); + }); + + it('should render consistent deploy state', () => { + const { rerender } = setup({ + columnId: 'state', + }); + + expect(screen.getByText('Not deployed')).toBeInTheDocument(); + + rerender( + + ); + expect(screen.getByText('Deployed')).toBeInTheDocument(); + + rerender( + + ); + expect(screen.getByText('Deployed')).toBeInTheDocument(); + }); + + it('should render consistent status', () => { + setup({ + columnId: 'status', + }); + + expect(screen.getByText('In progress...')).toBeInTheDocument(); + }); + + it('should render status details', () => { + setup({ + columnId: 'status', + isDetails: true, + }); + + expect(screen.getByText('In progress...')).toBeInTheDocument(); + expect(screen.getByText('Upload initiated on:')).toBeInTheDocument(); + }); + + it('should render consistent last updated', () => { + setup({ + columnId: 'lastUpdated', + }); + + expect(screen.getByText('Apr 27, 2023 2:15 PM')).toBeInTheDocument(); + }); + + it('should render "model-1" for name column', () => { + setup({ + columnId: 'name', + }); + + expect(screen.getByText('model-1')).toBeInTheDocument(); + }); + + it('should render "-" for unknown columId', () => { + setup({ + columnId: 'unknown', + }); + + expect(screen.getByText('-')).toBeInTheDocument(); + }); +}); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_error_details_modal.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_error_details_modal.test.tsx new file mode 100644 index 00000000..b1c1d21e --- /dev/null +++ b/public/components/model/model_versions_panel/__tests__/model_version_error_details_modal.test.tsx @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../test/test_utils'; +import { ModelVersionErrorDetailsModal } from '../model_version_error_details_modal'; + +describe('', () => { + it('should render model artifact upload failed screen', () => { + render( + + ); + + expect(screen.getByText('model-1-name version 3')).toBeInTheDocument(); + expect(screen.getByText('artifact upload failed')).toBeInTheDocument(); + expect(screen.getByText('model-1-name version 3')).toHaveAttribute( + 'href', + '/model-registry/model-version/model-1-id' + ); + expect(screen.getByText('Error message')).toBeInTheDocument(); + }); + + it('should render deployment failed screen', () => { + render( + + ); + + expect(screen.getByText('model-1-name version 3')).toBeInTheDocument(); + expect(screen.getByText('deployment failed')).toBeInTheDocument(); + expect(screen.getByText('model-1-name version 3')).toHaveAttribute( + 'href', + '/model-registry/model-version/model-1-id' + ); + expect(screen.getByText('{"foo": "bar"}')).toBeInTheDocument(); + expect(screen.getByLabelText('Copy')).toBeInTheDocument(); + }); + + it('should call closeModal after Close button clicked', async () => { + const closeModalMock = jest.fn(); + render( + + ); + + expect(closeModalMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Close')); + expect(closeModalMock).toHaveBeenCalledTimes(1); + + await userEvent.click(screen.getByLabelText('Closes this modal window')); + expect(closeModalMock).toHaveBeenCalledTimes(2); + }); +}); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_list_filter.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_list_filter.test.tsx new file mode 100644 index 00000000..2a8cb482 --- /dev/null +++ b/public/components/model/model_versions_panel/__tests__/model_version_list_filter.test.tsx @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { act, render, screen, within } from '../../../../../test/test_utils'; +import { ModelVersionListFilter } from '../model_version_list_filter'; +import { TagFilterOperator } from '../../../common'; + +describe('', () => { + it('should render default search bar, state, status and Add tag filter', () => { + render( + + ); + expect(screen.getByPlaceholderText('Search by version number, or keyword')).toBeInTheDocument(); + + expect(screen.getByText('State')).toBeInTheDocument(); + expect(screen.getByText('Status')).toBeInTheDocument(); + expect(screen.getByText('Add tag filter')).toBeInTheDocument(); + }); + + it('should render activate filter count and tags after value provided', () => { + render( + + ); + + expect(within(screen.getByText('State').parentElement!).getByText('1')).toBeInTheDocument(); + expect(within(screen.getByText('Status').parentElement!).getByText('2')).toBeInTheDocument(); + expect(screen.getByTitle('NOT tag1: 123')).toBeInTheDocument(); + }); + + it('should call onChangeMock with new state after state filter changed', async () => { + const onChangeMock = jest.fn(); + const user = userEvent.setup(); + render( + + ); + + await user.click(screen.getByText('State')); + + expect(onChangeMock).not.toHaveBeenCalled(); + await user.click(screen.getByText('Deployed')); + expect(onChangeMock).toHaveBeenCalledWith(expect.objectContaining({ state: ['Deployed'] })); + }); + + it('should call onChangeMock with new status after status filter changed', async () => { + const onChangeMock = jest.fn(); + const user = userEvent.setup(); + render( + + ); + + await user.click(screen.getByText('Status')); + + expect(onChangeMock).not.toHaveBeenCalled(); + await user.click(screen.getByText('Success')); + expect(onChangeMock).toHaveBeenCalledWith(expect.objectContaining({ status: ['Success'] })); + }); + + it( + 'should call onChangeMock with new tag after tag filter changed', + async () => { + jest.useFakeTimers(); + const onChangeMock = jest.fn(); + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + render( + + ); + + act(() => { + jest.advanceTimersByTime(1000); + }); + + await user.click(screen.getByTitle('Accuracy: test: Image classification')); + + await user.click(screen.getByText('Image classification')); + await user.click(screen.getByRole('option', { name: 'Computer vision' })); + await user.click(screen.getByText('Save')); + + expect(onChangeMock).toHaveBeenCalledWith( + expect.objectContaining({ + tag: [ + { + name: 'Accuracy: test', + operator: TagFilterOperator.Is, + value: 'Computer vision', + type: 'string', + }, + ], + }) + ); + + jest.useRealTimers(); + }, + 10 * 1000 + ); +}); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_status_cell.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_status_cell.test.tsx new file mode 100644 index 00000000..735bd371 --- /dev/null +++ b/public/components/model/model_versions_panel/__tests__/model_version_status_cell.test.tsx @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { render, screen } from '../../../../../test/test_utils'; +import { ModelVersionStatusCell } from '../model_version_status_cell'; +import { MODEL_STATE } from '../../../../../common'; + +describe('', () => { + it('should display "-" if unsupported state provided', async () => { + render(); + + expect(screen.getByText('-')).toBeInTheDocument(); + }); + + it('should display "In progress..." when state is "uploading" or "loading"', async () => { + const { rerender } = render(); + + expect(screen.getByText('In progress...')).toBeInTheDocument(); + + rerender(); + expect(screen.getByText('In progress...')).toBeInTheDocument(); + }); + + it('should display "Success" when state is "uploaded" or "loaded"', async () => { + const { rerender } = render(); + + expect(screen.getByText('Success')).toBeInTheDocument(); + + rerender(); + expect(screen.getByText('Success')).toBeInTheDocument(); + }); + + it('should display "Error" when state is "registerFailed" or "loadedFailed"', async () => { + const { rerender } = render(); + + expect(screen.getByText('Error')).toBeInTheDocument(); + + rerender(); + expect(screen.getByText('Error')).toBeInTheDocument(); + }); + + it('should display "Warning" when state is "partialLoaded"', async () => { + render(); + + expect(screen.getByText('Warning')).toBeInTheDocument(); + }); +}); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx new file mode 100644 index 00000000..f0d92ba3 --- /dev/null +++ b/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen, waitFor } from '../../../../../test/test_utils'; +import { ModelVersionStatusDetail } from '../model_version_status_detail'; +import { MODEL_STATE } from '../../../../../common'; + +describe('', () => { + it('should render "In progress...", uploading tip and upload initialized time ', async () => { + render( + + ); + + expect(screen.getByText('In progress...')).toBeInTheDocument(); + expect(screen.getByText(/The model artifact for.*is uploading./)).toBeInTheDocument(); + expect(screen.getByText('model-1 version 1')).toHaveAttribute( + 'href', + '/model-registry/model-version/1' + ); + + expect(screen.getByText('Upload initiated on:')).toBeInTheDocument(); + expect(screen.getByText('May 5, 2023 @ 08:52:53.541')).toBeInTheDocument(); + }); + + it('should render "-" if state not supported', async () => { + render( + + ); + + expect(screen.getByText('-')).toBeInTheDocument(); + }); + + it('should render "See full error" button for "loadFailed" state', async () => { + render( + + ); + + expect(screen.getByText('Error')).toBeInTheDocument(); + expect(screen.getByText(/.*deployment failed./)).toBeInTheDocument(); + expect(screen.getByText('Deployment failed on:')).toBeInTheDocument(); + expect(screen.getByText('See full error')).toBeInTheDocument(); + }); + + it('should display error detail after "See full error" button clicked', async () => { + const user = userEvent.setup(); + render( + + ); + + expect(screen.getByText('Error')).toBeInTheDocument(); + await user.hover(screen.getByText('Error')); + await user.click(screen.getByText('See full error')); + await waitFor(() => { + expect(screen.getByText('The artifact url is in valid')).toBeInTheDocument(); + }); + }); +}); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx new file mode 100644 index 00000000..da7a2e02 --- /dev/null +++ b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen, waitFor } from '../../../../../test/test_utils'; +import { ModelVersionTable } from '../model_version_table'; +import { MODEL_STATE } from '../../../../../common'; +import { within } from '@testing-library/dom'; + +const versions = [ + { + id: '1', + name: 'model-1', + version: '1.0.0', + state: MODEL_STATE.uploading, + tags: { 'Accuracy: test': 0.98, 'Accuracy: train': 0.99 }, + lastUpdated: 1682676759143, + createdTime: 1682676759143, + }, +]; + +describe('', () => { + it('should render consistent columns header ', async () => { + render(); + + await waitFor(() => { + expect(screen.getByTestId('dataGridHeaderCell-version')).toBeInTheDocument(); + expect(screen.getByTestId('dataGridHeaderCell-state')).toBeInTheDocument(); + expect(screen.getByTestId('dataGridHeaderCell-status')).toBeInTheDocument(); + expect(screen.getByTestId('dataGridHeaderCell-lastUpdated')).toBeInTheDocument(); + expect(screen.getByTestId('dataGridHeaderCell-tags.Accuracy: test')).toBeInTheDocument(); + expect(screen.getByTestId('dataGridHeaderCell-tags.Accuracy: train')).toBeInTheDocument(); + }); + }); + + it( + 'should render sorted column and call onSort after sort change', + async () => { + const user = userEvent.setup(); + const onSortMock = jest.fn(); + render( + + ); + + await waitFor( + async () => { + expect(screen.getByTestId('dataGridHeaderCellSortingIcon-version')).toBeInTheDocument(); + }, + { + timeout: 2000, + } + ); + await user.click(screen.getByText('Version')); + await waitFor(async () => { + expect(screen.getByText('Sort A-Z').closest('li')).toHaveClass( + 'euiDataGridHeader__action--selected' + ); + }); + + expect(onSortMock).not.toHaveBeenCalled(); + await user.click(screen.getByText('Sort Z-A')); + expect(onSortMock).toHaveBeenCalledWith([{ direction: 'desc', id: 'version' }]); + }, + 40 * 1000 + ); + + it( + 'should NOT render sort button for state and status column', + async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByText('State')); + expect(screen.queryByTitle('Sort A-Z')).toBeNull(); + + await user.click(screen.getByText('Status')); + expect(screen.queryByTitle('Sort A-Z')).toBeNull(); + }, + 20 * 1000 + ); + + it('should render consistent versions values', () => { + render(); + + const gridCells = screen.getAllByRole('gridcell'); + expect(gridCells.length).toBe(7); + expect(within(gridCells[0]).getByText('1.0.0')).toBeInTheDocument(); + expect(within(gridCells[1]).getByText('Not deployed')).toBeInTheDocument(); + expect(within(gridCells[2]).getByText('In progress...')).toBeInTheDocument(); + expect(within(gridCells[3]).getByText('Apr 28, 2023 10:12 AM')).toBeInTheDocument(); + expect(within(gridCells[4]).getByText('0.98')).toBeInTheDocument(); + expect(within(gridCells[5]).getByText('0.99')).toBeInTheDocument(); + expect(within(gridCells[6]).getByLabelText('show actions')).toBeInTheDocument(); + }); + + it( + 'should render pagination and call onChangePageMock and onChangeItemsPerPageMock if pagination provided', + async () => { + const user = userEvent.setup(); + const onChangePageMock = jest.fn(); + const onChangeItemsPerPageMock = jest.fn(); + render( + + ); + + expect(screen.getByText('Rows per page: 25')).toBeInTheDocument(); + expect(screen.getByLabelText('Page 1 of 5')).toHaveClass('euiPaginationButton-isActive'); + + await user.click(screen.getByText('Rows per page: 25')); + + expect(onChangeItemsPerPageMock).not.toHaveBeenCalled(); + await user.click(screen.getByText('10 rows')); + expect(onChangeItemsPerPageMock).toHaveBeenCalledWith(10); + await user.click(screen.getByText('Rows per page: 25')); + + expect(onChangePageMock).not.toHaveBeenCalled(); + await user.click(screen.getByLabelText('Page 2 of 5')); + expect(onChangePageMock).toHaveBeenCalledWith(1); + }, + 20 * 1000 + ); + + it( + 'should show status details after status cell expand button clicked', + async () => { + const user = userEvent.setup(); + render(); + + await user.hover(screen.getByText('In progress...')); + await user.click( + screen.getByText('In progress...').closest('div[role="gridcell"]')!.querySelector('button')! + ); + + expect(screen.getByText(/The model artifact for.*is uploading./)).toBeInTheDocument(); + }, + 10 * 1000 + ); +}); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx new file mode 100644 index 00000000..7cc0848b --- /dev/null +++ b/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; +import * as euiExports from '@elastic/eui'; + +import { render, screen, waitFor } from '../../../../../test/test_utils'; +import { ModelVersionTableRowActions } from '../model_version_table_row_actions'; +import { MODEL_STATE } from '../../../../../common'; + +jest.mock('@elastic/eui', () => { + return { + __esModule: true, + ...jest.requireActual('@elastic/eui'), + }; +}); + +describe('', () => { + it('should render actions icon and "Copy ID" and "Delete" button after clicked', async () => { + const user = userEvent.setup(); + render(); + + expect(screen.getByLabelText('show actions')).toBeInTheDocument(); + await user.click(screen.getByLabelText('show actions')); + + expect(screen.getByText('Copy ID')).toBeInTheDocument(); + expect(screen.getByText('Delete')).toBeInTheDocument(); + }); + + it('should render "Upload new artifact" button for REGISTER_FAILED state', async () => { + const user = userEvent.setup(); + render(); + await user.click(screen.getByLabelText('show actions')); + + expect(screen.getByText('Upload new artifact')).toBeInTheDocument(); + }); + + it('should render "Deploy" button for REGISTERED and UNDEPLOYED state', async () => { + const user = userEvent.setup(); + const { rerender } = render( + + ); + await user.click(screen.getByLabelText('show actions')); + + expect(screen.getByText('Deploy')).toBeInTheDocument(); + + rerender(); + expect(screen.getByText('Deploy')).toBeInTheDocument(); + }); + + it('should render "Undeploy" button for DEPLOYED and PARTIALLY_DEPLOYED state', async () => { + const user = userEvent.setup(); + const { rerender } = render(); + await user.click(screen.getByLabelText('show actions')); + + expect(screen.getByText('Undeploy')).toBeInTheDocument(); + + rerender(); + expect(screen.getByText('Undeploy')).toBeInTheDocument(); + }); + + it('should call close popover after menuitem click', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Delete')); + + await waitFor(() => { + expect(screen.queryByText('Delete')).toBeNull(); + }); + }); + + it('should call copyToClipboard with "1" after "Copy ID" button clicked', async () => { + const copyToClipboardMock = jest + .spyOn(euiExports, 'copyToClipboard') + .mockImplementation(jest.fn()); + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('show actions')); + + expect(copyToClipboardMock).not.toHaveBeenCalled(); + await user.click(screen.getByText('Copy ID')); + expect(copyToClipboardMock).toHaveBeenCalledWith('1'); + + copyToClipboardMock.mockRestore(); + }); +}); diff --git a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx new file mode 100644 index 00000000..4449ef4f --- /dev/null +++ b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen, waitFor, within } from '../../../../../test/test_utils'; +import { Model } from '../../../../apis/model'; +import { ModelVersionsPanel } from '../model_versions_panel'; + +describe('', () => { + it( + 'should render version count, refresh button, filter and table by default', + async () => { + render(); + + expect( + screen.getByPlaceholderText('Search by version number, or keyword') + ).toBeInTheDocument(); + expect(screen.getByTitle('State')).toBeInTheDocument(); + expect(screen.getByTitle('Status')).toBeInTheDocument(); + expect(screen.getByTitle('Add tag filter')).toBeInTheDocument(); + expect( + screen.getByPlaceholderText('Search by version number, or keyword') + ).toBeInTheDocument(); + expect(screen.getByText('Refresh')).toBeInTheDocument(); + + await waitFor(() => { + expect( + screen.getByText((text, node) => { + return text === 'Versions' && !!node?.childNodes[1]?.textContent?.includes('(3)'); + }) + ).toBeInTheDocument(); + }); + + expect( + within( + within(screen.getByLabelText('Model versions')).getAllByRole('gridcell')[0] + ).getByText('1.0.0') + ).toBeInTheDocument(); + }, + 10 * 1000 + ); + + it( + 'should call model search API again after refresh button clicked', + async () => { + const searchMock = jest.spyOn(Model.prototype, 'search'); + + render(); + + expect(searchMock).toHaveBeenCalledTimes(1); + + await userEvent.click(screen.getByText('Refresh')); + expect(searchMock).toHaveBeenCalledTimes(2); + + searchMock.mockRestore(); + }, + 10 * 1000 + ); + + it( + 'should call model search with consistent state parameters after deployed state filter applied', + async () => { + const searchMock = jest.spyOn(Model.prototype, 'search'); + + render(); + + await userEvent.click(screen.getByTitle('State')); + await userEvent.click(screen.getByRole('option', { name: 'Deployed' })); + + await waitFor(() => { + expect(searchMock).toHaveBeenLastCalledWith( + expect.objectContaining({ + states: ['DEPLOYED', 'PARTIALLY_DEPLOYED'], + }) + ); + }); + + await userEvent.click(screen.getByRole('option', { name: 'Deployed' })); + await userEvent.click(screen.getByRole('option', { name: 'Not deployed' })); + await waitFor(() => { + expect(searchMock).toHaveBeenLastCalledWith( + expect.objectContaining({ + states: ['DEPLOYING', 'REGISTERING', 'REGISTERED', 'DEPLOY_FAILED', 'REGISTER_FAILED'], + }) + ); + }); + + searchMock.mockRestore(); + }, + 10 * 10000 + ); +}); diff --git a/public/components/model/model_versions_panel/index.ts b/public/components/model/model_versions_panel/index.ts new file mode 100644 index 00000000..1ca895de --- /dev/null +++ b/public/components/model/model_versions_panel/index.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export { ModelVersionsPanel } from './model_versions_panel'; diff --git a/public/components/model/model_versions_panel/model_version_cell.tsx b/public/components/model/model_versions_panel/model_version_cell.tsx new file mode 100644 index 00000000..b9f80075 --- /dev/null +++ b/public/components/model/model_versions_panel/model_version_cell.tsx @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { get } from 'lodash'; +import { EuiBadge, EuiText } from '@elastic/eui'; + +import { renderTime } from '../../../utils/table'; +import { MODEL_STATE } from '../../../../common'; +import { VersionTableDataItem } from '../types'; + +import { ModelVersionStatusCell } from './model_version_status_cell'; +import { ModelVersionStatusDetail } from './model_version_status_detail'; + +interface ModelVersionCellProps { + columnId: string; + data: VersionTableDataItem; + isDetails: boolean; +} + +export const ModelVersionCell = ({ data, columnId, isDetails }: ModelVersionCellProps) => { + if (columnId === 'status' && isDetails) { + return ( + + ); + } + switch (columnId) { + case 'version': + return {data.version}; + case 'status': { + return ; + } + case 'state': { + const deployed = + data.state === MODEL_STATE.loaded || data.state === MODEL_STATE.partiallyLoaded; + return ( + + {deployed ? 'Deployed' : 'Not deployed'} + + ); + } + case 'lastUpdated': + return renderTime(data.lastUpdated, 'MMM D, YYYY h:m A'); + default: + return get(data, columnId, '-'); + } +}; diff --git a/public/components/model/model_versions_panel/model_version_error_details_modal.tsx b/public/components/model/model_versions_panel/model_version_error_details_modal.tsx new file mode 100644 index 00000000..955c7124 --- /dev/null +++ b/public/components/model/model_versions_panel/model_version_error_details_modal.tsx @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { generatePath, Link } from 'react-router-dom'; +import { + EuiModal, + EuiModalHeader, + EuiModalBody, + EuiModalHeaderTitle, + EuiModalFooter, + EuiTitle, + EuiButtonEmpty, + EuiText, + EuiLink, + EuiCodeBlock, + EuiSpacer, +} from '@elastic/eui'; + +import { routerPaths } from '../../../../common/router_paths'; + +export const ModelVersionErrorDetailsModal = ({ + id, + name, + version, + closeModal, + errorDetails, + isDeployFailed, +}: { + name: string; + id: string; + version: string; + closeModal: () => void; + errorDetails: string; + isDeployFailed?: boolean; +}) => { + return ( + + + + +

    + + {name} version {version} + {' '} + {isDeployFailed ? 'deployment failed' : 'artifact upload failed'} +

    +
    +
    +
    + + {isDeployFailed ? ( + <> + Error message: + + {errorDetails} + + ) : ( + +

    {errorDetails}

    +
    + )} +
    + + Close + +
    + ); +}; diff --git a/public/components/model/model_versions_panel/model_version_list_filter.tsx b/public/components/model/model_versions_panel/model_version_list_filter.tsx new file mode 100644 index 00000000..dc998ab0 --- /dev/null +++ b/public/components/model/model_versions_panel/model_version_list_filter.tsx @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useRef } from 'react'; +import { + EuiFilterGroup, + EuiFlexGroup, + EuiFlexItem, + EuiFieldSearch, + EuiIcon, + EuiSpacer, +} from '@elastic/eui'; + +import { TagFilterValue, TagFilter, OptionsFilter, SelectedTagFiltersPanel } from '../../common'; +import { useModelTagKeys } from '../../model_list/model_list.hooks'; + +const statusOptions = [ + { + name: 'In progress...', + value: 'InProgress' as const, + prepend: , + }, + { name: 'Success', value: 'Success' as const, prepend: }, + { name: 'Warning', value: 'Warning' as const, prepend: }, + { name: 'Error', value: 'Error' as const, prepend: }, +]; + +const stateOptions = ['Not deployed', 'Deployed']; + +const removeDuplicateTag = (tagFilters: TagFilterValue[]) => { + const existsTagMap: { [key: string]: boolean } = {}; + return tagFilters.filter((tagFilter) => { + const key = `${tagFilter.name}${tagFilter.operator}${tagFilter.value.toString()}`; + if (!existsTagMap[key]) { + existsTagMap[key] = true; + return true; + } + + return false; + }); +}; + +export interface ModelVersionListFilterValue { + status: Array; + state: Array; + tag: TagFilterValue[]; +} + +interface ModelVersionListFilterProps { + value: ModelVersionListFilterValue; + onChange: (value: ModelVersionListFilterValue) => void; +} + +export const ModelVersionListFilter = ({ value, onChange }: ModelVersionListFilterProps) => { + // TODO: Change to model tags API and pass model group id here + const [tagKeysLoading, tagKeys] = useModelTagKeys(); + const valueRef = useRef(value); + valueRef.current = value; + + const handleStateChange = useCallback( + (state: ModelVersionListFilterValue['state']) => { + onChange({ ...valueRef.current, state }); + }, + [onChange] + ); + + const handleStatusChange = useCallback( + (status: ModelVersionListFilterValue['status']) => { + onChange({ ...valueRef.current, status }); + }, + [onChange] + ); + + const handleTagChange = useCallback( + (tag: ModelVersionListFilterValue['tag']) => { + onChange({ ...valueRef.current, tag: removeDuplicateTag(tag) }); + }, + [onChange] + ); + + return ( + <> + + + + + + + + + + + + + {value.tag.length > 0 && ( + <> + + + + )} + + ); +}; diff --git a/public/components/model/model_versions_panel/model_version_status_cell.tsx b/public/components/model/model_versions_panel/model_version_status_cell.tsx new file mode 100644 index 00000000..92f3dafe --- /dev/null +++ b/public/components/model/model_versions_panel/model_version_status_cell.tsx @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiHealth } from '@elastic/eui'; + +import { MODEL_STATE } from '../../../../common'; + +const state2StatusContentMap: { [key in MODEL_STATE]?: [string, string] } = { + [MODEL_STATE.uploading]: ['#AFB0B3', 'In progress...'], + [MODEL_STATE.loading]: ['#AFB0B3', 'In progress...'], + [MODEL_STATE.uploaded]: ['success', 'Success'], + [MODEL_STATE.loaded]: ['success', 'Success'], + [MODEL_STATE.unloaded]: ['success', 'Success'], + [MODEL_STATE.loadFailed]: ['danger', 'Error'], + [MODEL_STATE.registerFailed]: ['danger', 'Error'], + [MODEL_STATE.partiallyLoaded]: ['warning', 'Warning'], +}; + +export const ModelVersionStatusCell = ({ state }: { state: MODEL_STATE }) => { + const statusContent = state2StatusContentMap[state]; + if (!statusContent) { + return <>-; + } + const [color, text] = statusContent; + return {text}; +}; diff --git a/public/components/model/model_versions_panel/model_version_status_detail.tsx b/public/components/model/model_versions_panel/model_version_status_detail.tsx new file mode 100644 index 00000000..d3475a67 --- /dev/null +++ b/public/components/model/model_versions_panel/model_version_status_detail.tsx @@ -0,0 +1,197 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useCallback } from 'react'; +import { + EuiSpacer, + EuiFlexGroup, + EuiText, + EuiButton, + EuiPopoverTitle, + EuiLink, +} from '@elastic/eui'; + +import { Link, generatePath } from 'react-router-dom'; +import { MODEL_STATE, routerPaths } from '../../../../common'; +import { APIProvider } from '../../../apis/api_provider'; +import { renderTime } from '../../../utils'; + +import { ModelVersionErrorDetailsModal } from './model_version_error_details_modal'; + +// TODO: Change to related time field after confirmed +export const state2DetailContentMap: { + [key in MODEL_STATE]?: { + title: string; + description: (versionLink: React.ReactNode) => React.ReactNode; + timeTitle: string; + timeField: 'createdTime'; + }; +} = { + [MODEL_STATE.uploading]: { + title: 'In progress...', + description: (versionLink: React.ReactNode) => ( + <>The model artifact for {versionLink} is uploading. + ), + timeTitle: 'Upload initiated on', + timeField: 'createdTime', + }, + [MODEL_STATE.loading]: { + title: 'In progress...', + description: (versionLink: React.ReactNode) => ( + <>The model artifact for {versionLink} is deploying. + ), + timeTitle: 'Deployment initiated on', + timeField: 'createdTime', + }, + [MODEL_STATE.uploaded]: { + title: 'Success', + description: (versionLink: React.ReactNode) => ( + <>The model artifact for {versionLink} uploaded. + ), + timeTitle: 'Uploaded on', + timeField: 'createdTime', + }, + [MODEL_STATE.loaded]: { + title: 'Success', + description: (versionLink: React.ReactNode) => <>{versionLink} deployed., + timeTitle: 'Deployed on', + timeField: 'createdTime', + }, + [MODEL_STATE.unloaded]: { + title: 'Success', + description: (versionLink: React.ReactNode) => <>{versionLink} undeployed., + timeTitle: 'Undeployed on', + timeField: 'createdTime', + }, + [MODEL_STATE.loadFailed]: { + title: 'Error', + description: (versionLink: React.ReactNode) => <>{versionLink} deployment failed., + timeTitle: 'Deployment failed on', + timeField: 'createdTime', + }, + [MODEL_STATE.registerFailed]: { + title: 'Error', + description: (versionLink: React.ReactNode) => <>{versionLink} artifact upload failed., + timeTitle: 'Upload failed on', + timeField: 'createdTime', + }, + [MODEL_STATE.partiallyLoaded]: { + title: 'Warning', + description: (versionLink: React.ReactNode) => ( + <>{versionLink} is deployed and partially responding. + ), + timeTitle: 'Last responded on', + timeField: 'createdTime', + }, +}; + +export const ModelVersionStatusDetail = ({ + id, + name, + state, + version, + ...restProps +}: { + id: string; + state: MODEL_STATE; + name: string; + version: string; + createdTime: number; +}) => { + const [isErrorDetailsModalShowed, setIsErrorDetailsModalShowed] = useState(false); + const [isLoadingErrorDetails, setIsLoadingErrorDetails] = useState(false); + const [errorDetails, setErrorDetails] = useState(); + + const handleSeeFullErrorClick = useCallback(async () => { + const state2TaskTypeMap: { [key in MODEL_STATE]?: string } = { + [MODEL_STATE.loadFailed]: 'DEPLOY_MODEL', + [MODEL_STATE.registerFailed]: 'REGISTER_MODEL', + }; + if (!(state in state2TaskTypeMap)) { + return; + } + if (errorDetails) { + setIsErrorDetailsModalShowed(true); + return; + } + setIsLoadingErrorDetails(true); + try { + const { data } = await APIProvider.getAPI('task').search({ + modelId: id, + taskType: state2TaskTypeMap[state], + from: 0, + size: 1, + sort: 'last_update_time-desc', + }); + if (data[0]?.error) { + setErrorDetails(data[0].error); + setIsErrorDetailsModalShowed(true); + } + } finally { + setIsLoadingErrorDetails(false); + } + }, [state, id, errorDetails]); + + const handleCloseModal = useCallback(() => { + setIsErrorDetailsModalShowed(false); + }, []); + + const statusContent = state2DetailContentMap[state]; + if (!statusContent) { + return <>-; + } + const { title, description, timeTitle, timeField } = statusContent; + + return ( + <> +
    + + {title} + +
    + + {description( + + {name} version {version} + + )} + + + + {timeTitle}: {renderTime(restProps[timeField], 'MMM d, yyyy @ HH:mm:ss.SSS')} + + {(state === MODEL_STATE.loadFailed || state === MODEL_STATE.registerFailed) && ( + <> + + + + See full error + + + + )} +
    +
    + {isErrorDetailsModalShowed && errorDetails && ( + + )} + + ); +}; diff --git a/public/components/model/model_versions_panel/model_version_table.tsx b/public/components/model/model_versions_panel/model_version_table.tsx new file mode 100644 index 00000000..29a4e2f4 --- /dev/null +++ b/public/components/model/model_versions_panel/model_version_table.tsx @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useState, useMemo } from 'react'; +import { + EuiDataGrid, + EuiDataGridProps, + EuiTablePagination, + EuiDataGridCellValueElementProps, +} from '@elastic/eui'; + +import { VersionTableDataItem } from '../types'; + +import { ModelVersionCell } from './model_version_cell'; +import { ModelVersionTableRowActions } from './model_version_table_row_actions'; + +interface VersionTableProps extends Pick { + tags: string[]; + versions: VersionTableDataItem[]; + totalVersionCount?: number; +} + +export const ModelVersionTable = ({ + tags, + sorting, + versions, + pagination, + totalVersionCount, +}: VersionTableProps) => { + const columns = useMemo( + () => [ + { + id: 'version', + displayAsText: 'Version', + defaultSortDirection: 'asc' as const, + }, + { + id: 'state', + displayAsText: 'State', + isSortable: false, + }, + { + id: 'status', + schema: 'status', + displayAsText: 'Status', + isSortable: false, + }, + { + id: 'lastUpdated', + displayAsText: 'Last updated', + }, + ...tags.map((tag) => ({ + id: `tags.${tag}`, + displayAsText: `Tag: ${tag}`, + })), + ], + [tags] + ); + const trailingControlColumns = useMemo( + () => [ + { + id: 'actions', + width: 40, + headerCellRender: () => null, + rowCellRender: ({ rowIndex }: EuiDataGridCellValueElementProps) => { + const version = versions[rowIndex]; + return ; + }, + }, + ], + [versions] + ); + const [visibleColumns, setVisibleColumns] = useState(() => columns.map(({ id }) => id)); + const columnVisibility = useMemo(() => ({ visibleColumns, setVisibleColumns }), [visibleColumns]); + + const renderCellValue = useCallback( + ({ rowIndex, columnId, isDetails }) => ( + + ), + [versions] + ); + + return ( +
    + + {pagination && typeof totalVersionCount === 'number' && ( +
    + +
    + )} +
    + ); +}; diff --git a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx new file mode 100644 index 00000000..e4511130 --- /dev/null +++ b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useCallback } from 'react'; +import { + EuiPopover, + EuiButtonIcon, + EuiContextMenuPanel, + EuiContextMenuItem, + copyToClipboard, +} from '@elastic/eui'; + +import { MODEL_STATE } from '../../../../common'; + +export const ModelVersionTableRowActions = ({ state, id }: { state: MODEL_STATE; id: string }) => { + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + + const handleShowActionsClick = useCallback(() => { + setIsPopoverOpen((flag) => !flag); + }, []); + + const closePopover = useCallback(() => { + setIsPopoverOpen(false); + }, []); + + return ( + + } + closePopover={closePopover} + ownFocus={false} + > +
    + { + copyToClipboard(id); + }} + style={{ padding: 8 }} + > + Copy ID + , + ...(state === MODEL_STATE.registerFailed + ? [ + + Upload new artifact + , + ] + : []), + ...(state === MODEL_STATE.uploaded || state === MODEL_STATE.unloaded + ? [ + + Deploy + , + ] + : []), + ...(state === MODEL_STATE.loaded || state === MODEL_STATE.partiallyLoaded + ? [ + + Undeploy + , + ] + : []), + + Delete + , + ]} + /> +
    +
    + ); +}; diff --git a/public/components/model/model_versions_panel/model_versions_panel.tsx b/public/components/model/model_versions_panel/model_versions_panel.tsx new file mode 100644 index 00000000..aacd6d75 --- /dev/null +++ b/public/components/model/model_versions_panel/model_versions_panel.tsx @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useState, useCallback } from 'react'; +import { + EuiButton, + EuiFlexGroup, + EuiFlexItem, + EuiPanel, + EuiSpacer, + EuiTextColor, + EuiTitle, +} from '@elastic/eui'; + +import { useFetcher } from '../../../hooks'; +import { APIProvider } from '../../../apis/api_provider'; +import { MODEL_STATE } from '../../../../common'; + +import { ModelVersionTable } from './model_version_table'; +import { ModelVersionListFilter, ModelVersionListFilterValue } from './model_version_list_filter'; + +// TODO: Use tags from model group +const tags = ['Tag1', 'Tag2']; + +const modelState2StatusMap: { + [key in MODEL_STATE]?: ModelVersionListFilterValue['status'][number]; +} = { + [MODEL_STATE.loading]: 'InProgress', + [MODEL_STATE.uploading]: 'InProgress', + [MODEL_STATE.uploaded]: 'Success', + [MODEL_STATE.loaded]: 'Success', + [MODEL_STATE.unloaded]: 'Success', + [MODEL_STATE.partiallyLoaded]: 'Warning', + [MODEL_STATE.loadFailed]: 'Error', + [MODEL_STATE.registerFailed]: 'Error', +}; + +const getStatesParam = ({ + status: statuses, + state: states, +}: Pick) => { + if (statuses.length === 0 && states.length === 0) { + return undefined; + } + return [ + MODEL_STATE.loading, + MODEL_STATE.uploading, + MODEL_STATE.uploaded, + MODEL_STATE.loaded, + MODEL_STATE.partiallyLoaded, + MODEL_STATE.loadFailed, + MODEL_STATE.registerFailed, + ].filter((modelState) => { + const stateRelatedStatus = modelState2StatusMap[modelState]; + if (stateRelatedStatus && statuses.includes(stateRelatedStatus)) { + return true; + } + if (modelState === MODEL_STATE.loaded || modelState === MODEL_STATE.partiallyLoaded) { + return states.includes('Deployed'); + } + return states.includes('Not deployed'); + }); +}; + +interface ModelVersionsPanelProps { + groupId: string; +} + +export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { + const [params, setParams] = useState<{ + pageIndex: number; + pageSize: number; + sort: Array<{ id: string; direction: 'asc' | 'desc' }>; + filter: ModelVersionListFilterValue; + }>({ + pageIndex: 0, + pageSize: 25, + sort: [], + filter: { + state: [], + status: [], + tag: [], + }, + }); + const { data: versionsData, reload } = useFetcher(APIProvider.getAPI('model').search, { + // TODO: Change to model group id + ids: [groupId], + from: params.pageIndex * params.pageSize, + size: params.pageSize, + states: getStatesParam(params.filter), + }); + const totalVersionCount = versionsData?.total_models; + + const versions = useMemo(() => { + if (!versionsData) { + return []; + } + return versionsData.data.map((item) => ({ + id: item.id, + name: item.name, + version: item.model_version, + state: item.model_state, + lastUpdated: item.last_updated_time, + // TODO: Change to use tags in model version once structure finalized + tags: {}, + createdTime: item.created_time, + })); + }, [versionsData]); + + const pagination = useMemo(() => { + if (!totalVersionCount) { + return undefined; + } + return { + pageIndex: params.pageIndex, + pageSize: params.pageSize, + pageSizeOptions: [10, 25, 50], + onChangeItemsPerPage: (pageSize: number) => { + setParams((previousParams) => ({ ...previousParams, pageSize })); + }, + onChangePage: (pageIndex: number) => { + setParams((previousParams) => ({ ...previousParams, pageIndex })); + }, + }; + }, [params.pageIndex, params.pageSize, totalVersionCount]); + + const versionsSorting = useMemo( + () => ({ + columns: params.sort, + onSort: (sort) => { + setParams((previousParams) => ({ ...previousParams, sort })); + }, + }), + [params] + ); + + const handleFilterChange = useCallback((filter: ModelVersionListFilterValue) => { + setParams((previousParams) => ({ ...previousParams, filter })); + }, []); + + return ( + + + + +

    + Versions + {typeof totalVersionCount === 'number' && ( +  ({totalVersionCount}) + )} +

    +
    +
    + + Refresh + +
    + + + +
    + ); +}; diff --git a/public/components/model/types.ts b/public/components/model/types.ts new file mode 100644 index 00000000..77d1f06f --- /dev/null +++ b/public/components/model/types.ts @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { MODEL_STATE } from '../../../common'; + +export interface VersionTableDataItem { + id: string; + name: string; + version: string; + state: MODEL_STATE; + lastUpdated: number; + tags: { [key: string]: string | number }; + createdTime: number; +} diff --git a/public/components/model_group/__tests__/model_group.test.tsx b/public/components/model_group/__tests__/model_group.test.tsx deleted file mode 100644 index 2df4b7e7..00000000 --- a/public/components/model_group/__tests__/model_group.test.tsx +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React from 'react'; -import userEvent from '@testing-library/user-event'; - -import { render, screen, waitFor, within } from '../../../../test/test_utils'; -import { ModelGroup } from '../model_group'; -import { routerPaths } from '../../../../common/router_paths'; -import { Route, generatePath } from 'react-router-dom'; - -const setup = () => { - const renderResult = render( - - - , - { route: generatePath(routerPaths.modelGroup, { id: '1' }) } - ); - - return { - renderResult, - }; -}; - -describe('', () => { - it('should display model name, action buttons, overview-card, tabs and tabpanel after data loaded', async () => { - setup(); - - await waitFor(() => { - expect(screen.queryByTestId('model-group-loading-indicator')).toBeNull(); - }); - expect(screen.getByText('model1')).toBeInTheDocument(); - expect(screen.getByText('Delete')).toBeInTheDocument(); - expect(screen.getByText('Register version')).toBeInTheDocument(); - expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); - expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); - expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); - expect(screen.getByRole('tab', { name: 'Versions' })).toHaveClass('euiTab-isSelected'); - expect(within(screen.getByRole('tabpanel')).getByText('Versions')).toBeInTheDocument(); - }); - - it('should display consistent tabs content after tab clicked', async () => { - setup(); - - await waitFor(() => { - expect(screen.queryByTestId('model-group-loading-indicator')).toBeNull(); - }); - expect(screen.getByRole('tab', { name: 'Versions' })).toHaveClass('euiTab-isSelected'); - expect(within(screen.getByRole('tabpanel')).getByText('Versions')).toBeInTheDocument(); - - await userEvent.click(screen.getByRole('tab', { name: 'Details' })); - expect(screen.getByRole('tab', { name: 'Details' })).toHaveClass('euiTab-isSelected'); - expect(within(screen.getByRole('tabpanel')).getByText('Details')).toBeInTheDocument(); - - await userEvent.click(screen.getByRole('tab', { name: 'Tags' })); - expect(screen.getByRole('tab', { name: 'Tags' })).toHaveClass('euiTab-isSelected'); - expect(within(screen.getByRole('tabpanel')).getByText('Tags')).toBeInTheDocument(); - }); -}); diff --git a/public/components/model_group/model_group_versions_panel.tsx b/public/components/model_group/model_group_versions_panel.tsx deleted file mode 100644 index 1cb3f0f4..00000000 --- a/public/components/model_group/model_group_versions_panel.tsx +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React from 'react'; -import { EuiPanel, EuiTitle } from '@elastic/eui'; - -export const ModelGroupVersionsPanel = () => { - return ( - - -

    Versions

    -
    -
    - ); -}; diff --git a/public/components/model_list/__tests__/model_filter.test.tsx b/public/components/model_list/__tests__/model_filter.test.tsx deleted file mode 100644 index e4b63c33..00000000 --- a/public/components/model_list/__tests__/model_filter.test.tsx +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -jest.mock('../../../apis/security'); - -import React from 'react'; -import userEvent from '@testing-library/user-event'; - -import { ModelFilter } from '../model_filter'; -import { render, screen } from '../../../../test/test_utils'; - -describe('', () => { - afterEach(() => { - jest.resetAllMocks(); - }); - - it('should render "Tags" with 0 active filter', () => { - render( - {}} - /> - ); - expect(screen.getByText('Tags')).toBeInTheDocument(); - expect(screen.getByText('0')).toBeInTheDocument(); - }); - - it('should render Tags with 2 active filter', () => { - render( - {}} - /> - ); - expect(screen.getByText('Tags')).toBeInTheDocument(); - expect(screen.getByText('2')).toBeInTheDocument(); - }); - - it('should render options filter after filter button clicked', async () => { - render( - {}} - /> - ); - expect(screen.queryByText('foo')).not.toBeInTheDocument(); - expect(screen.queryByPlaceholderText('Search Tags')).not.toBeInTheDocument(); - - await userEvent.click(screen.getByText('Tags')); - - expect(screen.getByText('foo')).toBeInTheDocument(); - expect(screen.getByPlaceholderText('Search Tags')).toBeInTheDocument(); - }); - - it('should render passed footer after filter button clicked', async () => { - const { getByText, queryByText } = render( - {}} - footer="footer" - /> - ); - expect(queryByText('footer')).not.toBeInTheDocument(); - - await userEvent.click(screen.getByText('Tags')); - expect(getByText('footer')).toBeInTheDocument(); - }); - - it('should only show "bar" after search', async () => { - render( - {}} - /> - ); - - await userEvent.click(screen.getByText('Tags')); - expect(screen.getByText('foo')).toBeInTheDocument(); - - await userEvent.type(screen.getByPlaceholderText('Search Tags'), 'bAr{enter}'); - expect(screen.queryByText('foo')).not.toBeInTheDocument(); - expect(screen.getByText('bar')).toBeInTheDocument(); - }); - - it('should call onChange with consistent value after option click', async () => { - const onChangeMock = jest.fn(); - const { rerender } = render( - - ); - - expect(onChangeMock).not.toHaveBeenCalled(); - - await userEvent.click(screen.getByText('Tags')); - await userEvent.click(screen.getByText('foo')); - expect(onChangeMock).toHaveBeenCalledWith(['foo']); - onChangeMock.mockClear(); - - rerender( - - ); - - await userEvent.click(screen.getByText('bar')); - expect(onChangeMock).toHaveBeenCalledWith(['foo', 'bar']); - onChangeMock.mockClear(); - - rerender( - - ); - - await userEvent.click(screen.getByText('bar')); - expect(onChangeMock).toHaveBeenCalledWith(['foo']); - onChangeMock.mockClear(); - }); -}); diff --git a/public/components/model_list/__tests__/model_filter_item.test.tsx b/public/components/model_list/__tests__/model_filter_item.test.tsx deleted file mode 100644 index 8b219960..00000000 --- a/public/components/model_list/__tests__/model_filter_item.test.tsx +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React from 'react'; -import userEvent from '@testing-library/user-event'; - -import { ModelFilterItem } from '../model_filter_item'; - -import { render, screen } from '../../../../test/test_utils'; - -describe('', () => { - it('should render passed children and check icon', () => { - render( - {}}> - foo - - ); - expect(screen.getByText('foo')).toBeInTheDocument(); - expect(screen.getByRole('img', { hidden: true })).toBeInTheDocument(); - }); - - it('should call onClick with "foo" after click', async () => { - const onClickMock = jest.fn(); - render( - - foo - - ); - await userEvent.click(screen.getByRole('option')); - expect(onClickMock).toHaveBeenCalledWith('foo'); - }); -}); diff --git a/public/components/model_list/__tests__/owner_filter.test.tsx b/public/components/model_list/__tests__/owner_filter.test.tsx index 3e6deccc..93cd6577 100644 --- a/public/components/model_list/__tests__/owner_filter.test.tsx +++ b/public/components/model_list/__tests__/owner_filter.test.tsx @@ -16,10 +16,9 @@ describe('', () => { jest.resetAllMocks(); }); - it('should render "Owner" with 3 filter for normal', async () => { - const { getByText, findByText } = render( {}} />); + it('should render "Owner" by default', async () => { + const { getByText } = render( {}} />); expect(getByText('Owner')).toBeInTheDocument(); - expect(await findByText('3')).toBeInTheDocument(); }); it('should render three options with 1 checked option and 1 active filter', async () => { diff --git a/public/components/model_list/model_filter.tsx b/public/components/model_list/model_filter.tsx deleted file mode 100644 index c2ef3187..00000000 --- a/public/components/model_list/model_filter.tsx +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React, { useCallback, useMemo, useRef, useState } from 'react'; -import { - EuiPopover, - EuiPopoverTitle, - EuiFieldSearch, - EuiFilterButton, - EuiPopoverFooter, -} from '@elastic/eui'; -import { ModelFilterItem } from './model_filter_item'; - -export interface ModelFilterProps { - name: string; - searchPlaceholder: string; - options: Array; - value: string[]; - onChange: (value: string[]) => void; - footer?: React.ReactNode; -} - -export const ModelFilter = ({ - name, - value, - footer, - options, - searchPlaceholder, - onChange, -}: ModelFilterProps) => { - const valueRef = useRef(value); - valueRef.current = value; - const onChangeRef = useRef(onChange); - onChangeRef.current = onChange; - const [isPopoverOpen, setIsPopoverOpen] = useState(false); - const [searchText, setSearchText] = useState(); - - const filteredOptions = useMemo( - () => - searchText - ? options.filter((option) => - (typeof option === 'string' ? option : option.name) - .toLowerCase() - .includes(searchText.toLowerCase()) - ) - : options, - [searchText, options] - ); - - const handleButtonClick = useCallback(() => { - setIsPopoverOpen((prevState) => !prevState); - }, []); - - const closePopover = useCallback(() => { - setIsPopoverOpen(false); - }, []); - - const handleFilterItemClick = useCallback((clickItemValue: string) => { - onChangeRef.current( - valueRef.current.includes(clickItemValue) - ? valueRef.current.filter((item) => item !== clickItemValue) - : valueRef.current.concat(clickItemValue) - ); - }, []); - - return ( - 0} - numActiveFilters={value.length} - > - {name} - - } - isOpen={isPopoverOpen} - closePopover={closePopover} - panelPaddingSize="none" - > - - - - {filteredOptions.map((item, index) => { - const itemValue = typeof item === 'string' ? item : item.value; - const checked = value.includes(itemValue) ? 'on' : undefined; - return ( - - {typeof item === 'string' ? item : item.name} - - ); - })} - {footer && {footer}} - - ); -}; diff --git a/public/components/model_list/model_filter_item.tsx b/public/components/model_list/model_filter_item.tsx deleted file mode 100644 index c3fbf809..00000000 --- a/public/components/model_list/model_filter_item.tsx +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React, { useCallback } from 'react'; -import { EuiFilterSelectItem, EuiFilterSelectItemProps } from '@elastic/eui'; - -export interface ModelFilterItemProps - extends Pick { - value: string; - onClick: (value: string) => void; -} - -export const ModelFilterItem = ({ checked, children, onClick, value }: ModelFilterItemProps) => { - const handleClick = useCallback(() => { - onClick(value); - }, [onClick, value]); - return ( - - {children} - - ); -}; diff --git a/public/components/model_list/model_list_filter.tsx b/public/components/model_list/model_list_filter.tsx index 0e8616d6..40b0da4d 100644 --- a/public/components/model_list/model_list_filter.tsx +++ b/public/components/model_list/model_list_filter.tsx @@ -12,9 +12,8 @@ import { } from '@elastic/eui'; import React, { useCallback, useRef } from 'react'; -import { TagFilterValue, SelectedTagFiltersPanel, DebouncedSearchBar } from '../common'; +import { TagFilterValue, SelectedTagFiltersPanel, DebouncedSearchBar, TagFilter } from '../common'; -import { TagFilter } from './tag_filter'; import { OwnerFilter } from './owner_filter'; import { useModelTagKeys } from './model_list.hooks'; diff --git a/public/components/model_list/owner_filter.tsx b/public/components/model_list/owner_filter.tsx index 67f7eeb4..177802c1 100644 --- a/public/components/model_list/owner_filter.tsx +++ b/public/components/model_list/owner_filter.tsx @@ -5,13 +5,16 @@ import React, { useCallback, useMemo } from 'react'; import { EuiButton } from '@elastic/eui'; -import { ModelFilter, ModelFilterProps } from './model_filter'; import { useFetcher } from '../../hooks/use_fetcher'; import { APIProvider } from '../../apis/api_provider'; +import { OptionsFilter, OptionsFilterProps } from '../common'; const ownerFetcher = () => Promise.resolve(['admin', 'owner-1', 'owner-2']); -export const OwnerFilter = ({ value, onChange }: Pick) => { +export const OwnerFilter = ({ + value, + onChange, +}: Pick) => { const { data: accountData } = useFetcher(APIProvider.getAPI('security').getAccount); const { data: ownerData } = useFetcher(ownerFetcher); const currentAccountName = accountData?.user_name; @@ -32,7 +35,7 @@ export const OwnerFilter = ({ value, onChange }: Pick Form', () => { expect(addSuccessMock).toHaveBeenCalled(); }); - it('should navigate to model group page when submit succeed', async () => { + it('should navigate to model page when submit succeed', async () => { const { user } = await setup(); await user.click(screen.getByRole('button', { name: /register model/i })); expect(location.href).toContain(`model-registry/model/${MOCKED_MODEL_ID}`); diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index 1d1caff2..935e8336 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -226,7 +226,7 @@ describe(' Tags', () => { ); }); - it('should allow adding one more tag when registering new version if model group has only two tags', async () => { + it('should allow adding one more tag when registering new version if model has only two tags', async () => { jest.spyOn(formHooks, 'useModelTags').mockReturnValue([ false, [ diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 619f3925..e8bc05ff 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -116,7 +116,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo try { const onComplete = (modelId: string) => { // Navigate to model group page - history.push(generatePath(routerPaths.modelGroup, { id: modelId })); + history.push(generatePath(routerPaths.model, { id: modelId })); notifications?.toasts.addSuccess({ title: mountReactNode( diff --git a/server/routes/model_router.ts b/server/routes/model_router.ts index 38c0d035..97a1f7d2 100644 --- a/server/routes/model_router.ts +++ b/server/routes/model_router.ts @@ -28,14 +28,15 @@ const modelSortQuerySchema = schema.oneOf([ ]); const modelStateSchema = schema.oneOf([ - schema.literal(MODEL_STATE.loadFailed), schema.literal(MODEL_STATE.loaded), - schema.literal(MODEL_STATE.loading), - schema.literal(MODEL_STATE.partiallyLoaded), schema.literal(MODEL_STATE.trained), - schema.literal(MODEL_STATE.uploaded), schema.literal(MODEL_STATE.unloaded), + schema.literal(MODEL_STATE.uploaded), schema.literal(MODEL_STATE.uploading), + schema.literal(MODEL_STATE.loading), + schema.literal(MODEL_STATE.partiallyLoaded), + schema.literal(MODEL_STATE.loadFailed), + schema.literal(MODEL_STATE.registerFailed), ]); const modelUploadBaseSchema = { diff --git a/server/routes/task_router.ts b/server/routes/task_router.ts index 23521af2..ec1ddea6 100644 --- a/server/routes/task_router.ts +++ b/server/routes/task_router.ts @@ -8,6 +8,11 @@ import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/co import { TaskService } from '../services'; import { TASK_API_ENDPOINT } from './constants'; +const taskSearchSortItemSchema = schema.oneOf([ + schema.literal('last_update_time-desc'), + schema.literal('last_update_time-asc'), +]); + export const taskRouter = (router: IRouter) => { router.get( { @@ -30,4 +35,42 @@ export const taskRouter = (router: IRouter) => { } } ); + router.get( + { + path: TASK_API_ENDPOINT, + validate: { + query: schema.object({ + from: schema.number(), + size: schema.number(), + sort: schema.maybe( + schema.oneOf([ + taskSearchSortItemSchema, + schema.arrayOf(taskSearchSortItemSchema, { maxSize: 1 }), + ]) + ), + model_id: schema.maybe(schema.string()), + task_type: schema.maybe(schema.string()), + state: schema.maybe(schema.string()), + }), + }, + }, + async (context, request) => { + const { model_id: modelId, task_type: taskType, sort, ...restQuery } = request.query; + try { + const body = await TaskService.search({ + client: context.core.opensearch.client, + modelId, + taskType, + sort: + typeof sort === 'string' + ? [sort] + : (sort as ['last_update_time-desc' | 'last_update_time-asc']), + ...restQuery, + }); + return opensearchDashboardsResponseFactory.ok({ body }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); }; diff --git a/server/services/task_service.ts b/server/services/task_service.ts index 45476da9..ec0233c5 100644 --- a/server/services/task_service.ts +++ b/server/services/task_service.ts @@ -19,7 +19,9 @@ */ import { IScopedClusterClient } from '../../../../src/core/server'; + import { TASK_BASE_API } from './utils/constants'; +import { generateTermQuery, generateMustQueries } from './utils/query'; export class TaskService { public static async getOne({ client, taskId }: { client: IScopedClusterClient; taskId: string }) { @@ -30,4 +32,55 @@ export class TaskService { }) ).body; } + + public static async search({ + client, + sort, + from, + size, + modelId, + taskType, + state, + }: { + client: IScopedClusterClient; + from: number; + size: number; + modelId?: string; + taskType?: string; + sort?: ['last_update_time-desc' | 'last_update_time-asc']; + state?: string; + }) { + const { + body: { hits }, + } = await client.asCurrentUser.transport.request({ + method: 'POST', + path: `${TASK_BASE_API}/_search`, + body: { + query: generateMustQueries([ + ...(modelId ? [generateTermQuery('model_id', modelId)] : []), + ...(taskType ? [generateTermQuery('task_type', taskType)] : []), + ...(state ? [generateTermQuery('state', state)] : []), + ]), + from, + size, + ...(sort + ? { + sort: sort.map((sorting) => { + const [field, direction] = sorting.split('-'); + return { + [field]: direction, + }; + }), + } + : {}), + }, + }); + return { + data: hits.hits.map(({ _id, _source }) => ({ + id: _id, + ..._source, + })), + total_tasks: hits.total.value, + }; + } } diff --git a/test/jest.config.js b/test/jest.config.js index 065a2f01..63666efe 100644 --- a/test/jest.config.js +++ b/test/jest.config.js @@ -28,4 +28,8 @@ module.exports = { '^.+\\.(js|tsx?)$': '/../../src/dev/jest/babel_transform.js', }, testEnvironment: 'jsdom', + globals: { + // Add this variable here, to avoid EuiDataGrid render failed. See more: https://github.com/opensearch-project/oui/blob/2229dd44ca4d1270b4b8d95c5ffbf5d99297a253/scripts/jest/setup/polyfills.js#L17 + _isJest: true, + }, }; diff --git a/test/mocks/handlers.ts b/test/mocks/handlers.ts index dc292be0..3fe67e5d 100644 --- a/test/mocks/handlers.ts +++ b/test/mocks/handlers.ts @@ -4,11 +4,14 @@ */ import { rest } from 'msw'; + +import { MODEL_AGGREGATE_API_ENDPOINT } from '../../server/routes/constants'; + import { modelConfig } from './data/model_config'; import { modelRepositoryResponse } from './data/model_repository'; import { modelHandlers } from './model_handlers'; import { modelAggregateResponse } from './data/model_aggregate'; -import { MODEL_AGGREGATE_API_ENDPOINT } from '../../server/routes/constants'; +import { taskHandlers } from './task_handlers'; export const handlers = [ rest.get('/api/ml-commons/model-repository', (req, res, ctx) => { @@ -21,4 +24,5 @@ export const handlers = [ rest.get(MODEL_AGGREGATE_API_ENDPOINT, (_req, res, ctx) => { return res(ctx.status(200), ctx.json(modelAggregateResponse)); }), + ...taskHandlers, ]; diff --git a/test/mocks/task_handlers.ts b/test/mocks/task_handlers.ts new file mode 100644 index 00000000..6e887f9e --- /dev/null +++ b/test/mocks/task_handlers.ts @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { rest } from 'msw'; + +import { TASK_API_ENDPOINT } from '../../server/routes/constants'; + +const tasks = [ + { + name: 'model1', + model_id: '1', + task_type: 'REGISTER_MODEL', + error: 'The artifact url is in valid', + }, +]; + +export const taskHandlers = [ + rest.get(TASK_API_ENDPOINT, (req, res, ctx) => { + const filteredData = tasks.filter((task) => { + const { + params: { model_id: modelId, task_type: taskType }, + } = req; + if (modelId && modelId !== task.model_id) { + return false; + } + if (taskType && taskType !== task.task_type) { + return false; + } + return true; + }); + const start = typeof req.params.from === 'number' ? req.params.from : 0; + const end = typeof req.params.size === 'number' ? start + req.params.size : filteredData.length; + + return res( + ctx.status(200), + ctx.json({ + data: filteredData.slice(start, end), + total_tasks: filteredData.length, + }) + ); + }), +]; From 635dffeda2cda47564bd7bee426fced514689315 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Thu, 11 May 2023 15:37:35 +0800 Subject: [PATCH 44/75] Feature/add id column for model versions table (#177) * feat: add ID column and hide by default Signed-off-by: Lin Wang * feat: remove Copy ID button in model version table action column Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../__tests__/model_version_table.test.tsx | 73 ++++++++++++++++++- .../model_version_table_row_actions.test.tsx | 27 +------ .../model_version_table.tsx | 55 +++++++++++++- .../model_version_table_row_actions.tsx | 18 +---- 4 files changed, 126 insertions(+), 47 deletions(-) diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx index da7a2e02..8413263d 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx @@ -5,11 +5,11 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; +import { within } from '@testing-library/dom'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionTable } from '../model_version_table'; import { MODEL_STATE } from '../../../../../common'; -import { within } from '@testing-library/dom'; const versions = [ { @@ -157,4 +157,75 @@ describe('', () => { }, 10 * 1000 ); + + it( + 'should show ID column after ID column checked', + async () => { + render(); + + await userEvent.click(screen.getByTestId('dataGridColumnSelectorButton')); + await userEvent.click( + within(screen.getByRole('dialog')).getByRole('switch', { checked: false, name: 'ID' }) + ); + + expect(screen.getByTestId('dataGridHeaderCell-id')).toBeInTheDocument(); + expect( + within(screen.getAllByTestId('dataGridRowCell')[4]).getByText('1') + ).toBeInTheDocument(); + }, + 10 * 1000 + ); + + it( + 'should call document.execCommand after ID column copy button clicked', + async () => { + const execCommandMock = jest.fn(); + const execCommandOrigin = document.execCommand; + document.execCommand = execCommandMock; + + render(); + const idCell = screen.getAllByTestId('dataGridRowCell')[4]; + + await userEvent.click(screen.getByTestId('dataGridColumnSelectorButton')); + await userEvent.click(screen.getByRole('switch', { checked: false, name: 'ID' })); + await userEvent.hover(idCell); + + expect(execCommandMock).not.toHaveBeenCalled(); + await userEvent.click(within(idCell).getByLabelText('Copy ID')); + expect(execCommandMock).toHaveBeenCalledWith('copy'); + + await userEvent.click(idCell); + + document.execCommand = execCommandOrigin; + }, + 10 * 1000 + ); + + it( + 'should call document.execCommand after ID column expand copy button clicked', + async () => { + const execCommandMock = jest.fn(); + const execCommandOrigin = document.execCommand; + document.execCommand = execCommandMock; + + render(); + const idCell = screen.getAllByTestId('dataGridRowCell')[4]; + + await userEvent.click(screen.getByTestId('dataGridColumnSelectorButton')); + await userEvent.click(screen.getByRole('switch', { checked: false, name: 'ID' })); + await userEvent.hover(idCell); + await userEvent.click( + within(idCell).getByTitle('Click or hit enter to interact with cell content') + ); + const copyButton = within(screen.getByRole('dialog')).getByText('Copy ID'); + + expect(execCommandMock).not.toHaveBeenCalled(); + await userEvent.click(copyButton); + expect(execCommandMock).toHaveBeenCalledWith('copy'); + expect(copyButton).toHaveTextContent('Copied'); + + document.execCommand = execCommandOrigin; + }, + 20 * 1000 + ); }); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx index 7cc0848b..213afbeb 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx @@ -5,28 +5,19 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; -import * as euiExports from '@elastic/eui'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionTableRowActions } from '../model_version_table_row_actions'; import { MODEL_STATE } from '../../../../../common'; -jest.mock('@elastic/eui', () => { - return { - __esModule: true, - ...jest.requireActual('@elastic/eui'), - }; -}); - describe('', () => { - it('should render actions icon and "Copy ID" and "Delete" button after clicked', async () => { + it('should render "actions icon" and "Delete" button after clicked', async () => { const user = userEvent.setup(); render(); expect(screen.getByLabelText('show actions')).toBeInTheDocument(); await user.click(screen.getByLabelText('show actions')); - expect(screen.getByText('Copy ID')).toBeInTheDocument(); expect(screen.getByText('Delete')).toBeInTheDocument(); }); @@ -73,20 +64,4 @@ describe('', () => { expect(screen.queryByText('Delete')).toBeNull(); }); }); - - it('should call copyToClipboard with "1" after "Copy ID" button clicked', async () => { - const copyToClipboardMock = jest - .spyOn(euiExports, 'copyToClipboard') - .mockImplementation(jest.fn()); - const user = userEvent.setup(); - render(); - - await user.click(screen.getByLabelText('show actions')); - - expect(copyToClipboardMock).not.toHaveBeenCalled(); - await user.click(screen.getByText('Copy ID')); - expect(copyToClipboardMock).toHaveBeenCalledWith('1'); - - copyToClipboardMock.mockRestore(); - }); }); diff --git a/public/components/model/model_versions_panel/model_version_table.tsx b/public/components/model/model_versions_panel/model_version_table.tsx index 29a4e2f4..43a6ea59 100644 --- a/public/components/model/model_versions_panel/model_version_table.tsx +++ b/public/components/model/model_versions_panel/model_version_table.tsx @@ -9,6 +9,11 @@ import { EuiDataGridProps, EuiTablePagination, EuiDataGridCellValueElementProps, + EuiIcon, + EuiDataGridColumn, + EuiCopy, + EuiButtonEmpty, + copyToClipboard, } from '@elastic/eui'; import { VersionTableDataItem } from '../types'; @@ -16,6 +21,23 @@ import { VersionTableDataItem } from '../types'; import { ModelVersionCell } from './model_version_cell'; import { ModelVersionTableRowActions } from './model_version_table_row_actions'; +const ExpandCopyIDButton = ({ textToCopy }: { textToCopy: string }) => { + const [isCopied, setIsCopied] = useState(false); + + return ( + { + copyToClipboard(textToCopy); + setIsCopied(true); + }} + iconType="copy" + style={{ width: 102 }} + > + {isCopied ? 'Copied' : 'Copy ID'} + + ); +}; + interface VersionTableProps extends Pick { tags: string[]; versions: VersionTableDataItem[]; @@ -29,7 +51,7 @@ export const ModelVersionTable = ({ pagination, totalVersionCount, }: VersionTableProps) => { - const columns = useMemo( + const columns = useMemo( () => [ { id: 'version', @@ -51,12 +73,37 @@ export const ModelVersionTable = ({ id: 'lastUpdated', displayAsText: 'Last updated', }, + { + id: 'id', + displayAsText: 'ID', + cellActions: [ + ({ rowIndex, isExpanded }) => { + const textToCopy = versions[rowIndex].id; + if (isExpanded) { + return ; + } + return ( + + {(copy) => ( + + )} + + ); + }, + ], + }, ...tags.map((tag) => ({ id: `tags.${tag}`, displayAsText: `Tag: ${tag}`, })), ], - [tags] + [tags, versions] ); const trailingControlColumns = useMemo( () => [ @@ -72,7 +119,9 @@ export const ModelVersionTable = ({ ], [versions] ); - const [visibleColumns, setVisibleColumns] = useState(() => columns.map(({ id }) => id)); + const [visibleColumns, setVisibleColumns] = useState(() => + columns.map(({ id }) => id).filter((columnId) => columnId !== 'id') + ); const columnVisibility = useMemo(() => ({ visibleColumns, setVisibleColumns }), [visibleColumns]); const renderCellValue = useCallback( diff --git a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx index e4511130..e08f37bf 100644 --- a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx +++ b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx @@ -4,13 +4,7 @@ */ import React, { useState, useCallback } from 'react'; -import { - EuiPopover, - EuiButtonIcon, - EuiContextMenuPanel, - EuiContextMenuItem, - copyToClipboard, -} from '@elastic/eui'; +import { EuiPopover, EuiButtonIcon, EuiContextMenuPanel, EuiContextMenuItem } from '@elastic/eui'; import { MODEL_STATE } from '../../../../common'; @@ -45,16 +39,6 @@ export const ModelVersionTableRowActions = ({ state, id }: { state: MODEL_STATE; { - copyToClipboard(id); - }} - style={{ padding: 8 }} - > - Copy ID - , ...(state === MODEL_STATE.registerFailed ? [ Date: Fri, 12 May 2023 12:15:11 +0800 Subject: [PATCH 45/75] Version details page mockup (#179) feat(ui): version details page mockup Signed-off-by: Yulong Ruan --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- common/constant.ts | 2 + public/apis/model.ts | 3 + .../__tests__/model_version_table.test.tsx | 2 +- .../model_version_status_detail.tsx | 4 +- .../__tests__/model_version.test.tsx | 7 +- .../model_version/model_version.tsx | 123 ++++++++++++++---- .../model_version/version_artifact.tsx | 34 +++++ .../model_version/version_callout.tsx | 83 ++++++++++++ .../model_version/version_details.tsx | 85 ++++++++++++ .../model_version/version_information.tsx | 34 +++++ .../components/model_version/version_tags.tsx | 34 +++++ test/mocks/model_handlers.ts | 43 +++++- 12 files changed, 420 insertions(+), 34 deletions(-) create mode 100644 public/components/model_version/version_artifact.tsx create mode 100644 public/components/model_version/version_callout.tsx create mode 100644 public/components/model_version/version_details.tsx create mode 100644 public/components/model_version/version_information.tsx create mode 100644 public/components/model_version/version_tags.tsx diff --git a/common/constant.ts b/common/constant.ts index 742b1475..2f828a38 100644 --- a/common/constant.ts +++ b/common/constant.ts @@ -7,3 +7,5 @@ export const ONE_MB = 1000 * 1000; export const ONE_GB = 1000 * ONE_MB; export const MAX_MODEL_CHUNK_SIZE = 10 * ONE_MB; + +export const DATE_FORMAT = 'MMM d, yyyy @ HH:mm:ss.SSS'; diff --git a/public/apis/model.ts b/public/apis/model.ts index 7a6742a2..753001da 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -16,6 +16,9 @@ import { InnerHttpProvider } from './inner_http_provider'; export interface ModelSearchItem { id: string; name: string; + // TODO: the new version details API may not have this field, because model description is on model group level + // we should fix this when integrating the new API changes + description: string; algorithm: string; model_state: MODEL_STATE; model_version: string; diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx index 8413263d..f94f7e24 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx @@ -198,7 +198,7 @@ describe('', () => { document.execCommand = execCommandOrigin; }, - 10 * 1000 + 20 * 1000 ); it( diff --git a/public/components/model/model_versions_panel/model_version_status_detail.tsx b/public/components/model/model_versions_panel/model_version_status_detail.tsx index d3475a67..5d3d0c5a 100644 --- a/public/components/model/model_versions_panel/model_version_status_detail.tsx +++ b/public/components/model/model_versions_panel/model_version_status_detail.tsx @@ -14,7 +14,7 @@ import { } from '@elastic/eui'; import { Link, generatePath } from 'react-router-dom'; -import { MODEL_STATE, routerPaths } from '../../../../common'; +import { DATE_FORMAT, MODEL_STATE, routerPaths } from '../../../../common'; import { APIProvider } from '../../../apis/api_provider'; import { renderTime } from '../../../utils'; @@ -163,7 +163,7 @@ export const ModelVersionStatusDetail = ({ - {timeTitle}: {renderTime(restProps[timeField], 'MMM d, yyyy @ HH:mm:ss.SSS')} + {timeTitle}: {renderTime(restProps[timeField], DATE_FORMAT)} {(state === MODEL_STATE.loadFailed || state === MODEL_STATE.registerFailed) && ( <> diff --git a/public/components/model_version/__tests__/model_version.test.tsx b/public/components/model_version/__tests__/model_version.test.tsx index 91d403ad..746f5c98 100644 --- a/public/components/model_version/__tests__/model_version.test.tsx +++ b/public/components/model_version/__tests__/model_version.test.tsx @@ -37,13 +37,16 @@ describe('', () => { setup(); expect(screen.getByTestId('modelVersionLoadingSpinner')).toBeInTheDocument(); + expect(screen.queryAllByTestId('ml-versionDetailsLoading')).toBeTruthy(); + await waitFor(() => { expect(screen.queryByTestId('modelVersionLoadingSpinner')).not.toBeInTheDocument(); + expect(screen.queryByTestId('ml-versionDetailsLoading')).not.toBeInTheDocument(); }); }); it('should display v1.0.1 and update location.pathname after version selected', async () => { - const mockRest = mockOffsetMethods(); + const mockReset = mockOffsetMethods(); const user = userEvent.setup(); setup(); @@ -60,6 +63,6 @@ describe('', () => { }); expect(location.pathname).toBe('/model-registry/model-version/2'); - mockRest(); + mockReset(); }); }); diff --git a/public/components/model_version/model_version.tsx b/public/components/model_version/model_version.tsx index 7b0c92df..4328da61 100644 --- a/public/components/model_version/model_version.tsx +++ b/public/components/model_version/model_version.tsx @@ -10,6 +10,10 @@ import { EuiFlexGroup, EuiFlexItem, EuiLoadingSpinner, + EuiSpacer, + EuiPanel, + EuiLoadingContent, + EuiTabbedContent, } from '@elastic/eui'; import { generatePath, useHistory, useParams } from 'react-router-dom'; @@ -17,10 +21,16 @@ import { useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; import { routerPaths } from '../../../common/router_paths'; import { VersionToggler } from './version_toggler'; +import { ModelVersionCallout } from './version_callout'; +import { MODEL_STATE } from '../../../common/model'; +import { ModelVersionDetails } from './version_details'; +import { ModelVersionInformation } from './version_information'; +import { ModelVersionArtifact } from './version_artifact'; +import { ModelVersionTags } from './version_tags'; export const ModelVersion = () => { const { id: modelId } = useParams<{ id: string }>(); - const { data: model } = useFetcher(APIProvider.getAPI('model').getOne, modelId); + const { data: model, loading } = useFetcher(APIProvider.getAPI('model').getOne, modelId); const [modelInfo, setModelInfo] = useState<{ version: string; name: string }>(); const history = useHistory(); const modelName = model?.name; @@ -51,34 +61,93 @@ export const ModelVersion = () => { }); }, [modelName, modelVersion]); - if (!modelInfo) { - return ; - } + const tabs = [ + { + id: 'version-information', + name: 'Version information', + content: loading ? ( + <> + + + + + + ) : ( + <> + + + + + + ), + }, + { + id: 'artifact-configuration', + name: 'Artifact and configuration', + content: loading ? ( + <> + + + + + + ) : ( + <> + + + + ), + }, + ]; + return ( <> - - {modelInfo.name} - - - -
    - } - rightSideGroupProps={{ - gutterSize: 'm', - }} - rightSideItems={[ - Register version, - Edit, - Deploy, - Delete, - ]} - /> + {!modelInfo ? ( + <> + + + + ) : ( + + {modelInfo.name} + + + +
    + } + rightSideGroupProps={{ + gutterSize: 'm', + }} + rightSideItems={[ + Register version, + Deploy, + Delete, + ]} + /> + )} + + + + {loading ? ( + + + + ) : ( + + )} + + ); }; diff --git a/public/components/model_version/version_artifact.tsx b/public/components/model_version/version_artifact.tsx new file mode 100644 index 00000000..f3c29cf1 --- /dev/null +++ b/public/components/model_version/version_artifact.tsx @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { + EuiButton, + EuiFlexGroup, + EuiFlexItem, + EuiHorizontalRule, + EuiPanel, + EuiSpacer, + EuiTitle, +} from '@elastic/eui'; + +export const ModelVersionArtifact = () => { + return ( + + + + +

    Artifact and configuration

    +
    +
    + + Edit + +
    + + +
    + ); +}; diff --git a/public/components/model_version/version_callout.tsx b/public/components/model_version/version_callout.tsx new file mode 100644 index 00000000..8e2206db --- /dev/null +++ b/public/components/model_version/version_callout.tsx @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useEffect } from 'react'; +import { EuiCallOut, EuiLoadingSpinner } from '@elastic/eui'; +import { MODEL_STATE } from '../../../common/model'; + +interface ModelVersionCalloutProps { + modelVersionId: string; + modelState: MODEL_STATE; +} + +const MODEL_STATE_MAPPING: { + [K in MODEL_STATE]?: { + title: React.ReactNode; + color: 'danger' | 'warning' | 'primary'; + iconType?: string; + }; +} = { + [MODEL_STATE.registerFailed]: { + title: 'Artifact upload failed', + color: 'danger' as const, + iconType: 'alert', + }, + [MODEL_STATE.loadFailed]: { + title: 'Deployment failed', + color: 'danger' as const, + iconType: 'alert', + }, + [MODEL_STATE.partiallyLoaded]: { + title: 'Model partially responding', + color: 'warning' as const, + iconType: 'alert', + }, + [MODEL_STATE.uploading]: { + title: ( + + + Model artifact upload in progress + + ), + color: 'primary' as const, + }, + [MODEL_STATE.loading]: { + title: ( + + + Model deployment in progress + + ), + color: 'primary' as const, + }, +}; + +export const ModelVersionCallout = ({ modelState, modelVersionId }: ModelVersionCalloutProps) => { + const calloutProps = MODEL_STATE_MAPPING[modelState]; + + useEffect(() => { + if (calloutProps) { + if (modelState === MODEL_STATE.loadFailed) { + // TODO: call task API to get the error details + } else if (modelState === MODEL_STATE.registerFailed) { + // TODO: call task API to get the error details + } + } + }, [modelVersionId, modelState, calloutProps]); + + if (!calloutProps) { + return null; + } + + return ( + + Error details: TODO + + ); +}; diff --git a/public/components/model_version/version_details.tsx b/public/components/model_version/version_details.tsx new file mode 100644 index 00000000..ac6b1795 --- /dev/null +++ b/public/components/model_version/version_details.tsx @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { + EuiFlexGroup, + EuiFlexItem, + EuiPanel, + EuiTitle, + EuiText, + EuiSpacer, + EuiCopy, + EuiIcon, +} from '@elastic/eui'; +import { renderTime } from '../../utils'; +import { DATE_FORMAT } from '../../../common/constant'; + +interface Props { + description?: string; + createdTime?: number; + lastUpdatedTime?: number; + modelId?: string; +} + +export const ModelVersionDetails = ({ + description, + createdTime, + lastUpdatedTime, + modelId, +}: Props) => { + return ( + + +

    Model description

    +
    + + {description || '-'} + + + +

    Version notes

    +
    + + TODO + + + + + +

    Owner

    +
    + TODO +
    + + +

    Created

    +
    + {createdTime ? renderTime(createdTime, DATE_FORMAT) : '-'} +
    + + +

    Last updated

    +
    + + {lastUpdatedTime ? renderTime(lastUpdatedTime, DATE_FORMAT) : '-'} + +
    + + +

    ID

    +
    + + {(copy) => ( + + {modelId ?? '-'} + + )} + +
    +
    +
    + ); +}; diff --git a/public/components/model_version/version_information.tsx b/public/components/model_version/version_information.tsx new file mode 100644 index 00000000..23794182 --- /dev/null +++ b/public/components/model_version/version_information.tsx @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + EuiFlexGroup, + EuiFlexItem, + EuiHorizontalRule, + EuiPanel, + EuiSpacer, + EuiTitle, + EuiButton, +} from '@elastic/eui'; +import React from 'react'; + +export const ModelVersionInformation = () => { + return ( + + + + +

    Version Information

    +
    +
    + + Edit + +
    + + +
    + ); +}; diff --git a/public/components/model_version/version_tags.tsx b/public/components/model_version/version_tags.tsx new file mode 100644 index 00000000..948fc84d --- /dev/null +++ b/public/components/model_version/version_tags.tsx @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + EuiFlexGroup, + EuiFlexItem, + EuiHorizontalRule, + EuiPanel, + EuiSpacer, + EuiTitle, + EuiButton, +} from '@elastic/eui'; +import React from 'react'; + +export const ModelVersionTags = () => { + return ( + + + + +

    Tags

    +
    +
    + + Edit + +
    + + +
    + ); +}; diff --git a/test/mocks/model_handlers.ts b/test/mocks/model_handlers.ts index f81f3f5e..68358c28 100644 --- a/test/mocks/model_handlers.ts +++ b/test/mocks/model_handlers.ts @@ -12,16 +12,55 @@ const models = [ id: '1', name: 'model1', model_version: '1.0.0', + description: 'model1 description', + created_time: 1683699467964, + last_registered_time: 1683699499632, + last_updated_time: 1683699499637, + model_config: { + all_config: '', + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + model_type: 'roberta', + }, + model_format: 'TORCH_SCRIPT', + model_state: 'REGISTERED', + total_chunks: 34, }, { id: '2', - model: 'model1', + name: 'model2', model_version: '1.0.1', + description: 'model2 description', + created_time: 1683699467964, + last_registered_time: 1683699499632, + last_updated_time: 1683699499637, + model_config: { + all_config: '', + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + model_type: 'roberta', + }, + model_format: 'TORCH_SCRIPT', + model_state: 'REGISTERED', + total_chunks: 34, }, { id: '3', - model: 'model2', + name: 'model3', model_version: '1.0.0', + description: 'model3 description', + created_time: 1683699467964, + last_registered_time: 1683699499632, + last_updated_time: 1683699499637, + model_config: { + all_config: '', + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + model_type: 'roberta', + }, + model_format: 'TORCH_SCRIPT', + model_state: 'DEPLOYED', + total_chunks: 34, }, ]; From b6e4dae198317359d62250aaf6667d5c41cfe145 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 12 May 2023 18:24:14 +0800 Subject: [PATCH 46/75] Feature/add details tab content in model group page (#176) * feat: change to use tab id store selected tab Signed-off-by: Lin Wang * feat: display - for empty description Signed-off-by: Lin Wang * feat: add bottom form action bar Signed-off-by: Lin Wang * feat: add name and description in details panel Signed-off-by: Lin Wang * feat: separate ModelNameField, ModelDescriptionField and ErrorCallOut to common folder Signed-off-by: Lin Wang * feat: add readonly and originalModelname to ModelNameField Signed-off-by: Lin Wang * feat: add readonly to ModelDescriptionField Signed-off-by: Lin Wang * feat: update error call out according new design Signed-off-by: Lin Wang * test: update model_handlers error field and name search Signed-off-by: Lin Wang * feat: change to use model name description field in common folder Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../forms/__tests__/error_call_out.test.tsx | 58 ++++++ .../model_description_field.test.tsx | 58 ++++++ .../forms/__tests__/model_name_field.test.tsx | 101 ++++++++++ .../forms}/error_call_out.tsx | 30 +-- public/components/common/forms/index.ts | 8 + .../common/forms/model_description_field.tsx | 58 ++++++ .../common/forms/model_name_field.tsx | 106 +++++++++++ public/components/common/index.ts | 1 + .../__tests__/bottom_form_action_bar.test.tsx | 65 +++++++ .../components/model/__tests__/model.test.tsx | 15 ++ .../__tests__/model_details_panel.test.tsx | 148 +++++++++++++++ .../__tests__/model_overview_card.test.tsx | 15 ++ .../model/bottom_form_action_bar.tsx | 86 +++++++++ public/components/model/model.tsx | 19 +- .../components/model/model_details_panel.tsx | 179 +++++++++++++++++- .../components/model/model_overview_card.tsx | 2 +- .../__tests__/model_versions_panel.test.tsx | 2 +- .../__tests__/model_version.test.tsx | 2 +- .../__tests__/version_toggler.test.tsx | 2 +- public/components/register_model/constants.ts | 5 +- .../register_model/model_details.tsx | 98 +--------- .../register_model/register_model.tsx | 28 +-- test/mocks/model_handlers.ts | 31 ++- 23 files changed, 983 insertions(+), 134 deletions(-) create mode 100644 public/components/common/forms/__tests__/error_call_out.test.tsx create mode 100644 public/components/common/forms/__tests__/model_description_field.test.tsx create mode 100644 public/components/common/forms/__tests__/model_name_field.test.tsx rename public/components/{register_model => common/forms}/error_call_out.tsx (68%) create mode 100644 public/components/common/forms/index.ts create mode 100644 public/components/common/forms/model_description_field.tsx create mode 100644 public/components/common/forms/model_name_field.tsx create mode 100644 public/components/model/__tests__/bottom_form_action_bar.test.tsx create mode 100644 public/components/model/__tests__/model_details_panel.test.tsx create mode 100644 public/components/model/bottom_form_action_bar.tsx diff --git a/public/components/common/forms/__tests__/error_call_out.test.tsx b/public/components/common/forms/__tests__/error_call_out.test.tsx new file mode 100644 index 00000000..59ac499c --- /dev/null +++ b/public/components/common/forms/__tests__/error_call_out.test.tsx @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +import React from 'react'; +import { screen, render } from '../../../../../test/test_utils'; + +import { ErrorCallOut } from '../error_call_out'; + +describe('', () => { + it('should NOT render call out if errors is empty', () => { + render(); + + expect(screen.queryByLabelText('Address errors in the form')).toBeNull(); + }); + + it('should render error call out if errors is not empty', () => { + render( + + ); + + expect(screen.getByLabelText('Address errors in the form')).toBeInTheDocument(); + expect(screen.getByText(/Name: Enter a name./)).toBeInTheDocument(); + expect(screen.getByText(/Name: Use a unique name./)).toBeInTheDocument(); + expect(screen.getByText(/File: Add a file./)).toBeInTheDocument(); + }); +}); diff --git a/public/components/common/forms/__tests__/model_description_field.test.tsx b/public/components/common/forms/__tests__/model_description_field.test.tsx new file mode 100644 index 00000000..afd3e107 --- /dev/null +++ b/public/components/common/forms/__tests__/model_description_field.test.tsx @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { useForm } from 'react-hook-form'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../test/test_utils'; + +import { ModelDescriptionField } from '../model_description_field'; + +const setup = (description: string = '', readOnly = false) => { + const DescriptionForm = () => { + const { control } = useForm({ + defaultValues: { + description, + }, + }); + return ; + }; + + render(); + + const input = screen.getByLabelText(/description/i); + + return { + input, + getHelpTextNode: () => input.nextSibling, + }; +}; + +describe('', () => { + it('should render "Description" title, content and "200 characters allowed" by default', () => { + const { input, getHelpTextNode } = setup(); + expect(screen.getByText(/Description/i)).toBeInTheDocument(); + expect(input).toHaveValue(''); + expect(getHelpTextNode()).toHaveTextContent('200 characters allowed'); + }); + + it('should display 200 characters and show "0 characters left" after input 201 characters', async () => { + const { input, getHelpTextNode } = setup(); + + await userEvent.type(input, 'x'.repeat(201)); + + expect(input.value).toHaveLength(200); + + expect(getHelpTextNode()).toHaveTextContent('0 characters left'); + }); + + it('should set textarea to readOnly and hide help text', async () => { + const { input, getHelpTextNode } = setup('foo', true); + + expect(input).toHaveAttribute('readonly'); + expect(getHelpTextNode()).not.toBeInTheDocument(); + }); +}); diff --git a/public/components/common/forms/__tests__/model_name_field.test.tsx b/public/components/common/forms/__tests__/model_name_field.test.tsx new file mode 100644 index 00000000..7bbaf90b --- /dev/null +++ b/public/components/common/forms/__tests__/model_name_field.test.tsx @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { useForm } from 'react-hook-form'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../test/test_utils'; + +import { ModelNameField } from '../model_name_field'; + +const setup = ({ + name = '', + readOnly = false, + originalModelName, +}: { name?: string; readOnly?: boolean; originalModelName?: string } = {}) => { + const NameForm = () => { + const { control, trigger } = useForm({ + mode: 'onChange', + defaultValues: { + name, + }, + }); + return ( + + ); + }; + + render(); + + const input = screen.getByLabelText(/name/i); + + return { + input, + getHelpTextNode: () => input.closest('.euiFormRow')?.querySelector('.euiFormHelpText'), + getErrorMessageNode: () => input.closest('.euiFormRow')?.querySelector('.euiFormErrorText'), + }; +}; + +describe('', () => { + it('should render "Name" title, content and "80 characters allowed" by default', () => { + const { input, getHelpTextNode } = setup(); + expect(screen.getByText('Name')).toBeInTheDocument(); + expect(input).toHaveValue(''); + expect(getHelpTextNode()).toHaveTextContent('80 characters allowed'); + }); + + it('should show 80 characters and "0 characters left" after 81 characters input', async () => { + const { input, getHelpTextNode } = setup(); + + await userEvent.type(input, 'x'.repeat(81)); + expect(input.value).toHaveLength(80); + expect(getHelpTextNode()).toHaveTextContent('0 characters left'); + }); + + it('should show "Name can not be empty" error message after name be cleared', async () => { + const { input, getErrorMessageNode } = setup({ name: '12345' }); + + await userEvent.clear(input); + + expect(getErrorMessageNode()).toHaveTextContent('Name can not be empty'); + }); + + it('should show "This name is already in use. Use a unique name for the model." after name duplicated', async () => { + const { input, getErrorMessageNode } = setup({ name: '12345' }); + + await userEvent.clear(input); + await userEvent.type(input, 'model1'); + // mock user blur + await userEvent.click(screen.getByText('Name')); + + expect(getErrorMessageNode()).toHaveTextContent( + 'This name is already in use. Use a unique name for the model.' + ); + }); + + it('should NOT show name duplicate error if changed name equal original name', async () => { + const { input, getErrorMessageNode } = setup({ name: 'model1', originalModelName: 'model1' }); + + await userEvent.clear(input); + await userEvent.type(input, 'model1'); + // mock user blur + await userEvent.click(screen.getByText('Name')); + + expect(getErrorMessageNode()).not.toBeInTheDocument(); + }); + + it('should set input to readOnly and hide help text', async () => { + const { input, getHelpTextNode } = setup({ name: 'foo', readOnly: true }); + + expect(input).toHaveAttribute('readonly'); + expect(getHelpTextNode()).not.toBeInTheDocument(); + }); +}); diff --git a/public/components/register_model/error_call_out.tsx b/public/components/common/forms/error_call_out.tsx similarity index 68% rename from public/components/register_model/error_call_out.tsx rename to public/components/common/forms/error_call_out.tsx index 84607aaa..4591f4d3 100644 --- a/public/components/register_model/error_call_out.tsx +++ b/public/components/common/forms/error_call_out.tsx @@ -5,24 +5,28 @@ import { EuiCallOut, EuiText } from '@elastic/eui'; import React, { useMemo } from 'react'; -import { useFormContext } from 'react-hook-form'; -import { FORM_ERRORS } from './constants'; +import { FieldErrors } from 'react-hook-form'; -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; - -export const ErrorCallOut = () => { - const form = useFormContext(); +interface ErrorCallOutProps { + formErrors: FieldErrors; + errorMessages: Array<{ + field: string; + type: string; + message: string; + }>; +} +export const ErrorCallOut = ({ formErrors, errorMessages }: ErrorCallOutProps) => { const errors = useMemo(() => { const messages: string[] = []; - Object.keys(form.formState.errors).forEach((errorField) => { - const error = form.formState.errors[errorField as keyof typeof form.formState.errors]; + Object.keys(formErrors).forEach((errorField) => { + const error = formErrors[errorField as keyof typeof formErrors]; // If form have: criteriaMode: 'all', error.types will be set a value // error.types will contain all the errors of each field // In this case, we will display all the errors in the callout if (error?.types) { Object.keys(error.types).forEach((k) => { - const errorMessage = FORM_ERRORS.find((e) => e.field === errorField && e.type === k); + const errorMessage = errorMessages.find((e) => e.field === errorField && e.type === k); if (errorMessage) { messages.push(errorMessage.message); } @@ -32,7 +36,7 @@ export const ErrorCallOut = () => { // to only produce the first error, even if a field has multiple errors. // In this case, error.types won't be set, and error.type and error.field represent the // first error - const errorMessage = FORM_ERRORS.find( + const errorMessage = errorMessages.find( (e) => e.field === errorField && e.type === error?.type ); if (errorMessage) { @@ -41,7 +45,7 @@ export const ErrorCallOut = () => { } }); return messages; - }, [form]); + }, [formErrors, errorMessages]); if (errors.length === 0) { return null; @@ -55,9 +59,9 @@ export const ErrorCallOut = () => { iconType="iInCircle" > -
      +
        {errors.map((e) => ( -
      • - {e}
      • +
      • {e}
      • ))}
      diff --git a/public/components/common/forms/index.ts b/public/components/common/forms/index.ts new file mode 100644 index 00000000..5039a354 --- /dev/null +++ b/public/components/common/forms/index.ts @@ -0,0 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from './model_description_field'; +export * from './model_name_field'; +export * from './error_call_out'; diff --git a/public/components/common/forms/model_description_field.tsx b/public/components/common/forms/model_description_field.tsx new file mode 100644 index 00000000..13c44f44 --- /dev/null +++ b/public/components/common/forms/model_description_field.tsx @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { EuiFormRow, EuiTextArea } from '@elastic/eui'; +import { Control, FieldPathByValue, useController } from 'react-hook-form'; + +interface ModeDescriptionFormData { + description: string; +} + +const DESCRIPTION_MAX_LENGTH = 200; + +interface ModelDescriptionFieldProps { + control: Control; + readOnly?: boolean; +} + +export const ModelDescriptionField = ({ + control, + readOnly, +}: ModelDescriptionFieldProps) => { + const descriptionFieldController = useController({ + name: 'description' as FieldPathByValue, + control, + }); + + const { ref: descriptionInputRef, ...descriptionField } = descriptionFieldController.field; + + return ( + + Description - optional + + } + isInvalid={Boolean(descriptionFieldController.fieldState.error)} + error={descriptionFieldController.fieldState.error?.message} + helpText={ + !readOnly && + `${Math.max(DESCRIPTION_MAX_LENGTH - descriptionField.value.length, 0)} characters ${ + descriptionField.value.length ? 'left' : 'allowed' + }.` + } + > + + + ); +}; diff --git a/public/components/common/forms/model_name_field.tsx b/public/components/common/forms/model_name_field.tsx new file mode 100644 index 00000000..633e9bee --- /dev/null +++ b/public/components/common/forms/model_name_field.tsx @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useRef } from 'react'; +import { EuiFieldText, EuiFormRow, EuiText } from '@elastic/eui'; +import { Control, FieldPathByValue, UseFormTrigger, useController } from 'react-hook-form'; + +import { APIProvider } from '../../../apis/api_provider'; + +export const MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR = 'duplicateName'; + +const NAME_MAX_LENGTH = 80; + +interface ModelNameFormData { + name: string; +} + +interface ModelNameFieldProps { + control: Control; + trigger: UseFormTrigger; + readOnly?: boolean; + originalModelName?: string; +} + +const isDuplicateModelName = async (name: string) => { + const searchResult = await APIProvider.getAPI('model').search({ + name, + from: 0, + size: 1, + }); + return searchResult.total_models >= 1; +}; + +export const ModelNameField = ({ + control, + trigger, + readOnly, + originalModelName, +}: ModelNameFieldProps) => { + const modelNameFocusedRef = useRef(false); + const nameFieldController = useController({ + name: 'name' as FieldPathByValue, + control, + rules: { + required: { value: true, message: 'Name can not be empty' }, + validate: { + [MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR]: async (name) => { + if ( + modelNameFocusedRef.current || + !name || + (originalModelName !== undefined && originalModelName === name) + ) { + return undefined; + } + const result = await isDuplicateModelName(name); + if (result) { + return 'This name is already in use. Use a unique name for the model.'; + } + return undefined; + }, + }, + }, + }); + + const { ref: nameInputRef, ...nameField } = nameFieldController.field; + + const handleModelNameFocus = useCallback(() => { + modelNameFocusedRef.current = true; + }, []); + + const handleModelNameBlur = useCallback(() => { + nameField.onBlur(); + modelNameFocusedRef.current = false; + trigger('name' as FieldPathByValue, {}); + }, [nameField, trigger]); + + return ( + + {Math.max(NAME_MAX_LENGTH - nameField.value.length, 0)} characters{' '} + {nameField.value.length ? 'left' : 'allowed'}. +
      + Use a unique name for the model. + + ) + } + > + +
      + ); +}; diff --git a/public/components/common/index.ts b/public/components/common/index.ts index 983d8f42..a4a67f3f 100644 --- a/public/components/common/index.ts +++ b/public/components/common/index.ts @@ -10,3 +10,4 @@ export * from './selected_tag_filter_panel'; export * from './debounced_search_bar'; export * from './tag_filter'; export * from './options_filter'; +export * from './forms'; diff --git a/public/components/model/__tests__/bottom_form_action_bar.test.tsx b/public/components/model/__tests__/bottom_form_action_bar.test.tsx new file mode 100644 index 00000000..47b9ae92 --- /dev/null +++ b/public/components/model/__tests__/bottom_form_action_bar.test.tsx @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../test/test_utils'; +import { BottomFormActionBar } from '../bottom_form_action_bar'; + +describe('', () => { + it('should display consistent unsaved changes and error count', () => { + render( + + ); + + expect(screen.getByText('2 unsaved change(s)')).toBeInTheDocument(); + expect(screen.getByText('1 error(s)')).toBeInTheDocument(); + }); + + it('should call onDiscardButtonClick when discard button is clicked', async () => { + const onDiscardButtonClickMock = jest.fn(); + render( + + ); + + expect(onDiscardButtonClickMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Discard change(s)')); + expect(onDiscardButtonClickMock).toHaveBeenCalledTimes(1); + }); + + it('should submit form after save button clicked', async () => { + const formId = 'test-form-id'; + const onSubmitMock = jest.fn((e) => e.preventDefault()); + render( +
      + + + ); + + expect(onSubmitMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Save')); + expect(onSubmitMock).toHaveBeenCalledTimes(1); + }); + + it('should disabled save button and show loading indicator', () => { + render( + + ); + + expect(screen.getByText('Save').closest('button')).toBeDisabled(); + expect(screen.getByText('Save').previousSibling).toHaveClass('euiLoadingSpinner'); + }); +}); diff --git a/public/components/model/__tests__/model.test.tsx b/public/components/model/__tests__/model.test.tsx index ab743271..1255d19e 100644 --- a/public/components/model/__tests__/model.test.tsx +++ b/public/components/model/__tests__/model.test.tsx @@ -66,4 +66,19 @@ describe('', () => { }, 10 * 1000 ); + + it( + 'should display model name in details tab', + async () => { + setup(); + + await waitFor(() => { + expect(screen.queryByTestId('model-group-loading-indicator')).toBeNull(); + }); + await userEvent.click(screen.getByRole('tab', { name: 'Details' })); + + expect(within(screen.getByRole('tabpanel')).getByDisplayValue('model1')).toBeInTheDocument(); + }, + 10 * 1000 + ); }); diff --git a/public/components/model/__tests__/model_details_panel.test.tsx b/public/components/model/__tests__/model_details_panel.test.tsx new file mode 100644 index 00000000..f2486c36 --- /dev/null +++ b/public/components/model/__tests__/model_details_panel.test.tsx @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { act, render, screen, within } from '../../../../test/test_utils'; +import { ModelDetailsPanel } from '../model_details_panel'; + +import * as PluginContext from '../../../../../../src/plugins/opensearch_dashboards_react/public'; + +// Cannot spyOn(PluginContext, 'useOpenSearchDashboards') directly as it results in error: +// TypeError: Cannot redefine property: useOpenSearchDashboards +// So we have to mock the entire module first as a workaround +jest.mock('../../../../../../src/plugins/opensearch_dashboards_react/public', () => { + return { + __esModule: true, + ...jest.requireActual('../../../../../../src/plugins/opensearch_dashboards_react/public'), + }; +}); + +describe('', () => { + it('should render edit button and name, description in read-only mode by default', () => { + render(); + + expect(screen.getByText('Edit')).toBeInTheDocument(); + expect(screen.getByDisplayValue('model-1')).toBeInTheDocument(); + expect(screen.getByDisplayValue('model-1')).toHaveAttribute('readonly'); + expect(screen.getByText('model-1 description')).toBeInTheDocument(); + expect(screen.getByDisplayValue('model-1 description')).toHaveAttribute('readonly'); + }); + + it('should turn edit mode on when edit button is clicked', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + expect(screen.getByText('Cancel')).toBeInTheDocument(); + + expect(screen.getByDisplayValue('model-1')).toBeInTheDocument(); + expect(screen.getByDisplayValue('model-1')).not.toHaveAttribute('readonly'); + expect(screen.getByText(/Use a unique name for the model./)).toBeInTheDocument(); + + expect(screen.getByText('model-1 description')).toBeInTheDocument(); + expect(screen.getByDisplayValue('model-1 description')).not.toHaveAttribute('readonly'); + }); + + it('should NOT allow type more than 80 characters and show 0 characters left for name', async () => { + render(); + + const nameInput = screen.getByLabelText(/Name/); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.clear(nameInput); + await userEvent.type(nameInput, 'x'.repeat(81)); + expect((nameInput as HTMLInputElement).value).toHaveLength(80); + expect( + within(nameInput.closest('.euiFormRow')!).getByText(/0 characters left./) + ).toBeInTheDocument(); + }); + + it('should NOT allow input characters more than 200 and show 0 characters left for description', async () => { + render(); + + const descriptionInput = screen.getByLabelText(/Description/); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.clear(descriptionInput); + await userEvent.type(descriptionInput, 'x'.repeat(201)); + expect((descriptionInput as HTMLInputElement).value).toHaveLength(200); + expect( + within(descriptionInput.closest('.euiFormRow')!).getByText(/0 characters left./) + ).toBeInTheDocument(); + }); + + it('should show unsaved changes count, error count, discard changes and save button after name input value cleared', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.clear(screen.getByDisplayValue('model-1')); + + expect(screen.getByText('1 error(s)')).toBeInTheDocument(); + expect(screen.getByText('1 unsaved change(s)')).toBeInTheDocument(); + expect(screen.getByText('Discard change(s)')).toBeInTheDocument(); + expect(screen.getByText('Save')).toBeInTheDocument(); + }); + + it('should reset to default name description after discard changes button clicked', async () => { + render(); + + const nameInput = screen.getByLabelText(/Name/); + const descriptionInput = screen.getByLabelText(/Description/); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.type(nameInput, 'updated'); + await userEvent.type(descriptionInput, 'description of model-1'); + + await userEvent.click(screen.getByText('Discard change(s)')); + expect(nameInput).toHaveValue('model-1'); + expect(descriptionInput).toHaveValue(''); + }); + + it('should show error callout after save button clicked', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.clear(screen.getByDisplayValue('model-1')); + await userEvent.click(screen.getByText('Save')); + + expect(screen.getByText('Address the following error(s) in the form')).toBeInTheDocument(); + expect(screen.getByText('Name: Enter a name.')).toBeInTheDocument(); + }); + + it('should call addSuccessToast after form submit successfully', async () => { + const addSuccessMock = jest.fn(); + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + jest.useFakeTimers(); + + const opensearchDashboardsMock = jest + .spyOn(PluginContext, 'useOpenSearchDashboards') + .mockReturnValue({ + services: { + notifications: { + toasts: { + addSuccess: addSuccessMock, + }, + }, + }, + }); + + render(); + + await user.click(screen.getByText('Edit')); + await user.type(screen.getByDisplayValue('model-1'), 'updated'); + await user.click(screen.getByText('Save')); + + // TODO: Remove it after integrated real model update API + await act(async () => { + jest.advanceTimersByTime(2000); + }); + + expect(addSuccessMock).toHaveBeenCalledTimes(1); + + jest.useRealTimers(); + opensearchDashboardsMock.mockRestore(); + }); +}); diff --git a/public/components/model/__tests__/model_overview_card.test.tsx b/public/components/model/__tests__/model_overview_card.test.tsx index 56809433..25e690ce 100644 --- a/public/components/model/__tests__/model_overview_card.test.tsx +++ b/public/components/model/__tests__/model_overview_card.test.tsx @@ -37,4 +37,19 @@ describe('', () => { within(screen.getByText('model-1-id')).getByTestId('copy-id-button') ).toBeInTheDocument(); }); + + it('should display "-" for empty description', () => { + render( + + ); + expect( + within(screen.getByText('Description').closest('dl')!).getByText('-') + ).toBeInTheDocument(); + }); }); diff --git a/public/components/model/bottom_form_action_bar.tsx b/public/components/model/bottom_form_action_bar.tsx new file mode 100644 index 00000000..6238132b --- /dev/null +++ b/public/components/model/bottom_form_action_bar.tsx @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { + EuiBottomBar, + EuiButton, + EuiButtonEmpty, + EuiFlexGroup, + EuiFlexItem, + EuiText, +} from '@elastic/eui'; +import useObservable from 'react-use/lib/useObservable'; +import { from } from 'rxjs'; + +import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; + +interface BottomFormActionBarProps { + formId: string; + errorCount?: number; + unSavedChangeCount?: number; + onDiscardButtonClick: () => void; + isSaveButtonLoading?: boolean; + isSaveButtonDisabled?: boolean; +} + +export const BottomFormActionBar = ({ + formId, + errorCount = 0, + unSavedChangeCount = 0, + onDiscardButtonClick, + isSaveButtonDisabled, + isSaveButtonLoading, +}: BottomFormActionBarProps) => { + const { + services: { chrome }, + } = useOpenSearchDashboards(); + const isLocked = useObservable(chrome?.getIsNavDrawerLocked$() ?? from([false])); + + return ( + + + + {errorCount > 0 && ( + + + {errorCount} error(s) + + + )} + {unSavedChangeCount > 0 && ( + + + {unSavedChangeCount} unsaved change(s) + + + )} + + + + + Discard change(s) + + + + + Save + + + + + + ); +}; diff --git a/public/components/model/model.tsx b/public/components/model/model.tsx index 870acd12..3e41eede 100644 --- a/public/components/model/model.tsx +++ b/public/components/model/model.tsx @@ -12,7 +12,7 @@ import { EuiTabbedContentTab, EuiText, } from '@elastic/eui'; -import React, { useState, useMemo } from 'react'; +import React, { useState, useMemo, useCallback } from 'react'; import { useParams } from 'react-router-dom'; import { useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; @@ -39,10 +39,11 @@ export const Model = () => { { name: 'Details', id: 'details', + // TODO: Add description property here content: ( <> - + ), }, @@ -57,9 +58,13 @@ export const Model = () => { ), }, ], - [modelId] + [modelId, data] ); - const [selectedTab, setSelectedTab] = useState(tabs[0]); + const [selectedTabId, setSelectedTabId] = useState(tabs[0].id); + + const handleTabClick = useCallback((tab: EuiTabbedContentTab) => { + setSelectedTabId(tab.id); + }, []); if (loading) { // TODO: need to update per design @@ -93,7 +98,11 @@ export const Model = () => { // TODO: Add description property here /> - + tab.id === selectedTabId)} + /> ); }; diff --git a/public/components/model/model_details_panel.tsx b/public/components/model/model_details_panel.tsx index f7aa2d2f..fb6c8298 100644 --- a/public/components/model/model_details_panel.tsx +++ b/public/components/model/model_details_panel.tsx @@ -3,17 +3,184 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React from 'react'; -import { EuiHorizontalRule, EuiPanel, EuiSpacer, EuiTitle } from '@elastic/eui'; +import React, { useState, useCallback, useEffect, useRef, useMemo } from 'react'; +import { + EuiButton, + EuiDescribedFormGroup, + EuiFlexGroup, + EuiFlexItem, + EuiForm, + EuiHorizontalRule, + EuiLink, + EuiPanel, + EuiSpacer, + EuiText, + EuiTitle, + htmlIdGenerator, +} from '@elastic/eui'; +import { useForm } from 'react-hook-form'; +import { generatePath, useHistory } from 'react-router-dom'; + +import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; +import { mountReactNode } from '../../../../../src/core/public/utils'; +import { routerPaths } from '../../../common'; +import { + ErrorCallOut, + MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR, + ModelDescriptionField, + ModelNameField, +} from '../common'; + +import { BottomFormActionBar } from './bottom_form_action_bar'; + +const formErrorMessages = [ + { + field: 'name', + type: 'required', + message: 'Name: Enter a name.', + }, + { + field: 'name', + type: MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR, + message: 'Name: Use a unique name.', + }, +]; + +interface ModelDetailsProps { + id: string; + name?: string; + description?: string; + onDetailsUpdate?: (formData: { name: string; description?: string }) => void; +} + +export const ModelDetailsPanel = ({ + id, + name, + description, + onDetailsUpdate, +}: ModelDetailsProps) => { + const formIdRef = useRef(htmlIdGenerator()()); + const history = useHistory(); + const { control, resetField, formState, handleSubmit, trigger } = useForm({ + mode: 'onChange', + defaultValues: { + name: '', + description: '', + }, + }); + const [isEditMode, setIsEditMode] = useState(false); + const { + services: { notifications }, + } = useOpenSearchDashboards(); + + // formState.errors won't change after formState updated, need to update errors object manually + const formErrors = useMemo(() => ({ ...formState.errors }), [formState]); + + const handleFormSubmit = useMemo( + () => + handleSubmit(async (formData) => { + // TODO: Just for mock form submit, need to use model update API after integrated + await new Promise((resolve) => { + window.setTimeout(resolve, 1000); + }); + notifications?.toasts.addSuccess({ + title: mountReactNode( + + Updated{' '} + { + history.push( + generatePath(routerPaths.model, { + id, + }) + ); + }} + > + {formData.name} + + + ), + }); + setIsEditMode(false); + onDetailsUpdate?.(formData); + resetField('name', { defaultValue: formData.name || '' }); + resetField('description', { defaultValue: formData.description || '' }); + }), + [id, history, notifications, resetField, handleSubmit, setIsEditMode, onDetailsUpdate] + ); + + const resetAllFields = useCallback(() => { + resetField('name', { defaultValue: name || '' }); + resetField('description', { defaultValue: description || '' }); + }, [name, description, resetField]); + + const handleEditClick = useCallback(() => { + setIsEditMode(true); + }, []); + + const handleCancelClick = useCallback(() => { + resetAllFields(); + setIsEditMode(false); + }, [resetAllFields]); + + useEffect(() => { + resetField('name', { defaultValue: name || '' }); + resetField('description', { defaultValue: description || '' }); + }, [name, description, resetField]); -export const ModelDetailsPanel = () => { return ( - -

      Details

      -
      + + + +

      Details

      +
      +
      + + + {isEditMode ? 'Cancel' : 'Edit'} + + +
      + + {formState.isSubmitted && ( + <> + + + + )} + + Name}> + + + + Description - optional + + } + description="Describe the model." + > + + + {(formState.dirtyFields.name || formState.dirtyFields.description) && ( + + )} +
      ); }; diff --git a/public/components/model/model_overview_card.tsx b/public/components/model/model_overview_card.tsx index ddb76a7f..3046d5ad 100644 --- a/public/components/model/model_overview_card.tsx +++ b/public/components/model/model_overview_card.tsx @@ -33,7 +33,7 @@ export const ModelOverviewCard = ({ listItems={[ { title: 'Description', - description: description || '', + description: description || '-', }, ]} /> diff --git a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx index 4449ef4f..5bb9c011 100644 --- a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx @@ -30,7 +30,7 @@ describe('', () => { await waitFor(() => { expect( screen.getByText((text, node) => { - return text === 'Versions' && !!node?.childNodes[1]?.textContent?.includes('(3)'); + return text === 'Versions' && !!node?.childNodes[1]?.textContent?.includes('(1)'); }) ).toBeInTheDocument(); }); diff --git a/public/components/model_version/__tests__/model_version.test.tsx b/public/components/model_version/__tests__/model_version.test.tsx index 746f5c98..f223c5b3 100644 --- a/public/components/model_version/__tests__/model_version.test.tsx +++ b/public/components/model_version/__tests__/model_version.test.tsx @@ -61,7 +61,7 @@ describe('', () => { await waitFor(() => { expect(screen.getByText('v1.0.1')).toBeInTheDocument(); }); - expect(location.pathname).toBe('/model-registry/model-version/2'); + expect(location.pathname).toBe('/model-registry/model-version/4'); mockReset(); }); diff --git a/public/components/model_version/__tests__/version_toggler.test.tsx b/public/components/model_version/__tests__/version_toggler.test.tsx index e0526c3e..9348a475 100644 --- a/public/components/model_version/__tests__/version_toggler.test.tsx +++ b/public/components/model_version/__tests__/version_toggler.test.tsx @@ -39,7 +39,7 @@ describe('', () => { await user.click(screen.getByText('1.0.1')); expect(onVersionChange).toHaveBeenCalledWith({ newVersion: '1.0.1', - newId: '2', + newId: '4', }); }); }); diff --git a/public/components/register_model/constants.ts b/public/components/register_model/constants.ts index 6c653d98..d0d8ac33 100644 --- a/public/components/register_model/constants.ts +++ b/public/components/register_model/constants.ts @@ -3,12 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ import { ONE_GB } from '../../../common/constant'; +import { MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR } from '../../components/common'; export const MAX_CHUNK_SIZE = 10 * 1000 * 1000; export const MAX_MODEL_FILE_SIZE = 4 * ONE_GB; export enum CUSTOM_FORM_ERROR_TYPES { - DUPLICATE_NAME = 'duplicateName', + DUPLICATE_NAME = MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR, FILE_SIZE_EXCEED_LIMIT = 'fileSizeExceedLimit', INVALID_CONFIGURATION = 'invalidConfiguration', CONFIGURATION_MISSING_MODEL_TYPE = 'configurationMissingModelType', @@ -25,7 +26,7 @@ export const FORM_ERRORS = [ }, { field: 'name', - type: CUSTOM_FORM_ERROR_TYPES.DUPLICATE_NAME, + type: MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR, message: 'Name: Use a unique name.', }, { diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 01594f93..7d1d60d4 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -3,108 +3,24 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useRef } from 'react'; -import { EuiFieldText, EuiFormRow, EuiTextArea, EuiText } from '@elastic/eui'; -import { useController, useFormContext } from 'react-hook-form'; -import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -import { APIProvider } from '../../apis/api_provider'; -import { CUSTOM_FORM_ERROR_TYPES } from './constants'; +import React from 'react'; +import { EuiText } from '@elastic/eui'; +import { useFormContext } from 'react-hook-form'; -const NAME_MAX_LENGTH = 80; -const DESCRIPTION_MAX_LENGTH = 200; +import { ModelNameField, ModelDescriptionField } from '../../components/common'; -const isDuplicateModelName = async (name: string) => { - const searchResult = await APIProvider.getAPI('model').search({ - name, - from: 0, - size: 1, - }); - return searchResult.total_models >= 1; -}; +import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; export const ModelDetailsPanel = () => { const { control, trigger } = useFormContext(); - const modelNameFocusedRef = useRef(false); - const nameFieldController = useController({ - name: 'name', - control, - rules: { - required: { value: true, message: 'Name can not be empty' }, - validate: { - [CUSTOM_FORM_ERROR_TYPES.DUPLICATE_NAME]: async (name) => { - return !modelNameFocusedRef.current && !!name && (await isDuplicateModelName(name)) - ? 'This name is already in use. Use a unique name for the model.' - : undefined; - }, - }, - }, - }); - - const descriptionFieldController = useController({ - name: 'description', - control, - }); - - const { ref: nameInputRef, ...nameField } = nameFieldController.field; - const { ref: descriptionInputRef, ...descriptionField } = descriptionFieldController.field; - - const handleModelNameFocus = useCallback(() => { - modelNameFocusedRef.current = true; - }, []); - - const handleModelNameBlur = useCallback(() => { - nameField.onBlur(); - modelNameFocusedRef.current = false; - trigger('name'); - }, [nameField, trigger]); return (

      Details

      - - {Math.max(NAME_MAX_LENGTH - nameField.value.length, 0)} characters{' '} - {nameField.value.length ? 'left' : 'allowed'}. -
      - Use a unique for the model. - - } - > - -
      - - Description - optional - - } - isInvalid={Boolean(descriptionFieldController.fieldState.error)} - error={descriptionFieldController.fieldState.error?.message} - helpText={`${Math.max( - DESCRIPTION_MAX_LENGTH - descriptionField.value.length, - 0 - )} characters ${descriptionField.value.length ? 'left' : 'allowed'}.`} - > - - + +
      ); }; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index e8bc05ff..88a7f49b 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useEffect, useState } from 'react'; +import React, { useCallback, useEffect, useState, useMemo } from 'react'; import { FieldErrors, useForm, FormProvider } from 'react-hook-form'; import { generatePath, useHistory, useParams } from 'react-router-dom'; import { @@ -24,25 +24,26 @@ import { import useObservable from 'react-use/lib/useObservable'; import { from } from 'rxjs'; -import { ModelDetailsPanel } from './model_details'; -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -import { ArtifactPanel } from './artifact'; -import { ConfigurationPanel } from './model_configuration'; -import { ModelTagsPanel } from './model_tags'; -import { submitModelWithFile, submitModelWithURL } from './register_model_api'; import { APIProvider } from '../../apis/api_provider'; import { upgradeModelVersion } from '../../utils'; import { useSearchParams } from '../../hooks/use_search_params'; import { isValidModelRegisterFormType } from './utils'; import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; import { mountReactNode } from '../../../../../src/core/public/utils'; -import { modelFileUploadManager } from './model_file_upload_manager'; -import { MAX_CHUNK_SIZE } from './constants'; import { routerPaths } from '../../../common/router_paths'; +import { ErrorCallOut } from '../../components/common'; +import { modelRepositoryManager } from '../../utils/model_repository_manager'; + import { modelTaskManager } from './model_task_manager'; import { ModelVersionNotesPanel } from './model_version_notes'; -import { modelRepositoryManager } from '../../utils/model_repository_manager'; -import { ErrorCallOut } from './error_call_out'; +import { modelFileUploadManager } from './model_file_upload_manager'; +import { MAX_CHUNK_SIZE, FORM_ERRORS } from './constants'; +import { ModelDetailsPanel } from './model_details'; +import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { ArtifactPanel } from './artifact'; +import { ConfigurationPanel } from './model_configuration'; +import { ModelTagsPanel } from './model_tags'; +import { submitModelWithFile, submitModelWithURL } from './register_model_api'; const DEFAULT_VALUES = { name: '', @@ -111,6 +112,9 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo criteriaMode: 'all', }); + // formState.errors won't change after formState updated, need to update errors object manually + const formErrors = useMemo(() => ({ ...form.formState.errors }), [form.formState]); + const onSubmit = useCallback( async (data: ModelFileFormData | ModelUrlFormData) => { try { @@ -312,7 +316,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo {isSubmitted && !form.formState.isValid && ( <> - + )} diff --git a/test/mocks/model_handlers.ts b/test/mocks/model_handlers.ts index 68358c28..8093ee67 100644 --- a/test/mocks/model_handlers.ts +++ b/test/mocks/model_handlers.ts @@ -62,11 +62,40 @@ const models = [ model_state: 'DEPLOYED', total_chunks: 34, }, + { + id: '4', + name: 'model1', + model_version: '1.0.1', + description: 'model1 version 1.0.1 description', + created_time: 1683699469964, + last_registered_time: 1683699599632, + last_updated_time: 1683699599637, + model_config: { + all_config: '', + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + model_type: 'roberta', + }, + model_format: 'TORCH_SCRIPT', + model_state: 'DEPLOYED', + total_chunks: 34, + }, ]; export const modelHandlers = [ rest.get(MODEL_API_ENDPOINT, (req, res, ctx) => { - const data = models.filter((model) => !req.params.name || model.name === req.params.name); + const { searchParams } = req.url; + const name = searchParams.get('name'); + const ids = searchParams.getAll('ids'); + const data = models.filter((model) => { + if (name) { + return model.name === name; + } + if (ids) { + return ids.includes(model.id); + } + return true; + }); return res( ctx.status(200), ctx.json({ From ec603baae76514e9edcf4514c1b9ba64ebab2e69 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Wed, 17 May 2023 10:23:56 +0800 Subject: [PATCH 47/75] feat(ui): version information edit component (#180) Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- public/apis/model.ts | 1 + .../model_version_notes_field.test.tsx | 41 +++++++++++++ .../forms/model_version_notes_field.tsx | 54 +++++++++++++++++ public/components/model/types.ts | 6 ++ .../__tests__/version_information.test.tsx | 58 +++++++++++++++++++ .../model_version/model_version.tsx | 19 +++++- public/components/model_version/types.ts | 21 +++++++ .../model_version/version_information.tsx | 51 +++++++++++++++- .../register_model/model_version_notes.tsx | 42 ++++---------- .../register_model/register_model.types.ts | 6 +- 10 files changed, 260 insertions(+), 39 deletions(-) create mode 100644 public/components/common/forms/__tests__/model_version_notes_field.test.tsx create mode 100644 public/components/common/forms/model_version_notes_field.tsx create mode 100644 public/components/model_version/__tests__/version_information.test.tsx create mode 100644 public/components/model_version/types.ts diff --git a/public/apis/model.ts b/public/apis/model.ts index 753001da..47645000 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -44,6 +44,7 @@ export interface ModelDetail extends ModelSearchItem { content: string; last_updated_time: number; created_time: number; + model_format: string; } export interface ModelSearchResponse { diff --git a/public/components/common/forms/__tests__/model_version_notes_field.test.tsx b/public/components/common/forms/__tests__/model_version_notes_field.test.tsx new file mode 100644 index 00000000..aea9c4d3 --- /dev/null +++ b/public/components/common/forms/__tests__/model_version_notes_field.test.tsx @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { useForm } from 'react-hook-form'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../test/test_utils'; +import { ModelVersionNotesField } from '../model_version_notes_field'; + +const TestApp = ({ readOnly = false }: { readOnly?: boolean }) => { + const form = useForm({ + defaultValues: { versionNotes: '' }, + }); + + return ( + + ); +}; + +describe('', () => { + it('should render a version notes textarea field', () => { + render(); + expect(screen.queryByRole('textbox')).toBeInTheDocument(); + expect(screen.getByRole('textbox')).toBeEnabled(); + }); + + it('should render a readonly version notes input', () => { + render(); + expect(screen.getByRole('textbox')).toBeDisabled(); + }); + + it('should only allow maximum 200 characters', async () => { + const user = userEvent.setup(); + render(); + await user.type(screen.getByRole('textbox'), 'x'.repeat(201)); + expect(screen.getByRole('textbox').value).toHaveLength(200); + }); +}); diff --git a/public/components/common/forms/model_version_notes_field.tsx b/public/components/common/forms/model_version_notes_field.tsx new file mode 100644 index 00000000..4c45340c --- /dev/null +++ b/public/components/common/forms/model_version_notes_field.tsx @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiFormRow, EuiTextArea } from '@elastic/eui'; +import { FieldPathByValue, useController } from 'react-hook-form'; +import type { Control } from 'react-hook-form'; + +interface VersionNotesFormData { + versionNotes?: string; +} + +interface Props { + label: React.ReactNode; + control: Control; + readOnly?: boolean; +} + +const VERSION_NOTES_MAX_LENGTH = 200; + +export const ModelVersionNotesField = ({ + control, + label, + readOnly = false, +}: Props) => { + const fieldController = useController({ + name: 'versionNotes' as FieldPathByValue, + control, + }); + const { ref, ...versionNotesField } = fieldController.field; + + return ( + + + + ); +}; diff --git a/public/components/model/types.ts b/public/components/model/types.ts index 77d1f06f..d063b98f 100644 --- a/public/components/model/types.ts +++ b/public/components/model/types.ts @@ -14,3 +14,9 @@ export interface VersionTableDataItem { tags: { [key: string]: string | number }; createdTime: number; } + +export interface Tag { + key: string; + value: string; + type: 'number' | 'string'; +} diff --git a/public/components/model_version/__tests__/version_information.test.tsx b/public/components/model_version/__tests__/version_information.test.tsx new file mode 100644 index 00000000..ecae38c4 --- /dev/null +++ b/public/components/model_version/__tests__/version_information.test.tsx @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import userEvent from '@testing-library/user-event'; +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import { render, screen } from '../../../../test/test_utils'; +import { ModelFileFormData, ModelUrlFormData } from '../types'; +import { ModelVersionInformation } from '../version_information'; + +const TestApp = () => { + const form = useForm({ + defaultValues: { versionNotes: 'test_version_notes' }, + }); + + return ( + + + + ); +}; + +describe('', () => { + it('should display version notes as readonly by default', () => { + render(); + expect(screen.getByLabelText('edit version notes')).toBeEnabled(); + expect(screen.getByDisplayValue('test_version_notes')).toBeDisabled(); + }); + + it('should allow to edit version notes after clicking edit button', async () => { + const user = userEvent.setup(); + render(); + expect(screen.getByDisplayValue('test_version_notes')).toBeDisabled(); + + await user.click(screen.getByLabelText('edit version notes')); + expect(screen.getByDisplayValue('test_version_notes')).toBeEnabled(); + expect(screen.getByLabelText('cancel edit version notes')).toBeInTheDocument(); + }); + + it('should reset the version notes changes and set the input to disabled after clicking cancel button', async () => { + const user = userEvent.setup(); + render(); + await user.click(screen.getByLabelText('edit version notes')); + const versionNotesInput = screen.getByLabelText('Version notes'); + expect(versionNotesInput).toBeEnabled(); + + await user.clear(versionNotesInput); + await user.type(versionNotesInput, 'new_test_version_notes'); + // version notes input updated + expect(screen.getByDisplayValue('new_test_version_notes')).toBeInTheDocument(); + + await user.click(screen.getByLabelText('cancel edit version notes')); + // reset to default value after clicking cancel button + expect(screen.getByDisplayValue('test_version_notes')).toBeDisabled(); + }); +}); diff --git a/public/components/model_version/model_version.tsx b/public/components/model_version/model_version.tsx index 4328da61..6353c4ce 100644 --- a/public/components/model_version/model_version.tsx +++ b/public/components/model_version/model_version.tsx @@ -17,6 +17,7 @@ import { } from '@elastic/eui'; import { generatePath, useHistory, useParams } from 'react-router-dom'; +import { FormProvider, useForm } from 'react-hook-form'; import { useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; import { routerPaths } from '../../../common/router_paths'; @@ -27,6 +28,7 @@ import { ModelVersionDetails } from './version_details'; import { ModelVersionInformation } from './version_information'; import { ModelVersionArtifact } from './version_artifact'; import { ModelVersionTags } from './version_tags'; +import { ModelFileFormData, ModelUrlFormData } from './types'; export const ModelVersion = () => { const { id: modelId } = useParams<{ id: string }>(); @@ -35,6 +37,7 @@ export const ModelVersion = () => { const history = useHistory(); const modelName = model?.name; const modelVersion = model?.model_version; + const form = useForm(); const onVersionChange = useCallback( ({ newVersion, newId }: { newVersion: string; newId: string }) => { @@ -61,6 +64,18 @@ export const ModelVersion = () => { }); }, [modelName, modelVersion]); + useEffect(() => { + if (model) { + form.reset({ + versionNotes: 'TODO', // TODO: read from model.versionNotes + tags: [], // TODO: read from model.tags + configuration: JSON.stringify(model.model_config), + modelFileFormat: model.model_format, + // TODO: read model url or model filename + }); + } + }, [model, form]); + const tabs = [ { id: 'version-information', @@ -101,7 +116,7 @@ export const ModelVersion = () => { ]; return ( - <> + {!modelInfo ? ( <> @@ -148,6 +163,6 @@ export const ModelVersion = () => { )} - + ); }; diff --git a/public/components/model_version/types.ts b/public/components/model_version/types.ts new file mode 100644 index 00000000..8d080e1d --- /dev/null +++ b/public/components/model_version/types.ts @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Tag } from '../model/types'; + +interface FormDataBase { + versionNotes?: string; + tags?: Tag[]; + configuration: string; + modelFileFormat: string; +} + +export interface ModelFileFormData extends FormDataBase { + modelFile: File; +} + +export interface ModelUrlFormData extends FormDataBase { + modelURL: string; +} diff --git a/public/components/model_version/version_information.tsx b/public/components/model_version/version_information.tsx index 23794182..25a66059 100644 --- a/public/components/model_version/version_information.tsx +++ b/public/components/model_version/version_information.tsx @@ -11,10 +11,27 @@ import { EuiSpacer, EuiTitle, EuiButton, + EuiText, } from '@elastic/eui'; -import React from 'react'; +import React, { useState, useCallback } from 'react'; +import { useFormContext, useFormState } from 'react-hook-form'; +import { ModelVersionNotesField } from '../common/forms/model_version_notes_field'; +import { ModelFileFormData, ModelUrlFormData } from './types'; export const ModelVersionInformation = () => { + const form = useFormContext(); + const formState = useFormState({ control: form.control }); + const [readOnly, setReadOnly] = useState(true); + + const onCancel = useCallback(() => { + form.resetField('versionNotes'); + setReadOnly(true); + }, [form]); + + // Whether edit button is disabled or not + // The edit button should be disabled if there were changes in other form fields + const isEditDisabled = formState.isDirty && !formState.dirtyFields.versionNotes; + return ( @@ -24,11 +41,41 @@ export const ModelVersionInformation = () => { - Edit + {readOnly ? ( + setReadOnly(false)} + > + Edit + + ) : ( + + Cancel + + )} + + + + +

      + Version notes - optional +

      + {"Describe what's new about this version."} +
      +
      + + + +
      ); }; diff --git a/public/components/register_model/model_version_notes.tsx b/public/components/register_model/model_version_notes.tsx index 23b2fd7c..69a4ce6e 100644 --- a/public/components/register_model/model_version_notes.tsx +++ b/public/components/register_model/model_version_notes.tsx @@ -4,47 +4,29 @@ */ import React from 'react'; -import { EuiText, EuiSpacer, EuiFormRow, EuiTextArea } from '@elastic/eui'; -import { useFormContext, useController } from 'react-hook-form'; +import { EuiText, EuiSpacer } from '@elastic/eui'; +import { useFormContext } from 'react-hook-form'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; - -const VERSION_NOTES_MAX_LENGTH = 200; +import { ModelVersionNotesField } from '../common/forms/model_version_notes_field'; export const ModelVersionNotesPanel = () => { const { control } = useFormContext(); - const fieldController = useController({ - name: 'versionNotes', - control, - }); - const { ref, ...versionNotesField } = fieldController.field; - return (
      -

      - Version notes - optional -

      +

      Version information

      - - - + + Version notes - optional{' '} + + } + />
      ); }; diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index 27064f9d..2dc44fa6 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -3,11 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -export interface Tag { - key: string; - value: string; - type: 'number' | 'string'; -} +import type { Tag } from '../model/types'; interface ModelFormBase { name: string; From 4df7a5139c75f23793ce3b63f777cf59b9aa24b3 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Thu, 18 May 2023 10:14:23 +0800 Subject: [PATCH 48/75] Feature/add tags tab content model group page (#181) * feat: separate tag_key render and types Signed-off-by: Lin Wang * feat: add model detail tags panel content Signed-off-by: Lin Wang * fix: new added tag input not aligned Signed-off-by: Lin Wang * refactor: use dirtyFields.tagKeys length directly Signed-off-by: Lin Wang * refactor: disable saved tag key by name Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- .../common/__tests__/tag_key.test.tsx | 15 ++ public/components/common/index.ts | 1 + .../tag_filter_popover_content.tsx | 28 +- public/components/common/tag_key.tsx | 40 +++ public/components/model/model.tsx | 3 +- public/components/model/model_tags_panel.tsx | 19 -- .../__tests__/model_saved_tag_key.test.tsx | 107 ++++++++ .../__tests__/model_tags_panel.test.tsx | 247 ++++++++++++++++++ .../model/model_tags_panel/index.ts | 6 + .../model_tags_panel/model_saved_tag_key.tsx | 113 ++++++++ .../model_tags_panel/model_tag_key_field.tsx | 245 +++++++++++++++++ .../model_tags_panel/model_tags_panel.tsx | 193 ++++++++++++++ public/components/model/types.ts | 5 + .../components/register_model/tag_field.tsx | 25 +- 14 files changed, 981 insertions(+), 66 deletions(-) create mode 100644 public/components/common/__tests__/tag_key.test.tsx create mode 100644 public/components/common/tag_key.tsx delete mode 100644 public/components/model/model_tags_panel.tsx create mode 100644 public/components/model/model_tags_panel/__tests__/model_saved_tag_key.test.tsx create mode 100644 public/components/model/model_tags_panel/__tests__/model_tags_panel.test.tsx create mode 100644 public/components/model/model_tags_panel/index.ts create mode 100644 public/components/model/model_tags_panel/model_saved_tag_key.tsx create mode 100644 public/components/model/model_tags_panel/model_tag_key_field.tsx create mode 100644 public/components/model/model_tags_panel/model_tags_panel.tsx diff --git a/public/components/common/__tests__/tag_key.test.tsx b/public/components/common/__tests__/tag_key.test.tsx new file mode 100644 index 00000000..e887a266 --- /dev/null +++ b/public/components/common/__tests__/tag_key.test.tsx @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { render, screen } from '../../../../test/test_utils'; +import { tagKeyOptionRenderer } from '../tag_key'; + +describe('tagKeyOptionRenderer', () => { + it('should render consistent name and type', () => { + render(tagKeyOptionRenderer({ label: 'F1', value: { name: 'F1', type: 'number' } }, '', '')); + expect(screen.getByText('F1')).toBeInTheDocument(); + expect(screen.getByText('Number')).toBeInTheDocument(); + }); +}); diff --git a/public/components/common/index.ts b/public/components/common/index.ts index a4a67f3f..ac1b4681 100644 --- a/public/components/common/index.ts +++ b/public/components/common/index.ts @@ -11,3 +11,4 @@ export * from './debounced_search_bar'; export * from './tag_filter'; export * from './options_filter'; export * from './forms'; +export * from './tag_key'; diff --git a/public/components/common/tag_filter_popover_content/tag_filter_popover_content.tsx b/public/components/common/tag_filter_popover_content/tag_filter_popover_content.tsx index 2319b845..62526825 100644 --- a/public/components/common/tag_filter_popover_content/tag_filter_popover_content.tsx +++ b/public/components/common/tag_filter_popover_content/tag_filter_popover_content.tsx @@ -11,19 +11,15 @@ import { EuiFormRow, EuiComboBox, EuiComboBoxOptionOption, - EuiToken, - EuiText, EuiSpacer, EuiButtonEmpty, EuiButton, EuiFieldNumber, } from '@elastic/eui'; -import { TagValueSelector } from './tag_value_selector'; -interface TagKey { - name: string; - type: 'string' | 'number'; -} +import { tagKeyOptionRenderer, TagKey } from '../tag_key'; + +import { TagValueSelector } from './tag_value_selector'; export enum TagFilterOperator { Is = 'is', @@ -137,24 +133,6 @@ export const TagFilterPopoverContent = ({ }, [selectedTagType]); const operator = selectedOperatorOptions[0]?.label as TagFilterOperator; - const tagKeyOptionRenderer = useCallback( - (option: EuiComboBoxOptionOption, _searchValue: string, contentClassName: string) => { - return ( -
      - - {option.label} - - {option.value?.type} - -
      - ); - }, - [] - ); - const handleSave = useCallback(() => { if (!selectedTag || !operator || !value) { return; diff --git a/public/components/common/tag_key.tsx b/public/components/common/tag_key.tsx new file mode 100644 index 00000000..c2f16054 --- /dev/null +++ b/public/components/common/tag_key.tsx @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiComboBoxOptionOption, EuiText, EuiToken } from '@elastic/eui'; + +export const MAX_TAG_LENGTH = 80; + +export const tagKeyTypeOptions = [ + { value: 'string', label: 'String' }, + { value: 'number', label: 'Number' }, +]; + +export type TagKeyType = 'string' | 'number'; + +export interface TagKey { + name: string; + type: TagKeyType; +} + +export const tagKeyOptionRenderer = ( + option: EuiComboBoxOptionOption, + _searchValue: string, + contentClassName: string +) => { + return ( +
      + + {option.label} + + {tagKeyTypeOptions.find((item) => item.value === option.value?.type)?.label} + +
      + ); +}; diff --git a/public/components/model/model.tsx b/public/components/model/model.tsx index 3e41eede..916677f4 100644 --- a/public/components/model/model.tsx +++ b/public/components/model/model.tsx @@ -50,10 +50,11 @@ export const Model = () => { { name: 'Tags', id: 'tags', + // TODO: Change tagKeys property from backend and update tagKeys after change content: ( <> - + {}} /> ), }, diff --git a/public/components/model/model_tags_panel.tsx b/public/components/model/model_tags_panel.tsx deleted file mode 100644 index e4d080a4..00000000 --- a/public/components/model/model_tags_panel.tsx +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React from 'react'; -import { EuiHorizontalRule, EuiPanel, EuiSpacer, EuiTitle } from '@elastic/eui'; - -export const ModelTagsPanel = () => { - return ( - - -

      Tags

      -
      - - -
      - ); -}; diff --git a/public/components/model/model_tags_panel/__tests__/model_saved_tag_key.test.tsx b/public/components/model/model_tags_panel/__tests__/model_saved_tag_key.test.tsx new file mode 100644 index 00000000..d2cb4037 --- /dev/null +++ b/public/components/model/model_tags_panel/__tests__/model_saved_tag_key.test.tsx @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../test/test_utils'; +import { ModelSavedTagKey } from '../model_saved_tag_key'; + +describe('', () => { + it('should render tag name and type be read-only by default', () => { + render( + + ); + + expect(screen.getByDisplayValue('Accuracy: Test')).toBeInTheDocument(); + expect(screen.getByDisplayValue('Accuracy: Test')).toHaveAttribute('readonly'); + expect(screen.getByDisplayValue('Number')).toBeInTheDocument(); + expect(screen.getByDisplayValue('Number')).toHaveAttribute('readonly'); + }); + + it('should render tag key and type label if showLabel equal true', () => { + render( + + ); + + expect(screen.getByText('Key')).toBeInTheDocument(); + expect(screen.getByText('Type')).toBeInTheDocument(); + }); + + it('should render remove button and show delete confirmation modal after button clicked', async () => { + render( + + ); + + expect(screen.getByLabelText('Remove saved tag key at row 1')).toBeInTheDocument(); + + await userEvent.click(screen.getByLabelText('Remove saved tag key at row 1')); + expect(screen.getByText('Delete tag key?')).toBeInTheDocument(); + expect( + screen.getByText( + "Deleting this tag key will remove the tag and the tag's values from all versions of this model. This action is irreversible." + ) + ).toBeInTheDocument(); + }); + + it('should call onRemove with index after confirm button clicked', async () => { + const onRemoveMock = jest.fn(); + render( + + ); + await userEvent.click(screen.getByLabelText('Remove saved tag key at row 1')); + + expect(onRemoveMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Delete tag key')); + expect(onRemoveMock).toHaveBeenCalledWith(0); + }); + + it('should hide confirmation modal after cancel button clicked', async () => { + const onRemoveMock = jest.fn(); + render( + + ); + await userEvent.click(screen.getByLabelText('Remove saved tag key at row 1')); + + await userEvent.click(screen.getByText('Cancel')); + expect(onRemoveMock).not.toHaveBeenCalled(); + expect(screen.queryByText('Delete tag key?')).toBeNull(); + }); +}); diff --git a/public/components/model/model_tags_panel/__tests__/model_tags_panel.test.tsx b/public/components/model/model_tags_panel/__tests__/model_tags_panel.test.tsx new file mode 100644 index 00000000..3686748b --- /dev/null +++ b/public/components/model/model_tags_panel/__tests__/model_tags_panel.test.tsx @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen, waitFor, within } from '../../../../../test/test_utils'; +import { ModelTagsPanel } from '../model_tags_panel'; + +describe('', () => { + it('should render saved tag keys in read-only mode and edit button by default', () => { + render( + + ); + + expect(screen.getByDisplayValue('Accuracy: Test')).toBeInTheDocument(); + expect(screen.getByDisplayValue('Accuracy: Test')).toHaveAttribute('readonly'); + expect(screen.getByDisplayValue('Number')).toBeInTheDocument(); + expect(screen.getByDisplayValue('Number')).toHaveAttribute('readonly'); + expect(screen.getByText('Edit')).toBeInTheDocument(); + }); + + it('should show remove saved tag key and add tag key button after edit button clicked', async () => { + render( + + ); + + await userEvent.click(screen.getByText('Edit')); + expect(screen.getByLabelText('Remove saved tag key at row 1')).toBeInTheDocument(); + }); + + it('should call onTagKeysChange after delete tag key confirmation modal confirmed', async () => { + const onTagKeysChangeMock = jest.fn(); + render( + + ); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByLabelText('Remove saved tag key at row 1')); + + expect(onTagKeysChangeMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Delete tag key')); + expect(onTagKeysChangeMock).toHaveBeenCalledWith([]); + }); + + it('should show "1 unsaved change(s)", discard button and save button after add tag key button clicked', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + + expect(screen.getByText('1 unsaved change(s)')).toBeInTheDocument(); + expect(screen.getByText('Discard change(s)')).toBeInTheDocument(); + expect(screen.getByText('Save')).toBeInTheDocument(); + }); + + it('should disable saved tag key in the new tag key dropdown', async () => { + render( + + ); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + expect(screen.getByRole('option', { name: 'Accuracy: test' })).toBeDisabled(); + }); + + it('should fill tag key type and set readonly after system exists tag key selected', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + + expect( + within(screen.getByTestId('modelTagKeyType1')).getByText('Select a type') + ).toBeInTheDocument(); + await userEvent.click(screen.getByText('Select or add a key')); + await userEvent.click(screen.getByRole('option', { name: 'F1' })); + expect( + within(screen.getByTestId('modelTagKeyType1')).getByDisplayValue('Number') + ).toBeInTheDocument(); + expect( + within(screen.getByTestId('modelTagKeyType1')).getByDisplayValue('Number') + ).toHaveAttribute('readonly'); + }); + + it('should show "Tag keys must be unique. Use a unique key." if tag key already exists', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + await userEvent.click(screen.getByText('Select or add a key')); + await userEvent.click(screen.getByRole('option', { name: 'F1' })); + + await userEvent.click(screen.getByText('Add tag key')); + await userEvent.click(screen.getByText('Select or add a key')); + await userEvent.click(screen.getByRole('option', { name: 'F1' })); + + expect( + within(screen.getByTestId('modelTagKeyName2')).getByText( + 'Tag keys must be unique. Use a unique key.' + ) + ).toBeInTheDocument(); + }); + + it('should show "Tag keys must be unique. Use a unique key." if outer saved tag keys update and already exists', async () => { + const { rerender } = render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + await userEvent.click(screen.getByText('Select or add a key')); + await userEvent.click(screen.getByRole('option', { name: 'F1' })); + + rerender( + + ); + + await waitFor(async () => { + expect( + within(screen.getByTestId('modelTagKeyName1')).getByText( + 'Tag keys must be unique. Use a unique key.' + ) + ).toBeInTheDocument(); + }); + }); + + it('should able to create not exists key and change tag type', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + + await userEvent.type(screen.getByText('Select or add a key'), 'not exists key{Enter}'); + expect(screen.getByText('not exists key')).toBeInTheDocument(); + expect(within(screen.getByTestId('modelTagKeyType1')).getByText('String')).toBeInTheDocument(); + + await userEvent.click(within(screen.getByTestId('modelTagKeyType1')).getByText('String')); + await userEvent.click(screen.getByRole('option', { name: 'Number' })); + expect(within(screen.getByTestId('modelTagKeyType1')).getByText('Number')).toBeInTheDocument(); + }); + + it('should show "80 characters allowed. Use 80 characters or less." if tag key name over 80 characters', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + + await userEvent.type(screen.getByText('Select or add a key'), `${'x'.repeat(81)}{Enter}`); + expect( + within(screen.getByTestId('modelTagKeyName1')).getByText( + '80 characters allowed. Use 80 characters or less.' + ) + ).toBeInTheDocument(); + }); + + it('should call onTagKeysChange with new added tag keys, hide save button and turn to readonly mode after save button clicked', async () => { + const onTagKeysChangeMock = jest.fn(); + render( + + ); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + await userEvent.click(screen.getByText('Select or add a key')); + await userEvent.click(screen.getByRole('option', { name: 'F1' })); + await userEvent.click(screen.getByText('Add tag key')); + await userEvent.type(screen.getByText('Select or add a key'), 'not exists key'); + + await userEvent.click(screen.getByText('Save')); + expect(onTagKeysChangeMock).toHaveBeenCalledWith([ + { name: 'Accuracy: test', type: 'number' }, + { name: 'F1', type: 'number' }, + { name: 'not exists key', type: 'string' }, + ]); + expect(screen.queryByText('Save')).toBeNull(); + expect(screen.queryByText('Edit')).toBeInTheDocument(); + }); + + it('should discard new added tag key and hide bottom bar after discard button clicked', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + await userEvent.click(screen.getByText('Select or add a key')); + await userEvent.click(screen.getByRole('option', { name: 'F1' })); + + const discardButton = screen.getByText('Discard change(s)'); + + await userEvent.click(discardButton); + expect(screen.queryByText('F1')).toBeNull(); + expect(discardButton).not.toBeInTheDocument(); + }); + + it('should hide bottom bar and remove new added tag key after remove one tag key button clicked', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + await userEvent.click(screen.getByText('Select or add a key')); + await userEvent.click(screen.getByRole('option', { name: 'F1' })); + await userEvent.click(screen.getByLabelText('Remove tag key at row 1')); + + expect(screen.queryByText('F1')).toBeNull(); + expect(screen.queryByText('Discard change(s)')).toBeNull(); + }); + + it('should show "You can add up to 0 more tags." and disable add more tag key button after 10 tag keys added', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + for (let i = 0; i < 10; i++) { + await userEvent.click(screen.getByText('Add tag key')); + } + expect(screen.getByText('You can add up to 0 more tags.')).toBeInTheDocument(); + expect(screen.getByText('Add tag key').closest('button')).toBeDisabled(); + }); + + it('should discard new tag key and turn to read only mode after cancel button clicked', async () => { + render(); + + await userEvent.click(screen.getByText('Edit')); + await userEvent.click(screen.getByText('Add tag key')); + await userEvent.click(screen.getByText('Select or add a key')); + await userEvent.click(screen.getByRole('option', { name: 'F1' })); + + await userEvent.click(screen.getByText('Cancel')); + expect(screen.queryByText('Discard change(s)')).toBeNull(); + expect(screen.queryByText('F1')).toBeNull(); + expect(screen.queryByText('Add tag key')).toBeNull(); + }); +}); diff --git a/public/components/model/model_tags_panel/index.ts b/public/components/model/model_tags_panel/index.ts new file mode 100644 index 00000000..c9358e38 --- /dev/null +++ b/public/components/model/model_tags_panel/index.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export { ModelTagsPanel } from './model_tags_panel'; diff --git a/public/components/model/model_tags_panel/model_saved_tag_key.tsx b/public/components/model/model_tags_panel/model_saved_tag_key.tsx new file mode 100644 index 00000000..717654c6 --- /dev/null +++ b/public/components/model/model_tags_panel/model_saved_tag_key.tsx @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useState } from 'react'; +import { + EuiButtonIcon, + EuiConfirmModal, + EuiFieldText, + EuiFlexGroup, + EuiFlexItem, + EuiFormRow, + EuiText, +} from '@elastic/eui'; + +import { tagKeyTypeOptions, TagKeyType } from '../../common'; + +const FORM_ITEM_WIDTH = 400; + +const TagKeyDeleteConfirmModal = ({ closeModal }: { closeModal: (confirmed: boolean) => void }) => { + const handleConfirm = useCallback(() => { + closeModal(true); + }, [closeModal]); + + const handleCancel = useCallback(() => { + closeModal(false); + }, [closeModal]); + + return ( + Delete tag key?} + cancelButtonText="Cancel" + confirmButtonText="Delete tag key" + maxWidth={500} + buttonColor="danger" + onConfirm={handleConfirm} + onCancel={handleCancel} + > + + Deleting this tag key will remove the tag and the tag's values from all versions of + this model. This action is irreversible. + + + ); +}; + +export const ModelSavedTagKey = ({ + name, + type, + index, + showRemoveButton, + onRemove, + showLabel, +}: { + name: string; + type: TagKeyType; + index: number; + showRemoveButton: boolean; + onRemove: (index: number) => void; + showLabel: boolean; +}) => { + const [isDeleteConfirmModalVisible, setIsDeleteConfirmModalVisible] = useState(false); + + const handleRemoveClick = useCallback(() => { + setIsDeleteConfirmModalVisible(true); + }, []); + + const closeModal = useCallback( + (confirmed: boolean) => { + setIsDeleteConfirmModalVisible(false); + if (confirmed) { + onRemove(index); + } + }, + [onRemove, index] + ); + + return ( + <> + + + + + + + + + item.value === type)?.label} + /> + + + {showRemoveButton && ( + + + + )} + + {isDeleteConfirmModalVisible && } + + ); +}; diff --git a/public/components/model/model_tags_panel/model_tag_key_field.tsx b/public/components/model/model_tags_panel/model_tag_key_field.tsx new file mode 100644 index 00000000..cefc6f29 --- /dev/null +++ b/public/components/model/model_tags_panel/model_tag_key_field.tsx @@ -0,0 +1,245 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useMemo } from 'react'; +import { + EuiButtonIcon, + EuiComboBox, + EuiComboBoxOptionOption, + EuiContext, + EuiFieldText, + EuiFlexGroup, + EuiFlexItem, + EuiFormRow, + EuiText, + EuiToken, +} from '@elastic/eui'; +import { Control, useController, useWatch } from 'react-hook-form'; + +import { TagKeyFormData } from '../types'; +import { TagKey, tagKeyTypeOptions } from '../../common'; + +const FORM_ITEM_WIDTH = 400; + +const MAX_TAG_LENGTH = 80; + +const KEY_COMBOBOX_I18N = { + mapping: { + 'euiComboBoxOptionsList.noAvailableOptions': 'No keys found. Add a key.', + }, +}; + +const singleSelection = { asPlainText: true }; + +interface ModelTagKeyFieldProps { + control: Control; + index: number; + onRemove: (index: number) => void; + allTagKeys: TagKey[]; + savedTagKeys: TagKey[]; + showLabel: boolean; +} + +export const ModelTagKeyField = ({ + control, + index, + onRemove, + allTagKeys, + showLabel, + savedTagKeys, +}: ModelTagKeyFieldProps) => { + const tagKeysInForm = useWatch({ + name: 'tagKeys', + control, + }); + const nameController = useController({ + name: `tagKeys.${index}.name`, + control, + rules: { + maxLength: { + value: MAX_TAG_LENGTH, + message: '80 characters allowed. Use 80 characters or less.', + }, + validate: (tagKey) => { + if (tagKeysInForm) { + // If a tag has key, validate if the same tag key was added before + if (tagKey) { + // Find if tag key already exists in saved tag key list, this is a rare case may caused by the saved tag keys updated + for (let i = 0; i < savedTagKeys.length; i++) { + if (savedTagKeys[i].name === tagKey) { + return 'Tag keys must be unique. Use a unique key.'; + } + } + // Find if the same tag key appears before the current tag key + for (let i = 0; i < index; i++) { + // If found the same tag key, then the current tag key is invalid + if (tagKeysInForm[i].name === tagKey) { + return 'Tag keys must be unique. Use a unique key.'; + } + } + } + } + return true; + }, + }, + }); + + const typeController = useController({ + name: `tagKeys.${index}.type`, + control, + }); + + const { ref: nameFieldRef, ...restNameFieldProps } = nameController.field; + const { ref: typeFieldRef, ...restTypeFieldProps } = typeController.field; + + const keyOptions = useMemo(() => { + const savedTagKeyMap = savedTagKeys.reduce<{ [key: string]: boolean }>( + (pValue, cValue) => ({ + ...pValue, + [cValue.name]: true, + }), + {} + ); + return allTagKeys.map((item) => ({ + label: item.name, + value: item, + disabled: savedTagKeyMap[item.name], + })); + }, [allTagKeys, savedTagKeys]); + + const keySelectedOptions = useMemo( + () => (nameController.field.value ? [{ label: nameController.field.value }] : []), + [nameController.field.value] + ); + + const typeSelectedOptions = useMemo(() => { + const option = tagKeyTypeOptions.find((item) => item.value === typeController.field.value); + return option && nameController.field.value ? [option] : []; + }, [typeController.field.value, nameController.field.value]); + + const isUsingSystemExistsTagKey = useMemo(() => { + return allTagKeys.findIndex((item) => item.name === nameController.field.value) !== -1; + }, [allTagKeys, nameController.field.value]); + + const handleRemoveClick = useCallback(() => { + onRemove(index); + }, [index, onRemove]); + + const handleNameKeyCreate = useCallback( + (value: string) => { + nameController.field.onChange(value); + typeController.field.onChange('string'); + }, + [nameController.field, typeController.field] + ); + + const handleNameChange = useCallback( + (data: Array>) => { + if (data[0]) { + nameController.field.onChange(data[0].label); + typeController.field.onChange(data[0].value?.type); + } else { + nameController.field.onChange(''); + typeController.field.onChange('string'); + } + }, + [nameController.field, typeController.field] + ); + + const handleValueChange = useCallback( + (data: EuiComboBoxOptionOption[]) => { + if (data?.[0].value) { + typeController.field.onChange(data[0].value); + } + }, + [typeController.field] + ); + + const renderKeyOption = useCallback( + (option: EuiComboBoxOptionOption, _searchValue: string, contentClassName: string) => { + return ( +
      + + {option.label} + + {option.value?.type} + +
      + ); + }, + [] + ); + + return ( + + + + + + placeholder="Select or add a key" + isInvalid={Boolean(nameController.fieldState.error)} + singleSelection={singleSelection} + options={keyOptions} + renderOption={renderKeyOption} + selectedOptions={keySelectedOptions} + onCreateOption={handleNameKeyCreate} + customOptionText="Add {searchValue} as a key." + inputRef={nameFieldRef} + {...restNameFieldProps} + onChange={handleNameChange} + /> + + + + + + {isUsingSystemExistsTagKey ? ( + item.value === restTypeFieldProps.value)?.label + } + /> + ) : ( + + )} + + + + + + + ); +}; diff --git a/public/components/model/model_tags_panel/model_tags_panel.tsx b/public/components/model/model_tags_panel/model_tags_panel.tsx new file mode 100644 index 00000000..7006b05b --- /dev/null +++ b/public/components/model/model_tags_panel/model_tags_panel.tsx @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useRef, useState, useCallback, useEffect } from 'react'; +import { + EuiButton, + EuiDescribedFormGroup, + EuiFlexGroup, + EuiFlexItem, + EuiForm, + EuiHorizontalRule, + EuiLink, + EuiPanel, + EuiSpacer, + EuiText, + EuiTitle, + htmlIdGenerator, +} from '@elastic/eui'; +import { useForm, useFieldArray } from 'react-hook-form'; + +import { TagKeyFormData } from '../types'; +import { BottomFormActionBar } from '../bottom_form_action_bar'; +import { TagKey } from '../../common'; +import { useModelTagKeys } from '../../model_list/model_list.hooks'; + +import { ModelTagKeyField } from './model_tag_key_field'; +import { ModelSavedTagKey } from './model_saved_tag_key'; + +const MAX_TAG_NUM = 10; + +interface ModelTagsPanelProps { + tagKeys: TagKey[]; + onTagKeysChange?: (tagKeys: TagKey[]) => void; +} + +export const ModelTagsPanel = ({ tagKeys, onTagKeysChange }: ModelTagsPanelProps) => { + const formIdRef = useRef(htmlIdGenerator()()); + const [isEditMode, setIsEditMode] = useState(false); + // TODO: change to real fetch all tags from backend in API integration phase + const [, allTagKeys] = useModelTagKeys(); + + const { control, resetField, handleSubmit, formState, trigger } = useForm({ + mode: 'onChange', + defaultValues: { + tagKeys: [], + }, + }); + const unSavedChangeCount = (formState.dirtyFields.tagKeys || []).length; + + const { fields, append, remove } = useFieldArray({ + name: 'tagKeys', + control, + }); + + const totalTagKeys = fields.length + tagKeys.length; + + const clearUnSavedChanges = useCallback(() => { + resetField('tagKeys', { defaultValue: [] }); + }, [resetField]); + + const handleFormSubmit = useMemo( + () => + handleSubmit(async (formData) => { + // TODO: Add model tags update to backend logic here + clearUnSavedChanges(); + onTagKeysChange?.([...tagKeys, ...formData.tagKeys]); + setIsEditMode(false); + }), + [tagKeys, onTagKeysChange, clearUnSavedChanges, handleSubmit] + ); + + const handleEditClick = useCallback(() => { + setIsEditMode(true); + }, []); + + const handleCancelClick = useCallback(() => { + setIsEditMode(false); + clearUnSavedChanges(); + }, [clearUnSavedChanges]); + + const handleAddClick = useCallback(() => { + append({ name: '', type: 'string' }); + }, [append]); + + const handleSavedTagRemove = useCallback( + (index: number) => { + onTagKeysChange?.([...tagKeys.slice(0, index), ...tagKeys.slice(index + 1)]); + }, + [onTagKeysChange, tagKeys] + ); + + useEffect(() => { + trigger('tagKeys'); + }, [tagKeys, trigger]); + + return ( + <> + + + + +

      Tags

      +
      + + Tags help your organization discover, compare, and track information related to + models. + +
      + + + {isEditMode ? 'Cancel' : 'Edit'} + + +
      + + + + + + Tag keys - optional + + } + description={ + <> + Manage the tag keys for all versions of this model. Deleting a tag key will remove + the tag from all versions. Adding a tag key will make the key available for all + versions. {/* TODO: fill out link address once confirmed */} + Learn more + + } + descriptionFlexItemProps={{ style: { maxWidth: 372 } }} + > + + {tagKeys.map(({ name, type }, index) => ( + + + + ))} + {fields.map((field, index) => ( + + + + ))} + {isEditMode && ( + +
      + = MAX_TAG_NUM} + fullWidth={false} + > + Add tag key + + + + {`You can add up to ${MAX_TAG_NUM - totalTagKeys} more tags.`} + +
      +
      + )} +
      +
      +
      +
      + {unSavedChangeCount > 0 && ( + + )} + + ); +}; diff --git a/public/components/model/types.ts b/public/components/model/types.ts index d063b98f..537fa9d4 100644 --- a/public/components/model/types.ts +++ b/public/components/model/types.ts @@ -4,6 +4,7 @@ */ import { MODEL_STATE } from '../../../common'; +import { TagKey } from '../common'; export interface VersionTableDataItem { id: string; @@ -20,3 +21,7 @@ export interface Tag { value: string; type: 'number' | 'string'; } + +export interface TagKeyFormData { + tagKeys: TagKey[]; +} diff --git a/public/components/register_model/tag_field.tsx b/public/components/register_model/tag_field.tsx index 484a5d85..5c1bd957 100644 --- a/public/components/register_model/tag_field.tsx +++ b/public/components/register_model/tag_field.tsx @@ -12,12 +12,13 @@ import { EuiContext, EuiButtonIcon, EuiFieldNumber, - EuiText, - EuiToken, EuiToolTip, } from '@elastic/eui'; import React, { useCallback, useMemo, useRef } from 'react'; import { useController, useWatch, useFormContext } from 'react-hook-form'; + +import { tagKeyOptionRenderer } from '../common'; + import { FORM_ITEM_WIDTH } from './form_constants'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { TagTypePopover } from './tag_type_popover'; @@ -215,24 +216,6 @@ export const ModelTagField = ({ [trigger] ); - const renderOption = useCallback( - (option: EuiComboBoxOptionOption, searchValue: string, contentClassName: string) => { - return ( -
      - - {option.label} - - {option.value?.type} - -
      - ); - }, - [] - ); - const onRemove = useCallback( (idx: number) => { if (tags?.length && tags.length > 1) { @@ -260,7 +243,7 @@ export const ModelTagField = ({ isInvalid={Boolean(tagKeyController.fieldState.error)} singleSelection={{ asPlainText: true }} options={keyOptions} - renderOption={renderOption} + renderOption={tagKeyOptionRenderer} selectedOptions={ tagKeyController.field.value ? [{ label: tagKeyController.field.value }] : [] } From 9aa8391b42765f45b820b38248f1d2505258f916 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Mon, 22 May 2023 13:53:26 +0800 Subject: [PATCH 49/75] feat: version tag edit (#182) On model version page, a new version tag edit panel is add to allow user to manage tags of a model version --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../model_tag_array_field.test.tsx | 85 +++++++ .../model_tag_array_field/tag_field.test.tsx | 218 ++++++++++++++++++ .../tag_type_popover.test.tsx | 4 +- .../model_version_notes_field.test.tsx | 9 +- .../forms/model_tag_array_field/index.ts | 6 + .../model_tag_array_field.tsx | 67 ++++++ .../model_tag_array_field}/tag_field.tsx | 202 +++++++++------- .../tag_type_popover.tsx | 0 .../forms/model_version_notes_field.tsx | 21 +- .../__tests__/model_tags.test.tsx | 92 ++++++++ .../__tests__/model_version.test.tsx | 51 +++- .../__tests__/version_information.test.tsx | 6 +- .../model_version/model_version.tsx | 6 +- .../model_version/version_information.tsx | 33 ++- .../components/model_version/version_tags.tsx | 71 +++++- .../__tests__/register_model_tags.test.tsx | 139 +---------- .../components/register_model/model_tags.tsx | 37 +-- .../register_model/model_version_notes.tsx | 5 - .../register_model/register_model.tsx | 2 +- 19 files changed, 768 insertions(+), 286 deletions(-) create mode 100644 public/components/common/forms/__tests__/model_tag_array_field/model_tag_array_field.test.tsx create mode 100644 public/components/common/forms/__tests__/model_tag_array_field/tag_field.test.tsx rename public/components/{register_model/__tests__ => common/forms/__tests__/model_tag_array_field}/tag_type_popover.test.tsx (89%) create mode 100644 public/components/common/forms/model_tag_array_field/index.ts create mode 100644 public/components/common/forms/model_tag_array_field/model_tag_array_field.tsx rename public/components/{register_model => common/forms/model_tag_array_field}/tag_field.tsx (61%) rename public/components/{register_model => common/forms/model_tag_array_field}/tag_type_popover.tsx (100%) create mode 100644 public/components/model_version/__tests__/model_tags.test.tsx diff --git a/public/components/common/forms/__tests__/model_tag_array_field/model_tag_array_field.test.tsx b/public/components/common/forms/__tests__/model_tag_array_field/model_tag_array_field.test.tsx new file mode 100644 index 00000000..fcab299a --- /dev/null +++ b/public/components/common/forms/__tests__/model_tag_array_field/model_tag_array_field.test.tsx @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import userEvent from '@testing-library/user-event'; + +import { Tag } from '../../../../../components/model/types'; +import { ModelTagArrayField } from '../../model_tag_array_field'; +import { TagGroup } from '../../model_tag_array_field/tag_field'; +import { render, screen } from '../../../../../../test/test_utils'; + +const TEST_TAG_GROUPS = [ + { name: 'Key1', type: 'string' as const, values: ['Value1'] }, + { name: 'Key2', type: 'number' as const, values: [0.95] }, +]; + +const TestApp = ({ + allowKeyCreate = true, + tagGroups = TEST_TAG_GROUPS, + defaultValues = { + tags: [ + { key: 'DefaultKey1', value: 'DefaultValue1', type: 'string' }, + { key: 'DefaultKey2', value: '1', type: 'number' }, + ], + }, + readOnly = false, +}: { + index?: number; + onDelete?: (index: number) => void; + allowKeyCreate?: boolean; + tagGroups?: TagGroup[]; + defaultValues?: { tags: Tag[] }; + readOnly?: boolean; +}) => { + const form = useForm({ + mode: 'onChange', + defaultValues, + }); + + return ( + + + + ); +}; + +describe('', () => { + it('should display the default tags', () => { + render(); + expect(screen.getAllByTestId('ml-versionTagRow')).toHaveLength(2); + expect(screen.getByDisplayValue('DefaultKey1')).toBeInTheDocument(); + expect(screen.getByDisplayValue('DefaultKey2')).toBeInTheDocument(); + }); + + it('should allow to add multiple tags', async () => { + const user = userEvent.setup(); + render(); + + // Add two tags + await user.click(screen.getByRole('button', { name: /add tag/i })); + await user.click(screen.getByRole('button', { name: /add tag/i })); + + // two new tag fields + two default tags + expect(screen.getAllByTestId('ml-versionTagRow')).toHaveLength(4); + }); + + it('should delete tag from the tag list', async () => { + const user = userEvent.setup(); + render(); + expect(screen.getAllByTestId('ml-versionTagRow')).toHaveLength(2); + + await user.click(screen.getByLabelText(/remove tag at row 1/i)); + expect(screen.getAllByTestId('ml-versionTagRow')).toHaveLength(1); + expect(screen.getByDisplayValue('DefaultKey2')).toBeInTheDocument(); + expect(screen.queryByDisplayValue('DefaultKey1')).not.toBeInTheDocument(); + }); + + it('should not allow to add new tag if it is readOnly', () => { + render(); + expect(screen.queryByRole('button', { name: /add tag/i })).not.toBeInTheDocument(); + }); +}); diff --git a/public/components/common/forms/__tests__/model_tag_array_field/tag_field.test.tsx b/public/components/common/forms/__tests__/model_tag_array_field/tag_field.test.tsx new file mode 100644 index 00000000..3993e1fe --- /dev/null +++ b/public/components/common/forms/__tests__/model_tag_array_field/tag_field.test.tsx @@ -0,0 +1,218 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import userEvent from '@testing-library/user-event'; +import { Tag } from '../../../../../components/model/types'; +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import { render, screen, within } from '../../../../../../test/test_utils'; +import { TagField, TagGroup } from '../../model_tag_array_field/tag_field'; + +const TEST_TAG_GROUPS = [ + { name: 'Key1', type: 'string' as const, values: ['Value1'] }, + { name: 'Key2', type: 'number' as const, values: [0.95] }, +]; + +const TestApp = ({ + index = 0, + onDelete = jest.fn(), + allowKeyCreate = true, + tagGroups = TEST_TAG_GROUPS, + defaultValues = { tags: [{ key: '', value: '', type: 'string' }] }, + readOnly = false, +}: { + index?: number; + onDelete?: (index: number) => void; + allowKeyCreate?: boolean; + tagGroups?: TagGroup[]; + defaultValues?: { tags: Tag[] }; + readOnly?: boolean; +}) => { + const form = useForm({ + mode: 'onChange', + defaultValues, + }); + + return ( + + + + ); +}; + +describe('', () => { + it('should render tag field with key and value input', () => { + render(); + const keyContainer = screen.queryByTestId('ml-tagKey1'); + const valueContainer = screen.queryByTestId('ml-tagValue1'); + + expect(keyContainer).toBeInTheDocument(); + expect(valueContainer).toBeInTheDocument(); + }); + + it('tag value input should be disabled if tag key is empty', async () => { + const user = userEvent.setup(); + render(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + expect(valueInput).toBeDisabled(); + + await user.type(keyInput, 'Key1{enter}'); + expect(valueInput).toBeEnabled(); + }); + + it('should display tag key and value as readOnly', () => { + render( + + ); + expect(screen.getByDisplayValue('Key1')).toHaveAttribute('readonly'); + expect(screen.getByDisplayValue('Value1')).toHaveAttribute('readonly'); + }); + + it('should NOT display delete button if tag field is readOnly', () => { + render( + + ); + expect(screen.queryByLabelText(/remove tag at row/i)).not.toBeInTheDocument(); + }); + + it('should NOT allow to edit the tag key if the tag is a default tag', () => { + render( + + ); + // tag key should be readonly + expect(screen.getByDisplayValue('Key1')).toHaveAttribute('readonly'); + // but we allow to edit tag value + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + expect(valueInput).not.toHaveAttribute('readonly'); + expect(valueInput).toBeEnabled(); + }); + + it('should NOT display value type prepend', () => { + render( + + ); + // tag key should be readonly + expect(screen.getByDisplayValue('Key1')).toHaveAttribute('readonly'); + // but we allow to edit tag value + const valueContainer = screen.getByTestId('ml-tagValue1'); + const prepend = within(valueContainer).queryByText('String'); + expect(prepend).not.toBeInTheDocument(); + }); + + it('should display error when creating new tag key with more than 80 characters', async () => { + const user = userEvent.setup(); + render(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await user.type(keyInput, `${'x'.repeat(81)}{enter}`); + expect( + within(keyContainer).queryByText('80 characters allowed. Use 80 characters or less.') + ).toBeInTheDocument(); + }); + + it('should display error when creating new tag value with more than 80 characters', async () => { + const user = userEvent.setup(); + render(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await user.type(keyInput, 'dummy key{enter}'); + + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + await user.type(valueInput, `${'x'.repeat(81)}{enter}`); + expect( + within(valueContainer).queryByText('80 characters allowed. Use 80 characters or less.') + ).toBeInTheDocument(); + }); + + it('should display "No keys found" and "No values found" if no tag keys and no tag values are provided', async () => { + const user = userEvent.setup(); + render(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await user.click(keyInput); + expect(screen.getByText('No keys found. Add a key.')).toBeInTheDocument(); + + await user.type(keyInput, 'dummy key{enter}'); + + const valueContainer = screen.getByTestId('ml-tagValue1'); + const valueInput = within(valueContainer).getByRole('textbox'); + await user.click(valueInput); + expect(screen.getByText('No values found. Add a value.')).toBeInTheDocument(); + }); + + it('should NOT display "Key1" in the option list after "Key1" selected', async () => { + const user = userEvent.setup(); + render(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const keyInput = within(keyContainer).getByRole('textbox'); + await user.click(keyInput); + const optionListContainer = screen.getByTestId('comboBoxOptionsList'); + expect(within(optionListContainer).getByTitle('Key1')).toBeInTheDocument(); + + await user.click(within(optionListContainer).getByTitle('Key1')); + expect(within(optionListContainer).queryByTitle('Key1')).toBe(null); + }); + + it('should not allow to select tag type if selected an existed tag', async () => { + const user = userEvent.setup(); + render(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const valueContainer = screen.getByTestId('ml-tagValue1'); + const keyInput = within(keyContainer).getByRole('textbox'); + + await user.click(keyInput); + // selected an existed tag + await user.click(within(screen.getByTestId('comboBoxOptionsList')).getByTitle('Key1')); + + expect(within(valueContainer).queryByLabelText('select tag type')).not.toBeInTheDocument(); + }); + + it('should display a list of tag value for selection after selecting a tag key', async () => { + const user = userEvent.setup(); + render(); + + const keyContainer = screen.getByTestId('ml-tagKey1'); + const valueContainer = screen.getByTestId('ml-tagValue1'); + const keyInput = within(keyContainer).getByRole('textbox'); + const valueInput = within(valueContainer).getByRole('textbox'); + + await user.click(keyInput); + // selected an existed tag + await user.click(within(screen.getByTestId('comboBoxOptionsList')).getByTitle('Key1')); + + await user.click(valueInput); + expect( + within(screen.getByTestId('comboBoxOptionsList')).queryByTitle('Value1') + ).toBeInTheDocument(); + }); +}); diff --git a/public/components/register_model/__tests__/tag_type_popover.test.tsx b/public/components/common/forms/__tests__/model_tag_array_field/tag_type_popover.test.tsx similarity index 89% rename from public/components/register_model/__tests__/tag_type_popover.test.tsx rename to public/components/common/forms/__tests__/model_tag_array_field/tag_type_popover.test.tsx index 87030bf2..e24e2140 100644 --- a/public/components/register_model/__tests__/tag_type_popover.test.tsx +++ b/public/components/common/forms/__tests__/model_tag_array_field/tag_type_popover.test.tsx @@ -6,8 +6,8 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; -import { render, screen } from '../../../../test/test_utils'; -import { TagTypePopover } from '../tag_type_popover'; +import { render, screen } from '../../../../../../test/test_utils'; +import { TagTypePopover } from '../../model_tag_array_field/tag_type_popover'; describe('', () => { it('should display tag type popover when clicking', async () => { diff --git a/public/components/common/forms/__tests__/model_version_notes_field.test.tsx b/public/components/common/forms/__tests__/model_version_notes_field.test.tsx index aea9c4d3..f5cd0e2b 100644 --- a/public/components/common/forms/__tests__/model_version_notes_field.test.tsx +++ b/public/components/common/forms/__tests__/model_version_notes_field.test.tsx @@ -4,7 +4,7 @@ */ import React from 'react'; -import { useForm } from 'react-hook-form'; +import { FormProvider, useForm } from 'react-hook-form'; import userEvent from '@testing-library/user-event'; import { render, screen } from '../../../../../test/test_utils'; @@ -16,7 +16,9 @@ const TestApp = ({ readOnly = false }: { readOnly?: boolean }) => { }); return ( - + + + ); }; @@ -25,11 +27,12 @@ describe('', () => { render(); expect(screen.queryByRole('textbox')).toBeInTheDocument(); expect(screen.getByRole('textbox')).toBeEnabled(); + expect(screen.getByRole('textbox')).not.toHaveAttribute('readonly'); }); it('should render a readonly version notes input', () => { render(); - expect(screen.getByRole('textbox')).toBeDisabled(); + expect(screen.getByRole('textbox')).toHaveAttribute('readonly'); }); it('should only allow maximum 200 characters', async () => { diff --git a/public/components/common/forms/model_tag_array_field/index.ts b/public/components/common/forms/model_tag_array_field/index.ts new file mode 100644 index 00000000..67113894 --- /dev/null +++ b/public/components/common/forms/model_tag_array_field/index.ts @@ -0,0 +1,6 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from './model_tag_array_field'; diff --git a/public/components/common/forms/model_tag_array_field/model_tag_array_field.tsx b/public/components/common/forms/model_tag_array_field/model_tag_array_field.tsx new file mode 100644 index 00000000..22d2b780 --- /dev/null +++ b/public/components/common/forms/model_tag_array_field/model_tag_array_field.tsx @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { EuiButton, EuiSpacer, EuiText } from '@elastic/eui'; +import { Tag } from '../../../model/types'; +import { TagField, TagGroup } from './tag_field'; + +interface Props { + tags: TagGroup[]; + readOnly?: boolean; + allowKeyCreate?: boolean; + maxTagNum?: number; +} + +export const ModelTagArrayField = ({ + tags, + readOnly = false, + allowKeyCreate = true, + maxTagNum = 10, +}: Props) => { + const { control } = useFormContext<{ tags?: Tag[] }>(); + const { fields, append, remove } = useFieldArray({ + name: 'tags', + control, + }); + + const addNewTag = useCallback(() => { + append({ key: '', value: '', type: 'string' }); + }, [append]); + + return ( + <> + {fields.map((field, index) => { + return ( + + ); + })} + + {!readOnly && ( + <> + = maxTagNum} + onClick={addNewTag} + > + Add tag + + + + {`You can add up to ${maxTagNum - fields.length} more tags.`} + + + )} + + ); +}; diff --git a/public/components/register_model/tag_field.tsx b/public/components/common/forms/model_tag_array_field/tag_field.tsx similarity index 61% rename from public/components/register_model/tag_field.tsx rename to public/components/common/forms/model_tag_array_field/tag_field.tsx index 5c1bd957..bdec5bda 100644 --- a/public/components/register_model/tag_field.tsx +++ b/public/components/common/forms/model_tag_array_field/tag_field.tsx @@ -13,27 +13,27 @@ import { EuiButtonIcon, EuiFieldNumber, EuiToolTip, + EuiFieldText, } from '@elastic/eui'; import React, { useCallback, useMemo, useRef } from 'react'; -import { useController, useWatch, useFormContext } from 'react-hook-form'; +import { useController, useWatch, useFormContext, useFormState } from 'react-hook-form'; -import { tagKeyOptionRenderer } from '../common'; - -import { FORM_ITEM_WIDTH } from './form_constants'; -import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { tagKeyOptionRenderer } from '../../../common'; +import { Tag } from '../../../model/types'; import { TagTypePopover } from './tag_type_popover'; -interface TagGroup { +export interface TagGroup { name: string; type: 'string' | 'number'; values: string[] | number[]; } -interface ModelTagFieldProps { +interface TagFieldProps { index: number; onDelete: (index: number) => void; allowKeyCreate?: boolean; tagGroups: TagGroup[]; + readOnly?: boolean; } const MAX_TAG_LENGTH = 80; @@ -58,14 +58,16 @@ function getComboBoxValue(data: Array>) { } } -export const ModelTagField = ({ +export const TagField = ({ index, tagGroups, allowKeyCreate, onDelete, -}: ModelTagFieldProps) => { + readOnly, +}: TagFieldProps) => { const rowEleRef = useRef(null); - const { trigger, control } = useFormContext(); + const { trigger, control } = useFormContext<{ tags?: Tag[] }>(); + const { defaultValues } = useFormState({ control }); const tags = useWatch({ control, name: 'tags', @@ -123,6 +125,16 @@ export const ModelTagField = ({ }, }); + // we don't allow to change tag type and tag key if it is a default tag + const isDefaultTag = useMemo(() => { + if (defaultValues && defaultValues.tags) { + return Boolean( + defaultValues.tags.find((t) => t?.key && t.key === tagKeyController.field.value) + ); + } + return false; + }, [defaultValues, tagKeyController]); + const selectedTagGroup = useMemo( () => tagGroups.find((t) => t.name === tagKeyController.field.value), [tagGroups, tagKeyController] @@ -228,9 +240,70 @@ export const ModelTagField = ({ [tags, onDelete, tagKeyController.field, tagValueController.field] ); + const createValueField = () => { + if (readOnly) { + return ; + } + + if (tagTypeController.field.value === 'string') { + const prepend = + selectedTagGroup || !tagKeyController.field.value ? ( + 'String' + ) : ( + + ); + + return ( + + ); + } + + const prepend = + selectedTagGroup || !tagKeyController.field.value ? ( + 'Number' + ) : ( + + ); + + return ( + + ); + }; + return ( - - + + - - placeholder="Select or add a key" - isInvalid={Boolean(tagKeyController.fieldState.error)} - singleSelection={{ asPlainText: true }} - options={keyOptions} - renderOption={tagKeyOptionRenderer} - selectedOptions={ - tagKeyController.field.value ? [{ label: tagKeyController.field.value }] : [] - } - onChange={onKeyChange} - onCreateOption={allowKeyCreate ? onKeyCreate : undefined} - customOptionText="Add {searchValue} as a key." - onBlur={tagKeyController.field.onBlur} - inputRef={tagKeyController.field.ref} - /> + {readOnly || isDefaultTag ? ( + + ) : ( + + placeholder="Select or add a key" + isInvalid={Boolean(tagKeyController.fieldState.error)} + singleSelection={{ asPlainText: true }} + options={keyOptions} + renderOption={tagKeyOptionRenderer} + selectedOptions={ + tagKeyController.field.value ? [{ label: tagKeyController.field.value }] : [] + } + onChange={onKeyChange} + onCreateOption={allowKeyCreate ? onKeyCreate : undefined} + customOptionText="Add {searchValue} as a key." + onBlur={tagKeyController.field.onBlur} + inputRef={tagKeyController.field.ref} + /> + )} - + - {tagTypeController.field.value === 'string' ? ( - - ) - } - placeholder="Select or add a value" - isInvalid={Boolean(tagValueController.fieldState.error)} - singleSelection={{ asPlainText: true }} - options={valueOptions} - selectedOptions={ - tagValueController.field.value ? [{ label: tagValueController.field.value }] : [] - } - onChange={onStringValueChange} - onCreateOption={onValueCreate} - customOptionText="Add {searchValue} as a value." - onBlur={tagValueController.field.onBlur} - inputRef={tagValueController.field.ref} - isDisabled={!Boolean(tagKeyController.field.value)} - /> - ) : ( - - ) - } - placeholder="Add a value" - value={tagValueController.field.value} - isInvalid={Boolean(tagValueController.fieldState.error)} - onChange={onNumberValueChange} - onBlur={tagValueController.field.onBlur} - inputRef={tagValueController.field.ref} - disabled={!Boolean(tagKeyController.field.value)} - /> - )} + {createValueField()} - - 1 ? 'Remove' : 'Clear'}> - onRemove(index)} - /> - - + {!readOnly && ( + + 1 ? 'Remove' : 'Clear'}> + onRemove(index)} + /> + + + )} ); }; diff --git a/public/components/register_model/tag_type_popover.tsx b/public/components/common/forms/model_tag_array_field/tag_type_popover.tsx similarity index 100% rename from public/components/register_model/tag_type_popover.tsx rename to public/components/common/forms/model_tag_array_field/tag_type_popover.tsx diff --git a/public/components/common/forms/model_version_notes_field.tsx b/public/components/common/forms/model_version_notes_field.tsx index 4c45340c..f477822f 100644 --- a/public/components/common/forms/model_version_notes_field.tsx +++ b/public/components/common/forms/model_version_notes_field.tsx @@ -5,28 +5,19 @@ import React from 'react'; import { EuiFormRow, EuiTextArea } from '@elastic/eui'; -import { FieldPathByValue, useController } from 'react-hook-form'; -import type { Control } from 'react-hook-form'; +import { useController, useFormContext } from 'react-hook-form'; -interface VersionNotesFormData { - versionNotes?: string; -} - -interface Props { +interface Props { label: React.ReactNode; - control: Control; readOnly?: boolean; } const VERSION_NOTES_MAX_LENGTH = 200; -export const ModelVersionNotesField = ({ - control, - label, - readOnly = false, -}: Props) => { +export const ModelVersionNotesField = ({ label, readOnly = false }: Props) => { + const { control } = useFormContext<{ versionNotes?: string }>(); const fieldController = useController({ - name: 'versionNotes' as FieldPathByValue, + name: 'versionNotes', control, }); const { ref, ...versionNotesField } = fieldController.field; @@ -42,7 +33,7 @@ export const ModelVersionNotesField = ({ label={label} > { + const form = useForm({ + mode: 'onChange', + defaultValues, + }); + + return ( + + + + ); +}; + +describe('', () => { + it('should display tags as readonly by default', () => { + render(); + expect(screen.getByDisplayValue('DefaultKey1')).toHaveAttribute('readonly'); + expect(screen.getByDisplayValue('DefaultKey2')).toHaveAttribute('readonly'); + expect(screen.getByDisplayValue('DefaultValue1')).toHaveAttribute('readonly'); + expect(screen.getByDisplayValue('0.85')).toHaveAttribute('readonly'); + }); + + it('should enable tag list editing after clicking edit button', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit tags')); + // Edit button becomes Cancel button + expect(screen.getByLabelText('cancel edit tags')).toBeInTheDocument(); + + // Add tag button should be displayed + expect(screen.getByRole('button', { name: /add tag/i })).toBeInTheDocument(); + + // Delete button should be displayed + expect(screen.queryAllByLabelText(/^remove tag at row .*$/i)).toHaveLength(2); + }); + + it('should reset tag array field after clicking cancel button', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit tags')); + await user.click(screen.getByRole('button', { name: /add tag/i })); + + const keyContainer = screen.getByTestId('ml-tagKey3'); + const keyInput = within(keyContainer).getByRole('textbox'); + + const valueContainer = screen.getByTestId('ml-tagValue3'); + const valueInput = within(valueContainer).getByRole('textbox'); + + // add a new tag + await user.type(keyInput, 'Key1{enter}'); + await user.type(valueInput, 'Value1{enter}'); + + // change the value of an existing tag + const numberValueInput = screen.getByDisplayValue('0.85'); + await user.clear(numberValueInput); + await user.type(numberValueInput, '0.95'); + + // now we have 3 tags: 2 default tags + 1 new tag + expect(screen.getAllByTestId('ml-versionTagRow')).toHaveLength(3); + + // click cancel button + await user.click(screen.getByLabelText('cancel edit tags')); + // reset to 2 tags, newly added tag is removed + expect(screen.getAllByTestId('ml-versionTagRow')).toHaveLength(2); + // tag value change is reset + expect(screen.queryByDisplayValue('0.85')).toBeInTheDocument(); + }); +}); diff --git a/public/components/model_version/__tests__/model_version.test.tsx b/public/components/model_version/__tests__/model_version.test.tsx index f223c5b3..c2ea1646 100644 --- a/public/components/model_version/__tests__/model_version.test.tsx +++ b/public/components/model_version/__tests__/model_version.test.tsx @@ -8,7 +8,7 @@ import { generatePath, Route } from 'react-router-dom'; import userEvent from '@testing-library/user-event'; import { ModelVersion } from '../model_version'; -import { render, screen, waitFor, mockOffsetMethods } from '../../../../test/test_utils'; +import { render, screen, waitFor, within, mockOffsetMethods } from '../../../../test/test_utils'; import { routerPaths } from '../../../../common/router_paths'; const setup = () => @@ -65,4 +65,53 @@ describe('', () => { mockReset(); }); + + it('should NOT allow to edit tags if version notes is edited', async () => { + setup(); + + const user = userEvent.setup(); + + // wait for data loading finished + await waitFor(() => { + expect(screen.queryByTestId('modelVersionLoadingSpinner')).not.toBeInTheDocument(); + expect(screen.queryByTestId('ml-versionDetailsLoading')).not.toBeInTheDocument(); + }); + + const editVersionNotesButton = within( + screen.getByTestId('ml-versionInformationPanel') + ).getByLabelText('edit version notes'); + + await user.click(editVersionNotesButton); + // version notes changed + await user.type(screen.getByLabelText('Version notes'), 'test notes'); + + // edit tags button should be disable as user can NOT edit notes and tags at the same time + expect( + within(screen.getByTestId('ml-versionTagPanel')).getByLabelText('edit tags') + ).toBeDisabled(); + }); + + it('should NOT allow to edit version notes if tags is edited', async () => { + setup(); + + const user = userEvent.setup(); + + // wait for data loading finished + await waitFor(() => { + expect(screen.queryByTestId('modelVersionLoadingSpinner')).not.toBeInTheDocument(); + expect(screen.queryByTestId('ml-versionDetailsLoading')).not.toBeInTheDocument(); + }); + + const editTagsButton = within(screen.getByTestId('ml-versionTagPanel')).getByLabelText( + 'edit tags' + ); + // enable editing + await user.click(editTagsButton); + // add a new tag + await user.click(screen.getByRole('button', { name: /add tag/i })); + // version information edit button should be disabled + expect( + within(screen.getByTestId('ml-versionInformationPanel')).getByLabelText('edit version notes') + ).toBeDisabled(); + }); }); diff --git a/public/components/model_version/__tests__/version_information.test.tsx b/public/components/model_version/__tests__/version_information.test.tsx index ecae38c4..a857d940 100644 --- a/public/components/model_version/__tests__/version_information.test.tsx +++ b/public/components/model_version/__tests__/version_information.test.tsx @@ -26,13 +26,13 @@ describe('', () => { it('should display version notes as readonly by default', () => { render(); expect(screen.getByLabelText('edit version notes')).toBeEnabled(); - expect(screen.getByDisplayValue('test_version_notes')).toBeDisabled(); + expect(screen.getByDisplayValue('test_version_notes')).toHaveAttribute('readonly'); }); it('should allow to edit version notes after clicking edit button', async () => { const user = userEvent.setup(); render(); - expect(screen.getByDisplayValue('test_version_notes')).toBeDisabled(); + expect(screen.getByDisplayValue('test_version_notes')).toHaveAttribute('readonly'); await user.click(screen.getByLabelText('edit version notes')); expect(screen.getByDisplayValue('test_version_notes')).toBeEnabled(); @@ -53,6 +53,6 @@ describe('', () => { await user.click(screen.getByLabelText('cancel edit version notes')); // reset to default value after clicking cancel button - expect(screen.getByDisplayValue('test_version_notes')).toBeDisabled(); + expect(screen.getByDisplayValue('test_version_notes')).toHaveAttribute('readonly'); }); }); diff --git a/public/components/model_version/model_version.tsx b/public/components/model_version/model_version.tsx index 6353c4ce..83a82e9f 100644 --- a/public/components/model_version/model_version.tsx +++ b/public/components/model_version/model_version.tsx @@ -68,7 +68,11 @@ export const ModelVersion = () => { if (model) { form.reset({ versionNotes: 'TODO', // TODO: read from model.versionNotes - tags: [], // TODO: read from model.tags + tags: [ + { key: 'Accuracy', value: '0.85', type: 'number' as const }, + { key: 'Precision', value: '0.64', type: 'number' as const }, + { key: 'Task', value: 'Image classification', type: 'string' as const }, + ], // TODO: read from model.tags configuration: JSON.stringify(model.model_config), modelFileFormat: model.model_format, // TODO: read model url or model filename diff --git a/public/components/model_version/version_information.tsx b/public/components/model_version/version_information.tsx index 25a66059..c2515422 100644 --- a/public/components/model_version/version_information.tsx +++ b/public/components/model_version/version_information.tsx @@ -13,27 +13,42 @@ import { EuiButton, EuiText, } from '@elastic/eui'; -import React, { useState, useCallback } from 'react'; +import React, { useRef, useEffect, useState, useCallback } from 'react'; import { useFormContext, useFormState } from 'react-hook-form'; import { ModelVersionNotesField } from '../common/forms/model_version_notes_field'; import { ModelFileFormData, ModelUrlFormData } from './types'; export const ModelVersionInformation = () => { const form = useFormContext(); - const formState = useFormState({ control: form.control }); + // Returned formState is wrapped with Proxy to improve render performance and + // skip extra computation if specific state is not subscribed, so make sure you + // deconstruct or read it before render in order to enable the subscription. + const { isDirty, dirtyFields } = useFormState({ control: form.control }); const [readOnly, setReadOnly] = useState(true); + const formRef = useRef(form); + formRef.current = form; const onCancel = useCallback(() => { form.resetField('versionNotes'); setReadOnly(true); }, [form]); + useEffect(() => { + // reset form value to default when component unmounted, this makes sure + // the unsaved changes are dropped when the component unmounted + return () => { + if (formRef.current.formState.dirtyFields.versionNotes) { + formRef.current.resetField('versionNotes'); + } + }; + }, []); + // Whether edit button is disabled or not // The edit button should be disabled if there were changes in other form fields - const isEditDisabled = formState.isDirty && !formState.dirtyFields.versionNotes; + const isEditDisabled = isDirty && !dirtyFields.versionNotes; return ( - + @@ -41,7 +56,7 @@ export const ModelVersionInformation = () => { - {readOnly ? ( + {readOnly || isEditDisabled ? ( { - +

      Version notes - optional @@ -69,11 +84,7 @@ export const ModelVersionInformation = () => { - + diff --git a/public/components/model_version/version_tags.tsx b/public/components/model_version/version_tags.tsx index 948fc84d..657ebbce 100644 --- a/public/components/model_version/version_tags.tsx +++ b/public/components/model_version/version_tags.tsx @@ -11,12 +11,41 @@ import { EuiSpacer, EuiTitle, EuiButton, + EuiText, + EuiLink, } from '@elastic/eui'; -import React from 'react'; +import React, { useEffect, useRef, useState, useCallback } from 'react'; +import { useFormContext, useFormState } from 'react-hook-form'; +import { ModelTagArrayField } from '../common/forms/model_tag_array_field'; +import { useModelTags } from '../register_model/register_model.hooks'; +import { ModelFileFormData, ModelUrlFormData } from './types'; export const ModelVersionTags = () => { + const [, tags] = useModelTags(); + const form = useFormContext(); + const { isDirty, dirtyFields } = useFormState({ control: form.control }); + const [readOnly, setReadOnly] = useState(true); + const formRef = useRef(form); + formRef.current = form; + + const onCancel = useCallback(() => { + form.resetField('tags'); + setReadOnly(true); + }, [form]); + + useEffect(() => { + // reset form value to default when unmount + return () => { + if (formRef.current.formState.dirtyFields.tags) { + formRef.current.resetField('tags'); + } + }; + }, []); + + const isEditDisabled = isDirty && !dirtyFields.tags; + return ( - + @@ -24,11 +53,47 @@ export const ModelVersionTags = () => { - Edit + {readOnly || isEditDisabled ? ( + setReadOnly(false)} + > + Edit + + ) : ( + + Cancel + + )} + + + + +

      + Tags - optional +

      + + Tags help your organization discover and compare models, and track information related + to model versions, such as evaluation metrics.{' '} + + Learn more + + +
      +
      + + + +
      ); }; diff --git a/public/components/register_model/__tests__/register_model_tags.test.tsx b/public/components/register_model/__tests__/register_model_tags.test.tsx index 935e8336..e1f361b1 100644 --- a/public/components/register_model/__tests__/register_model_tags.test.tsx +++ b/public/components/register_model/__tests__/register_model_tags.test.tsx @@ -26,16 +26,6 @@ describe(' Tags', () => { jest.clearAllMocks(); }); - it('should render a tags panel', async () => { - await setup(); - - const keyContainer = screen.queryByTestId('ml-tagKey1'); - const valueContainer = screen.queryByTestId('ml-tagValue1'); - - expect(keyContainer).toBeInTheDocument(); - expect(valueContainer).toBeInTheDocument(); - }); - it('should submit the form without selecting tags', async () => { const result = await setup(); await result.user.click(result.submitButton); @@ -43,20 +33,6 @@ describe(' Tags', () => { expect(onSubmitMock).toHaveBeenCalled(); }); - it('tag value input should be disabled if tag key is empty', async () => { - const result = await setup(); - - const keyContainer = screen.getByTestId('ml-tagKey1'); - const keyInput = within(keyContainer).getByRole('textbox'); - - const valueContainer = screen.getByTestId('ml-tagValue1'); - const valueInput = within(valueContainer).getByRole('textbox'); - expect(valueInput).toBeDisabled(); - - await result.user.type(keyInput, 'Key1{enter}'); - expect(valueInput).toBeEnabled(); - }); - it('should submit the form with selected tags', async () => { const result = await setup(); @@ -80,8 +56,8 @@ describe(' Tags', () => { const result = await setup(); // Add two tags - await result.user.click(screen.getByText(/add new tag/i)); - await result.user.click(screen.getByText(/add new tag/i)); + await result.user.click(screen.getByRole('button', { name: /add tag/i })); + await result.user.click(screen.getByRole('button', { name: /add tag/i })); expect( screen.getAllByText(/select or add a key/i, { selector: '.euiComboBoxPlaceholder' }) @@ -154,7 +130,7 @@ describe(' Tags', () => { await result.user.type(valueInput1, 'Value 1'); // Add a new tag, and input the same tag key and value - await result.user.click(screen.getByText(/add new tag/i)); + await result.user.click(screen.getByRole('button', { name: /add tag/i })); // input tag key: 'Key 1' const keyContainer2 = screen.getByTestId('ml-tagKey2'); const keyInput2 = within(keyContainer2).getByRole('textbox'); @@ -182,19 +158,17 @@ describe(' Tags', () => { const MAX_TAG_NUM = 10; // It has one tag by default, we can add 24 more tags - const addNewTagButton = screen.getByText(/add new tag/i); + const addNewTagButton = screen.getByRole('button', { name: /add tag/i }); for (let i = 1; i < MAX_TAG_NUM; i++) { await result.user.click(addNewTagButton); } // 10 tags are displayed await waitFor(() => expect(screen.queryAllByTestId(/ml-tagKey/i)).toHaveLength(10)); - // add new tag button should not be displayed - await waitFor(() => - expect(screen.getByRole('button', { name: /add new tag/i })).toBeDisabled() - ); + // add tag button should not be displayed + await waitFor(() => expect(screen.getByRole('button', { name: /add tag/i })).toBeDisabled()); }, - // The test will fail due to timeout as we interact with the page a lot(24 button click to add new tags) + // The test will fail due to timeout as we interact with the page a lot(24 button click to add tags) // So we try to increase test running timeout to 60000ms to mitigate the timeout issue 60 * 1000 ); @@ -203,8 +177,8 @@ describe(' Tags', () => { const result = await setup(); // Add two tags - await result.user.click(screen.getByText(/add new tag/i)); - await result.user.click(screen.getByText(/add new tag/i)); + await result.user.click(screen.getByRole('button', { name: /add tag/i })); + await result.user.click(screen.getByRole('button', { name: /add tag/i })); expect( screen.getAllByText(/select or add a key/i, { selector: '.euiComboBoxPlaceholder' }) @@ -240,11 +214,9 @@ describe(' Tags', () => { mode: 'version', }); - await result.user.click(screen.getByText(/add new tag/i)); + await result.user.click(screen.getByRole('button', { name: /add tag/i })); - await waitFor(() => - expect(screen.getByRole('button', { name: /add new tag/i })).toBeDisabled() - ); + await waitFor(() => expect(screen.getByRole('button', { name: /add tag/i })).toBeDisabled()); }); it('should prevent creating new tag key when registering new version', async () => { @@ -267,95 +239,6 @@ describe(' Tags', () => { ).toBeInTheDocument(); }); - it('should display error when creating new tag key with more than 80 characters', async () => { - const result = await setup(); - - const keyContainer = screen.getByTestId('ml-tagKey1'); - const keyInput = within(keyContainer).getByRole('textbox'); - await result.user.type(keyInput, `${'x'.repeat(81)}{enter}`); - expect( - within(keyContainer).queryByText('80 characters allowed. Use 80 characters or less.') - ).toBeInTheDocument(); - }); - - it('should display error when creating new tag value with more than 80 characters', async () => { - const result = await setup(); - - const keyContainer = screen.getByTestId('ml-tagKey1'); - const keyInput = within(keyContainer).getByRole('textbox'); - await result.user.type(keyInput, 'dummy key{enter}'); - - const valueContainer = screen.getByTestId('ml-tagValue1'); - const valueInput = within(valueContainer).getByRole('textbox'); - await result.user.type(valueInput, `${'x'.repeat(81)}{enter}`); - expect( - within(valueContainer).queryByText('80 characters allowed. Use 80 characters or less.') - ).toBeInTheDocument(); - }); - - it('should display "No keys found" and "No values found" if no tag keys and no tag values are provided', async () => { - jest.spyOn(formHooks, 'useModelTags').mockReturnValue([false, []]); - - const result = await setup(); - - const keyContainer = screen.getByTestId('ml-tagKey1'); - const keyInput = within(keyContainer).getByRole('textbox'); - await result.user.click(keyInput); - expect(screen.getByText('No keys found. Add a key.')).toBeInTheDocument(); - - await result.user.type(keyInput, 'dummy key{enter}'); - - const valueContainer = screen.getByTestId('ml-tagValue1'); - const valueInput = within(valueContainer).getByRole('textbox'); - await result.user.click(valueInput); - expect(screen.getByText('No values found. Add a value.')).toBeInTheDocument(); - }); - - it('should NOT display "Key1" in the option list after "Key1" selected', async () => { - const result = await setup(); - - const keyContainer = screen.getByTestId('ml-tagKey1'); - const keyInput = within(keyContainer).getByRole('textbox'); - await result.user.click(keyInput); - const optionListContainer = screen.getByTestId('comboBoxOptionsList'); - expect(within(optionListContainer).getByTitle('Key1')).toBeInTheDocument(); - - await result.user.click(within(optionListContainer).getByTitle('Key1')); - expect(within(optionListContainer).queryByTitle('Key1')).toBe(null); - }); - - it('should not allow to select tag type if selected an existed tag', async () => { - const result = await setup(); - - const keyContainer = screen.getByTestId('ml-tagKey1'); - const valueContainer = screen.getByTestId('ml-tagValue1'); - const keyInput = within(keyContainer).getByRole('textbox'); - - await result.user.click(keyInput); - // selected an existed tag - await result.user.click(within(screen.getByTestId('comboBoxOptionsList')).getByTitle('Key1')); - - expect(within(valueContainer).queryByLabelText('select tag type')).not.toBeInTheDocument(); - }); - - it('should display a list of tag value for selection after selecting a tag key', async () => { - const result = await setup(); - - const keyContainer = screen.getByTestId('ml-tagKey1'); - const valueContainer = screen.getByTestId('ml-tagValue1'); - const keyInput = within(keyContainer).getByRole('textbox'); - const valueInput = within(valueContainer).getByRole('textbox'); - - await result.user.click(keyInput); - // selected an existed tag - await result.user.click(within(screen.getByTestId('comboBoxOptionsList')).getByTitle('Key1')); - - await result.user.click(valueInput); - expect( - within(screen.getByTestId('comboBoxOptionsList')).queryByTitle('Value1') - ).toBeInTheDocument(); - }); - it('should clear the tag input when click remove button if there is only one tag', async () => { const result = await setup(); diff --git a/public/components/register_model/model_tags.tsx b/public/components/register_model/model_tags.tsx index 391ae0a9..d0fddcb8 100644 --- a/public/components/register_model/model_tags.tsx +++ b/public/components/register_model/model_tags.tsx @@ -3,32 +3,21 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback } from 'react'; -import { EuiButton, EuiSpacer, EuiText, EuiLink } from '@elastic/eui'; -import { useFieldArray, useFormContext } from 'react-hook-form'; +import React from 'react'; +import { EuiSpacer, EuiText, EuiLink } from '@elastic/eui'; import { useParams } from 'react-router-dom'; -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -import { ModelTagField } from './tag_field'; import { useModelTags } from './register_model.hooks'; +import { ModelTagArrayField } from '../common/forms/model_tag_array_field'; const MAX_TAG_NUM = 10; export const ModelTagsPanel = () => { - const { control } = useFormContext(); const { id: latestVersionId } = useParams<{ id: string | undefined }>(); const [, tags] = useModelTags(); - const { fields, append, remove } = useFieldArray({ - name: 'tags', - control, - }); const isRegisterNewVersion = !!latestVersionId; const maxTagNum = isRegisterNewVersion ? tags.length : MAX_TAG_NUM; - const addNewTag = useCallback(() => { - append({ key: '', value: '', type: 'string' }); - }, [append]); - return (
      @@ -57,25 +46,7 @@ export const ModelTagsPanel = () => { - {fields.map((field, index) => { - return ( - - ); - })} - - = maxTagNum} onClick={addNewTag}> - Add new tag - - - - {`You can add up to ${maxTagNum - fields.length} more tags.`} - +
      ); }; diff --git a/public/components/register_model/model_version_notes.tsx b/public/components/register_model/model_version_notes.tsx index 69a4ce6e..6cd8d191 100644 --- a/public/components/register_model/model_version_notes.tsx +++ b/public/components/register_model/model_version_notes.tsx @@ -5,14 +5,10 @@ import React from 'react'; import { EuiText, EuiSpacer } from '@elastic/eui'; -import { useFormContext } from 'react-hook-form'; -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { ModelVersionNotesField } from '../common/forms/model_version_notes_field'; export const ModelVersionNotesPanel = () => { - const { control } = useFormContext(); - return (
      @@ -20,7 +16,6 @@ export const ModelVersionNotesPanel = () => { Version notes - optional{' '} diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 88a7f49b..526ae4fb 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -302,7 +302,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo verticalPosition="center" horizontalPosition="center" paddingSize="none" - style={{ maxWidth: 1000 }} + style={{ width: 1000 }} > Date: Tue, 23 May 2023 14:16:04 +0800 Subject: [PATCH 50/75] feat(ui): artifact and configuration edit (#187) feat: edit model artifact and configuration on version details page Enable user to edit model version artifact and configuration --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../forms/__tests__/artifact_file.test.tsx | 66 ++++++ .../forms/__tests__/artifact_url.test.tsx | 91 ++++++++ .../__tests__/model_file_format.test.tsx | 89 ++++++++ .../components/common/forms/artifact_file.tsx | 80 +++++++ .../components/common/forms/artifact_url.tsx | 67 ++++++ .../forms/form_constants.ts} | 8 +- .../__tests__/model_configuration.test.tsx | 110 ++++++++++ .../model_configuration}/error_message.tsx | 0 .../model_configuration}/help_flyout.tsx | 0 .../forms/model_configuration/index.ts} | 2 +- .../model_configuration.tsx | 142 +++++++++++++ .../common/forms/model_file_format.tsx | 75 +++++++ .../common/forms/model_name_field.tsx | 3 +- .../__tests__}/model_tag_array_field.test.tsx | 4 +- .../__tests__}/tag_field.test.tsx | 0 .../__tests__}/tag_type_popover.test.tsx | 0 .../components/model/model_details_panel.tsx | 8 +- .../__tests__/version_artifact.test.tsx | 200 ++++++++++++++++++ .../version_artifact_source.test.tsx | 39 ++++ .../__tests__/version_information.test.tsx | 4 +- .../model_version/model_version.tsx | 9 +- public/components/model_version/types.ts | 17 +- .../model_version/version_artifact.tsx | 125 ++++++++++- .../model_version/version_artifact_source.tsx | 65 ++++++ .../model_version/version_information.tsx | 12 +- .../components/model_version/version_tags.tsx | 12 +- .../register_model_configuration.test.tsx | 8 - public/components/register_model/artifact.tsx | 89 +------- .../register_model/artifact_file.tsx | 51 ----- .../register_model/artifact_url.tsx | 45 ---- .../register_model/model_configuration.tsx | 128 +---------- .../register_model/register_model.tsx | 2 +- .../register_model/register_model_api.ts | 2 +- 33 files changed, 1194 insertions(+), 359 deletions(-) create mode 100644 public/components/common/forms/__tests__/artifact_file.test.tsx create mode 100644 public/components/common/forms/__tests__/artifact_url.test.tsx create mode 100644 public/components/common/forms/__tests__/model_file_format.test.tsx create mode 100644 public/components/common/forms/artifact_file.tsx create mode 100644 public/components/common/forms/artifact_url.tsx rename public/components/{register_model/constants.ts => common/forms/form_constants.ts} (92%) create mode 100644 public/components/common/forms/model_configuration/__tests__/model_configuration.test.tsx rename public/components/{register_model => common/forms/model_configuration}/error_message.tsx (100%) rename public/components/{register_model => common/forms/model_configuration}/help_flyout.tsx (100%) rename public/components/{register_model/form_constants.ts => common/forms/model_configuration/index.ts} (68%) create mode 100644 public/components/common/forms/model_configuration/model_configuration.tsx create mode 100644 public/components/common/forms/model_file_format.tsx rename public/components/common/forms/{__tests__/model_tag_array_field => model_tag_array_field/__tests__}/model_tag_array_field.test.tsx (95%) rename public/components/common/forms/{__tests__/model_tag_array_field => model_tag_array_field/__tests__}/tag_field.test.tsx (100%) rename public/components/common/forms/{__tests__/model_tag_array_field => model_tag_array_field/__tests__}/tag_type_popover.test.tsx (100%) create mode 100644 public/components/model_version/__tests__/version_artifact.test.tsx create mode 100644 public/components/model_version/__tests__/version_artifact_source.test.tsx create mode 100644 public/components/model_version/version_artifact_source.tsx delete mode 100644 public/components/register_model/artifact_file.tsx delete mode 100644 public/components/register_model/artifact_url.tsx diff --git a/public/components/common/forms/__tests__/artifact_file.test.tsx b/public/components/common/forms/__tests__/artifact_file.test.tsx new file mode 100644 index 00000000..c8508804 --- /dev/null +++ b/public/components/common/forms/__tests__/artifact_file.test.tsx @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import userEvent from '@testing-library/user-event'; + +import { ModelFileUploader } from '../artifact_file'; +import { render, screen } from '../../../../../test/test_utils'; +import { ONE_GB } from '../../../../../common/constant'; + +const TestApp = ({ readOnly = false }: { readOnly?: boolean }) => { + const form = useForm({ + defaultValues: {}, + mode: 'onChange', + }); + + return ( + + + + ); +}; + +describe('', () => { + it('should render a file upload input field', () => { + render(); + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInTheDocument(); + }); + + it('should display error if selected file size > 4GB', async () => { + const user = userEvent.setup(); + render(); + const modelFileInput = screen.getByLabelText(/^file$/i); + // File size can not exceed 4GB + const invalidFile = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + Object.defineProperty(invalidFile, 'size', { value: 4 * ONE_GB + 1 }); + await user.upload(modelFileInput, invalidFile); + + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInvalid(); + expect(screen.getByText('Maximum file size exceeded. Add a smaller file.')).toBeInTheDocument(); + }); + + it('should allow to upload file <= 4GB', async () => { + const user = userEvent.setup(); + render(); + const modelFileInput = screen.getByLabelText(/^file$/i); + // File size can not exceed 4GB + const invalidFile = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + Object.defineProperty(invalidFile, 'size', { value: 4 * ONE_GB }); + await user.upload(modelFileInput, invalidFile); + + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeValid(); + expect( + screen.queryByText('Maximum file size exceeded. Add a smaller file.') + ).not.toBeInTheDocument(); + }); + + it('should display a readonly file uploader', () => { + render(); + const modelFileInput = screen.getByLabelText(/^file$/i); + expect(modelFileInput).toHaveAttribute('readonly'); + }); +}); diff --git a/public/components/common/forms/__tests__/artifact_url.test.tsx b/public/components/common/forms/__tests__/artifact_url.test.tsx new file mode 100644 index 00000000..b4e72571 --- /dev/null +++ b/public/components/common/forms/__tests__/artifact_url.test.tsx @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../test/test_utils'; +import { ModelArtifactUrl } from '../artifact_url'; + +const TestApp = ({ + readOnly = false, + onSubmit = jest.fn(), + onError = jest.fn(), +}: { + readOnly?: boolean; + onSubmit?: () => any; + onError?: () => any; +}) => { + const form = useForm({ + defaultValues: {}, + mode: 'onChange', + }); + + return ( + +
      + + + +
      + ); +}; + +describe('', () => { + it('should render a URL input field', () => { + render(); + + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + expect(urlInput).toBeInTheDocument(); + }); + + it('should display URL input as readonly', () => { + render(); + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + expect(urlInput).toHaveAttribute('readonly'); + }); + + it('should display error message if URL is invalid', async () => { + const user = userEvent.setup(); + render(); + + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + await user.type(urlInput, 'invalid_url'); + expect(urlInput).toBeInvalid(); + expect(screen.getByText('URL is invalid. Enter a valid URL.')).toBeInTheDocument(); + }); + + it('should NOT display error message if URL is valid', async () => { + const user = userEvent.setup(); + render(); + + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + await user.type(urlInput, 'https://url.to/artifact.zip'); + expect(urlInput).toBeValid(); + expect(screen.queryByText('URL is invalid. Enter a valid URL.')).not.toBeInTheDocument(); + }); + + it('should display error message if URL is empty', async () => { + const user = userEvent.setup(); + render(); + + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + await user.clear(urlInput); + await user.click(screen.getByText('Submit')); + expect(urlInput).toBeInvalid(); + expect(screen.getByText('URL is required. Enter a URL.')).toBeInTheDocument(); + }); +}); diff --git a/public/components/common/forms/__tests__/model_file_format.test.tsx b/public/components/common/forms/__tests__/model_file_format.test.tsx new file mode 100644 index 00000000..f399bddc --- /dev/null +++ b/public/components/common/forms/__tests__/model_file_format.test.tsx @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import userEvent from '@testing-library/user-event'; + +import { render, screen, within } from '../../../../../test/test_utils'; +import { ModelFileFormatSelect } from '../model_file_format'; + +const TestApp = ({ + readOnly = false, + onSubmit = jest.fn(), + onError = jest.fn(), +}: { + readOnly?: boolean; + onSubmit?: () => any; + onError?: () => any; +}) => { + const form = useForm({ + defaultValues: { + modelFileFormat: 'TORCH_SCRIPT', + }, + }); + + return ( + +
      + + + +
      + ); +}; + +describe('', () => { + it('should render a model file format select', async () => { + const user = userEvent.setup(); + render(); + + const input = screen.getByLabelText(/model file format/i); + await user.click(input); + + const listBox = screen.getByRole('listBox'); + + expect(within(listBox).getByText('Torchscript(.pt)')).toBeInTheDocument(); + expect(within(listBox).getByText('ONNX(.onnx)')).toBeInTheDocument(); + }); + + it('should select model file format from the dropdown list', async () => { + const user = userEvent.setup(); + render(); + + const comboBox = screen.getByRole('combobox'); + // the default value + expect(within(comboBox).getByText('Torchscript(.pt)')).toBeInTheDocument(); + + const input = screen.getByLabelText(/model file format/i); + await user.click(input); + + const listBox = screen.getByRole('listBox'); + // select another value + await user.click(within(listBox).getByText('ONNX(.onnx)')); + expect(within(comboBox).getByText('ONNX(.onnx)')).toBeInTheDocument(); + }); + + it('should display model file format input as readonly', () => { + render(); + const input = screen.getByLabelText(/model file format/i); + expect(input).toHaveAttribute('readonly'); + // render default value in a readonly input + expect(screen.getByDisplayValue('Torchscript(.pt)')).toBeInTheDocument(); + }); + + it('should display error message if model file format select is empty', async () => { + const user = userEvent.setup(); + render(); + + // clear default value + await user.click(screen.getByLabelText('Clear input')); + await user.click(screen.getByText('Submit')); + + expect( + screen.getByText('Model file format is required. Select a model file format.') + ).toBeInTheDocument(); + }); +}); diff --git a/public/components/common/forms/artifact_file.tsx b/public/components/common/forms/artifact_file.tsx new file mode 100644 index 00000000..2ff29729 --- /dev/null +++ b/public/components/common/forms/artifact_file.tsx @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useEffect } from 'react'; +import { EuiFormRow, EuiFilePicker, EuiText, EuiFieldText } from '@elastic/eui'; +import { useController, useFormContext } from 'react-hook-form'; + +import { CUSTOM_FORM_ERROR_TYPES, MAX_MODEL_FILE_SIZE } from './form_constants'; +import { ONE_GB } from '../../../../common/constant'; + +function validateFileSize(file?: File) { + if (file && file.size > MAX_MODEL_FILE_SIZE) { + return 'Maximum file size exceeded. Add a smaller file.'; + } + return true; +} + +export const UploadHelpText = () => ( + <> + + Accepted file format: ZIP (.zip). Maximum size, {MAX_MODEL_FILE_SIZE / ONE_GB}GB. + + + The ZIP mush include the following contents: +
        +
      • Model File, accepted formats: Torchscript(.pt), ONNX(.onnx)
      • +
      • Tokenizer file, accepted format: JSON(.json)
      • +
      +
      + +); + +interface Props { + label?: string; + readOnly?: boolean; +} + +export const ModelFileUploader = ({ readOnly = false, label = 'File' }: Props) => { + const { control, unregister } = useFormContext<{ modelFile?: File }>(); + const modelFileFieldController = useController({ + name: 'modelFile', + control, + rules: { + required: { value: true, message: 'A file is required. Add a file.' }, + validate: { + [CUSTOM_FORM_ERROR_TYPES.FILE_SIZE_EXCEED_LIMIT]: validateFileSize, + }, + }, + }); + + useEffect(() => { + return () => { + unregister('modelFile', { keepDefaultValue: true }); + }; + }, [unregister]); + + return readOnly ? ( + + + + ) : ( + + { + modelFileFieldController.field.onChange(fileList?.item(0)); + }} + /> + + ); +}; diff --git a/public/components/common/forms/artifact_url.tsx b/public/components/common/forms/artifact_url.tsx new file mode 100644 index 00000000..972828ac --- /dev/null +++ b/public/components/common/forms/artifact_url.tsx @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useEffect } from 'react'; +import { EuiFormRow, htmlIdGenerator, EuiFieldText, EuiCopy, EuiIcon } from '@elastic/eui'; +import { useController, useFormContext } from 'react-hook-form'; + +import { URL_REGEX } from '../../../utils/regex'; + +interface Props { + label?: string; + readOnly?: boolean; +} + +export const ModelArtifactUrl = ({ label = 'URL', readOnly = false }: Props) => { + const { control, unregister } = useFormContext<{ modelURL?: string }>(); + const modelUrlFieldController = useController({ + name: 'modelURL', + control, + rules: { + required: { value: true, message: 'URL is required. Enter a URL.' }, + pattern: { value: URL_REGEX, message: 'URL is invalid. Enter a valid URL.' }, + }, + }); + + useEffect(() => { + return () => { + unregister('modelURL', { keepDefaultValue: true }); + }; + }, [unregister]); + + return readOnly ? ( +
      + + + + + {(copy) => ( + + )} + +
      + ) : ( + + + + ); +}; diff --git a/public/components/register_model/constants.ts b/public/components/common/forms/form_constants.ts similarity index 92% rename from public/components/register_model/constants.ts rename to public/components/common/forms/form_constants.ts index d0d8ac33..e8c85e3c 100644 --- a/public/components/register_model/constants.ts +++ b/public/components/common/forms/form_constants.ts @@ -2,14 +2,16 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -import { ONE_GB } from '../../../common/constant'; -import { MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR } from '../../components/common'; + +import { ONE_GB } from '../../../../common/constant'; + +export const MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR: string = 'duplicateName'; export const MAX_CHUNK_SIZE = 10 * 1000 * 1000; export const MAX_MODEL_FILE_SIZE = 4 * ONE_GB; export enum CUSTOM_FORM_ERROR_TYPES { - DUPLICATE_NAME = MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR, + DUPLICATE_NAME = 'duplicateName', FILE_SIZE_EXCEED_LIMIT = 'fileSizeExceedLimit', INVALID_CONFIGURATION = 'invalidConfiguration', CONFIGURATION_MISSING_MODEL_TYPE = 'configurationMissingModelType', diff --git a/public/components/common/forms/model_configuration/__tests__/model_configuration.test.tsx b/public/components/common/forms/model_configuration/__tests__/model_configuration.test.tsx new file mode 100644 index 00000000..f0eec7b5 --- /dev/null +++ b/public/components/common/forms/model_configuration/__tests__/model_configuration.test.tsx @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../../test/test_utils'; +import { ModelConfiguration } from '../model_configuration'; + +const TestApp = ({ + configuration = '{}', + readOnly = false, + onSubmit = jest.fn(), + onError = jest.fn(), +}: { + configuration?: string; + readOnly?: boolean; + onSubmit?: () => any; + onError?: () => any; +}) => { + const form = useForm({ + defaultValues: { + configuration, + }, + }); + + return ( + +
      + + + +
      + ); +}; + +describe('', () => { + it('should render a help flyout when click help button', async () => { + const user = userEvent.setup(); + render(); + + expect(screen.getByLabelText('JSON configuration')).toBeInTheDocument(); + await user.click(screen.getByTestId('model-configuration-help-button')); + expect(screen.getByRole('dialog')).toBeInTheDocument(); + }); + + it('should display error if configuration is empty', async () => { + const user = userEvent.setup(); + render(); + await user.click(screen.getByText('Submit')); + expect(screen.getByText('Configuration is required.')).toBeInTheDocument(); + }); + + it('should display error if configuration is NOT a valid JSON string', async () => { + const user = userEvent.setup(); + render(); + await user.click(screen.getByText('Submit')); + expect(screen.getByText('JSON is invalid. Enter a valid JSON')).toBeInTheDocument(); + }); + + it('should display error if model_type is missing', async () => { + const configuration = { + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + }; + const user = userEvent.setup(); + render(); + await user.click(screen.getByText('Submit')); + expect(screen.getByText('model_type is required. Specify the model_type')).toBeInTheDocument(); + }); + + it('should display error if model_type is invalid', async () => { + const configuration = { + model_type: 768, + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + }; + const user = userEvent.setup(); + render(); + await user.click(screen.getByText('Submit')); + expect(screen.getByText('model_type must be a string')).toBeInTheDocument(); + }); + + it('should display error if embedding_dimension is invalid', async () => { + const configuration = { + model_type: 'roberta', + embedding_dimension: 'invalid_value', + framework_type: 'SENTENCE_TRANSFORMERS', + }; + const user = userEvent.setup(); + render(); + await user.click(screen.getByText('Submit')); + expect(screen.getByText('embedding_dimension must be a number')).toBeInTheDocument(); + }); + + it('should display error if framework_type is invalid', async () => { + const configuration = { + model_type: 'roberta', + embedding_dimension: 768, + framework_type: 0, + }; + const user = userEvent.setup(); + render(); + await user.click(screen.getByText('Submit')); + expect(screen.getByText('framework_type must be a string')).toBeInTheDocument(); + }); +}); diff --git a/public/components/register_model/error_message.tsx b/public/components/common/forms/model_configuration/error_message.tsx similarity index 100% rename from public/components/register_model/error_message.tsx rename to public/components/common/forms/model_configuration/error_message.tsx diff --git a/public/components/register_model/help_flyout.tsx b/public/components/common/forms/model_configuration/help_flyout.tsx similarity index 100% rename from public/components/register_model/help_flyout.tsx rename to public/components/common/forms/model_configuration/help_flyout.tsx diff --git a/public/components/register_model/form_constants.ts b/public/components/common/forms/model_configuration/index.ts similarity index 68% rename from public/components/register_model/form_constants.ts rename to public/components/common/forms/model_configuration/index.ts index c8068177..7037e42c 100644 --- a/public/components/register_model/form_constants.ts +++ b/public/components/common/forms/model_configuration/index.ts @@ -3,4 +3,4 @@ * SPDX-License-Identifier: Apache-2.0 */ -export const FORM_ITEM_WIDTH = 400; +export * from './model_configuration'; diff --git a/public/components/common/forms/model_configuration/model_configuration.tsx b/public/components/common/forms/model_configuration/model_configuration.tsx new file mode 100644 index 00000000..bce558db --- /dev/null +++ b/public/components/common/forms/model_configuration/model_configuration.tsx @@ -0,0 +1,142 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState } from 'react'; +import { EuiButtonEmpty, EuiCodeBlock, EuiCodeEditor, EuiFormRow } from '@elastic/eui'; +import { useController, useFormContext } from 'react-hook-form'; + +import { ErrorMessage } from './error_message'; +import { HelpFlyout } from './help_flyout'; +import { CUSTOM_FORM_ERROR_TYPES } from '../form_constants'; + +function validateConfigurationObject(value: string) { + try { + JSON.parse(value.trim()); + } catch { + return 'JSON is invalid. Enter a valid JSON'; + } + return true; +} + +function validateModelType(value: string) { + try { + const config = JSON.parse(value.trim()); + if (!('model_type' in config)) { + return 'model_type is required. Specify the model_type'; + } + } catch { + return true; + } + return true; +} + +function validateModelTypeValue(value: string) { + try { + const config = JSON.parse(value.trim()); + if ('model_type' in config && typeof config.model_type !== 'string') { + return 'model_type must be a string'; + } + } catch { + return true; + } + return true; +} + +function validateEmbeddingDimensionValue(value: string) { + try { + const config = JSON.parse(value.trim()); + if ('embedding_dimension' in config && typeof config.embedding_dimension !== 'number') { + return 'embedding_dimension must be a number'; + } + } catch { + return true; + } + + return true; +} + +function validateFrameworkTypeValue(value: string) { + try { + const config = JSON.parse(value.trim()); + if ('framework_type' in config && typeof config.framework_type !== 'string') { + return 'framework_type must be a string'; + } + } catch { + return true; + } + return true; +} + +interface Props { + readOnly?: boolean; +} + +export const ModelConfiguration = ({ readOnly = false }: Props) => { + const [isHelpVisible, setIsHelpVisible] = useState(false); + const { control } = useFormContext<{ configuration: string }>(); + const configurationFieldController = useController({ + name: 'configuration', + control, + rules: { + required: { value: true, message: 'Configuration is required.' }, + validate: { + [CUSTOM_FORM_ERROR_TYPES.INVALID_CONFIGURATION]: validateConfigurationObject, + [CUSTOM_FORM_ERROR_TYPES.CONFIGURATION_MISSING_MODEL_TYPE]: validateModelType, + [CUSTOM_FORM_ERROR_TYPES.INVALID_MODEL_TYPE_VALUE]: validateModelTypeValue, + [CUSTOM_FORM_ERROR_TYPES.INVALID_EMBEDDING_DIMENSION_VALUE]: validateEmbeddingDimensionValue, + [CUSTOM_FORM_ERROR_TYPES.INVALID_FRAMEWORK_TYPE_VALUE]: validateFrameworkTypeValue, + }, + }, + }); + + return ( + <> + } + labelAppend={ + setIsHelpVisible(true)} + size="xs" + color="primary" + data-test-subj="model-configuration-help-button" + > + Help + + } + > + {readOnly ? ( + + {configurationFieldController.field.value} + + ) : ( + configurationFieldController.field.onChange(value)} + setOptions={{ + fontSize: '14px', + enableBasicAutocompletion: true, + enableLiveAutocompletion: true, + }} + /> + )} + + {isHelpVisible && setIsHelpVisible(false)} />} + + ); +}; diff --git a/public/components/common/forms/model_file_format.tsx b/public/components/common/forms/model_file_format.tsx new file mode 100644 index 00000000..4da864da --- /dev/null +++ b/public/components/common/forms/model_file_format.tsx @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EuiComboBox, EuiComboBoxOptionOption, EuiFieldText, EuiFormRow } from '@elastic/eui'; +import { useController, useFormContext } from 'react-hook-form'; +import React, { useMemo, useCallback } from 'react'; + +export const FILE_FORMAT_OPTIONS = [ + { + label: 'ONNX(.onnx)', + value: 'ONNX', + }, + { + label: 'Torchscript(.pt)', + value: 'TORCH_SCRIPT', + }, +]; + +interface Props { + readOnly?: boolean; +} + +export const ModelFileFormatSelect = ({ readOnly = false }: Props) => { + const { control } = useFormContext<{ modelFileFormat: string }>(); + + const modelFileFormatController = useController({ + name: 'modelFileFormat', + control, + rules: { + required: { + value: true, + message: 'Model file format is required. Select a model file format.', + }, + }, + }); + + const { ref: fileFormatInputRef, ...fileFormatField } = modelFileFormatController.field; + + const selectedFileFormatOption = useMemo(() => { + if (fileFormatField.value) { + return FILE_FORMAT_OPTIONS.find((fmt) => fmt.value === fileFormatField.value); + } + }, [fileFormatField]); + + const onFileFormatChange = useCallback( + (options: Array>) => { + const value = options[0]?.value; + fileFormatField.onChange(value); + }, + [fileFormatField] + ); + + return ( + + {readOnly ? ( + + ) : ( + + )} + + ); +}; diff --git a/public/components/common/forms/model_name_field.tsx b/public/components/common/forms/model_name_field.tsx index 633e9bee..658e4ed8 100644 --- a/public/components/common/forms/model_name_field.tsx +++ b/public/components/common/forms/model_name_field.tsx @@ -8,8 +8,7 @@ import { EuiFieldText, EuiFormRow, EuiText } from '@elastic/eui'; import { Control, FieldPathByValue, UseFormTrigger, useController } from 'react-hook-form'; import { APIProvider } from '../../../apis/api_provider'; - -export const MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR = 'duplicateName'; +import { MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR } from './form_constants'; const NAME_MAX_LENGTH = 80; diff --git a/public/components/common/forms/__tests__/model_tag_array_field/model_tag_array_field.test.tsx b/public/components/common/forms/model_tag_array_field/__tests__/model_tag_array_field.test.tsx similarity index 95% rename from public/components/common/forms/__tests__/model_tag_array_field/model_tag_array_field.test.tsx rename to public/components/common/forms/model_tag_array_field/__tests__/model_tag_array_field.test.tsx index fcab299a..ea05a0ec 100644 --- a/public/components/common/forms/__tests__/model_tag_array_field/model_tag_array_field.test.tsx +++ b/public/components/common/forms/model_tag_array_field/__tests__/model_tag_array_field.test.tsx @@ -8,9 +8,9 @@ import { FormProvider, useForm } from 'react-hook-form'; import userEvent from '@testing-library/user-event'; import { Tag } from '../../../../../components/model/types'; -import { ModelTagArrayField } from '../../model_tag_array_field'; -import { TagGroup } from '../../model_tag_array_field/tag_field'; import { render, screen } from '../../../../../../test/test_utils'; +import { ModelTagArrayField } from '../model_tag_array_field'; +import { TagGroup } from '../tag_field'; const TEST_TAG_GROUPS = [ { name: 'Key1', type: 'string' as const, values: ['Value1'] }, diff --git a/public/components/common/forms/__tests__/model_tag_array_field/tag_field.test.tsx b/public/components/common/forms/model_tag_array_field/__tests__/tag_field.test.tsx similarity index 100% rename from public/components/common/forms/__tests__/model_tag_array_field/tag_field.test.tsx rename to public/components/common/forms/model_tag_array_field/__tests__/tag_field.test.tsx diff --git a/public/components/common/forms/__tests__/model_tag_array_field/tag_type_popover.test.tsx b/public/components/common/forms/model_tag_array_field/__tests__/tag_type_popover.test.tsx similarity index 100% rename from public/components/common/forms/__tests__/model_tag_array_field/tag_type_popover.test.tsx rename to public/components/common/forms/model_tag_array_field/__tests__/tag_type_popover.test.tsx diff --git a/public/components/model/model_details_panel.tsx b/public/components/model/model_details_panel.tsx index fb6c8298..8bc35773 100644 --- a/public/components/model/model_details_panel.tsx +++ b/public/components/model/model_details_panel.tsx @@ -24,12 +24,8 @@ import { generatePath, useHistory } from 'react-router-dom'; import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; import { mountReactNode } from '../../../../../src/core/public/utils'; import { routerPaths } from '../../../common'; -import { - ErrorCallOut, - MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR, - ModelDescriptionField, - ModelNameField, -} from '../common'; +import { ErrorCallOut, ModelDescriptionField, ModelNameField } from '../common'; +import { MODEL_NAME_FIELD_DUPLICATE_NAME_ERROR } from '../common/forms/form_constants'; import { BottomFormActionBar } from './bottom_form_action_bar'; diff --git a/public/components/model_version/__tests__/version_artifact.test.tsx b/public/components/model_version/__tests__/version_artifact.test.tsx new file mode 100644 index 00000000..4ef961a2 --- /dev/null +++ b/public/components/model_version/__tests__/version_artifact.test.tsx @@ -0,0 +1,200 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import userEvent from '@testing-library/user-event'; +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; + +import { render, screen, within } from '../../../../test/test_utils'; +import { ModelVersionFormData } from '../types'; +import { ModelVersionArtifact } from '../version_artifact'; + +const DEFAULT_VALUES = { + artifactSource: 'source_not_changed' as const, + modelURL: 'http://url.to/artifact.zip', + modelFileFormat: 'TORCH_SCRIPT', + configuration: '{}', +}; + +const TestApp = ({ defaultValues = DEFAULT_VALUES }: { defaultValues?: ModelVersionFormData }) => { + const form = useForm({ + defaultValues, + }); + + return ( + + + + ); +}; + +describe('', () => { + it('should display a readonly "Artifact and configuration" panel', () => { + render(); + expect(screen.getByLabelText('Uploaded artifact details(URL)')).toHaveAttribute('readonly'); + expect(screen.getByLabelText('Model file format')).toHaveAttribute('readonly'); + expect(screen.getByLabelText('readonly configuration')).toBeInTheDocument(); + + expect(screen.getByDisplayValue('http://url.to/artifact.zip')).toBeInTheDocument(); + expect(screen.getByDisplayValue('Torchscript(.pt)')).toBeInTheDocument(); + }); + + it('should display a readonly input with file name', () => { + const defaultValues = { + artifactSource: 'source_not_changed' as const, + modelFile: new File([], 'artifact.zip'), + modelFileFormat: 'TORCH_SCRIPT', + configuration: '{}', + }; + render(); + expect(screen.getByDisplayValue('artifact.zip')).toBeInTheDocument(); + expect(screen.getByDisplayValue('artifact.zip')).toHaveAttribute('readonly'); + }); + + it('should display "Artifact source select" after clicking "Edit" button', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit version artifact')); + expect(screen.getByText('Artifact source')).toBeInTheDocument(); + expect(screen.getByLabelText('Upload new from local file')).toBeInTheDocument(); + expect(screen.getByLabelText('Upload new from URL')).toBeInTheDocument(); + }); + + it('should select "Keep existing" by default after click "Edit" button', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit version artifact')); + expect(screen.getByLabelText('Keep existing')).toBeChecked(); + // model url is readonly + expect(screen.getByLabelText('Uploaded artifact details(URL)')).toHaveAttribute('readonly'); + // model file format is editable + expect(screen.getByLabelText('Model file format')).not.toHaveAttribute('readonly'); + + // configuration json input is editable + expect(screen.getByLabelText('JSON configuration')).toBeEnabled(); + expect(screen.getByLabelText('JSON configuration')).not.toHaveAttribute('readonly'); + }); + + it('should display model file upload input when selecting "Upload new from local file"', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit version artifact')); + await user.click(screen.getByLabelText('Upload new from local file')); + + expect(screen.getByLabelText(/file/i, { selector: 'input[type="file"]' })).toBeInTheDocument(); + }); + + it('should display model URL input when selecting "Upload new from URL"', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit version artifact')); + await user.click(screen.getByLabelText('Upload new from URL')); + + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + expect(urlInput).toBeInTheDocument(); + }); + + it('should display an editable "Model file format" when NOT selecting "Keep existing"', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit version artifact')); + + await user.click(screen.getByLabelText('Upload new from local file')); + expect(screen.getByLabelText('Model file format')).not.toHaveAttribute('readonly'); + + await user.click(screen.getByLabelText('Upload new from URL')); + expect(screen.getByLabelText('Model file format')).not.toHaveAttribute('readonly'); + }); + + it('should reset form changes after clicking "Cancel" button', async () => { + const user = userEvent.setup(); + render(); + + // model url field should display the default url + expect(screen.getByLabelText('Uploaded artifact details(URL)')).toHaveDisplayValue( + 'http://url.to/artifact.zip' + ); + // model file format should display the default file format + expect(screen.getByLabelText('Model file format')).toHaveDisplayValue('Torchscript(.pt)'); + + await user.click(screen.getByLabelText('edit version artifact')); + await user.click(screen.getByLabelText('Upload new from URL')); + + // update model url + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + await user.clear(urlInput); + await user.type(urlInput, 'http://ur.new/artifact.zip'); + + // select another model file format + const input = screen.getByLabelText(/model file format/i); + await user.click(input); + const listBox = screen.getByRole('listBox'); + await user.click(within(listBox).getByText('ONNX(.onnx)')); + + await user.click(screen.getByLabelText('cancel edit version artifact')); + expect(screen.getByLabelText('Uploaded artifact details(URL)')).toHaveDisplayValue( + 'http://url.to/artifact.zip' + ); + expect(screen.getByLabelText('Model file format')).toHaveDisplayValue('Torchscript(.pt)'); + }); + + it('should NOT change artifact(via URL) after selecting "Keep existing"', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit version artifact')); + + // model url field should display the default url + expect(screen.getByLabelText('Uploaded artifact details(URL)')).toHaveDisplayValue( + 'http://url.to/artifact.zip' + ); + + // starts to edit url + await user.click(screen.getByLabelText('Upload new from URL')); + // update model url + const urlInput = screen.getByLabelText(/url/i, { + selector: 'input[type="text"]', + }); + await user.clear(urlInput); + await user.type(urlInput, 'http://ur.new/artifact.zip'); + + await user.click(screen.getByLabelText('Keep existing')); + expect(screen.getByLabelText('Uploaded artifact details(URL)')).toHaveDisplayValue( + 'http://url.to/artifact.zip' + ); + }); + + it('should NOT change artifact(via file upload) after selecting "Keep existing"', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByLabelText('edit version artifact')); + + // model url field should display the default url + expect(screen.getByLabelText('Uploaded artifact details(URL)')).toHaveDisplayValue( + 'http://url.to/artifact.zip' + ); + + // starts to upload model file + await user.click(screen.getByLabelText('Upload new from local file')); + const modelFileInput = screen.getByLabelText(/^file$/i); + const file = new File(['test model file'], 'model.zip', { type: 'application/zip' }); + await user.upload(modelFileInput, file); + + await user.click(screen.getByLabelText('Keep existing')); + expect(screen.getByLabelText('Uploaded artifact details(URL)')).toHaveDisplayValue( + 'http://url.to/artifact.zip' + ); + }); +}); diff --git a/public/components/model_version/__tests__/version_artifact_source.test.tsx b/public/components/model_version/__tests__/version_artifact_source.test.tsx new file mode 100644 index 00000000..3f07568b --- /dev/null +++ b/public/components/model_version/__tests__/version_artifact_source.test.tsx @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import userEvent from '@testing-library/user-event'; +import React from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import { render, screen } from '../../../../test/test_utils'; +import { VersionArtifactSource } from '../version_artifact_source'; + +const TestApp = () => { + const form = useForm({ + defaultValues: { artifactSource: 'source_not_changed' }, + }); + + return ( + + + + ); +}; + +describe('', () => { + it('should display artifact sources', () => { + render(); + expect(screen.getByLabelText('Keep existing')).toBeInTheDocument(); + expect(screen.getByLabelText('Keep existing')).toBeChecked(); + expect(screen.getByLabelText('Upload new from local file')).toBeInTheDocument(); + expect(screen.getByLabelText('Upload new from URL')).toBeInTheDocument(); + }); + + it('should change selected artifact source', async () => { + const user = userEvent.setup(); + render(); + await user.click(screen.getByLabelText('Upload new from local file')); + expect(screen.getByLabelText('Upload new from local file')).toBeChecked(); + }); +}); diff --git a/public/components/model_version/__tests__/version_information.test.tsx b/public/components/model_version/__tests__/version_information.test.tsx index a857d940..762f81d2 100644 --- a/public/components/model_version/__tests__/version_information.test.tsx +++ b/public/components/model_version/__tests__/version_information.test.tsx @@ -7,11 +7,11 @@ import userEvent from '@testing-library/user-event'; import React from 'react'; import { FormProvider, useForm } from 'react-hook-form'; import { render, screen } from '../../../../test/test_utils'; -import { ModelFileFormData, ModelUrlFormData } from '../types'; +import { ModelVersionFormData } from '../types'; import { ModelVersionInformation } from '../version_information'; const TestApp = () => { - const form = useForm({ + const form = useForm({ defaultValues: { versionNotes: 'test_version_notes' }, }); diff --git a/public/components/model_version/model_version.tsx b/public/components/model_version/model_version.tsx index 83a82e9f..9562e819 100644 --- a/public/components/model_version/model_version.tsx +++ b/public/components/model_version/model_version.tsx @@ -28,7 +28,7 @@ import { ModelVersionDetails } from './version_details'; import { ModelVersionInformation } from './version_information'; import { ModelVersionArtifact } from './version_artifact'; import { ModelVersionTags } from './version_tags'; -import { ModelFileFormData, ModelUrlFormData } from './types'; +import { ModelVersionFormData } from './types'; export const ModelVersion = () => { const { id: modelId } = useParams<{ id: string }>(); @@ -37,7 +37,7 @@ export const ModelVersion = () => { const history = useHistory(); const modelName = model?.name; const modelVersion = model?.model_version; - const form = useForm(); + const form = useForm(); const onVersionChange = useCallback( ({ newVersion, newId }: { newVersion: string; newId: string }) => { @@ -73,9 +73,12 @@ export const ModelVersion = () => { { key: 'Precision', value: '0.64', type: 'number' as const }, { key: 'Task', value: 'Image classification', type: 'string' as const }, ], // TODO: read from model.tags - configuration: JSON.stringify(model.model_config), + configuration: JSON.stringify(model.model_config, undefined, 2), modelFileFormat: model.model_format, // TODO: read model url or model filename + artifactSource: 'source_not_changed', + // modelFile: new File([], 'artifact.zip'), + modelURL: 'http://url.to/artifact.zip', }); } }, [model, form]); diff --git a/public/components/model_version/types.ts b/public/components/model_version/types.ts index 8d080e1d..a99ab7b5 100644 --- a/public/components/model_version/types.ts +++ b/public/components/model_version/types.ts @@ -5,17 +5,12 @@ import type { Tag } from '../model/types'; -interface FormDataBase { +export interface ModelVersionFormData { + artifactSource?: 'source_not_changed' | 'source_from_computer' | 'source_from_url'; versionNotes?: string; tags?: Tag[]; - configuration: string; - modelFileFormat: string; -} - -export interface ModelFileFormData extends FormDataBase { - modelFile: File; -} - -export interface ModelUrlFormData extends FormDataBase { - modelURL: string; + configuration?: string; + modelFileFormat?: string; + modelFile?: File; + modelURL?: string; } diff --git a/public/components/model_version/version_artifact.tsx b/public/components/model_version/version_artifact.tsx index f3c29cf1..c2b8a0ae 100644 --- a/public/components/model_version/version_artifact.tsx +++ b/public/components/model_version/version_artifact.tsx @@ -3,20 +3,81 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React from 'react'; import { - EuiButton, EuiFlexGroup, EuiFlexItem, EuiHorizontalRule, EuiPanel, EuiSpacer, EuiTitle, + EuiButton, + EuiText, + EuiLink, + EuiCode, } from '@elastic/eui'; +import React, { useRef, useEffect, useState, useCallback } from 'react'; +import { useFormContext, useFormState, useWatch } from 'react-hook-form'; +import { ModelFileUploader, UploadHelpText } from '../common/forms/artifact_file'; +import { ModelArtifactUrl } from '../common/forms/artifact_url'; +import { ModelConfiguration } from '../common/forms/model_configuration'; +import { ModelFileFormatSelect } from '../common/forms/model_file_format'; +import { ModelVersionFormData } from './types'; +import { VersionArtifactSource } from './version_artifact_source'; export const ModelVersionArtifact = () => { + const form = useFormContext(); + // Returned formState is wrapped with Proxy to improve render performance and + // skip extra computation if specific state is not subscribed, so make sure you + // deconstruct or read it before render in order to enable the subscription. + const { defaultValues } = useFormState({ control: form.control }); + const [readOnly, setReadOnly] = useState(true); + const formRef = useRef(form); + formRef.current = form; + + const artifactSource = useWatch({ + name: 'artifactSource', + control: form.control, + }); + + const onCancel = useCallback(() => { + formRef.current.reset(); + setReadOnly(true); + }, []); + + const onSourceChange = useCallback((source: string) => { + if (source === 'source_not_changed') { + formRef.current.resetField('modelFile'); + formRef.current.resetField('modelURL'); + } + }, []); + + useEffect(() => { + // reset form value to default when component unmounted, this makes sure + // the unsaved changes are dropped when the component unmounted + return () => { + formRef.current.reset(); + }; + }, []); + + const renderArtifactInput = () => { + if (artifactSource === 'source_not_changed') { + if (defaultValues && defaultValues.modelFile) { + return ; + } + if (defaultValues && defaultValues.modelURL) { + return ; + } + } + if (artifactSource === 'source_from_computer') { + return ; + } + if (artifactSource === 'source_from_url') { + return ; + } + }; + return ( - + @@ -24,11 +85,67 @@ export const ModelVersionArtifact = () => { - Edit + {readOnly ? ( + setReadOnly(false)}> + Edit + + ) : ( + + Cancel + + )} + + + + +

      Artifact

      + + The zipped artifact must include a model file and a tokenizer file. If uploading with + a local file, keep this browser open until the upload completes.{' '} + + Learn more + + +
      +
      + + {!readOnly && ( + <> + + + + )} + {renderArtifactInput()} + {!readOnly && artifactSource !== 'source_not_changed' && ( + <> + + + + )} + + + +
      + + + + +

      Configuration

      + + The model configuration specifies the model_type,{' '} + embedding_dimension, and framework_type of the + model. + +
      +
      + + + +
      ); }; diff --git a/public/components/model_version/version_artifact_source.tsx b/public/components/model_version/version_artifact_source.tsx new file mode 100644 index 00000000..5ee03d18 --- /dev/null +++ b/public/components/model_version/version_artifact_source.tsx @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EuiRadioGroup, EuiText } from '@elastic/eui'; +import React, { useCallback } from 'react'; +import { useController, useFormContext } from 'react-hook-form'; + +type Source = 'source_not_changed' | 'source_from_computer' | 'source_from_url'; + +const OPTIONS = [ + { + id: 'source_not_changed', + label: 'Keep existing', + }, + { + id: 'source_from_computer', + label: 'Upload new from local file', + }, + { + id: 'source_from_url', + label: 'Upload new from URL', + }, +]; + +interface Props { + onChange?: (source: Source) => void; +} + +export const VersionArtifactSource = ({ onChange }: Props) => { + const { control } = useFormContext<{ + artifactSource: Source; + }>(); + + const sourceArtifactControl = useController({ + name: 'artifactSource', + control, + }); + + const onSourceChange = useCallback( + (id: string) => { + sourceArtifactControl.field.onChange(id); + if (onChange) { + onChange(id as Source); + } + }, + [sourceArtifactControl.field, onChange] + ); + + return ( + +

      Artifact source

      + + ), + }} + /> + ); +}; diff --git a/public/components/model_version/version_information.tsx b/public/components/model_version/version_information.tsx index c2515422..9322e520 100644 --- a/public/components/model_version/version_information.tsx +++ b/public/components/model_version/version_information.tsx @@ -16,10 +16,10 @@ import { import React, { useRef, useEffect, useState, useCallback } from 'react'; import { useFormContext, useFormState } from 'react-hook-form'; import { ModelVersionNotesField } from '../common/forms/model_version_notes_field'; -import { ModelFileFormData, ModelUrlFormData } from './types'; +import { ModelVersionFormData } from './types'; export const ModelVersionInformation = () => { - const form = useFormContext(); + const form = useFormContext(); // Returned formState is wrapped with Proxy to improve render performance and // skip extra computation if specific state is not subscribed, so make sure you // deconstruct or read it before render in order to enable the subscription. @@ -29,17 +29,15 @@ export const ModelVersionInformation = () => { formRef.current = form; const onCancel = useCallback(() => { - form.resetField('versionNotes'); + formRef.current.resetField('versionNotes'); setReadOnly(true); - }, [form]); + }, []); useEffect(() => { // reset form value to default when component unmounted, this makes sure // the unsaved changes are dropped when the component unmounted return () => { - if (formRef.current.formState.dirtyFields.versionNotes) { - formRef.current.resetField('versionNotes'); - } + formRef.current.resetField('versionNotes'); }; }, []); diff --git a/public/components/model_version/version_tags.tsx b/public/components/model_version/version_tags.tsx index 657ebbce..016839fe 100644 --- a/public/components/model_version/version_tags.tsx +++ b/public/components/model_version/version_tags.tsx @@ -18,27 +18,25 @@ import React, { useEffect, useRef, useState, useCallback } from 'react'; import { useFormContext, useFormState } from 'react-hook-form'; import { ModelTagArrayField } from '../common/forms/model_tag_array_field'; import { useModelTags } from '../register_model/register_model.hooks'; -import { ModelFileFormData, ModelUrlFormData } from './types'; +import { ModelVersionFormData } from './types'; export const ModelVersionTags = () => { const [, tags] = useModelTags(); - const form = useFormContext(); + const form = useFormContext(); const { isDirty, dirtyFields } = useFormState({ control: form.control }); const [readOnly, setReadOnly] = useState(true); const formRef = useRef(form); formRef.current = form; const onCancel = useCallback(() => { - form.resetField('tags'); + formRef.current.resetField('tags'); setReadOnly(true); - }, [form]); + }, []); useEffect(() => { // reset form value to default when unmount return () => { - if (formRef.current.formState.dirtyFields.tags) { - formRef.current.resetField('tags'); - } + formRef.current.resetField('tags'); }; }, []); diff --git a/public/components/register_model/__tests__/register_model_configuration.test.tsx b/public/components/register_model/__tests__/register_model_configuration.test.tsx index dad7f926..6ffa00be 100644 --- a/public/components/register_model/__tests__/register_model_configuration.test.tsx +++ b/public/components/register_model/__tests__/register_model_configuration.test.tsx @@ -23,14 +23,6 @@ describe(' Configuration', () => { jest.clearAllMocks(); }); - it('should render a help flyout when click help button', async () => { - const { user } = await setup(); - - expect(screen.getByLabelText('Configuration in JSON')).toBeInTheDocument(); - await user.click(screen.getByTestId('model-configuration-help-button')); - expect(screen.getByRole('dialog')).toBeInTheDocument(); - }); - it('should not allow to submit form if model_type is missing', async () => { // Missing model_type const invalidConfiguration = `{}`; diff --git a/public/components/register_model/artifact.tsx b/public/components/register_model/artifact.tsx index 9498f7fe..17e71779 100644 --- a/public/components/register_model/artifact.tsx +++ b/public/components/register_model/artifact.tsx @@ -3,69 +3,18 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useMemo, useState } from 'react'; -import { - htmlIdGenerator, - EuiSpacer, - EuiText, - EuiRadio, - EuiLink, - EuiFormRow, - EuiComboBox, - EuiComboBoxOptionOption, -} from '@elastic/eui'; -import { useController, useFormContext } from 'react-hook-form'; +import React, { useState } from 'react'; +import { htmlIdGenerator, EuiSpacer, EuiText, EuiRadio, EuiLink } from '@elastic/eui'; -import { ModelFileUploader } from './artifact_file'; -import { ArtifactUrl } from './artifact_url'; -import { ONE_GB } from '../../../common/constant'; -import { MAX_MODEL_FILE_SIZE } from './constants'; -import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; - -const FILE_FORMAT_OPTIONS = [ - { - label: 'ONNX(.onnx)', - value: 'ONNX', - }, - { - label: 'Torchscript(.pt)', - value: 'TORCH_SCRIPT', - }, -]; +import { ModelFileUploader, UploadHelpText } from '../common/forms/artifact_file'; +import { ModelArtifactUrl } from '../common/forms/artifact_url'; +import { ModelFileFormatSelect } from '../common/forms/model_file_format'; export const ArtifactPanel = () => { - const { control } = useFormContext(); const [selectedSource, setSelectedSource] = useState<'source_from_computer' | 'source_from_url'>( 'source_from_computer' ); - const modelFileFormatController = useController({ - name: 'modelFileFormat', - control, - rules: { - required: { - value: true, - message: 'Model file format is required. Select a model file format.', - }, - }, - }); - - const { ref: fileFormatInputRef, ...fileFormatField } = modelFileFormatController.field; - - const selectedFileFormatOption = useMemo(() => { - if (fileFormatField.value) { - return FILE_FORMAT_OPTIONS.find((fmt) => fmt.value === fileFormatField.value); - } - }, [fileFormatField]); - - const onFileFormatChange = useCallback( - (options: Array>) => { - const value = options[0]?.value; - fileFormatField.onChange(value); - }, - [fileFormatField] - ); - return (
      @@ -99,33 +48,11 @@ export const ArtifactPanel = () => { /> {selectedSource === 'source_from_computer' && } - {selectedSource === 'source_from_url' && } + {selectedSource === 'source_from_url' && } - - Accepted file format: ZIP (.zip). Maximum size, {MAX_MODEL_FILE_SIZE / ONE_GB}GB. - - - The ZIP mush include the following contents: -
        -
      • Model File, accepted formats: Torchscript(.pt), ONNX(.onnx)
      • -
      • Tokenizer file, accepted format: JSON(.json)
      • -
      -
      + - - - +
      ); }; diff --git a/public/components/register_model/artifact_file.tsx b/public/components/register_model/artifact_file.tsx deleted file mode 100644 index b5a12b2d..00000000 --- a/public/components/register_model/artifact_file.tsx +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React from 'react'; -import { EuiFormRow, EuiFilePicker } from '@elastic/eui'; -import { useController, useFormContext } from 'react-hook-form'; - -import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -import { CUSTOM_FORM_ERROR_TYPES, MAX_MODEL_FILE_SIZE } from './constants'; - -function validateFileSize(file?: File) { - if (file && file.size > MAX_MODEL_FILE_SIZE) { - return 'Maximum file size exceeded. Add a smaller file.'; - } - return true; -} - -export const ModelFileUploader = () => { - const { control } = useFormContext(); - const modelFileFieldController = useController({ - name: 'modelFile', - control, - rules: { - required: { value: true, message: 'A file is required. Add a file.' }, - validate: { - [CUSTOM_FORM_ERROR_TYPES.FILE_SIZE_EXCEED_LIMIT]: validateFileSize, - }, - }, - shouldUnregister: true, - }); - - return ( - - { - modelFileFieldController.field.onChange(fileList?.item(0)); - }} - /> - - ); -}; diff --git a/public/components/register_model/artifact_url.tsx b/public/components/register_model/artifact_url.tsx deleted file mode 100644 index cde55f5c..00000000 --- a/public/components/register_model/artifact_url.tsx +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import React from 'react'; -import { EuiFormRow, htmlIdGenerator, EuiFieldText } from '@elastic/eui'; -import { useController, useFormContext } from 'react-hook-form'; - -import { FORM_ITEM_WIDTH } from './form_constants'; -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -import { URL_REGEX } from '../../utils/regex'; - -export const ArtifactUrl = () => { - const { control } = useFormContext(); - const modelUrlFieldController = useController({ - name: 'modelURL', - control, - rules: { - required: { value: true, message: 'URL is required. Enter a URL.' }, - pattern: { value: URL_REGEX, message: 'URL is invalid. Enter a valid URL.' }, - }, - shouldUnregister: true, - }); - - return ( - - - - ); -}; diff --git a/public/components/register_model/model_configuration.tsx b/public/components/register_model/model_configuration.tsx index 73436742..c89b84e4 100644 --- a/public/components/register_model/model_configuration.tsx +++ b/public/components/register_model/model_configuration.tsx @@ -3,101 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useState } from 'react'; -import { - EuiFormRow, - EuiCodeEditor, - EuiText, - EuiTextColor, - EuiCode, - EuiSpacer, - EuiButtonEmpty, -} from '@elastic/eui'; -import { useController, useFormContext } from 'react-hook-form'; +import React from 'react'; +import { EuiText, EuiTextColor, EuiCode, EuiSpacer } from '@elastic/eui'; import '../../ace-themes/sql_console.js'; -import { FORM_ITEM_WIDTH } from './form_constants'; -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; -import { HelpFlyout } from './help_flyout'; -import { CUSTOM_FORM_ERROR_TYPES } from './constants'; -import { ErrorMessage } from './error_message'; - -function validateConfigurationObject(value: string) { - try { - JSON.parse(value.trim()); - } catch { - return 'JSON is invalid. Enter a valid JSON'; - } - return true; -} - -function validateModelType(value: string) { - try { - const config = JSON.parse(value.trim()); - if (!('model_type' in config)) { - return 'model_type is required. Specify the model_type'; - } - } catch { - return true; - } - return true; -} - -function validateModelTypeValue(value: string) { - try { - const config = JSON.parse(value.trim()); - if ('model_type' in config && typeof config.model_type !== 'string') { - return 'model_type must be a string'; - } - } catch { - return true; - } - return true; -} - -function validateEmbeddingDimensionValue(value: string) { - try { - const config = JSON.parse(value.trim()); - if ('embedding_dimension' in config && typeof config.embedding_dimension !== 'number') { - return 'embedding_dimension must be a number'; - } - } catch { - return true; - } - - return true; -} - -function validateFrameworkTypeValue(value: string) { - try { - const config = JSON.parse(value.trim()); - if ('framework_type' in config && typeof config.framework_type !== 'string') { - return 'framework_type must be a string'; - } - } catch { - return true; - } - return true; -} +import { ModelConfiguration } from '../common/forms/model_configuration'; export const ConfigurationPanel = () => { - const { control } = useFormContext(); - const [isHelpVisible, setIsHelpVisible] = useState(false); - const configurationFieldController = useController({ - name: 'configuration', - control, - rules: { - required: { value: true, message: 'Configuration is required.' }, - validate: { - [CUSTOM_FORM_ERROR_TYPES.INVALID_CONFIGURATION]: validateConfigurationObject, - [CUSTOM_FORM_ERROR_TYPES.CONFIGURATION_MISSING_MODEL_TYPE]: validateModelType, - [CUSTOM_FORM_ERROR_TYPES.INVALID_MODEL_TYPE_VALUE]: validateModelTypeValue, - [CUSTOM_FORM_ERROR_TYPES.INVALID_EMBEDDING_DIMENSION_VALUE]: validateEmbeddingDimensionValue, - [CUSTOM_FORM_ERROR_TYPES.INVALID_FRAMEWORK_TYPE_VALUE]: validateFrameworkTypeValue, - }, - }, - }); - return (
      @@ -121,39 +33,7 @@ export const ConfigurationPanel = () => { - } - labelAppend={ - setIsHelpVisible(true)} - size="xs" - color="primary" - data-test-subj="model-configuration-help-button" - > - Help - - } - > - configurationFieldController.field.onChange(value)} - setOptions={{ - fontSize: '14px', - enableBasicAutocompletion: true, - enableLiveAutocompletion: true, - }} - /> - - {isHelpVisible && setIsHelpVisible(false)} />} +
      ); }; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 526ae4fb..7153c81d 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -37,7 +37,7 @@ import { modelRepositoryManager } from '../../utils/model_repository_manager'; import { modelTaskManager } from './model_task_manager'; import { ModelVersionNotesPanel } from './model_version_notes'; import { modelFileUploadManager } from './model_file_upload_manager'; -import { MAX_CHUNK_SIZE, FORM_ERRORS } from './constants'; +import { MAX_CHUNK_SIZE, FORM_ERRORS } from '../common/forms/form_constants'; import { ModelDetailsPanel } from './model_details'; import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; import { ArtifactPanel } from './artifact'; diff --git a/public/components/register_model/register_model_api.ts b/public/components/register_model/register_model_api.ts index f4333398..87ba345c 100644 --- a/public/components/register_model/register_model_api.ts +++ b/public/components/register_model/register_model_api.ts @@ -4,7 +4,7 @@ */ import { APIProvider } from '../../apis/api_provider'; -import { MAX_CHUNK_SIZE } from './constants'; +import { MAX_CHUNK_SIZE } from '../common/forms/form_constants'; import { getModelContentHashValue } from './get_model_content_hash_value'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; From b7289bd6ec7ef7b45feba960d438a1b80dcf54d0 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Wed, 24 May 2023 11:13:32 +0800 Subject: [PATCH 51/75] Feature/add model version loading empty error screens (#185) * feat: add loading, empty, failed and no result screen for model detail version table Signed-off-by: Lin Wang * feat: add versionOrKeyword parameter for model search Signed-off-by: Lin Wang * feat: add search logic to model detail version table Signed-off-by: Lin Wang * refactor: update lastUpdated to lastUpdatedTime in model detail version table Signed-off-by: Lin Wang * feat: add last_registered_time, last_deployed_time and last_undeployed_time to model detail version status cell Signed-off-by: Lin Wang * feat: only show 3 tag columns by default Signed-off-by: Lin Wang * feat: change to use dateFormat in ui settings Signed-off-by: Lin Wang * feat: hide time title for in progress version Signed-off-by: Lin Wang * feat: add sort to model detail version table Signed-off-by: Lin Wang * test: increase timeout to fix test case error in github runner Signed-off-by: Lin Wang * test: change to wait for version header exists Signed-off-by: Lin Wang * chore: address PR comments Signed-off-by: Lin Wang * feat: add UiSettingDateFormatTime Signed-off-by: Lin Wang * feat: change to use UiSettingDateFormatTime Signed-off-by: Lin Wang * fix: model search sort validate with sort pair Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- common/constant.ts | 2 +- common/model.ts | 10 - public/apis/model.ts | 9 +- .../ui_setting_date_format_time.test.tsx | 49 +++++ public/components/common/index.ts | 1 + .../common/ui_setting_date_format_time.tsx | 19 ++ .../__tests__/model_overview_card.test.tsx | 6 +- .../components/model/model_overview_card.tsx | 7 +- .../__tests__/model_version_cell.test.tsx | 11 +- .../model_version_list_filter.test.tsx | 60 +++++ .../model_version_status_detail.test.tsx | 25 ++- .../__tests__/model_version_table.test.tsx | 15 +- .../__tests__/model_versions_panel.test.tsx | 169 +++++++++++++- .../model_version_cell.tsx | 9 +- .../model_version_list_filter.tsx | 31 ++- .../model_version_status_detail.tsx | 38 ++-- .../model_version_table.tsx | 17 +- .../model_versions_panel.tsx | 207 ++++++++++++++++-- public/components/model/types.ts | 5 +- .../model_version/version_details.tsx | 9 +- server/routes/model_router.ts | 46 +++- server/services/model_service.ts | 9 +- server/services/utils/model.ts | 20 ++ 23 files changed, 670 insertions(+), 104 deletions(-) create mode 100644 public/components/common/__tests__/ui_setting_date_format_time.test.tsx create mode 100644 public/components/common/ui_setting_date_format_time.tsx diff --git a/common/constant.ts b/common/constant.ts index 2f828a38..ee0b73b2 100644 --- a/common/constant.ts +++ b/common/constant.ts @@ -8,4 +8,4 @@ export const ONE_GB = 1000 * ONE_MB; export const MAX_MODEL_CHUNK_SIZE = 10 * ONE_MB; -export const DATE_FORMAT = 'MMM d, yyyy @ HH:mm:ss.SSS'; +export const DATE_FORMAT = 'MMM D, yyyy @ HH:mm:ss.SSS'; diff --git a/common/model.ts b/common/model.ts index f81f1e95..eba23bed 100644 --- a/common/model.ts +++ b/common/model.ts @@ -49,13 +49,3 @@ export interface OpenSearchCustomerModel extends OpenSearchModelBase { version: number; planning_worker_nodes: string[]; } - -export type ModelSearchSort = - | 'name-asc' - | 'name-desc' - | 'id-asc' - | 'model_state-asc' - | 'model_state-desc' - | 'id-desc' - | 'version-desc' - | 'version-asc'; diff --git a/public/apis/model.ts b/public/apis/model.ts index 47645000..ed9aaec9 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { MODEL_STATE, ModelSearchSort } from '../../common'; +import { MODEL_STATE } from '../../common'; import { MODEL_API_ENDPOINT, MODEL_LOAD_API_ENDPOINT, @@ -38,6 +38,9 @@ export interface ModelSearchItem { }; last_updated_time: number; created_time: number; + last_registered_time?: number; + last_deployed_time?: number; + last_undeployed_time?: number; } export interface ModelDetail extends ModelSearchItem { @@ -100,13 +103,15 @@ export class Model { public search(query: { algorithms?: string[]; ids?: string[]; - sort?: ModelSearchSort[]; + sort?: string[]; + name?: string; from: number; size: number; states?: MODEL_STATE[]; nameOrId?: string; extraQuery?: Record; dataSourceId?: string; + versionOrKeyword?: string; }) { const { extraQuery, dataSourceId, ...restQuery } = query; return InnerHttpProvider.getHttp().get(MODEL_API_ENDPOINT, { diff --git a/public/components/common/__tests__/ui_setting_date_format_time.test.tsx b/public/components/common/__tests__/ui_setting_date_format_time.test.tsx new file mode 100644 index 00000000..32e77649 --- /dev/null +++ b/public/components/common/__tests__/ui_setting_date_format_time.test.tsx @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import * as PluginContext from '../../../../../../src/plugins/opensearch_dashboards_react/public'; +import { render, screen } from '../../../../test/test_utils'; +import { UiSettingDateFormatTime } from '../ui_setting_date_format_time'; + +// Cannot spyOn(PluginContext, 'useOpenSearchDashboards') directly as it results in error: +// TypeError: Cannot redefine property: useOpenSearchDashboards +// So we have to mock the entire module first as a workaround +jest.mock('../../../../../../src/plugins/opensearch_dashboards_react/public', () => { + return { + __esModule: true, + ...jest.requireActual('../../../../../../src/plugins/opensearch_dashboards_react/public'), + }; +}); + +describe('', () => { + it('should render "-" if time was undefined', () => { + render(); + expect(screen.getByText('-')).toBeInTheDocument(); + }); + + it('should render consistent time text based ui setting', () => { + const opensearchDashboardsMock = jest + .spyOn(PluginContext, 'useOpenSearchDashboards') + .mockReturnValue({ + services: { + uiSettings: { + get: () => 'MMM D, yyyy @ HH:mm:ss', + }, + }, + }); + + render(); + expect(screen.getByText('Apr 28, 2023 @ 10:12:39')).toBeInTheDocument(); + + opensearchDashboardsMock.mockRestore(); + }); + + it('should render consistent time text based default time format', () => { + render(); + expect(screen.getByText('Apr 28, 2023 @ 10:12:39.143')).toBeInTheDocument(); + }); +}); diff --git a/public/components/common/index.ts b/public/components/common/index.ts index ac1b4681..569035d3 100644 --- a/public/components/common/index.ts +++ b/public/components/common/index.ts @@ -12,3 +12,4 @@ export * from './tag_filter'; export * from './options_filter'; export * from './forms'; export * from './tag_key'; +export * from './ui_setting_date_format_time'; diff --git a/public/components/common/ui_setting_date_format_time.tsx b/public/components/common/ui_setting_date_format_time.tsx new file mode 100644 index 00000000..1431c1dc --- /dev/null +++ b/public/components/common/ui_setting_date_format_time.tsx @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; +import { DATE_FORMAT } from '../../../common'; +import { renderTime } from '../../utils'; + +export const UiSettingDateFormatTime = ({ time }: { time: number | undefined }) => { + const { + services: { uiSettings }, + } = useOpenSearchDashboards(); + const dateFormat = uiSettings?.get('dateFormat'); + + return <>{time ? renderTime(time, dateFormat || DATE_FORMAT) : '-'}; +}; diff --git a/public/components/model/__tests__/model_overview_card.test.tsx b/public/components/model/__tests__/model_overview_card.test.tsx index 25e690ce..4946073a 100644 --- a/public/components/model/__tests__/model_overview_card.test.tsx +++ b/public/components/model/__tests__/model_overview_card.test.tsx @@ -25,11 +25,13 @@ describe('', () => { expect(screen.getByText('Foo (you)')).toBeInTheDocument(); expect(screen.getByText('Created')).toBeInTheDocument(); expect( - within(screen.getByText('Created').closest('dl')!).getByText('Apr 24, 2023 8:18 AM') + within(screen.getByText('Created').closest('dl')!).getByText('Apr 24, 2023 @ 08:18:30.318') ).toBeInTheDocument(); expect(screen.getByText('Last updated')).toBeInTheDocument(); expect( - within(screen.getByText('Last updated').closest('dl')!).getByText('Apr 24, 2023 1:18 PM') + within(screen.getByText('Last updated').closest('dl')!).getByText( + 'Apr 24, 2023 @ 13:18:30.318' + ) ).toBeInTheDocument(); expect(screen.getByText('model-1-id')).toBeInTheDocument(); diff --git a/public/components/model/model_overview_card.tsx b/public/components/model/model_overview_card.tsx index 3046d5ad..958ca39e 100644 --- a/public/components/model/model_overview_card.tsx +++ b/public/components/model/model_overview_card.tsx @@ -5,8 +5,7 @@ import { EuiDescriptionList, EuiFlexGroup, EuiFlexItem, EuiPanel, EuiSpacer } from '@elastic/eui'; import React from 'react'; -import { CopyableText } from '../common'; -import { renderTime } from '../../utils'; +import { CopyableText, UiSettingDateFormatTime } from '../common'; interface ModelOverviewCardProps { id: string; @@ -54,7 +53,7 @@ export const ModelOverviewCard = ({ listItems={[ { title: 'Created', - description: renderTime(createdTime, 'MMM D, YYYY h:m A'), + description: , }, ]} /> @@ -64,7 +63,7 @@ export const ModelOverviewCard = ({ listItems={[ { title: 'Last updated', - description: renderTime(updatedTime, 'MMM D, YYYY h:m A'), + description: , }, ]} /> diff --git a/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx index 4c069b32..4e5a24aa 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx @@ -18,7 +18,7 @@ const setup = (options: { columnId: string; isDetails?: boolean }) => version: '1.0.0', state: MODEL_STATE.uploading, tags: {}, - lastUpdated: 1682604957236, + lastUpdatedTime: 1682604957236, createdTime: 1682604957236, }} isDetails={false} @@ -50,7 +50,7 @@ describe('', () => { version: '1.0.0', state: MODEL_STATE.loaded, tags: {}, - lastUpdated: 1682604957236, + lastUpdatedTime: 1682604957236, createdTime: 1682604957236, }} isDetails={false} @@ -67,7 +67,7 @@ describe('', () => { version: '1.0.0', state: MODEL_STATE.partiallyLoaded, tags: {}, - lastUpdated: 1682604957236, + lastUpdatedTime: 1682604957236, createdTime: 1682604957236, }} isDetails={false} @@ -92,15 +92,14 @@ describe('', () => { }); expect(screen.getByText('In progress...')).toBeInTheDocument(); - expect(screen.getByText('Upload initiated on:')).toBeInTheDocument(); }); it('should render consistent last updated', () => { setup({ - columnId: 'lastUpdated', + columnId: 'lastUpdatedTime', }); - expect(screen.getByText('Apr 27, 2023 2:15 PM')).toBeInTheDocument(); + expect(screen.getByText('Apr 27, 2023 @ 14:15:57.236')).toBeInTheDocument(); }); it('should render "model-1" for name column', () => { diff --git a/public/components/model/model_versions_panel/__tests__/model_version_list_filter.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_list_filter.test.tsx index 2a8cb482..8678c378 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_list_filter.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_list_filter.test.tsx @@ -124,4 +124,64 @@ describe('', () => { }, 10 * 1000 ); + + it('should call onChangeMock with search text after search text typed', async () => { + jest.useFakeTimers(); + const onChangeMock = jest.fn(); + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + render( + + ); + + expect(onChangeMock).not.toHaveBeenCalled(); + + await user.type( + screen.getByPlaceholderText('Search by version number, or keyword'), + 'search text' + ); + act(() => { + jest.advanceTimersByTime(500); + }); + expect(onChangeMock).toHaveBeenCalledWith( + expect.objectContaining({ + search: 'search text', + }) + ); + + jest.useRealTimers(); + }); + + it('should bind searchInput and able to clear search text by searchInputRef', async () => { + const onChangeMock = jest.fn(); + let searchInput: HTMLInputElement | null = null; + + render( + { + searchInput = input; + }} + /> + ); + expect(searchInput).not.toBeNull(); + + const searchTextInput = screen.getByPlaceholderText('Search by version number, or keyword'); + await userEvent.type(searchTextInput, 'search text'); + + expect(searchTextInput).toHaveValue('search text'); + searchInput!.value = ''; + expect(searchTextInput).toHaveValue(''); + }); }); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx index f0d92ba3..d620910b 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx @@ -11,7 +11,7 @@ import { ModelVersionStatusDetail } from '../model_version_status_detail'; import { MODEL_STATE } from '../../../../../common'; describe('', () => { - it('should render "In progress...", uploading tip and upload initialized time ', async () => { + it('should render "In progress..." and uploading tip', async () => { render( ', () => { 'href', '/model-registry/model-version/1' ); + }); + + it('should render "Success", deployment tip and deployed time ', async () => { + render( + + ); + + expect(screen.getByText('Success')).toBeInTheDocument(); + expect(screen.getByText(/.*deployed./)).toBeInTheDocument(); + expect(screen.getByText('model-1 version 1')).toHaveAttribute( + 'href', + '/model-registry/model-version/1' + ); - expect(screen.getByText('Upload initiated on:')).toBeInTheDocument(); + expect(screen.getByText('Deployed on:')).toBeInTheDocument(); expect(screen.getByText('May 5, 2023 @ 08:52:53.541')).toBeInTheDocument(); }); @@ -55,6 +75,7 @@ describe('', () => { version="1.0.0" state={MODEL_STATE.loadFailed} createdTime={1683276773541} + lastDeployedTime={1683276773541} /> ); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx index f94f7e24..0b2940e4 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx @@ -18,22 +18,27 @@ const versions = [ version: '1.0.0', state: MODEL_STATE.uploading, tags: { 'Accuracy: test': 0.98, 'Accuracy: train': 0.99 }, - lastUpdated: 1682676759143, + lastUpdatedTime: 1682676759143, createdTime: 1682676759143, }, ]; describe('', () => { - it('should render consistent columns header ', async () => { - render(); + it('should render consistent columns header and hide id, tags columns by default', async () => { + render( + + ); await waitFor(() => { expect(screen.getByTestId('dataGridHeaderCell-version')).toBeInTheDocument(); expect(screen.getByTestId('dataGridHeaderCell-state')).toBeInTheDocument(); expect(screen.getByTestId('dataGridHeaderCell-status')).toBeInTheDocument(); - expect(screen.getByTestId('dataGridHeaderCell-lastUpdated')).toBeInTheDocument(); + expect(screen.getByTestId('dataGridHeaderCell-lastUpdatedTime')).toBeInTheDocument(); expect(screen.getByTestId('dataGridHeaderCell-tags.Accuracy: test')).toBeInTheDocument(); expect(screen.getByTestId('dataGridHeaderCell-tags.Accuracy: train')).toBeInTheDocument(); + expect(screen.getByTestId('dataGridHeaderCell-tags.F1')).toBeInTheDocument(); + expect(screen.queryByTestId('dataGridHeaderCell-id')).not.toBeInTheDocument(); + expect(screen.queryByTestId('dataGridHeaderCell-tags.F2')).not.toBeInTheDocument(); }); }); @@ -98,7 +103,7 @@ describe('', () => { expect(within(gridCells[0]).getByText('1.0.0')).toBeInTheDocument(); expect(within(gridCells[1]).getByText('Not deployed')).toBeInTheDocument(); expect(within(gridCells[2]).getByText('In progress...')).toBeInTheDocument(); - expect(within(gridCells[3]).getByText('Apr 28, 2023 10:12 AM')).toBeInTheDocument(); + expect(within(gridCells[3]).getByText('Apr 28, 2023 @ 10:12:39.143')).toBeInTheDocument(); expect(within(gridCells[4]).getByText('0.98')).toBeInTheDocument(); expect(within(gridCells[5]).getByText('0.99')).toBeInTheDocument(); expect(within(gridCells[6]).getByLabelText('show actions')).toBeInTheDocument(); diff --git a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx index 5bb9c011..547c1a0b 100644 --- a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx @@ -9,6 +9,14 @@ import userEvent from '@testing-library/user-event'; import { render, screen, waitFor, within } from '../../../../../test/test_utils'; import { Model } from '../../../../apis/model'; import { ModelVersionsPanel } from '../model_versions_panel'; +import * as PluginContext from '../../../../../../../src/plugins/opensearch_dashboards_react/public'; + +jest.mock('../../../../../../../src/plugins/opensearch_dashboards_react/public', () => { + return { + __esModule: true, + ...jest.requireActual('../../../../../../../src/plugins/opensearch_dashboards_react/public'), + }; +}); describe('', () => { it( @@ -47,7 +55,12 @@ describe('', () => { it( 'should call model search API again after refresh button clicked', async () => { - const searchMock = jest.spyOn(Model.prototype, 'search'); + const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { + return { + data: [], + total_models: 0, + }; + }); render(); @@ -74,6 +87,8 @@ describe('', () => { await waitFor(() => { expect(searchMock).toHaveBeenLastCalledWith( expect.objectContaining({ + from: 0, + size: 25, states: ['DEPLOYED', 'PARTIALLY_DEPLOYED'], }) ); @@ -93,4 +108,156 @@ describe('', () => { }, 10 * 10000 ); + + it('should render loading screen when calling model search API', async () => { + const searchMock = jest + .spyOn(Model.prototype, 'search') + .mockImplementation(() => new Promise(() => {})); + render(); + + expect(screen.getByText('Loading versions')).toBeInTheDocument(); + + searchMock.mockRestore(); + }); + + it('should render error screen and show error toast after call model search failed', async () => { + const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { + throw new Error(); + }); + const dangerMock = jest.fn(); + const pluginMock = jest.spyOn(PluginContext, 'useOpenSearchDashboards').mockReturnValue({ + notifications: { + toasts: { + danger: dangerMock, + }, + }, + }); + render(); + + await waitFor(() => { + expect(screen.getByText('Failed to load versions')).toBeInTheDocument(); + expect(dangerMock).toHaveBeenCalledWith( + expect.objectContaining({ + title: 'Failed to load data', + }) + ); + }); + + searchMock.mockRestore(); + pluginMock.mockRestore(); + }); + + it('should render empty screen if model no versions', async () => { + const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { + return { + data: [], + total_models: 0, + }; + }); + + render(); + + await waitFor(() => { + expect(screen.getByText('Registered versions will appear here.')).toBeInTheDocument(); + expect(screen.getByText('Register new version')).toBeInTheDocument(); + expect(screen.getByText('Register new version').closest('a')).toHaveAttribute( + 'href', + '/model-registry/register-model/1' + ); + expect(screen.getByText('Read documentation')).toBeInTheDocument(); + }); + + searchMock.mockRestore(); + }); + + it('should render no-result screen and reset search button if no result for specific condition', async () => { + render(); + await waitFor(() => { + expect(screen.getByTitle('Status')).toBeInTheDocument(); + }); + + const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { + return { + data: [], + total_models: 0, + }; + }); + await userEvent.click(screen.getByTitle('Status')); + await userEvent.click(screen.getByRole('option', { name: 'In progress...' })); + + expect( + screen.getByText( + 'There are no results for your search. Reset the search criteria to view registered versions.' + ) + ).toBeInTheDocument(); + expect(screen.getByText('Reset search criteria')).toBeInTheDocument(); + + searchMock.mockRestore(); + }); + + it( + 'should call model search without filter condition after reset button clicked', + async () => { + render(); + await waitFor(() => { + expect(screen.getByTitle('Status')).toBeInTheDocument(); + }); + + const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { + return { + data: [], + total_models: 0, + }; + }); + await userEvent.click(screen.getByTitle('Status')); + await userEvent.click(screen.getByRole('option', { name: 'In progress...' })); + + expect(searchMock).toHaveBeenCalledTimes(1); + await userEvent.click(screen.getByText('Reset search criteria')); + expect(searchMock).toHaveBeenCalledTimes(2); + expect(searchMock).toHaveBeenCalledWith({ + from: 0, + size: 25, + // TODO: Change to model group id once parameter added + ids: expect.any(Array), + }); + + searchMock.mockRestore(); + }, + 10 * 1000 + ); + + it( + 'should only sort by Last updated column after column sort button clicked', + async () => { + render(); + await waitFor(() => { + expect(screen.getByTestId('dataGridHeaderCell-version')).toBeInTheDocument(); + }); + const searchMock = jest.spyOn(Model.prototype, 'search'); + + await userEvent.click( + within(screen.getByTestId('dataGridHeaderCell-version')).getByText('Version') + ); + await userEvent.click(screen.getByTitle('Sort A-Z')); + expect(searchMock).toHaveBeenLastCalledWith( + expect.objectContaining({ + sort: ['version-asc'], + }) + ); + + await userEvent.click( + within(screen.getByTestId('dataGridHeaderCell-lastUpdatedTime')).getByText('Last updated') + ); + await userEvent.click(screen.getByTitle('Sort Z-A')); + expect(searchMock).toHaveBeenLastCalledWith( + expect.objectContaining({ + sort: ['last_updated_time-desc'], + }) + ); + + searchMock.mockRestore(); + }, + 20 * 1000 + ); }); diff --git a/public/components/model/model_versions_panel/model_version_cell.tsx b/public/components/model/model_versions_panel/model_version_cell.tsx index b9f80075..6445baab 100644 --- a/public/components/model/model_versions_panel/model_version_cell.tsx +++ b/public/components/model/model_versions_panel/model_version_cell.tsx @@ -7,9 +7,9 @@ import React from 'react'; import { get } from 'lodash'; import { EuiBadge, EuiText } from '@elastic/eui'; -import { renderTime } from '../../../utils/table'; import { MODEL_STATE } from '../../../../common'; import { VersionTableDataItem } from '../types'; +import { UiSettingDateFormatTime } from '../../common'; import { ModelVersionStatusCell } from './model_version_status_cell'; import { ModelVersionStatusDetail } from './model_version_status_detail'; @@ -29,6 +29,9 @@ export const ModelVersionCell = ({ data, columnId, isDetails }: ModelVersionCell version={data.version} state={data.state} createdTime={data.createdTime} + lastRegisteredTime={data.lastRegisteredTime} + lastDeployedTime={data.lastDeployedTime} + lastUndeployedTime={data.lastUndeployedTime} /> ); } @@ -47,8 +50,8 @@ export const ModelVersionCell = ({ data, columnId, isDetails }: ModelVersionCell ); } - case 'lastUpdated': - return renderTime(data.lastUpdated, 'MMM D, YYYY h:m A'); + case 'lastUpdatedTime': + return ; default: return get(data, columnId, '-'); } diff --git a/public/components/model/model_versions_panel/model_version_list_filter.tsx b/public/components/model/model_versions_panel/model_version_list_filter.tsx index dc998ab0..6a90b2b0 100644 --- a/public/components/model/model_versions_panel/model_version_list_filter.tsx +++ b/public/components/model/model_versions_panel/model_version_list_filter.tsx @@ -13,7 +13,13 @@ import { EuiSpacer, } from '@elastic/eui'; -import { TagFilterValue, TagFilter, OptionsFilter, SelectedTagFiltersPanel } from '../../common'; +import { + TagFilterValue, + TagFilter, + OptionsFilter, + SelectedTagFiltersPanel, + DebouncedSearchBar, +} from '../../common'; import { useModelTagKeys } from '../../model_list/model_list.hooks'; const statusOptions = [ @@ -46,14 +52,20 @@ export interface ModelVersionListFilterValue { status: Array; state: Array; tag: TagFilterValue[]; + search?: string; } interface ModelVersionListFilterProps { - value: ModelVersionListFilterValue; + searchInputRef?: (input: HTMLInputElement | null) => void; + value: Omit; onChange: (value: ModelVersionListFilterValue) => void; } -export const ModelVersionListFilter = ({ value, onChange }: ModelVersionListFilterProps) => { +export const ModelVersionListFilter = ({ + searchInputRef, + value, + onChange, +}: ModelVersionListFilterProps) => { // TODO: Change to model tags API and pass model group id here const [tagKeysLoading, tagKeys] = useModelTagKeys(); const valueRef = useRef(value); @@ -80,11 +92,22 @@ export const ModelVersionListFilter = ({ value, onChange }: ModelVersionListFilt [onChange] ); + const handleSearch = useCallback( + (search: string) => { + onChange({ ...valueRef.current, search }); + }, + [onChange] + ); + return ( <> - + diff --git a/public/components/model/model_versions_panel/model_version_status_detail.tsx b/public/components/model/model_versions_panel/model_version_status_detail.tsx index 5d3d0c5a..79de4339 100644 --- a/public/components/model/model_versions_panel/model_version_status_detail.tsx +++ b/public/components/model/model_versions_panel/model_version_status_detail.tsx @@ -12,11 +12,11 @@ import { EuiPopoverTitle, EuiLink, } from '@elastic/eui'; - import { Link, generatePath } from 'react-router-dom'; -import { DATE_FORMAT, MODEL_STATE, routerPaths } from '../../../../common'; + +import { MODEL_STATE, routerPaths } from '../../../../common'; +import { UiSettingDateFormatTime } from '../../common'; import { APIProvider } from '../../../apis/api_provider'; -import { renderTime } from '../../../utils'; import { ModelVersionErrorDetailsModal } from './model_version_error_details_modal'; @@ -25,8 +25,8 @@ export const state2DetailContentMap: { [key in MODEL_STATE]?: { title: string; description: (versionLink: React.ReactNode) => React.ReactNode; - timeTitle: string; - timeField: 'createdTime'; + timeTitle?: string; + timeField?: 'createdTime' | 'lastRegisteredTime' | 'lastDeployedTime' | 'lastUndeployedTime'; }; } = { [MODEL_STATE.uploading]: { @@ -34,16 +34,12 @@ export const state2DetailContentMap: { description: (versionLink: React.ReactNode) => ( <>The model artifact for {versionLink} is uploading. ), - timeTitle: 'Upload initiated on', - timeField: 'createdTime', }, [MODEL_STATE.loading]: { title: 'In progress...', description: (versionLink: React.ReactNode) => ( <>The model artifact for {versionLink} is deploying. ), - timeTitle: 'Deployment initiated on', - timeField: 'createdTime', }, [MODEL_STATE.uploaded]: { title: 'Success', @@ -51,25 +47,25 @@ export const state2DetailContentMap: { <>The model artifact for {versionLink} uploaded. ), timeTitle: 'Uploaded on', - timeField: 'createdTime', + timeField: 'lastRegisteredTime', }, [MODEL_STATE.loaded]: { title: 'Success', description: (versionLink: React.ReactNode) => <>{versionLink} deployed., timeTitle: 'Deployed on', - timeField: 'createdTime', + timeField: 'lastDeployedTime', }, [MODEL_STATE.unloaded]: { title: 'Success', description: (versionLink: React.ReactNode) => <>{versionLink} undeployed., timeTitle: 'Undeployed on', - timeField: 'createdTime', + timeField: 'lastUndeployedTime', }, [MODEL_STATE.loadFailed]: { title: 'Error', description: (versionLink: React.ReactNode) => <>{versionLink} deployment failed., timeTitle: 'Deployment failed on', - timeField: 'createdTime', + timeField: 'lastDeployedTime', }, [MODEL_STATE.registerFailed]: { title: 'Error', @@ -99,6 +95,9 @@ export const ModelVersionStatusDetail = ({ name: string; version: string; createdTime: number; + lastRegisteredTime?: number; + lastDeployedTime?: number; + lastUndeployedTime?: number; }) => { const [isErrorDetailsModalShowed, setIsErrorDetailsModalShowed] = useState(false); const [isLoadingErrorDetails, setIsLoadingErrorDetails] = useState(false); @@ -143,6 +142,7 @@ export const ModelVersionStatusDetail = ({ return <>-; } const { title, description, timeTitle, timeField } = statusContent; + const timeValue = timeField ? restProps[timeField] : undefined; return ( <> @@ -161,10 +161,14 @@ export const ModelVersionStatusDetail = ({ )} - - - {timeTitle}: {renderTime(restProps[timeField], DATE_FORMAT)} - + {timeTitle && ( + <> + + + {timeTitle}: + + + )} {(state === MODEL_STATE.loadFailed || state === MODEL_STATE.registerFailed) && ( <> diff --git a/public/components/model/model_versions_panel/model_version_table.tsx b/public/components/model/model_versions_panel/model_version_table.tsx index 43a6ea59..ff26ed34 100644 --- a/public/components/model/model_versions_panel/model_version_table.tsx +++ b/public/components/model/model_versions_panel/model_version_table.tsx @@ -70,7 +70,7 @@ export const ModelVersionTable = ({ isSortable: false, }, { - id: 'lastUpdated', + id: 'lastUpdatedTime', displayAsText: 'Last updated', }, { @@ -119,9 +119,18 @@ export const ModelVersionTable = ({ ], [versions] ); - const [visibleColumns, setVisibleColumns] = useState(() => - columns.map(({ id }) => id).filter((columnId) => columnId !== 'id') - ); + const [visibleColumns, setVisibleColumns] = useState(() => { + const tagHiddenByDefaultColumns = tags.slice(3); + return columns + .map(({ id }) => id) + .filter((columnId) => { + if (columnId.startsWith('tags.')) { + const [_prefix, tag] = columnId.split('.'); + return !tagHiddenByDefaultColumns.includes(tag); + } + return columnId !== 'id'; + }); + }); const columnVisibility = useMemo(() => ({ visibleColumns, setVisibleColumns }), [visibleColumns]); const renderCellValue = useCallback( diff --git a/public/components/model/model_versions_panel/model_versions_panel.tsx b/public/components/model/model_versions_panel/model_versions_panel.tsx index aacd6d75..9a576e51 100644 --- a/public/components/model/model_versions_panel/model_versions_panel.tsx +++ b/public/components/model/model_versions_panel/model_versions_panel.tsx @@ -3,26 +3,36 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useMemo, useState, useCallback } from 'react'; +import React, { useMemo, useState, useCallback, useEffect, useRef } from 'react'; import { EuiButton, + EuiButtonEmpty, + EuiDataGridSorting, + EuiEmptyPrompt, EuiFlexGroup, EuiFlexItem, + EuiIcon, + EuiLink, + EuiLoadingSpinner, EuiPanel, EuiSpacer, + EuiText, EuiTextColor, EuiTitle, } from '@elastic/eui'; +import { generatePath } from 'react-router-dom'; import { useFetcher } from '../../../hooks'; import { APIProvider } from '../../../apis/api_provider'; -import { MODEL_STATE } from '../../../../common'; +import { MODEL_STATE, routerPaths } from '../../../../common'; +import { useOpenSearchDashboards } from '../../../../../../src/plugins/opensearch_dashboards_react/public'; import { ModelVersionTable } from './model_version_table'; import { ModelVersionListFilter, ModelVersionListFilterValue } from './model_version_list_filter'; // TODO: Use tags from model group const tags = ['Tag1', 'Tag2']; +const emptyPromptStyle = { maxWidth: 528 }; const modelState2StatusMap: { [key in MODEL_STATE]?: ModelVersionListFilterValue['status'][number]; @@ -64,11 +74,21 @@ const getStatesParam = ({ }); }; +const getSortParam = (sort: Array<{ id: string; direction: 'asc' | 'desc' }>) => { + const id2fieldMap: { [key: string]: string } = { + lastUpdatedTime: 'last_updated_time', + }; + return sort.length > 0 + ? sort.map(({ id, direction }) => `${id2fieldMap[id] || id}-${direction}`) + : undefined; +}; + interface ModelVersionsPanelProps { groupId: string; } export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { + const searchInputRef = useRef(null); const [params, setParams] = useState<{ pageIndex: number; pageSize: number; @@ -84,14 +104,20 @@ export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { tag: [], }, }); - const { data: versionsData, reload } = useFetcher(APIProvider.getAPI('model').search, { - // TODO: Change to model group id - ids: [groupId], - from: params.pageIndex * params.pageSize, - size: params.pageSize, - states: getStatesParam(params.filter), - }); + const { data: versionsData, reload, loading, error } = useFetcher( + APIProvider.getAPI('model').search, + { + // TODO: Change to model group id + ids: [groupId], + from: params.pageIndex * params.pageSize, + size: params.pageSize, + states: getStatesParam(params.filter), + versionOrKeyword: params.filter.search, + sort: getSortParam(params.sort), + } + ); const totalVersionCount = versionsData?.total_models; + const { notifications } = useOpenSearchDashboards(); const versions = useMemo(() => { if (!versionsData) { @@ -102,10 +128,13 @@ export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { name: item.name, version: item.model_version, state: item.model_state, - lastUpdated: item.last_updated_time, + lastUpdatedTime: item.last_updated_time, // TODO: Change to use tags in model version once structure finalized tags: {}, createdTime: item.created_time, + lastRegisteredTime: item.last_registered_time, + lastDeployedTime: item.last_deployed_time, + lastUndeployedTime: item.last_undeployed_time, })); }, [versionsData]); @@ -126,20 +155,70 @@ export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { }; }, [params.pageIndex, params.pageSize, totalVersionCount]); - const versionsSorting = useMemo( + const versionsSorting = useMemo( () => ({ columns: params.sort, onSort: (sort) => { - setParams((previousParams) => ({ ...previousParams, sort })); + setParams((previousParams) => ({ + ...previousParams, + sort: sort.filter( + (item) => !previousParams.sort.find((previousItem) => previousItem.id === item.id) + ), + })); }, }), [params] ); + const panelStatus = useMemo(() => { + if (loading) { + return 'loading'; + } + if (error) { + return 'error'; + } + const { tag, state, status, search } = params.filter; + if ( + totalVersionCount === 0 && + (tag.length > 0 || state.length > 0 || status.length > 0 || !!search) + ) { + return 'no-result'; + } + if (totalVersionCount === 0) { + return 'empty'; + } + return 'normal'; + }, [totalVersionCount, loading, error, params]); + const handleFilterChange = useCallback((filter: ModelVersionListFilterValue) => { - setParams((previousParams) => ({ ...previousParams, filter })); + setParams((previousParams) => ({ ...previousParams, pageIndex: 0, filter })); }, []); + const handleResetSearch = useCallback(() => { + setParams((previousParams) => ({ + ...previousParams, + filter: { tag: [], state: [], status: [] }, + })); + if (searchInputRef.current) { + searchInputRef.current.value = ''; + } + }, []); + + const bindSearchInputSearch = useCallback((input: HTMLInputElement | null) => { + if (searchInputRef) { + searchInputRef.current = input; + } + }, []); + + useEffect(() => { + if (error) { + notifications.toasts.danger({ + title: 'Failed to load data', + body: 'Check your internet connection.', + }); + } + }, [error, notifications.toasts]); + return ( @@ -147,7 +226,7 @@ export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => {

      Versions - {typeof totalVersionCount === 'number' && ( + {typeof totalVersionCount === 'number' && panelStatus !== 'empty' && (  ({totalVersionCount}) )}

      @@ -157,15 +236,97 @@ export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { Refresh
      - - - + {panelStatus !== 'empty' && ( + + )} + + {(panelStatus === 'normal' || panelStatus === 'no-result') && ( + + )} + {panelStatus === 'loading' && ( + + + + + +

      + Loading versions +

      +
      + + } + /> + )} + {panelStatus === 'error' && ( + + + + + +

      Failed to load versions

      +
      + + Check your internet connection + + + + } + /> + )} + {panelStatus === 'no-result' && ( + + + + There are no results for your search. Reset the search criteria to view registered + versions. + + + Reset search criteria + + + } + /> + )} + {panelStatus === 'empty' && ( + + + Registered versions will appear here. + + + Register new version + + + {/* TODO: Update to real link after confirmed */} + + Read documentation + + + + } + /> + )} ); }; diff --git a/public/components/model/types.ts b/public/components/model/types.ts index 537fa9d4..01fe3ea0 100644 --- a/public/components/model/types.ts +++ b/public/components/model/types.ts @@ -11,9 +11,12 @@ export interface VersionTableDataItem { name: string; version: string; state: MODEL_STATE; - lastUpdated: number; + lastUpdatedTime: number; tags: { [key: string]: string | number }; createdTime: number; + lastRegisteredTime?: number; + lastDeployedTime?: number; + lastUndeployedTime?: number; } export interface Tag { diff --git a/public/components/model_version/version_details.tsx b/public/components/model_version/version_details.tsx index ac6b1795..24b02645 100644 --- a/public/components/model_version/version_details.tsx +++ b/public/components/model_version/version_details.tsx @@ -14,8 +14,7 @@ import { EuiCopy, EuiIcon, } from '@elastic/eui'; -import { renderTime } from '../../utils'; -import { DATE_FORMAT } from '../../../common/constant'; +import { UiSettingDateFormatTime } from '../common'; interface Props { description?: string; @@ -57,14 +56,16 @@ export const ModelVersionDetails = ({

      Created

      - {createdTime ? renderTime(createdTime, DATE_FORMAT) : '-'} + + +

      Last updated

      - {lastUpdatedTime ? renderTime(lastUpdatedTime, DATE_FORMAT) : '-'} +
      diff --git a/server/routes/model_router.ts b/server/routes/model_router.ts index 97a1f7d2..f6264e97 100644 --- a/server/routes/model_router.ts +++ b/server/routes/model_router.ts @@ -16,16 +16,29 @@ import { } from './constants'; import { getOpenSearchClientTransport } from './utils'; -const modelSortQuerySchema = schema.oneOf([ - schema.literal('version-desc'), - schema.literal('version-asc'), - schema.literal('name-asc'), - schema.literal('name-desc'), - schema.literal('model_state-asc'), - schema.literal('model_state-desc'), - schema.literal('id-asc'), - schema.literal('id-desc'), -]); +const validateSortItem = (sort: string) => { + const [key, direction] = sort.split('-'); + if (key === undefined || direction === undefined) { + return 'Invalidate sort'; + } + if (direction !== 'asc' && direction !== 'desc') { + return 'Invalidate sort'; + } + const availableSortKeys = ['id', 'version', 'last_updated_time', 'name', 'model_state']; + + if (!availableSortKeys.includes(key) && !key.startsWith('tags.')) { + return 'Invalidate sort'; + } + return undefined; +}; + +const validateUniqueSort = (sort: string[]) => { + const uniqueSortKeys = new Set(sort.map((item) => item.split('-')[0])); + if (uniqueSortKeys.size < sort.length) { + return 'Invalidate sort'; + } + return undefined; +}; const modelStateSchema = schema.oneOf([ schema.literal(MODEL_STATE.loaded), @@ -66,13 +79,20 @@ export const modelRouter = (services: { modelService: ModelService }, router: IR path: MODEL_API_ENDPOINT, validate: { query: schema.object({ + name: schema.maybe(schema.string()), from: schema.number({ min: 0 }), size: schema.number({ max: 50 }), sort: schema.maybe( - schema.oneOf([modelSortQuerySchema, schema.arrayOf(modelSortQuerySchema)]) + schema.oneOf([ + schema.string({ validate: validateSortItem }), + schema.arrayOf(schema.string({ validate: validateSortItem }), { + validate: validateUniqueSort, + }), + ]) ), states: schema.maybe(schema.oneOf([schema.arrayOf(modelStateSchema), modelStateSchema])), nameOrId: schema.maybe(schema.string()), + versionOrKeyword: schema.maybe(schema.string()), extra_query: schema.maybe(schema.recordOf(schema.string(), schema.any())), data_source_id: schema.maybe(schema.string()), }), @@ -83,10 +103,12 @@ export const modelRouter = (services: { modelService: ModelService }, router: IR from, size, sort, + name, states, nameOrId, extra_query: extraQuery, data_source_id: dataSourceId, + versionOrKeyword, } = request.query; try { const payload = await ModelService.search({ @@ -97,9 +119,11 @@ export const modelRouter = (services: { modelService: ModelService }, router: IR from, size, sort: typeof sort === 'string' ? [sort] : sort, + name, states: typeof states === 'string' ? [states] : states, nameOrId, extraQuery, + versionOrKeyword, }); return response.ok({ body: payload }); } catch (err) { diff --git a/server/services/model_service.ts b/server/services/model_service.ts index efd80d09..18afd5b6 100644 --- a/server/services/model_service.ts +++ b/server/services/model_service.ts @@ -24,11 +24,11 @@ import { ScopeableRequest, ILegacyClusterClient, } from '../../../../src/core/server'; -import { MODEL_STATE, ModelSearchSort } from '../../common'; +import { MODEL_STATE } from '../../common'; -import { convertModelSource, generateModelSearchQuery } from './utils/model'; -import { MODEL_BASE_API, MODEL_META_API, MODEL_UPLOAD_API } from './utils/constants'; +import { generateModelSearchQuery } from './utils/model'; import { RecordNotFoundError } from './errors'; +import { MODEL_BASE_API, MODEL_META_API, MODEL_UPLOAD_API } from './utils/constants'; const modelSortFieldMapping: { [key: string]: string } = { version: 'model_version', @@ -84,11 +84,12 @@ export class ModelService { transport: OpenSearchClient['transport']; from: number; size: number; - sort?: ModelSearchSort[]; + sort?: string[]; name?: string; states?: MODEL_STATE[]; extraQuery?: Record; nameOrId?: string; + versionOrKeyword?: string; }) { const { body: { hits }, diff --git a/server/services/utils/model.ts b/server/services/utils/model.ts index 708fdfc6..c53fd7f5 100644 --- a/server/services/utils/model.ts +++ b/server/services/utils/model.ts @@ -27,6 +27,7 @@ export const generateModelSearchQuery = ({ states, nameOrId, extraQuery, + versionOrKeyword, }: { ids?: string[]; algorithms?: string[]; @@ -34,6 +35,7 @@ export const generateModelSearchQuery = ({ states?: MODEL_STATE[]; nameOrId?: string; extraQuery?: Record; + versionOrKeyword?: string; }) => ({ bool: { must: [ @@ -66,6 +68,24 @@ export const generateModelSearchQuery = ({ ] : []), ...(extraQuery ? [extraQuery] : []), + ...(versionOrKeyword + ? [ + { + bool: { + should: [ + { + wildcard: { + model_version: { + value: `*${versionOrKeyword}*`, + case_insensitive: true, + }, + }, + }, + ], + }, + }, + ] + : []), ], must_not: { exists: { From 8ab8a3cabe867f72bcd17da9bac697f98e334b36 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Wed, 24 May 2023 15:27:04 +0800 Subject: [PATCH 52/75] Feature/add deploy confirmation modal (#186) * feat: move ModelVersionErrorDetailsModal to common folder Signed-off-by: Lin Wang * feat: add undeployment-failed mode and plainVersionLink Signed-off-by: Lin Wang * feat: add model version deployment confirm modal Signed-off-by: Lin Wang * feat: add deploy confirm modal in model version table row actions Signed-off-by: Lin Wang * feat: address PR comments Signed-off-by: Lin Wang * refactor: update mode to errorType Signed-off-by: Lin Wang * test: increase model version panel test timeout Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- public/components/common/index.ts | 1 + ..._version_deployment_confirm_modal.test.tsx | 347 ++++++++++++++++++ ...model_version_error_details_modal.test.tsx | 47 ++- public/components/common/modals/index.ts | 7 + ...model_version_deployment_confirm_modal.tsx | 168 +++++++++ .../model_version_error_details_modal.tsx | 31 +- .../model_version_table_row_actions.test.tsx | 89 ++++- .../__tests__/model_versions_panel.test.tsx | 46 +-- .../model_version_status_detail.tsx | 6 +- .../model_version_table.tsx | 6 +- .../model_version_table_row_actions.tsx | 171 ++++++--- 11 files changed, 819 insertions(+), 100 deletions(-) create mode 100644 public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx rename public/components/{model/model_versions_panel => common/modals}/__tests__/model_version_error_details_modal.test.tsx (60%) create mode 100644 public/components/common/modals/index.ts create mode 100644 public/components/common/modals/model_version_deployment_confirm_modal.tsx rename public/components/{model/model_versions_panel => common/modals}/model_version_error_details_modal.tsx (63%) diff --git a/public/components/common/index.ts b/public/components/common/index.ts index 569035d3..beb7a86c 100644 --- a/public/components/common/index.ts +++ b/public/components/common/index.ts @@ -13,3 +13,4 @@ export * from './options_filter'; export * from './forms'; export * from './tag_key'; export * from './ui_setting_date_format_time'; +export * from './modals'; diff --git a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx new file mode 100644 index 00000000..b4b3016a --- /dev/null +++ b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx @@ -0,0 +1,347 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; +import { EuiToast } from '@elastic/eui'; + +import { render, screen, waitFor } from '../../../../../test/test_utils'; +import { ModelVersionDeploymentConfirmModal } from '../model_version_deployment_confirm_modal'; +import { Model } from '../../../../apis/model'; + +import * as PluginContext from '../../../../../../../src/plugins/opensearch_dashboards_react/public'; +import { MountWrapper } from '../../../../../../../src/core/public/utils'; +import { MountPoint } from 'opensearch-dashboards/public'; +import { OverlayModalOpenOptions } from 'src/core/public/overlays'; + +// Cannot spyOn(PluginContext, 'useOpenSearchDashboards') directly as it results in error: +// TypeError: Cannot redefine property: useOpenSearchDashboards +// So we have to mock the entire module first as a workaround +jest.mock('../../../../../../../src/plugins/opensearch_dashboards_react/public', () => { + return { + __esModule: true, + ...jest.requireActual('../../../../../../../src/plugins/opensearch_dashboards_react/public'), + }; +}); + +const generateToastMock = () => + jest.fn((toastInput) => { + render( + + ) + } + > + {typeof toastInput !== 'string' && + (typeof toastInput.text !== 'string' && toastInput.text ? ( + + ) : ( + toastInput.text + ))} + + ); + }); + +const mockAddDangerAndOverlay = () => { + return jest.spyOn(PluginContext, 'useOpenSearchDashboards').mockReturnValue({ + services: { + notifications: { + toasts: { + addDanger: generateToastMock(), + }, + }, + overlays: { + openModal: jest.fn((modelMountPoint: MountPoint, options?: OverlayModalOpenOptions) => { + const { unmount } = render(); + return { + onClose: Promise.resolve(), + close: async () => { + unmount(); + }, + }; + }), + }, + }, + }); +}; + +describe('', () => { + describe('model=deploy', () => { + it('should render deploy title and confirm message', () => { + render( + + ); + + expect(screen.getByTestId('confirmModalTitleText')).toHaveTextContent( + 'Deploy model-1 version 1' + ); + expect(screen.getByText('This version will begin deploying.')).toBeInTheDocument(); + expect(screen.getByText('model-1 version 1')).toHaveAttribute( + 'href', + '/model-registry/model-version/1' + ); + }); + + it('should call model load after deploy button clicked', async () => { + const modelLoadMock = jest + .spyOn(Model.prototype, 'load') + .mockReturnValue(Promise.resolve({ task_id: 'foo', status: 'succeeded' })); + render( + + ); + + expect(modelLoadMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByRole('button', { name: 'Deploy' })); + expect(modelLoadMock).toHaveBeenCalledTimes(1); + + modelLoadMock.mockRestore(); + }); + + it('should show error toast if model load throw error', async () => { + const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); + const modelLoadMock = jest + .spyOn(Model.prototype, 'load') + .mockRejectedValue(new Error('error')); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Deploy' })); + + expect(screen.getByText('deployment failed.')).toBeInTheDocument(); + expect(screen.getByText('See full error')).toBeInTheDocument(); + + modelLoadMock.mockRestore(); + useOpenSearchDashboardsMock.mockRestore(); + }); + + it('should show full error after "See full error" clicked', async () => { + const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); + const modelLoadMock = jest + .spyOn(Model.prototype, 'load') + .mockRejectedValue(new Error('This is a full error message.')); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Deploy' })); + await userEvent.click(screen.getByText('See full error')); + + expect(screen.getByText('Error message:')).toBeInTheDocument(); + expect(screen.getByText('This is a full error message.')).toBeInTheDocument(); + + modelLoadMock.mockRestore(); + useOpenSearchDashboardsMock.mockRestore(); + }); + + it('should hide full error after close button clicked', async () => { + const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); + const modelLoadMock = jest + .spyOn(Model.prototype, 'load') + .mockRejectedValue(new Error('This is a full error message.')); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Deploy' })); + await userEvent.click(screen.getByText('See full error')); + await userEvent.click(screen.getByText('Close')); + + expect(screen.queryByText('This is a full error message.')).not.toBeInTheDocument(); + + modelLoadMock.mockRestore(); + useOpenSearchDashboardsMock.mockRestore(); + }); + }); + + describe('model=undeploy', () => { + it('should render undeploy title and confirm message', () => { + render( + + ); + + expect(screen.getByTestId('confirmModalTitleText')).toHaveTextContent( + 'Undeploy model-1 version 1' + ); + expect( + screen.getByText('This version will be undeployed. You can deploy it again later.') + ).toBeInTheDocument(); + expect(screen.getByText('model-1 version 1')).toHaveAttribute( + 'href', + '/model-registry/model-version/1' + ); + }); + + it('should call model unload after undeploy button clicked', async () => { + const modelLoadMock = jest.spyOn(Model.prototype, 'unload').mockImplementation(); + render( + + ); + + expect(modelLoadMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); + expect(modelLoadMock).toHaveBeenCalledTimes(1); + + modelLoadMock.mockRestore(); + }); + + it('should show success toast after modal unload success', async () => { + const useOpenSearchDashboardsMock = jest + .spyOn(PluginContext, 'useOpenSearchDashboards') + .mockReturnValue({ + services: { + notifications: { + toasts: { + addSuccess: generateToastMock(), + }, + }, + }, + }); + const modelLoadMock = jest.spyOn(Model.prototype, 'unload').mockImplementation(); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); + + await waitFor(() => { + expect(screen.getByTestId('euiToastHeader')).toHaveTextContent( + 'Undeployed model-1 version 1' + ); + }); + + modelLoadMock.mockRestore(); + useOpenSearchDashboardsMock.mockRestore(); + }); + + it('should show error toast if model unload throw error', async () => { + const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); + const modelLoadMock = jest + .spyOn(Model.prototype, 'unload') + .mockRejectedValue(new Error('error')); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); + + expect(screen.getByText('undeployment failed.')).toBeInTheDocument(); + expect(screen.getByText('See full error')).toBeInTheDocument(); + + modelLoadMock.mockRestore(); + useOpenSearchDashboardsMock.mockRestore(); + }); + + it('should show full error after "See full error" clicked', async () => { + const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); + const modelLoadMock = jest + .spyOn(Model.prototype, 'unload') + .mockRejectedValue(new Error('This is a full error message.')); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); + await userEvent.click(screen.getByText('See full error')); + + expect(screen.getByText('Error message:')).toBeInTheDocument(); + expect(screen.getByText('This is a full error message.')).toBeInTheDocument(); + + modelLoadMock.mockRestore(); + useOpenSearchDashboardsMock.mockRestore(); + }); + + it('should hide full error after close button clicked', async () => { + const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); + const modelLoadMock = jest + .spyOn(Model.prototype, 'unload') + .mockRejectedValue(new Error('This is a full error message.')); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); + await userEvent.click(screen.getByText('See full error')); + await userEvent.click(screen.getByText('Close')); + + expect(screen.queryByText('This is a full error message.')).not.toBeInTheDocument(); + + modelLoadMock.mockRestore(); + useOpenSearchDashboardsMock.mockRestore(); + }); + }); +}); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_error_details_modal.test.tsx b/public/components/common/modals/__tests__/model_version_error_details_modal.test.tsx similarity index 60% rename from public/components/model/model_versions_panel/__tests__/model_version_error_details_modal.test.tsx rename to public/components/common/modals/__tests__/model_version_error_details_modal.test.tsx index b1c1d21e..4a90cbef 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_error_details_modal.test.tsx +++ b/public/components/common/modals/__tests__/model_version_error_details_modal.test.tsx @@ -5,6 +5,7 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; +import { render as originRender } from '@testing-library/react'; import { render, screen } from '../../../../../test/test_utils'; import { ModelVersionErrorDetailsModal } from '../model_version_error_details_modal'; @@ -18,6 +19,7 @@ describe('', () => { version="3" errorDetails="Error message" closeModal={jest.fn()} + errorType="artifact-upload-failed" /> ); @@ -38,7 +40,7 @@ describe('', () => { version="3" errorDetails={'{"foo": "bar"}'} closeModal={jest.fn()} - isDeployFailed + errorType="deployment-failed" /> ); @@ -61,7 +63,7 @@ describe('', () => { version="3" errorDetails={'{"foo": "bar"}'} closeModal={closeModalMock} - isDeployFailed + errorType="deployment-failed" /> ); @@ -72,4 +74,45 @@ describe('', () => { await userEvent.click(screen.getByLabelText('Closes this modal window')); expect(closeModalMock).toHaveBeenCalledTimes(2); }); + + it('should render undeployment failed screen', () => { + render( + + ); + + expect(screen.getByText('model-1-name version 3')).toBeInTheDocument(); + expect(screen.getByText('undeployment failed')).toBeInTheDocument(); + expect(screen.getByText('model-1-name version 3')).toHaveAttribute( + 'href', + '/model-registry/model-version/model-1-id' + ); + expect(screen.getByText('{"foo": "bar"}')).toBeInTheDocument(); + expect(screen.getByLabelText('Copy')).toBeInTheDocument(); + }); + + it('should render consistent plain model version link without react-router provider', () => { + originRender( + + ); + + expect(screen.getByText('model-1-name version 3')).toHaveAttribute( + 'href', + '/foo/model-registry/model-version/model-1-id' + ); + }); }); diff --git a/public/components/common/modals/index.ts b/public/components/common/modals/index.ts new file mode 100644 index 00000000..fce83901 --- /dev/null +++ b/public/components/common/modals/index.ts @@ -0,0 +1,7 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export { ModelVersionErrorDetailsModal } from './model_version_error_details_modal'; +export { ModelVersionDeploymentConfirmModal } from './model_version_deployment_confirm_modal'; diff --git a/public/components/common/modals/model_version_deployment_confirm_modal.tsx b/public/components/common/modals/model_version_deployment_confirm_modal.tsx new file mode 100644 index 00000000..3d4b36d1 --- /dev/null +++ b/public/components/common/modals/model_version_deployment_confirm_modal.tsx @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useState } from 'react'; +import { + EuiButton, + EuiConfirmModal, + EuiFlexGroup, + EuiFlexItem, + EuiLink, + EuiSpacer, + EuiText, +} from '@elastic/eui'; +import { Link, generatePath, useHistory } from 'react-router-dom'; + +import { useOpenSearchDashboards } from '../../../../../../src/plugins/opensearch_dashboards_react/public'; +import { mountReactNode } from '../../../../../../src/core/public/utils'; +import { routerPaths } from '../../../../common'; +import { APIProvider } from '../../../apis/api_provider'; + +import { ModelVersionErrorDetailsModal } from './model_version_error_details_modal'; + +export const ModelVersionDeploymentConfirmModal = ({ + id, + mode, + name, + version, + closeModal, +}: { + id: string; + mode: 'deploy' | 'undeploy'; + name: string; + version: string; + closeModal: () => void; +}) => { + const [isSubmitting, setIsSubmitting] = useState(false); + const { + services: { notifications, overlays }, + } = useOpenSearchDashboards(); + const history = useHistory(); + const mapping = { + deploy: { + title: 'Deploy', + description: 'This version will begin deploying.', + errorMessage: 'deployment failed.', + errorType: 'deployment-failed' as const, + action: APIProvider.getAPI('model').load, + }, + undeploy: { + title: 'Undeploy', + description: 'This version will be undeployed. You can deploy it again later.', + errorMessage: 'undeployment failed.', + errorType: 'undeployment-failed' as const, + action: APIProvider.getAPI('model').unload, + }, + }; + const { title, description, errorMessage, errorType, action } = mapping[mode]; + + const handleConfirm = useCallback(async () => { + setIsSubmitting(true); + const modelVersionUrl = history.createHref({ + pathname: generatePath(routerPaths.modelVersion, { id }), + }); + try { + await action(id); + } catch (e) { + notifications?.toasts.addDanger({ + title: mountReactNode( + <> + + {name} version {version} + + . + + ), + text: mountReactNode( + <> + {errorMessage} + + + + + { + const overlayRef = overlays?.openModal( + mountReactNode( + { + overlayRef?.close(); + }} + errorDetails={e instanceof Error ? e.message : JSON.stringify(e)} + /> + ) + ); + }} + > + See full error + + + + + ), + }); + return; + } finally { + setIsSubmitting(false); + closeModal(); + } + // The undeploy API call is sync, we can show error message after immediately + if (mode === 'undeploy') { + notifications?.toasts.addSuccess({ + title: mountReactNode( + <> + Undeployed{' '} + + {name} version {version} + + . + + ), + }); + return; + } + // TODO: Implement model version table status updated after integrate model version table automatic refresh status column + }, [ + id, + notifications, + action, + closeModal, + overlays, + history, + name, + version, + errorType, + errorMessage, + mode, + ]); + + return ( + + {title}{' '} + + {name} version {version} + + + } + confirmButtonText={title} + cancelButtonText="Cancel" + onCancel={closeModal} + onConfirm={handleConfirm} + confirmButtonDisabled={isSubmitting} + isLoading={isSubmitting} + maxWidth={500} + > + {description} + + ); +}; diff --git a/public/components/model/model_versions_panel/model_version_error_details_modal.tsx b/public/components/common/modals/model_version_error_details_modal.tsx similarity index 63% rename from public/components/model/model_versions_panel/model_version_error_details_modal.tsx rename to public/components/common/modals/model_version_error_details_modal.tsx index 955c7124..6f6debae 100644 --- a/public/components/model/model_versions_panel/model_version_error_details_modal.tsx +++ b/public/components/common/modals/model_version_error_details_modal.tsx @@ -22,37 +22,52 @@ import { import { routerPaths } from '../../../../common/router_paths'; +const errorType2ErrorTitleMap = { + 'deployment-failed': 'deployment failed', + 'artifact-upload-failed': 'artifact upload failed', + 'undeployment-failed': 'undeployment failed', +}; + export const ModelVersionErrorDetailsModal = ({ id, name, version, + errorType, closeModal, errorDetails, - isDeployFailed, + plainVersionLink, }: { - name: string; id: string; + name: string; version: string; + errorType: 'deployment-failed' | 'artifact-upload-failed' | 'undeployment-failed'; closeModal: () => void; errorDetails: string; - isDeployFailed?: boolean; + plainVersionLink?: string; }) => { + const errorTitle = errorType2ErrorTitleMap[errorType]; + const linkText = `${name} version ${version}`; + return (

      - - {name} version {version} - {' '} - {isDeployFailed ? 'deployment failed' : 'artifact upload failed'} + {plainVersionLink ? ( + {linkText} + ) : ( + + {linkText} + + )}{' '} + {errorTitle}

      - {isDeployFailed ? ( + {errorType === 'deployment-failed' || errorType === 'undeployment-failed' ? ( <> Error message: diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx index 213afbeb..c479e069 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx @@ -10,10 +10,14 @@ import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionTableRowActions } from '../model_version_table_row_actions'; import { MODEL_STATE } from '../../../../../common'; +const setup = (state: MODEL_STATE) => { + return render(); +}; + describe('', () => { it('should render "actions icon" and "Delete" button after clicked', async () => { const user = userEvent.setup(); - render(); + setup(MODEL_STATE.uploading); expect(screen.getByLabelText('show actions')).toBeInTheDocument(); await user.click(screen.getByLabelText('show actions')); @@ -23,39 +27,56 @@ describe('', () => { it('should render "Upload new artifact" button for REGISTER_FAILED state', async () => { const user = userEvent.setup(); - render(); + setup(MODEL_STATE.registerFailed); await user.click(screen.getByLabelText('show actions')); expect(screen.getByText('Upload new artifact')).toBeInTheDocument(); }); - it('should render "Deploy" button for REGISTERED and UNDEPLOYED state', async () => { + it('should render "Deploy" button for REGISTERED, DEPLOY_FAILED and UNDEPLOYED state', async () => { const user = userEvent.setup(); - const { rerender } = render( - - ); + const { rerender } = setup(MODEL_STATE.uploaded); await user.click(screen.getByLabelText('show actions')); expect(screen.getByText('Deploy')).toBeInTheDocument(); - rerender(); + rerender( + + ); + expect(screen.getByText('Deploy')).toBeInTheDocument(); + + rerender( + + ); expect(screen.getByText('Deploy')).toBeInTheDocument(); }); it('should render "Undeploy" button for DEPLOYED and PARTIALLY_DEPLOYED state', async () => { const user = userEvent.setup(); - const { rerender } = render(); + const { rerender } = setup(MODEL_STATE.loaded); await user.click(screen.getByLabelText('show actions')); expect(screen.getByText('Undeploy')).toBeInTheDocument(); - rerender(); + rerender( + + ); expect(screen.getByText('Undeploy')).toBeInTheDocument(); }); it('should call close popover after menuitem click', async () => { const user = userEvent.setup(); - render(); + setup(MODEL_STATE.loaded); await user.click(screen.getByLabelText('show actions')); await user.click(screen.getByText('Delete')); @@ -64,4 +85,52 @@ describe('', () => { expect(screen.queryByText('Delete')).toBeNull(); }); }); + + it('should show deploy confirm modal after "Deploy" button clicked', async () => { + const user = userEvent.setup(); + setup(MODEL_STATE.uploaded); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Deploy')); + + expect(screen.getByTestId('confirmModalTitleText')).toHaveTextContent( + 'Deploy model-1 version 1' + ); + expect(screen.getByText('This version will begin deploying.')).toBeInTheDocument(); + }); + + it('should hide deploy confirm modal after "Cancel" button clicked', async () => { + const user = userEvent.setup(); + setup(MODEL_STATE.uploaded); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Deploy')); + await user.click(screen.getByText('Cancel')); + + expect(screen.queryByText('This version will begin deploying.')).not.toBeInTheDocument(); + }); + + it('should show undeploy confirm modal after "Deploy" button clicked', async () => { + const user = userEvent.setup(); + setup(MODEL_STATE.loaded); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Undeploy')); + + expect(screen.getByTestId('confirmModalTitleText')).toHaveTextContent( + 'Undeploy model-1 version 1' + ); + expect( + screen.getByText('This version will be undeployed. You can deploy it again later.') + ).toBeInTheDocument(); + }); + + it('should hide undeploy confirm modal after "Cancel" button clicked', async () => { + const user = userEvent.setup(); + setup(MODEL_STATE.loaded); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Undeploy')); + await user.click(screen.getByText('Cancel')); + + expect( + screen.queryByText('This version will be undeployed. You can deploy it again later.') + ).not.toBeInTheDocument(); + }); }); diff --git a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx index 547c1a0b..0449f8a7 100644 --- a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx @@ -170,30 +170,34 @@ describe('', () => { searchMock.mockRestore(); }); - it('should render no-result screen and reset search button if no result for specific condition', async () => { - render(); - await waitFor(() => { - expect(screen.getByTitle('Status')).toBeInTheDocument(); - }); + it( + 'should render no-result screen and reset search button if no result for specific condition', + async () => { + render(); + await waitFor(() => { + expect(screen.getByTitle('Status')).toBeInTheDocument(); + }); - const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { - return { - data: [], - total_models: 0, - }; - }); - await userEvent.click(screen.getByTitle('Status')); - await userEvent.click(screen.getByRole('option', { name: 'In progress...' })); + const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { + return { + data: [], + total_models: 0, + }; + }); + await userEvent.click(screen.getByTitle('Status')); + await userEvent.click(screen.getByRole('option', { name: 'In progress...' })); - expect( - screen.getByText( - 'There are no results for your search. Reset the search criteria to view registered versions.' - ) - ).toBeInTheDocument(); - expect(screen.getByText('Reset search criteria')).toBeInTheDocument(); + expect( + screen.getByText( + 'There are no results for your search. Reset the search criteria to view registered versions.' + ) + ).toBeInTheDocument(); + expect(screen.getByText('Reset search criteria')).toBeInTheDocument(); - searchMock.mockRestore(); - }); + searchMock.mockRestore(); + }, + 10 * 1000 + ); it( 'should call model search without filter condition after reset button clicked', diff --git a/public/components/model/model_versions_panel/model_version_status_detail.tsx b/public/components/model/model_versions_panel/model_version_status_detail.tsx index 79de4339..0508319b 100644 --- a/public/components/model/model_versions_panel/model_version_status_detail.tsx +++ b/public/components/model/model_versions_panel/model_version_status_detail.tsx @@ -18,7 +18,7 @@ import { MODEL_STATE, routerPaths } from '../../../../common'; import { UiSettingDateFormatTime } from '../../common'; import { APIProvider } from '../../../apis/api_provider'; -import { ModelVersionErrorDetailsModal } from './model_version_error_details_modal'; +import { ModelVersionErrorDetailsModal } from '../../common'; // TODO: Change to related time field after confirmed export const state2DetailContentMap: { @@ -192,7 +192,9 @@ export const ModelVersionStatusDetail = ({ name={name} version={version} errorDetails={errorDetails} - isDeployFailed={state === MODEL_STATE.loadFailed} + errorType={ + state === MODEL_STATE.loadFailed ? 'deployment-failed' : 'artifact-upload-failed' + } closeModal={handleCloseModal} /> )} diff --git a/public/components/model/model_versions_panel/model_version_table.tsx b/public/components/model/model_versions_panel/model_version_table.tsx index ff26ed34..3d3347e3 100644 --- a/public/components/model/model_versions_panel/model_version_table.tsx +++ b/public/components/model/model_versions_panel/model_version_table.tsx @@ -112,8 +112,10 @@ export const ModelVersionTable = ({ width: 40, headerCellRender: () => null, rowCellRender: ({ rowIndex }: EuiDataGridCellValueElementProps) => { - const version = versions[rowIndex]; - return ; + const { id, name, version, state } = versions[rowIndex]; + return ( + + ); }, }, ], diff --git a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx index e08f37bf..61b4bdc9 100644 --- a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx +++ b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx @@ -7,9 +7,22 @@ import React, { useState, useCallback } from 'react'; import { EuiPopover, EuiButtonIcon, EuiContextMenuPanel, EuiContextMenuItem } from '@elastic/eui'; import { MODEL_STATE } from '../../../../common'; +import { ModelVersionDeploymentConfirmModal } from '../../common'; -export const ModelVersionTableRowActions = ({ state, id }: { state: MODEL_STATE; id: string }) => { +export const ModelVersionTableRowActions = ({ + state, + id, + name, + version, +}: { + state: MODEL_STATE; + id: string; + name: string; + version: string; +}) => { const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const [isDeployConfirmModalShow, setIsDeployConfirmModalShow] = useState(false); + const [isUndeployConfirmModalShow, setIsUndeployConfirmModalShow] = useState(false); const handleShowActionsClick = useCallback(() => { setIsPopoverOpen((flag) => !flag); @@ -19,62 +32,110 @@ export const ModelVersionTableRowActions = ({ state, id }: { state: MODEL_STATE; setIsPopoverOpen(false); }, []); + const handleDeployClick = useCallback(() => { + setIsDeployConfirmModalShow(true); + }, []); + + const handleUndeployClick = useCallback(() => { + setIsUndeployConfirmModalShow(true); + }, []); + + const closeDeployConfirmModal = useCallback(() => { + setIsDeployConfirmModalShow(false); + }, []); + + const closeUndeployConfirmModal = useCallback(() => { + setIsUndeployConfirmModalShow(false); + }, []); + return ( - + + } + closePopover={closePopover} + ownFocus={false} + > +
      + + Upload new artifact + , + ] + : []), + ...(state === MODEL_STATE.uploaded || + state === MODEL_STATE.unloaded || + state === MODEL_STATE.loadFailed + ? [ + + Deploy + , + ] + : []), + ...(state === MODEL_STATE.loaded || state === MODEL_STATE.partiallyLoaded + ? [ + + Undeploy + , + ] + : []), + + Delete + , + ]} + /> +
      +
      + {isDeployConfirmModalShow && ( + - } - closePopover={closePopover} - ownFocus={false} - > -
      - - Upload new artifact - , - ] - : []), - ...(state === MODEL_STATE.uploaded || state === MODEL_STATE.unloaded - ? [ - - Deploy - , - ] - : []), - ...(state === MODEL_STATE.loaded || state === MODEL_STATE.partiallyLoaded - ? [ - - Undeploy - , - ] - : []), - - Delete - , - ]} + )} + {isUndeployConfirmModalShow && ( + -
      -
      + )} + ); }; From edec198e2dcccf677e0c640e400b32f78cec9402 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 26 May 2023 09:09:31 +0800 Subject: [PATCH 53/75] Feature/create model group before register (#192) * feat: add model group related API Signed-off-by: Lin Wang * feat: call model group register and delete when model register Signed-off-by: Lin Wang * feat: check name unique from model group search Signed-off-by: Lin Wang * fix: register model group call in register model version Signed-off-by: Lin Wang * fix: model delete after model version register failed Signed-off-by: Lin Wang * feat: add model access control related fields for create model group API Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- public/apis/api_provider.ts | 9 ++ public/apis/model.ts | 5 +- public/apis/model_group.ts | 70 +++++++++ .../common/forms/model_name_field.tsx | 4 +- public/components/global_breadcrumbs.tsx | 2 +- .../__tests__/register_model_api.test.ts | 144 ++++++++++++++++++ .../__tests__/register_model_details.test.tsx | 7 +- .../__tests__/register_model_form.test.tsx | 49 ++---- .../register_model/__tests__/setup.tsx | 3 - .../register_model/register_model.tsx | 39 +++-- .../register_model/register_model.types.ts | 4 +- .../register_model/register_model_api.ts | 94 ++++++++---- server/plugin.ts | 2 + server/routes/constants.ts | 1 + server/routes/index.ts | 1 + server/routes/model_group_router.ts | 139 +++++++++++++++++ server/routes/model_router.ts | 5 +- server/services/model_group_service.ts | 132 ++++++++++++++++ server/services/model_service.ts | 8 +- server/services/utils/constants.ts | 4 + test/mocks/handlers.ts | 2 + test/mocks/model_group_handlers.ts | 45 ++++++ 22 files changed, 664 insertions(+), 105 deletions(-) create mode 100644 public/apis/model_group.ts create mode 100644 public/components/register_model/__tests__/register_model_api.test.ts create mode 100644 server/routes/model_group_router.ts create mode 100644 server/services/model_group_service.ts create mode 100644 test/mocks/model_group_handlers.ts diff --git a/public/apis/api_provider.ts b/public/apis/api_provider.ts index bcb8b109..ebbc6465 100644 --- a/public/apis/api_provider.ts +++ b/public/apis/api_provider.ts @@ -6,6 +6,7 @@ import { Connector } from './connector'; import { Model } from './model'; import { ModelAggregate } from './model_aggregate'; +import { ModelGroup } from './model_group'; import { ModelRepository } from './model_repository'; import { Profile } from './profile'; import { Security } from './security'; @@ -19,6 +20,7 @@ const apiInstanceStore: { security: Security | undefined; task: Task | undefined; modelRepository: ModelRepository | undefined; + modelGroup: ModelGroup | undefined; } = { model: undefined, modelAggregate: undefined, @@ -27,6 +29,7 @@ const apiInstanceStore: { security: undefined, task: undefined, modelRepository: undefined, + modelGroup: undefined, }; export class APIProvider { @@ -37,6 +40,7 @@ export class APIProvider { public static getAPI(type: 'connector'): Connector; public static getAPI(type: 'security'): Security; public static getAPI(type: 'modelRepository'): ModelRepository; + public static getAPI(type: 'modelGroup'): ModelGroup; public static getAPI(type: keyof typeof apiInstanceStore) { if (apiInstanceStore[type]) { return apiInstanceStore[type]!; @@ -77,6 +81,11 @@ export class APIProvider { apiInstanceStore.modelRepository = newInstance; return newInstance; } + case 'modelGroup': { + const newInstance = new ModelGroup(); + apiInstanceStore.modelGroup = newInstance; + return newInstance; + } } } public static clear() { diff --git a/public/apis/model.ts b/public/apis/model.ts index ed9aaec9..89e66355 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -84,10 +84,11 @@ export interface ModelProfileResponse { interface UploadModelBase { name: string; - version: string; - description: string; + version?: string; + description?: string; modelFormat: string; modelConfig: Record; + modelGroupId: string; } export interface UploadModelByURL extends UploadModelBase { diff --git a/public/apis/model_group.ts b/public/apis/model_group.ts new file mode 100644 index 00000000..b5ca09d5 --- /dev/null +++ b/public/apis/model_group.ts @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { MODEL_GROUP_API_ENDPOINT } from '../../server/routes/constants'; +import { InnerHttpProvider } from './inner_http_provider'; + +interface ModelGroupSearchItem { + id: string; + owner: { + backend_roles: string[]; + roles: string[]; + name: string; + }; + latest_version: number; + last_updated_time: number; + name: string; + description?: string; +} + +export interface ModelGroupSearchResponse { + data: ModelGroupSearchItem[]; + total_model_groups: number; +} + +export class ModelGroup { + public register(body: { + name: string; + description?: string; + modelAccessMode: 'public' | 'restricted' | 'private'; + backendRoles?: string[]; + addAllBackendRoles?: boolean; + }) { + return InnerHttpProvider.getHttp().post<{ model_group_id: string; status: 'CREATED' }>( + MODEL_GROUP_API_ENDPOINT, + { + body: JSON.stringify(body), + } + ); + } + + public update({ id, name, description }: { id: string; name?: string; description?: string }) { + return InnerHttpProvider.getHttp().put<{ status: 'success' }>( + `${MODEL_GROUP_API_ENDPOINT}/${id}`, + { + body: JSON.stringify({ + name, + description, + }), + } + ); + } + + public delete(id: string) { + return InnerHttpProvider.getHttp().delete<{ status: 'success' }>( + `${MODEL_GROUP_API_ENDPOINT}/${id}` + ); + } + + public search(query: { id?: string; name?: string; from: number; size: number }) { + return InnerHttpProvider.getHttp().get(MODEL_GROUP_API_ENDPOINT, { + query, + }); + } + + public getOne = async (id: string) => { + return (await this.search({ id, from: 0, size: 1 })).data[0]; + }; +} diff --git a/public/components/common/forms/model_name_field.tsx b/public/components/common/forms/model_name_field.tsx index 658e4ed8..514c42ab 100644 --- a/public/components/common/forms/model_name_field.tsx +++ b/public/components/common/forms/model_name_field.tsx @@ -24,12 +24,12 @@ interface ModelNameFieldProps { } const isDuplicateModelName = async (name: string) => { - const searchResult = await APIProvider.getAPI('model').search({ + const searchResult = await APIProvider.getAPI('modelGroup').search({ name, from: 0, size: 1, }); - return searchResult.total_models >= 1; + return searchResult.total_model_groups >= 1; }; export const ModelNameField = ({ diff --git a/public/components/global_breadcrumbs.tsx b/public/components/global_breadcrumbs.tsx index 946efbbd..cebdc3bd 100644 --- a/public/components/global_breadcrumbs.tsx +++ b/public/components/global_breadcrumbs.tsx @@ -47,7 +47,7 @@ const getModelRegisterBreadcrumbs = (basename: string, matchedParams: {}) => { staticBreadcrumbs: baseModelRegistryBreadcrumbs, // TODO: Change to model group API asyncBreadcrumbsLoader: () => - APIProvider.getAPI('model') + APIProvider.getAPI('modelGroup') .getOne(modelId) .then( (model) => diff --git a/public/components/register_model/__tests__/register_model_api.test.ts b/public/components/register_model/__tests__/register_model_api.test.ts new file mode 100644 index 00000000..a8682c45 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_api.test.ts @@ -0,0 +1,144 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ModelGroup } from '../../../apis/model_group'; +import { Model } from '../../../apis/model'; +import { submitModelWithFile, submitModelWithURL } from '../register_model_api'; + +describe('register model api', () => { + beforeEach(() => { + jest + .spyOn(ModelGroup.prototype, 'register') + .mockResolvedValue({ model_group_id: 'foo', status: 'success' }); + jest.spyOn(ModelGroup.prototype, 'delete').mockResolvedValue({ status: 'success' }); + jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ task_id: 'foo', model_id: 'bar' }); + }); + + afterEach(() => { + jest.spyOn(ModelGroup.prototype, 'register').mockRestore(); + jest.spyOn(ModelGroup.prototype, 'delete').mockRestore(); + jest.spyOn(Model.prototype, 'upload').mockRestore(); + }); + + it('should not call register model group API if modelId provided', async () => { + expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); + + await submitModelWithFile({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + modelId: 'a-exists-model-id', + modelFile: new File([], 'artifact.zip'), + }); + + expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); + }); + + it('should not call delete model group API if modelId provided and model upload failed', async () => { + const uploadError = new Error(); + const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); + + try { + await submitModelWithFile({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + modelId: 'a-exists-model-id', + modelFile: new File([], 'artifact.zip'), + }); + } catch (error) { + expect(error).toBe(uploadError); + } + expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); + + uploadMock.mockRestore(); + }); + + describe('submitModelWithFile', () => { + it('should call register model group API with name and description', async () => { + expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); + + await submitModelWithFile({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + modelFile: new File([], 'artifact.zip'), + }); + + expect(ModelGroup.prototype.register).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'foo', + description: 'bar', + }) + ); + }); + + it('should delete created model group API upload failed', async () => { + const uploadError = new Error(); + const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); + + expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); + try { + await submitModelWithFile({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + modelFile: new File([], 'artifact.zip'), + }); + } catch (error) { + expect(uploadError).toBe(error); + } + expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('foo'); + + uploadMock.mockRestore(); + }); + }); + + describe('submitModelWithURL', () => { + it('should call register model group API with name and description', async () => { + expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); + + await submitModelWithURL({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + modelURL: 'https://address.to/artifact.zip', + }); + + expect(ModelGroup.prototype.register).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'foo', + description: 'bar', + }) + ); + }); + + it('should delete created model group API upload failed', async () => { + const uploadError = new Error(); + const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); + + expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); + try { + await submitModelWithURL({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + modelURL: 'https://address.to/artifact.zip', + }); + } catch (error) { + expect(uploadError).toBe(error); + } + expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('foo'); + + uploadMock.mockRestore(); + }); + }); +}); diff --git a/public/components/register_model/__tests__/register_model_details.test.tsx b/public/components/register_model/__tests__/register_model_details.test.tsx index c441eb00..4cb42b67 100644 --- a/public/components/register_model/__tests__/register_model_details.test.tsx +++ b/public/components/register_model/__tests__/register_model_details.test.tsx @@ -5,7 +5,6 @@ import { setup } from './setup'; import * as formAPI from '../register_model_api'; -import { Model } from '../../../apis/model'; describe(' Details', () => { const onSubmitMock = jest.fn().mockResolvedValue('model_id'); @@ -53,13 +52,9 @@ describe(' Details', () => { it('should NOT submit the register model form if model name is duplicated', async () => { const result = await setup(); - jest.spyOn(Model.prototype, 'search').mockResolvedValue({ - data: [], - total_models: 1, - }); await result.user.clear(result.nameInput); - await result.user.type(result.nameInput, 'a-duplicated-model-name'); + await result.user.type(result.nameInput, 'model1'); await result.user.click(result.submitButton); expect(result.nameInput).toBeInvalid(); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 2569f3d5..5763d851 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -21,30 +21,6 @@ jest.mock('../../../../../../src/plugins/opensearch_dashboards_react/public', () }; }); -const MOCKED_DATA = { - id: 'C7jN0YQBjgpeQQ_RmiDE', - model_version: '1.0.7', - created_time: 1669967223491, - model_config: { - all_config: - '{"_name_or_path":"nreimers/MiniLM-L6-H384-uncased","architectures":["BertModel"],"attention_probs_dropout_prob":0.1,"gradient_checkpointing":false,"hidden_act":"gelu","hidden_dropout_prob":0.1,"hidden_size":384,"initializer_range":0.02,"intermediate_size":1536,"layer_norm_eps":1e-12,"max_position_embeddings":512,"model_type":"bert","num_attention_heads":12,"num_hidden_layers":6,"pad_token_id":0,"position_embedding_type":"absolute","transformers_version":"4.8.2","type_vocab_size":2,"use_cache":true,"vocab_size":30522}', - model_type: 'bert', - embedding_dimension: 384, - framework_type: 'SENTENCE_TRANSFORMERS', - }, - last_loaded_time: 1672895017422, - model_format: 'TORCH_SCRIPT', - last_uploaded_time: 1669967226531, - name: 'all-MiniLM-L6-v2', - model_state: 'LOADED', - total_chunks: 9, - model_content_size_in_bytes: 83408741, - algorithm: 'TEXT_EMBEDDING', - model_content_hash_value: '9376c2ebd7c83f99ec2526323786c348d2382e6d86576f750c89ea544d6bbb14', - current_worker_node_count: 1, - planning_worker_node_count: 1, -}; - describe(' Form', () => { const MOCKED_MODEL_ID = 'model_id'; const addDangerMock = jest.fn(); @@ -71,21 +47,15 @@ describe(' Form', () => { }); it('should init form when id param in url route', async () => { - const mockResult = MOCKED_DATA; - jest.spyOn(Model.prototype, 'getOne').mockResolvedValue(mockResult); - await setup({ route: '/test_model_id', mode: 'version' }); - - const { name } = mockResult; + await setup({ route: '/1', mode: 'version' }); await waitFor(() => { - expect(screen.getByText(name)).toBeInTheDocument(); + expect(screen.getByText('model1')).toBeInTheDocument(); }); }); it('submit button label should be `Register version` when register new version', async () => { - jest.spyOn(Model.prototype, 'getOne').mockResolvedValue(MOCKED_DATA); - - await setup({ route: '/test_model_id', mode: 'version' }); + await setup({ route: '/1', mode: 'version' }); expect(screen.getByRole('button', { name: /register version/i })).toBeInTheDocument(); }); @@ -156,4 +126,17 @@ describe(' Form', () => { await user.click(screen.getByRole('button', { name: /register model/i })); expect(addDangerMock).toHaveBeenCalled(); }); + + it('should call submit with file with provided model id and name', async () => { + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); + const { user } = await setup({ route: '/1', mode: 'version' }); + await user.click(screen.getByRole('button', { name: /register version/i })); + + expect(onSubmitMock).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'model1', + modelId: '1', + }) + ); + }); }); diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index c35b78aa..5095b1bc 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -9,7 +9,6 @@ import { Route } from 'react-router-dom'; import { UserEvent } from '@testing-library/user-event/dist/types/setup/setup'; import { RegisterModelForm } from '../register_model'; -import { Model } from '../../../apis/model'; import { render, RenderWithRouteProps, screen, waitFor } from '../../../../test/test_utils'; import { ModelFileFormData, ModelUrlFormData } from '../register_model.types'; @@ -106,8 +105,6 @@ export async function setup( throw new Error('Description input not found'); } - // Mock model name unique - jest.spyOn(Model.prototype, 'search').mockResolvedValue({ data: [], total_models: 0 }); // fill model name if (mode === 'model') { await user.type(nameInput, 'test model name'); diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 7153c81d..22992d12 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -25,7 +25,6 @@ import useObservable from 'react-use/lib/useObservable'; import { from } from 'rxjs'; import { APIProvider } from '../../apis/api_provider'; -import { upgradeModelVersion } from '../../utils'; import { useSearchParams } from '../../hooks/use_search_params'; import { isValidModelRegisterFormType } from './utils'; import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; @@ -79,7 +78,7 @@ const FileAndVersionTitle = () => { export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterModelFormProps) => { const history = useHistory(); const [isSubmitted, setIsSubmitted] = useState(false); - const { id: latestVersionId } = useParams<{ id: string | undefined }>(); + const { id: registerToModelId } = useParams<{ id: string | undefined }>(); const [modelGroupName, setModelGroupName] = useState(); const searchParams = useSearchParams(); const typeParams = searchParams.get('type'); @@ -96,13 +95,13 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo formType === 'import' ? [ModelDetailsPanel, ModelTagsPanel, ModelVersionNotesPanel] : [ - ...(latestVersionId ? [] : [ModelOverviewTitle]), - ...(latestVersionId ? [] : [ModelDetailsPanel]), - ...(latestVersionId ? [] : [ModelTagsPanel]), - ...(latestVersionId ? [] : [FileAndVersionTitle]), + ...(registerToModelId ? [] : [ModelOverviewTitle]), + ...(registerToModelId ? [] : [ModelDetailsPanel]), + ...(registerToModelId ? [] : [ModelTagsPanel]), + ...(registerToModelId ? [] : [FileAndVersionTitle]), ArtifactPanel, ConfigurationPanel, - ...(latestVersionId ? [ModelTagsPanel] : []), + ...(registerToModelId ? [ModelTagsPanel] : []), ModelVersionNotesPanel, ]; @@ -162,7 +161,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo // Navigate to model list if form submit successfully history.push(routerPaths.modelList); - if (latestVersionId) { + if (data.modelId) { notifications?.toasts.addSuccess({ title: mountReactNode( @@ -198,27 +197,25 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo } } }, - [notifications, form, latestVersionId, history] + [notifications, form, history] ); useEffect(() => { - if (!latestVersionId) return; + if (!registerToModelId) return; const initializeForm = async () => { + form.setValue('modelId', registerToModelId); try { - const data = await APIProvider.getAPI('model').getOne(latestVersionId); + const data = await APIProvider.getAPI('modelGroup').getOne(registerToModelId); // TODO: clarify which fields to pre-populate - const { model_version: modelVersion, name, model_config: modelConfig } = data; - const newVersion = upgradeModelVersion(modelVersion); + const { name } = data; form.setValue('name', name); - form.setValue('version', newVersion); - form.setValue('configuration', modelConfig?.all_config ?? ''); setModelGroupName(name); } catch (e) { // TODO: handle error here } }; initializeForm(); - }, [latestVersionId, form]); + }, [registerToModelId, form]); useEffect(() => { if (!nameParams) { @@ -262,10 +259,10 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const errorCount = Object.keys(form.formState.errors).length; const formHeader = ( <> - + - {latestVersionId && ( + {registerToModelId && ( <> Register a new version of {modelGroupName}. The version number will be automatically incremented.  @@ -275,8 +272,8 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo . )} - {formType === 'import' && !latestVersionId && <>Register a pre-trained model.} - {formType === 'upload' && !latestVersionId && ( + {formType === 'import' && !registerToModelId && <>Register a pre-trained model.} + {formType === 'upload' && !registerToModelId && ( <> Register your model to manage its life cycle, and facilitate model discovery across your organization. @@ -357,7 +354,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo onClick={() => setIsSubmitted(true)} fill > - {latestVersionId ? 'Register version' : 'Register model'} + {registerToModelId ? 'Register version' : 'Register model'}
      diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index 2dc44fa6..d8f57f0a 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -7,8 +7,8 @@ import type { Tag } from '../model/types'; interface ModelFormBase { name: string; - version: string; - description: string; + modelId?: string; + description?: string; configuration: string; modelFileFormat: string; tags?: Tag[]; diff --git a/public/components/register_model/register_model_api.ts b/public/components/register_model/register_model_api.ts index 87ba345c..65b0b44d 100644 --- a/public/components/register_model/register_model_api.ts +++ b/public/components/register_model/register_model_api.ts @@ -8,44 +8,78 @@ import { MAX_CHUNK_SIZE } from '../common/forms/form_constants'; import { getModelContentHashValue } from './get_model_content_hash_value'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +const getModelUploadBase = ({ + name, + versionNotes, + modelFileFormat, + configuration, +}: ModelFileFormData | ModelUrlFormData) => ({ + name, + description: versionNotes, + modelFormat: modelFileFormat, + modelConfig: JSON.parse(configuration), +}); + +const createModelIfNeedAndUploadVersion = async ({ + name, + modelId, + description, + uploader, +}: { + name: string; + modelId?: string; + description?: string; + uploader: (modelId: string) => Promise; +}) => { + if (modelId) { + return await uploader(modelId); + } + modelId = ( + await APIProvider.getAPI('modelGroup').register({ + name, + description, + // TODO: This value should follow form data, need to be updated after UI design confirmed + modelAccessMode: 'public', + }) + ).model_group_id; + + try { + return await uploader(modelId); + } catch (error) { + APIProvider.getAPI('modelGroup').delete(modelId); + throw error; + } +}; + export async function submitModelWithFile(model: ModelFileFormData) { - const modelUploadBase = { - name: model.name, - version: model.version, - description: model.description, - // TODO: Need to confirm if we have the model format input - modelFormat: 'TORCH_SCRIPT', - modelConfig: JSON.parse(model.configuration), - }; const { modelFile } = model; const totalChunks = Math.ceil(modelFile.size / MAX_CHUNK_SIZE); const modelContentHashValue = await getModelContentHashValue(modelFile); - const modelId = ( - await APIProvider.getAPI('model').upload({ - ...modelUploadBase, - totalChunks, - modelContentHashValue, + return ( + await createModelIfNeedAndUploadVersion({ + ...model, + uploader: (modelId: string) => + APIProvider.getAPI('model').upload({ + ...getModelUploadBase(model), + modelGroupId: modelId, + totalChunks, + modelContentHashValue, + }), }) ).model_id; - - return modelId; } export async function submitModelWithURL(model: ModelUrlFormData) { - const modelUploadBase = { - name: model.name, - version: model.version, - description: model.description, - // TODO: Need to confirm if we have the model format input - modelFormat: 'TORCH_SCRIPT', - modelConfig: JSON.parse(model.configuration), - }; - - const { task_id: taskId } = await APIProvider.getAPI('model').upload({ - ...modelUploadBase, - url: model.modelURL, - }); - - return taskId; + return ( + await createModelIfNeedAndUploadVersion({ + ...model, + uploader: (modelId: string) => + APIProvider.getAPI('model').upload({ + ...getModelUploadBase(model), + modelGroupId: modelId, + url: model.modelURL, + }), + }) + ).task_id; } diff --git a/server/plugin.ts b/server/plugin.ts index 7c8d4b56..d5116fb9 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -21,6 +21,7 @@ import { securityRouter, taskRouter, modelRepositoryRouter, + modelGroupRouter, } from './routes'; import { ModelService } from './services'; @@ -50,6 +51,7 @@ export class MlCommonsPlugin implements Plugin { + router.post( + { + path: MODEL_GROUP_API_ENDPOINT, + validate: { + body: schema.object({ + name: schema.string(), + description: schema.maybe(schema.string()), + modelAccessMode: schema.oneOf([ + schema.literal('public'), + schema.literal('private'), + schema.literal('restricted'), + ]), + backendRoles: schema.maybe(schema.arrayOf(schema.string())), + addAllBackendRoles: schema.maybe(schema.boolean()), + }), + }, + }, + async (context, request) => { + const { name, description, modelAccessMode, backendRoles, addAllBackendRoles } = request.body; + try { + const payload = await ModelGroupService.register({ + client: context.core.opensearch.client, + name, + description, + modelAccessMode, + backendRoles, + addAllBackendRoles, + }); + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), + }); + } + } + ); + + router.put( + { + path: `${MODEL_GROUP_API_ENDPOINT}/{groupId}`, + validate: { + params: schema.object({ + groupId: schema.string(), + }), + body: schema.object({ + name: schema.maybe(schema.string()), + description: schema.maybe(schema.string()), + }), + }, + }, + async (context, request) => { + const { + params: { groupId }, + body: { name, description }, + } = request; + try { + const payload = await ModelGroupService.update({ + client: context.core.opensearch.client, + id: groupId, + name, + description, + }); + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), + }); + } + } + ); + + router.delete( + { + path: `${MODEL_GROUP_API_ENDPOINT}/{groupId}`, + validate: { + params: schema.object({ + groupId: schema.string(), + }), + }, + }, + async (context, request) => { + try { + const payload = await ModelGroupService.delete({ + client: context.core.opensearch.client, + id: request.params.groupId, + }); + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), + }); + } + } + ); + + router.get( + { + path: MODEL_GROUP_API_ENDPOINT, + validate: { + query: schema.object({ + id: schema.maybe(schema.string()), + name: schema.maybe(schema.string()), + from: schema.number({ min: 0 }), + size: schema.number({ max: 100 }), + }), + }, + }, + async (context, request) => { + const { id, name, from, size } = request.query; + try { + const payload = await ModelGroupService.search({ + client: context.core.opensearch.client, + id, + name, + from, + size, + }); + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), + }); + } + } + ); +}; diff --git a/server/routes/model_router.ts b/server/routes/model_router.ts index f6264e97..65609cd4 100644 --- a/server/routes/model_router.ts +++ b/server/routes/model_router.ts @@ -54,10 +54,11 @@ const modelStateSchema = schema.oneOf([ const modelUploadBaseSchema = { name: schema.string(), - version: schema.string(), - description: schema.string(), + version: schema.maybe(schema.string()), + description: schema.maybe(schema.string()), modelFormat: schema.string(), modelConfig: schema.object({}, { unknowns: 'allow' }), + modelGroupId: schema.string(), }; const modelUploadByURLSchema = schema.object({ diff --git a/server/services/model_group_service.ts b/server/services/model_group_service.ts new file mode 100644 index 00000000..7ad26cfd --- /dev/null +++ b/server/services/model_group_service.ts @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Copyright OpenSearch Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +import { IScopedClusterClient } from '../../../../src/core/server'; + +import { + MODEL_GROUP_BASE_API, + MODEL_GROUP_REGISTER_API, + MODEL_GROUP_SEARCH_API, + MODEL_GROUP_UPDATE_API, +} from './utils/constants'; +import { generateMustQueries, generateTermQuery } from './utils/query'; + +export class ModelGroupService { + public static async register(params: { + client: IScopedClusterClient; + name: string; + description?: string; + modelAccessMode: 'public' | 'restricted' | 'private'; + backendRoles?: string[]; + addAllBackendRoles?: boolean; + }) { + const { client, name, description, modelAccessMode, backendRoles, addAllBackendRoles } = params; + const result = ( + await client.asCurrentUser.transport.request({ + method: 'POST', + path: MODEL_GROUP_REGISTER_API, + body: { + name, + description, + model_access_mode: modelAccessMode, + backend_roles: backendRoles, + add_all_backend_roles: addAllBackendRoles, + }, + }) + ).body as { + model_group_id: string; + status: 'CREATED'; + }; + return result; + } + + public static async update({ + client, + id, + name, + description, + }: { + client: IScopedClusterClient; + id: string; + name?: string; + description?: string; + }) { + const result = ( + await client.asCurrentUser.transport.request({ + method: 'PUT', + path: MODEL_GROUP_UPDATE_API.replace('', id), + body: { + name, + description, + }, + }) + ).body as { + status: 'UPDATED'; + }; + return result; + } + + public static async delete({ client, id }: { client: IScopedClusterClient; id: string }) { + const result = ( + await client.asCurrentUser.transport.request({ + method: 'DELETE', + path: `${MODEL_GROUP_BASE_API}/${id}`, + }) + ).body; + return result; + } + + public static async search({ + client, + id, + name, + from, + size, + }: { + client: IScopedClusterClient; + id?: string; + name?: string; + from: number; + size: number; + }) { + const { + body: { hits }, + } = await client.asCurrentUser.transport.request({ + method: 'GET', + path: MODEL_GROUP_SEARCH_API, + body: { + query: generateMustQueries([ + ...(id ? [generateTermQuery('_id', id)] : []), + ...(name ? [generateTermQuery('name', name)] : []), + ]), + from, + size, + }, + }); + + return { + data: hits.hits.map(({ _id, _source }) => ({ + id: _id, + ..._source, + })), + total_model_groups: hits.total.value, + }; + } +} diff --git a/server/services/model_service.ts b/server/services/model_service.ts index 18afd5b6..2df551f0 100644 --- a/server/services/model_service.ts +++ b/server/services/model_service.ts @@ -38,10 +38,11 @@ const modelSortFieldMapping: { [key: string]: string } = { interface UploadModelBase { name: string; - version: string; - description: string; + version?: string; + description?: string; modelFormat: string; modelConfig: Record; + modelGroupId: string; } interface UploadModelByURL extends UploadModelBase { @@ -178,13 +179,14 @@ export class ModelService { client: IScopedClusterClient; model: T; }): UploadResult { - const { name, version, description, modelFormat, modelConfig } = model; + const { name, version, description, modelFormat, modelConfig, modelGroupId } = model; const uploadModelBase = { name, version, description, model_format: modelFormat, model_config: modelConfig, + model_group_id: modelGroupId, }; if (isUploadModelByURL(model)) { const { task_id: taskId, status } = ( diff --git a/server/services/utils/constants.ts b/server/services/utils/constants.ts index 847d0ead..f0359cb6 100644 --- a/server/services/utils/constants.ts +++ b/server/services/utils/constants.ts @@ -26,6 +26,10 @@ export const MODEL_UPLOAD_API = `${MODEL_BASE_API}/_upload`; export const MODEL_META_API = `${MODEL_BASE_API}/meta`; export const MODEL_PROFILE_API = `${PROFILE_BASE_API}/models`; export const PREDICT_BASE_API = `${ML_COMMONS_API_PREFIX}/_predict`; +export const MODEL_GROUP_BASE_API = `${ML_COMMONS_API_PREFIX}/model_groups`; +export const MODEL_GROUP_REGISTER_API = `${MODEL_GROUP_BASE_API}/_register`; +export const MODEL_GROUP_UPDATE_API = `${MODEL_GROUP_BASE_API}//_update`; +export const MODEL_GROUP_SEARCH_API = `${MODEL_GROUP_BASE_API}/_search`; export const SECURITY_API_PREFIX = '/_plugins/_security/api'; export const SECURITY_ACCOUNT_API = `${SECURITY_API_PREFIX}/account`; diff --git a/test/mocks/handlers.ts b/test/mocks/handlers.ts index 3fe67e5d..1f0af718 100644 --- a/test/mocks/handlers.ts +++ b/test/mocks/handlers.ts @@ -12,6 +12,7 @@ import { modelRepositoryResponse } from './data/model_repository'; import { modelHandlers } from './model_handlers'; import { modelAggregateResponse } from './data/model_aggregate'; import { taskHandlers } from './task_handlers'; +import { modelGroupHandlers } from './model_group_handlers'; export const handlers = [ rest.get('/api/ml-commons/model-repository', (req, res, ctx) => { @@ -25,4 +26,5 @@ export const handlers = [ return res(ctx.status(200), ctx.json(modelAggregateResponse)); }), ...taskHandlers, + ...modelGroupHandlers, ]; diff --git a/test/mocks/model_group_handlers.ts b/test/mocks/model_group_handlers.ts new file mode 100644 index 00000000..b0e273ef --- /dev/null +++ b/test/mocks/model_group_handlers.ts @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { rest } from 'msw'; + +import { MODEL_GROUP_API_ENDPOINT } from '../../server/routes/constants'; + +const modelGroups = [ + { + name: 'model1', + id: '1', + latest_version: 1, + description: 'foo bar', + }, +]; + +export const modelGroupHandlers = [ + rest.get(MODEL_GROUP_API_ENDPOINT, (req, res, ctx) => { + const { searchParams } = req.url; + const name = searchParams.get('name'); + const id = searchParams.get('id'); + const from = parseInt(searchParams.get('from') || '0', 10); + const size = parseInt(searchParams.get('size') || `${modelGroups.length}`, 10); + const filteredData = modelGroups.filter((modelGroup) => { + if (id && id !== modelGroup.id) { + return false; + } + if (name && name !== modelGroup.name) { + return false; + } + return true; + }); + const end = size ? from + size : filteredData.length; + + return res( + ctx.status(200), + ctx.json({ + data: filteredData.slice(from, end), + total_model_groups: filteredData.length, + }) + ); + }), +]; From 100af2e06d45701df2f665843c30450fe2f3a9f1 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 26 May 2023 17:21:30 +0800 Subject: [PATCH 54/75] Feature/jump to model detail page with correct (#193) * feat: mock EuiDataGrid to speedup case running Signed-off-by: Lin Wang * feat: add modelGroupId to model search Signed-off-by: Lin Wang * test: support modelGroupId search mock Signed-off-by: Lin Wang * test: mock EuiDataGrid to speedup model test Signed-off-by: Lin Wang * fix: redirect to model page with correct model id Signed-off-by: Lin Wang * feat: update to use created_time Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- public/apis/model.ts | 1 + public/apis/model_group.ts | 2 + public/components/global_breadcrumbs.tsx | 3 +- .../components/model/__tests__/model.test.tsx | 18 ++++++ public/components/model/model.tsx | 17 +++-- .../__tests__/model_versions_panel.test.tsx | 46 ++++++++++---- .../model_versions_panel.tsx | 9 ++- .../__tests__/register_model_api.test.ts | 36 ++++++++++- .../__tests__/register_model_form.test.tsx | 4 +- .../register_model/register_model.tsx | 23 ++++--- .../register_model/register_model_api.ts | 63 +++++++++++-------- server/routes/model_router.ts | 3 + server/services/model_service.ts | 1 + server/services/utils/model.ts | 3 + test/mocks/model_group_handlers.ts | 16 +++++ test/mocks/model_handlers.ts | 10 ++- 16 files changed, 189 insertions(+), 66 deletions(-) diff --git a/public/apis/model.ts b/public/apis/model.ts index 89e66355..28f2e83a 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -113,6 +113,7 @@ export class Model { extraQuery?: Record; dataSourceId?: string; versionOrKeyword?: string; + modelGroupId?: string; }) { const { extraQuery, dataSourceId, ...restQuery } = query; return InnerHttpProvider.getHttp().get(MODEL_API_ENDPOINT, { diff --git a/public/apis/model_group.ts b/public/apis/model_group.ts index b5ca09d5..dd9f7d97 100644 --- a/public/apis/model_group.ts +++ b/public/apis/model_group.ts @@ -14,9 +14,11 @@ interface ModelGroupSearchItem { name: string; }; latest_version: number; + created_time: number; last_updated_time: number; name: string; description?: string; + access: 'public' | 'restricted' | 'private'; } export interface ModelGroupSearchResponse { diff --git a/public/components/global_breadcrumbs.tsx b/public/components/global_breadcrumbs.tsx index cebdc3bd..3a651e5a 100644 --- a/public/components/global_breadcrumbs.tsx +++ b/public/components/global_breadcrumbs.tsx @@ -79,9 +79,8 @@ const getModelBreadcrumbs = (basename: string, matchedParams: {}) => { const modelId = matchedParams.id; return { staticBreadcrumbs: baseModelRegistryBreadcrumbs, - // TODO: Change to model group API asyncBreadcrumbsLoader: () => { - return APIProvider.getAPI('model') + return APIProvider.getAPI('modelGroup') .getOne(modelId) .then( (model) => diff --git a/public/components/model/__tests__/model.test.tsx b/public/components/model/__tests__/model.test.tsx index 1255d19e..91128f4c 100644 --- a/public/components/model/__tests__/model.test.tsx +++ b/public/components/model/__tests__/model.test.tsx @@ -5,6 +5,7 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; +import * as euiExports from '@elastic/eui'; import { render, screen, waitFor, within } from '../../../../test/test_utils'; import { Model } from '../model'; @@ -24,10 +25,19 @@ const setup = () => { }; }; +jest.mock('@elastic/eui', () => ({ + __esModule: true, + ...jest.requireActual('@elastic/eui'), +})); + +const mockEuiDataGrid = () => + jest.spyOn(euiExports, 'EuiDataGrid').mockImplementation(() =>
      EuiDataGrid
      ); + describe('', () => { it( 'should display model name, action buttons, overview-card, tabs and tabpanel after data loaded', async () => { + const euiDataGridMock = mockEuiDataGrid(); setup(); await waitFor(() => { @@ -41,6 +51,8 @@ describe('', () => { expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); expect(screen.getByRole('tab', { name: 'Versions' })).toHaveClass('euiTab-isSelected'); expect(within(screen.getByRole('tabpanel')).getByText('Versions')).toBeInTheDocument(); + + euiDataGridMock.mockRestore(); }, 10 * 1000 ); @@ -48,6 +60,7 @@ describe('', () => { it( 'should display consistent tabs content after tab clicked', async () => { + const euiDataGridMock = mockEuiDataGrid(); setup(); await waitFor(() => { @@ -63,6 +76,8 @@ describe('', () => { await userEvent.click(screen.getByRole('tab', { name: 'Tags' })); expect(screen.getByRole('tab', { name: 'Tags' })).toHaveClass('euiTab-isSelected'); expect(within(screen.getByRole('tabpanel')).getByText('Tags')).toBeInTheDocument(); + + euiDataGridMock.mockRestore(); }, 10 * 1000 ); @@ -70,6 +85,7 @@ describe('', () => { it( 'should display model name in details tab', async () => { + const euiDataGridMock = mockEuiDataGrid(); setup(); await waitFor(() => { @@ -78,6 +94,8 @@ describe('', () => { await userEvent.click(screen.getByRole('tab', { name: 'Details' })); expect(within(screen.getByRole('tabpanel')).getByDisplayValue('model1')).toBeInTheDocument(); + + euiDataGridMock.mockRestore(); }, 10 * 1000 ); diff --git a/public/components/model/model.tsx b/public/components/model/model.tsx index 916677f4..41c4b34f 100644 --- a/public/components/model/model.tsx +++ b/public/components/model/model.tsx @@ -13,9 +13,12 @@ import { EuiText, } from '@elastic/eui'; import React, { useState, useMemo, useCallback } from 'react'; -import { useParams } from 'react-router-dom'; +import { Link, generatePath, useParams } from 'react-router-dom'; + +import { routerPaths } from '../../../common'; import { useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; + import { ModelOverviewCard } from './model_overview_card'; import { ModelVersionsPanel } from './model_versions_panel'; import { ModelDetailsPanel } from './model_details_panel'; @@ -23,7 +26,7 @@ import { ModelTagsPanel } from './model_tags_panel'; export const Model = () => { const { id: modelId } = useParams<{ id: string }>(); - const { data, loading, error } = useFetcher(APIProvider.getAPI('model').getOne, modelId); + const { data, loading, error } = useFetcher(APIProvider.getAPI('modelGroup').getOne, modelId); const tabs = useMemo( () => [ { @@ -32,7 +35,7 @@ export const Model = () => { content: ( <> - + ), }, @@ -86,17 +89,19 @@ export const Model = () => { } rightSideItems={[ - Register version, + + Register version + , Delete, ]} /> ({ + __esModule: true, + ...jest.requireActual('@elastic/eui'), +})); + +const mockEuiDataGrid = () => + jest.spyOn(euiExports, 'EuiDataGrid').mockImplementation(() =>
      EuiDataGrid
      ); + describe('', () => { it( 'should render version count, refresh button, filter and table by default', async () => { - render(); + render(); expect( screen.getByPlaceholderText('Search by version number, or keyword') @@ -38,7 +47,7 @@ describe('', () => { await waitFor(() => { expect( screen.getByText((text, node) => { - return text === 'Versions' && !!node?.childNodes[1]?.textContent?.includes('(1)'); + return text === 'Versions' && !!node?.childNodes[1]?.textContent?.includes('(2)'); }) ).toBeInTheDocument(); }); @@ -55,6 +64,7 @@ describe('', () => { it( 'should call model search API again after refresh button clicked', async () => { + const euiDataGridMock = mockEuiDataGrid(); const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { return { data: [], @@ -62,7 +72,7 @@ describe('', () => { }; }); - render(); + render(); expect(searchMock).toHaveBeenCalledTimes(1); @@ -70,6 +80,7 @@ describe('', () => { expect(searchMock).toHaveBeenCalledTimes(2); searchMock.mockRestore(); + euiDataGridMock.mockRestore(); }, 10 * 1000 ); @@ -77,9 +88,10 @@ describe('', () => { it( 'should call model search with consistent state parameters after deployed state filter applied', async () => { + const euiDataGridMock = mockEuiDataGrid(); const searchMock = jest.spyOn(Model.prototype, 'search'); - render(); + render(); await userEvent.click(screen.getByTitle('State')); await userEvent.click(screen.getByRole('option', { name: 'Deployed' })); @@ -105,22 +117,26 @@ describe('', () => { }); searchMock.mockRestore(); + euiDataGridMock.mockRestore(); }, 10 * 10000 ); it('should render loading screen when calling model search API', async () => { + const euiDataGridMock = mockEuiDataGrid(); const searchMock = jest .spyOn(Model.prototype, 'search') .mockImplementation(() => new Promise(() => {})); - render(); + render(); expect(screen.getByText('Loading versions')).toBeInTheDocument(); searchMock.mockRestore(); + euiDataGridMock.mockRestore(); }); it('should render error screen and show error toast after call model search failed', async () => { + const euiDataGridMock = mockEuiDataGrid(); const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { throw new Error(); }); @@ -132,7 +148,7 @@ describe('', () => { }, }, }); - render(); + render(); await waitFor(() => { expect(screen.getByText('Failed to load versions')).toBeInTheDocument(); @@ -145,9 +161,11 @@ describe('', () => { searchMock.mockRestore(); pluginMock.mockRestore(); + euiDataGridMock.mockRestore(); }); it('should render empty screen if model no versions', async () => { + const euiDataGridMock = mockEuiDataGrid(); const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { return { data: [], @@ -155,7 +173,7 @@ describe('', () => { }; }); - render(); + render(); await waitFor(() => { expect(screen.getByText('Registered versions will appear here.')).toBeInTheDocument(); @@ -168,12 +186,14 @@ describe('', () => { }); searchMock.mockRestore(); + euiDataGridMock.mockRestore(); }); it( 'should render no-result screen and reset search button if no result for specific condition', async () => { - render(); + const euiDataGridMock = mockEuiDataGrid(); + render(); await waitFor(() => { expect(screen.getByTitle('Status')).toBeInTheDocument(); }); @@ -195,6 +215,7 @@ describe('', () => { expect(screen.getByText('Reset search criteria')).toBeInTheDocument(); searchMock.mockRestore(); + euiDataGridMock.mockRestore(); }, 10 * 1000 ); @@ -202,7 +223,8 @@ describe('', () => { it( 'should call model search without filter condition after reset button clicked', async () => { - render(); + const euiDataGridMock = mockEuiDataGrid(); + render(); await waitFor(() => { expect(screen.getByTitle('Status')).toBeInTheDocument(); }); @@ -222,11 +244,11 @@ describe('', () => { expect(searchMock).toHaveBeenCalledWith({ from: 0, size: 25, - // TODO: Change to model group id once parameter added - ids: expect.any(Array), + modelGroupId: '1', }); searchMock.mockRestore(); + euiDataGridMock.mockRestore(); }, 10 * 1000 ); @@ -234,7 +256,7 @@ describe('', () => { it( 'should only sort by Last updated column after column sort button clicked', async () => { - render(); + render(); await waitFor(() => { expect(screen.getByTestId('dataGridHeaderCell-version')).toBeInTheDocument(); }); diff --git a/public/components/model/model_versions_panel/model_versions_panel.tsx b/public/components/model/model_versions_panel/model_versions_panel.tsx index 9a576e51..8699e58b 100644 --- a/public/components/model/model_versions_panel/model_versions_panel.tsx +++ b/public/components/model/model_versions_panel/model_versions_panel.tsx @@ -84,10 +84,10 @@ const getSortParam = (sort: Array<{ id: string; direction: 'asc' | 'desc' }>) => }; interface ModelVersionsPanelProps { - groupId: string; + modelId: string; } -export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { +export const ModelVersionsPanel = ({ modelId }: ModelVersionsPanelProps) => { const searchInputRef = useRef(null); const [params, setParams] = useState<{ pageIndex: number; @@ -107,8 +107,7 @@ export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { const { data: versionsData, reload, loading, error } = useFetcher( APIProvider.getAPI('model').search, { - // TODO: Change to model group id - ids: [groupId], + modelGroupId: modelId, from: params.pageIndex * params.pageSize, size: params.pageSize, states: getStatesParam(params.filter), @@ -314,7 +313,7 @@ export const ModelVersionsPanel = ({ groupId }: ModelVersionsPanelProps) => { Registered versions will appear here. - + Register new version diff --git a/public/components/register_model/__tests__/register_model_api.test.ts b/public/components/register_model/__tests__/register_model_api.test.ts index a8682c45..f2dc7196 100644 --- a/public/components/register_model/__tests__/register_model_api.test.ts +++ b/public/components/register_model/__tests__/register_model_api.test.ts @@ -11,7 +11,7 @@ describe('register model api', () => { beforeEach(() => { jest .spyOn(ModelGroup.prototype, 'register') - .mockResolvedValue({ model_group_id: 'foo', status: 'success' }); + .mockResolvedValue({ model_group_id: '1', status: 'CREATED' }); jest.spyOn(ModelGroup.prototype, 'delete').mockResolvedValue({ status: 'success' }); jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ task_id: 'foo', model_id: 'bar' }); }); @@ -94,10 +94,25 @@ describe('register model api', () => { } catch (error) { expect(uploadError).toBe(error); } - expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('foo'); + expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('1'); uploadMock.mockRestore(); }); + + it('should return task id and model id after submit successful', async () => { + const result = await submitModelWithFile({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + modelFile: new File([], 'artifact.zip'), + }); + + expect(result).toEqual({ + modelId: '1', + modelVersionId: 'bar', + }); + }); }); describe('submitModelWithURL', () => { @@ -136,9 +151,24 @@ describe('register model api', () => { } catch (error) { expect(uploadError).toBe(error); } - expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('foo'); + expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('1'); uploadMock.mockRestore(); }); + + it('should return model id and model version id after submit successful', async () => { + const result = await submitModelWithURL({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + modelURL: 'https://address.to/artifact.zip', + }); + + expect(result).toEqual({ + modelId: '1', + taskId: 'foo', + }); + }); }); }); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 5763d851..67496418 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -25,7 +25,9 @@ describe(' Form', () => { const MOCKED_MODEL_ID = 'model_id'; const addDangerMock = jest.fn(); const addSuccessMock = jest.fn(); - const onSubmitMock = jest.fn().mockResolvedValue(MOCKED_MODEL_ID); + const onSubmitMock = jest + .fn() + .mockResolvedValue({ modelId: MOCKED_MODEL_ID, modelVersionId: 'model_version_id' }); beforeEach(() => { jest.spyOn(PluginContext, 'useOpenSearchDashboards').mockReturnValue({ diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 22992d12..1b1ba5a9 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -117,10 +117,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const onSubmit = useCallback( async (data: ModelFileFormData | ModelUrlFormData) => { try { - const onComplete = (modelId: string) => { - // Navigate to model group page - history.push(generatePath(routerPaths.model, { id: modelId })); - + const onComplete = () => { notifications?.toasts.addSuccess({ title: mountReactNode( @@ -143,23 +140,29 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo text: 'The new version was not created.', }); }; - + let modelId; if ('modelFile' in data) { - const modelId = await submitModelWithFile(data); + const result = await submitModelWithFile(data); modelFileUploadManager.upload({ file: data.modelFile, - modelId, + modelId: result.modelVersionId, chunkSize: MAX_CHUNK_SIZE, onComplete, onError, }); + modelId = result.modelId; } else { - const taskId = await submitModelWithURL(data); - modelTaskManager.query({ taskId, onComplete, onError }); + const result = await submitModelWithURL(data); + modelTaskManager.query({ + taskId: result.taskId, + onComplete, + onError, + }); + modelId = result.modelId; } // Navigate to model list if form submit successfully - history.push(routerPaths.modelList); + history.push(generatePath(routerPaths.model, { id: modelId })); if (data.modelId) { notifications?.toasts.addSuccess({ diff --git a/public/components/register_model/register_model_api.ts b/public/components/register_model/register_model_api.ts index 65b0b44d..d58e3f25 100644 --- a/public/components/register_model/register_model_api.ts +++ b/public/components/register_model/register_model_api.ts @@ -30,9 +30,12 @@ const createModelIfNeedAndUploadVersion = async ({ modelId?: string; description?: string; uploader: (modelId: string) => Promise; -}) => { +}): Promise<{ uploadResult: T; modelId: string }> => { if (modelId) { - return await uploader(modelId); + return { + uploadResult: await uploader(modelId), + modelId, + }; } modelId = ( await APIProvider.getAPI('modelGroup').register({ @@ -44,7 +47,10 @@ const createModelIfNeedAndUploadVersion = async ({ ).model_group_id; try { - return await uploader(modelId); + return { + uploadResult: await uploader(modelId), + modelId, + }; } catch (error) { APIProvider.getAPI('modelGroup').delete(modelId); throw error; @@ -55,31 +61,36 @@ export async function submitModelWithFile(model: ModelFileFormData) { const { modelFile } = model; const totalChunks = Math.ceil(modelFile.size / MAX_CHUNK_SIZE); const modelContentHashValue = await getModelContentHashValue(modelFile); + const result = await createModelIfNeedAndUploadVersion({ + ...model, + uploader: (modelId: string) => + APIProvider.getAPI('model').upload({ + ...getModelUploadBase(model), + modelGroupId: modelId, + totalChunks, + modelContentHashValue, + }), + }); - return ( - await createModelIfNeedAndUploadVersion({ - ...model, - uploader: (modelId: string) => - APIProvider.getAPI('model').upload({ - ...getModelUploadBase(model), - modelGroupId: modelId, - totalChunks, - modelContentHashValue, - }), - }) - ).model_id; + return { + modelId: result.modelId, + modelVersionId: result.uploadResult.model_id, + }; } export async function submitModelWithURL(model: ModelUrlFormData) { - return ( - await createModelIfNeedAndUploadVersion({ - ...model, - uploader: (modelId: string) => - APIProvider.getAPI('model').upload({ - ...getModelUploadBase(model), - modelGroupId: modelId, - url: model.modelURL, - }), - }) - ).task_id; + const result = await createModelIfNeedAndUploadVersion({ + ...model, + uploader: (modelId: string) => + APIProvider.getAPI('model').upload({ + ...getModelUploadBase(model), + modelGroupId: modelId, + url: model.modelURL, + }), + }); + + return { + modelId: result.modelId, + taskId: result.uploadResult.task_id, + }; } diff --git a/server/routes/model_router.ts b/server/routes/model_router.ts index 65609cd4..996f3c48 100644 --- a/server/routes/model_router.ts +++ b/server/routes/model_router.ts @@ -96,6 +96,7 @@ export const modelRouter = (services: { modelService: ModelService }, router: IR versionOrKeyword: schema.maybe(schema.string()), extra_query: schema.maybe(schema.recordOf(schema.string(), schema.any())), data_source_id: schema.maybe(schema.string()), + modelGroupId: schema.maybe(schema.string()), }), }, }, @@ -109,6 +110,7 @@ export const modelRouter = (services: { modelService: ModelService }, router: IR nameOrId, extra_query: extraQuery, data_source_id: dataSourceId, + modelGroupId, versionOrKeyword, } = request.query; try { @@ -124,6 +126,7 @@ export const modelRouter = (services: { modelService: ModelService }, router: IR states: typeof states === 'string' ? [states] : states, nameOrId, extraQuery, + modelGroupId, versionOrKeyword, }); return response.ok({ body: payload }); diff --git a/server/services/model_service.ts b/server/services/model_service.ts index 2df551f0..3793e1e0 100644 --- a/server/services/model_service.ts +++ b/server/services/model_service.ts @@ -91,6 +91,7 @@ export class ModelService { extraQuery?: Record; nameOrId?: string; versionOrKeyword?: string; + modelGroupId?: string; }) { const { body: { hits }, diff --git a/server/services/utils/model.ts b/server/services/utils/model.ts index c53fd7f5..ef280d7d 100644 --- a/server/services/utils/model.ts +++ b/server/services/utils/model.ts @@ -27,6 +27,7 @@ export const generateModelSearchQuery = ({ states, nameOrId, extraQuery, + modelGroupId, versionOrKeyword, }: { ids?: string[]; @@ -36,6 +37,7 @@ export const generateModelSearchQuery = ({ nameOrId?: string; extraQuery?: Record; versionOrKeyword?: string; + modelGroupId?: string; }) => ({ bool: { must: [ @@ -86,6 +88,7 @@ export const generateModelSearchQuery = ({ }, ] : []), + ...(modelGroupId ? [generateTermQuery('model_group_id.keyword', modelGroupId)] : []), ], must_not: { exists: { diff --git a/test/mocks/model_group_handlers.ts b/test/mocks/model_group_handlers.ts index b0e273ef..2b26dff2 100644 --- a/test/mocks/model_group_handlers.ts +++ b/test/mocks/model_group_handlers.ts @@ -13,6 +13,13 @@ const modelGroups = [ id: '1', latest_version: 1, description: 'foo bar', + owner: { + backend_roles: ['admin'], + name: 'admin', + roles: ['admin'], + }, + created_time: 1683699499637, + last_updated_time: 1685073391256, }, ]; @@ -42,4 +49,13 @@ export const modelGroupHandlers = [ }) ); }), + + rest.post(MODEL_GROUP_API_ENDPOINT, (req, res, ctx) => { + return res( + ctx.status(200), + ctx.json({ + model_group_id: '1', + }) + ); + }), ]; diff --git a/test/mocks/model_handlers.ts b/test/mocks/model_handlers.ts index 8093ee67..244f6265 100644 --- a/test/mocks/model_handlers.ts +++ b/test/mocks/model_handlers.ts @@ -25,6 +25,7 @@ const models = [ model_format: 'TORCH_SCRIPT', model_state: 'REGISTERED', total_chunks: 34, + model_group_id: '1', }, { id: '2', @@ -43,6 +44,7 @@ const models = [ model_format: 'TORCH_SCRIPT', model_state: 'REGISTERED', total_chunks: 34, + model_group_id: '2', }, { id: '3', @@ -61,6 +63,7 @@ const models = [ model_format: 'TORCH_SCRIPT', model_state: 'DEPLOYED', total_chunks: 34, + model_group_id: '3', }, { id: '4', @@ -79,6 +82,7 @@ const models = [ model_format: 'TORCH_SCRIPT', model_state: 'DEPLOYED', total_chunks: 34, + model_group_id: '1', }, ]; @@ -87,13 +91,17 @@ export const modelHandlers = [ const { searchParams } = req.url; const name = searchParams.get('name'); const ids = searchParams.getAll('ids'); + const modelGroupId = searchParams.get('modelGroupId'); const data = models.filter((model) => { if (name) { return model.name === name; } - if (ids) { + if (ids && ids.length > 0) { return ids.includes(model.id); } + if (modelGroupId) { + return model.model_group_id === modelGroupId; + } return true; }); return res( From c0831d56adbe4f7a8488a5c180bf1d1b32ebd348 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Tue, 30 May 2023 10:20:44 +0800 Subject: [PATCH 55/75] Feature/update model list with model group (#197) * feat: update model aggregate service according new design Signed-off-by: Lin Wang * feat: add link to the model version page Signed-off-by: Lin Wang * feat: remove model to model version and remove legacy opensearch client usage Signed-off-by: Lin Wang * refactor: update MODEL_STATE to MODEL_VERSION_STATE Signed-off-by: Lin Wang * update model-version API endpoint address Signed-off-by: Lin Wang * refactor: rename model group to model Signed-off-by: Lin Wang * refactor: replace queryString with extraQuery Signed-off-by: Lin Wang * feat: rename modelId to id in model version Signed-off-by: Lin Wang * fix: model group id not exists in model register API Signed-off-by: Lin Wang * feat: update GET to POST for model search Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- common/index.ts | 4 +- common/model.ts | 66 ++-- common/model_aggregate.ts | 27 ++ common/model_version.ts | 17 + common/profile.ts | 4 +- public/apis/__mocks__/model.ts | 31 -- public/apis/api_provider.ts | 28 +- public/apis/model.ts | 190 +++-------- public/apis/model_aggregate.ts | 23 +- public/apis/model_group.ts | 72 ---- public/apis/model_version.ts | 177 ++++++++++ .../__tests__/global_breadcrumbs.test.tsx | 6 +- .../common/forms/model_name_field.tsx | 4 +- ..._version_deployment_confirm_modal.test.tsx | 20 +- ...model_version_deployment_confirm_modal.tsx | 4 +- public/components/global_breadcrumbs.tsx | 6 +- public/components/model/model.tsx | 2 +- .../__tests__/model_version_cell.test.tsx | 8 +- .../model_version_status_cell.test.tsx | 20 +- .../model_version_status_detail.test.tsx | 12 +- .../__tests__/model_version_table.test.tsx | 4 +- .../model_version_table_row_actions.test.tsx | 33 +- .../__tests__/model_versions_panel.test.tsx | 58 ++-- .../model_version_cell.tsx | 14 +- .../model_version_status_cell.tsx | 22 +- .../model_version_status_detail.tsx | 35 +- .../model_version_table_row_actions.tsx | 15 +- .../model_versions_panel.tsx | 45 +-- public/components/model/types.ts | 4 +- public/components/model_drawer/index.tsx | 4 +- .../components/model_drawer/version_table.tsx | 12 +- .../model_list/__tests__/model_list.test.tsx | 88 +++-- .../model_list/__tests__/model_table.test.tsx | 67 ++-- .../model_table_uploading_cell.test.tsx | 16 +- public/components/model_list/index.tsx | 128 ++++--- .../model_list/model_confirm_delete_modal.tsx | 6 +- public/components/model_list/model_table.tsx | 109 ++---- .../model_list/model_table_uploading_cell.tsx | 6 +- .../model_version/model_version.tsx | 12 +- .../model_version/version_callout.tsx | 24 +- .../model_version/version_toggler.tsx | 2 +- .../monitoring/model_deployment_table.tsx | 4 +- .../components/monitoring/use_monitoring.ts | 20 +- .../model_file_uploader_manager.test.ts | 6 +- .../__tests__/register_model_api.test.ts | 48 +-- .../__tests__/register_model_form.test.tsx | 9 +- .../model_file_upload_manager.ts | 2 +- .../register_model/register_model.tsx | 2 +- .../register_model/register_model_api.ts | 16 +- server/clusters/create_model_cluster.ts | 15 - server/clusters/model_plugin.ts | 90 ----- server/plugin.ts | 16 +- server/routes/constants.ts | 13 +- server/routes/index.ts | 4 +- server/routes/model_aggregate_router.ts | 30 +- server/routes/model_group_router.ts | 139 -------- server/routes/model_router.ts | 316 +++++------------- server/routes/model_version_router.ts | 311 +++++++++++++++++ server/services/index.ts | 3 +- server/services/model_aggregate_service.ts | 191 ++++------- server/services/model_group_service.ts | 132 -------- server/services/model_service.ts | 283 ++++++---------- server/services/model_version_service.ts | 240 +++++++++++++ server/services/utils/model.ts | 12 +- test/mocks/data/model_aggregate.ts | 10 +- test/mocks/handlers.ts | 8 +- test/mocks/model_group_handlers.ts | 61 ---- test/mocks/model_handlers.ts | 111 ++---- test/mocks/model_version_handlers.ts | 123 +++++++ 69 files changed, 1790 insertions(+), 1850 deletions(-) create mode 100644 common/model_aggregate.ts create mode 100644 common/model_version.ts delete mode 100644 public/apis/__mocks__/model.ts delete mode 100644 public/apis/model_group.ts create mode 100644 public/apis/model_version.ts delete mode 100644 server/clusters/create_model_cluster.ts delete mode 100644 server/clusters/model_plugin.ts delete mode 100644 server/routes/model_group_router.ts create mode 100644 server/routes/model_version_router.ts delete mode 100644 server/services/model_group_service.ts create mode 100644 server/services/model_version_service.ts delete mode 100644 test/mocks/model_group_handlers.ts create mode 100644 test/mocks/model_version_handlers.ts diff --git a/common/index.ts b/common/index.ts index c06353fe..b852f396 100644 --- a/common/index.ts +++ b/common/index.ts @@ -9,5 +9,7 @@ export const PLUGIN_DESC = `ML Commons for OpenSearch eases the development of m export * from './constant'; export * from './status'; -export * from './model'; +export * from './model_version'; export * from './router_paths'; +export * from './model'; +export * from './model_aggregate'; diff --git a/common/model.ts b/common/model.ts index eba23bed..be717653 100644 --- a/common/model.ts +++ b/common/model.ts @@ -3,49 +3,29 @@ * SPDX-License-Identifier: Apache-2.0 */ -// TODO: rename the enum keys accordingly -export enum MODEL_STATE { - loaded = 'DEPLOYED', - trained = 'TRAINED', - unloaded = 'UNDEPLOYED', - uploaded = 'REGISTERED', - uploading = 'REGISTERING', - loading = 'DEPLOYING', - partiallyLoaded = 'PARTIALLY_DEPLOYED', - loadFailed = 'DEPLOY_FAILED', - registerFailed = 'REGISTER_FAILED', -} - -export interface OpenSearchModelBase { +export interface OpenSearchModel { + id: string; + owner: { + backend_roles: string[]; + roles: string[]; + name: string; + }; + latest_version: number; + created_time: number; + last_updated_time: number; name: string; - model_id: string; - model_state: MODEL_STATE; - model_version: string; + description?: string; + access: 'public' | 'restricted' | 'private'; } -export interface OpenSearchSelfTrainedModel extends OpenSearchModelBase { - algorithm: string; -} - -export interface OpenSearchCustomerModel extends OpenSearchModelBase { - algorithm: string; - chunk_number: number; - created_time: number; - description: string; - last_loaded_time?: number; - last_unloaded_time?: number; - last_uploaded_time: number; - model_config: { - all_config?: string; - embedding_dimension: number; - framework_type: string; - model_type: string; - }; - model_content: string; - model_content_hash_value: string; - model_content_size_in_bytes: string; - model_format: string; - total_chunks: number; - version: number; - planning_worker_nodes: string[]; -} +export type ModelSort = + | 'name-asc' + | 'name-desc' + | 'latest_version-asc' + | 'latest_version-desc' + | 'description-asc' + | 'description-desc' + | 'owner.name-asc' + | 'owner.name-desc' + | 'last_updated_time-asc' + | 'last_updated_time-desc'; diff --git a/common/model_aggregate.ts b/common/model_aggregate.ts new file mode 100644 index 00000000..ce1d9148 --- /dev/null +++ b/common/model_aggregate.ts @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export interface ModelAggregateItem { + id: string; + name: string; + description?: string; + latest_version: number; + deployed_versions: string[]; + owner_name: string; + created_time?: number; + last_updated_time: number; +} + +export type ModelAggregateSort = + | 'name-asc' + | 'name-desc' + | 'latest_version-asc' + | 'latest_version-desc' + | 'description-asc' + | 'description-desc' + | 'owner_name-asc' + | 'owner_name-desc' + | 'last_updated_time-asc' + | 'last_updated_time-desc'; diff --git a/common/model_version.ts b/common/model_version.ts new file mode 100644 index 00000000..2fedbe30 --- /dev/null +++ b/common/model_version.ts @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +// TODO: rename the enum keys accordingly +export enum MODEL_VERSION_STATE { + deployed = 'DEPLOYED', + trained = 'TRAINED', + undeployed = 'UNDEPLOYED', + registered = 'REGISTERED', + registering = 'REGISTERING', + deploying = 'DEPLOYING', + partiallyDeployed = 'PARTIALLY_DEPLOYED', + deployFailed = 'DEPLOY_FAILED', + registerFailed = 'REGISTER_FAILED', +} diff --git a/common/profile.ts b/common/profile.ts index 1585cc5f..fdfc1803 100644 --- a/common/profile.ts +++ b/common/profile.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { MODEL_STATE } from './model'; +import { MODEL_VERSION_STATE } from './model'; export interface OpenSearchMLCommonsProfile { nodes: { @@ -12,7 +12,7 @@ export interface OpenSearchMLCommonsProfile { [key: string]: { worker_nodes: string[]; predictor: string; - model_state: MODEL_STATE; + model_state: MODEL_VERSION_STATE; predict_request_stats: { count: number; max: number; diff --git a/public/apis/__mocks__/model.ts b/public/apis/__mocks__/model.ts deleted file mode 100644 index 1435b938..00000000 --- a/public/apis/__mocks__/model.ts +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -export class Model { - public search() { - return Promise.resolve({ - data: [ - { - id: 'model-1-id', - name: 'model-1-name', - current_worker_node_count: 1, - planning_worker_node_count: 3, - planning_worker_nodes: ['node1', 'node2', 'node3'], - }, - ], - total_models: 1, - }); - } - - public upload({ url }: { url?: string }) { - return Promise.resolve( - url === undefined ? { model_id: 'model-id-1' } : { task_id: 'task-id-1' } - ); - } - - public uploadChunk() { - return Promise.resolve(); - } -} diff --git a/public/apis/api_provider.ts b/public/apis/api_provider.ts index ebbc6465..9f37044a 100644 --- a/public/apis/api_provider.ts +++ b/public/apis/api_provider.ts @@ -4,51 +4,51 @@ */ import { Connector } from './connector'; -import { Model } from './model'; +import { ModelVersion } from './model_version'; import { ModelAggregate } from './model_aggregate'; -import { ModelGroup } from './model_group'; +import { Model } from './model'; import { ModelRepository } from './model_repository'; import { Profile } from './profile'; import { Security } from './security'; import { Task } from './task'; const apiInstanceStore: { - model: Model | undefined; + modelVersion: ModelVersion | undefined; modelAggregate: ModelAggregate | undefined; profile: Profile | undefined; connector: Connector | undefined; security: Security | undefined; task: Task | undefined; modelRepository: ModelRepository | undefined; - modelGroup: ModelGroup | undefined; + model: Model | undefined; } = { - model: undefined, + modelVersion: undefined, modelAggregate: undefined, profile: undefined, connector: undefined, security: undefined, task: undefined, modelRepository: undefined, - modelGroup: undefined, + model: undefined, }; export class APIProvider { public static getAPI(type: 'task'): Task; - public static getAPI(type: 'model'): Model; + public static getAPI(type: 'modelVersion'): ModelVersion; public static getAPI(type: 'modelAggregate'): ModelAggregate; public static getAPI(type: 'profile'): Profile; public static getAPI(type: 'connector'): Connector; public static getAPI(type: 'security'): Security; public static getAPI(type: 'modelRepository'): ModelRepository; - public static getAPI(type: 'modelGroup'): ModelGroup; + public static getAPI(type: 'model'): Model; public static getAPI(type: keyof typeof apiInstanceStore) { if (apiInstanceStore[type]) { return apiInstanceStore[type]!; } switch (type) { - case 'model': { - const newInstance = new Model(); - apiInstanceStore.model = newInstance; + case 'modelVersion': { + const newInstance = new ModelVersion(); + apiInstanceStore.modelVersion = newInstance; return newInstance; } case 'modelAggregate': { @@ -81,9 +81,9 @@ export class APIProvider { apiInstanceStore.modelRepository = newInstance; return newInstance; } - case 'modelGroup': { - const newInstance = new ModelGroup(); - apiInstanceStore.modelGroup = newInstance; + case 'model': { + const newInstance = new Model(); + apiInstanceStore.model = newInstance; return newInstance; } } diff --git a/public/apis/model.ts b/public/apis/model.ts index 28f2e83a..8ff6b635 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -3,172 +3,64 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { MODEL_STATE } from '../../common'; -import { - MODEL_API_ENDPOINT, - MODEL_LOAD_API_ENDPOINT, - MODEL_UNLOAD_API_ENDPOINT, - MODEL_UPLOAD_API_ENDPOINT, - MODEL_PROFILE_API_ENDPOINT, -} from '../../server/routes/constants'; +import { OpenSearchModel } from '../../common'; +import { MODEL_API_ENDPOINT } from '../../server/routes/constants'; import { InnerHttpProvider } from './inner_http_provider'; -export interface ModelSearchItem { - id: string; - name: string; - // TODO: the new version details API may not have this field, because model description is on model group level - // we should fix this when integrating the new API changes - description: string; - algorithm: string; - model_state: MODEL_STATE; - model_version: string; - current_worker_node_count: number; - planning_worker_node_count: number; - planning_worker_nodes: string[]; - connector_id?: string; - connector?: { - name: string; - description?: string; - }; - model_config?: { - all_config?: string; - embedding_dimension: number; - framework_type: string; - model_type: string; - }; - last_updated_time: number; - created_time: number; - last_registered_time?: number; - last_deployed_time?: number; - last_undeployed_time?: number; -} - -export interface ModelDetail extends ModelSearchItem { - content: string; - last_updated_time: number; - created_time: number; - model_format: string; -} - export interface ModelSearchResponse { - data: ModelSearchItem[]; + data: OpenSearchModel[]; total_models: number; } -export interface ModelLoadResponse { - task_id: string; - status: string; -} - -export interface ModelUnloadResponse { - [nodeId: string]: { - stats: { - [modelId: string]: string; - }; - }; -} - -export interface ModelProfileResponse { - nodes: { - [nodeId: string]: { - models: { - [modelId: string]: { - model_state: string; - predictor: string; - worker_nodes: string[]; - }; - }; - }; - }; -} - -interface UploadModelBase { - name: string; - version?: string; - description?: string; - modelFormat: string; - modelConfig: Record; - modelGroupId: string; -} +export class Model { + public register(body: { + name: string; + description?: string; + modelAccessMode: 'public' | 'restricted' | 'private'; + backendRoles?: string[]; + addAllBackendRoles?: boolean; + }) { + return InnerHttpProvider.getHttp().post<{ model_id: string; status: string }>( + MODEL_API_ENDPOINT, + { + body: JSON.stringify(body), + } + ); + } -export interface UploadModelByURL extends UploadModelBase { - url: string; -} + public update({ id, name, description }: { id: string; name?: string; description?: string }) { + return InnerHttpProvider.getHttp().put<{ status: 'success' }>(`${MODEL_API_ENDPOINT}/${id}`, { + body: JSON.stringify({ + name, + description, + }), + }); + } -export interface UploadModelByChunk extends UploadModelBase { - modelContentHashValue: string; - totalChunks: number; -} + public delete(id: string) { + return InnerHttpProvider.getHttp().delete<{ status: 'success' }>(`${MODEL_API_ENDPOINT}/${id}`); + } -export class Model { public search(query: { - algorithms?: string[]; ids?: string[]; - sort?: string[]; name?: string; from: number; size: number; - states?: MODEL_STATE[]; - nameOrId?: string; - extraQuery?: Record; - dataSourceId?: string; - versionOrKeyword?: string; - modelGroupId?: string; + extraQuery?: string; }) { - const { extraQuery, dataSourceId, ...restQuery } = query; + const { extraQuery, ...restQuery } = query; return InnerHttpProvider.getHttp().get(MODEL_API_ENDPOINT, { - query: extraQuery - ? { ...restQuery, extra_query: JSON.stringify(extraQuery), data_source_id: dataSourceId } - : { ...restQuery, data_source_id: dataSourceId }, - }); - } - - public delete(modelId: string) { - return InnerHttpProvider.getHttp().delete(`${MODEL_API_ENDPOINT}/${modelId}`); - } - - public getOne(modelId: string) { - return InnerHttpProvider.getHttp().get(`${MODEL_API_ENDPOINT}/${modelId}`); - } - - public load(modelId: string) { - return InnerHttpProvider.getHttp().post( - `${MODEL_LOAD_API_ENDPOINT}/${modelId}` - ); - } - - public unload(modelId: string) { - return InnerHttpProvider.getHttp().post( - `${MODEL_UNLOAD_API_ENDPOINT}/${modelId}` - ); - } - - public profile(modelId: string) { - return InnerHttpProvider.getHttp().get( - `${MODEL_PROFILE_API_ENDPOINT}/${modelId}` - ); - } - - public upload( - model: T - ): Promise< - T extends UploadModelByURL - ? { task_id: string } - : T extends UploadModelByChunk - ? { model_id: string } - : never - > { - return InnerHttpProvider.getHttp().post(MODEL_UPLOAD_API_ENDPOINT, { - body: JSON.stringify(model), + query: extraQuery ? { ...restQuery, extra_query: JSON.stringify(extraQuery) } : restQuery, }); } - public uploadChunk(modelId: string, chunkId: string, chunkContent: Blob) { - return InnerHttpProvider.getHttp().post(`${MODEL_API_ENDPOINT}/${modelId}/chunk/${chunkId}`, { - body: chunkContent, - headers: { - 'Content-Type': 'application/octet-stream', - }, - }); - } + public getOne = async (id: string) => { + return ( + await this.search({ + ids: [id], + from: 0, + size: 1, + }) + ).data[0]; + }; } diff --git a/public/apis/model_aggregate.ts b/public/apis/model_aggregate.ts index 24ca58c5..2c53a71f 100644 --- a/public/apis/model_aggregate.ts +++ b/public/apis/model_aggregate.ts @@ -3,32 +3,23 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { MODEL_VERSION_STATE, ModelAggregateItem, ModelAggregateSort } from '../../common'; import { MODEL_AGGREGATE_API_ENDPOINT } from '../../server/routes/constants'; -import { InnerHttpProvider } from './inner_http_provider'; -import { MODEL_STATE } from '../../common/model'; -export interface ModelAggregateSearchItem { - name: string; - description?: string; - latest_version: string; - latest_version_state: MODEL_STATE; - deployed_versions: string[]; - owner: string; - created_time?: number; -} +import { InnerHttpProvider } from './inner_http_provider'; interface ModelAggregateSearchResponse { - data: ModelAggregateSearchItem[]; + data: ModelAggregateItem[]; total_models: number; } export class ModelAggregate { public search(query: { - size: number; from: number; - sort: 'created_time'; - order: 'desc' | 'asc'; - name?: string; + size: number; + sort?: ModelAggregateSort; + states?: MODEL_VERSION_STATE[]; + extraQuery?: string; }) { return InnerHttpProvider.getHttp().get( MODEL_AGGREGATE_API_ENDPOINT, diff --git a/public/apis/model_group.ts b/public/apis/model_group.ts deleted file mode 100644 index dd9f7d97..00000000 --- a/public/apis/model_group.ts +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { MODEL_GROUP_API_ENDPOINT } from '../../server/routes/constants'; -import { InnerHttpProvider } from './inner_http_provider'; - -interface ModelGroupSearchItem { - id: string; - owner: { - backend_roles: string[]; - roles: string[]; - name: string; - }; - latest_version: number; - created_time: number; - last_updated_time: number; - name: string; - description?: string; - access: 'public' | 'restricted' | 'private'; -} - -export interface ModelGroupSearchResponse { - data: ModelGroupSearchItem[]; - total_model_groups: number; -} - -export class ModelGroup { - public register(body: { - name: string; - description?: string; - modelAccessMode: 'public' | 'restricted' | 'private'; - backendRoles?: string[]; - addAllBackendRoles?: boolean; - }) { - return InnerHttpProvider.getHttp().post<{ model_group_id: string; status: 'CREATED' }>( - MODEL_GROUP_API_ENDPOINT, - { - body: JSON.stringify(body), - } - ); - } - - public update({ id, name, description }: { id: string; name?: string; description?: string }) { - return InnerHttpProvider.getHttp().put<{ status: 'success' }>( - `${MODEL_GROUP_API_ENDPOINT}/${id}`, - { - body: JSON.stringify({ - name, - description, - }), - } - ); - } - - public delete(id: string) { - return InnerHttpProvider.getHttp().delete<{ status: 'success' }>( - `${MODEL_GROUP_API_ENDPOINT}/${id}` - ); - } - - public search(query: { id?: string; name?: string; from: number; size: number }) { - return InnerHttpProvider.getHttp().get(MODEL_GROUP_API_ENDPOINT, { - query, - }); - } - - public getOne = async (id: string) => { - return (await this.search({ id, from: 0, size: 1 })).data[0]; - }; -} diff --git a/public/apis/model_version.ts b/public/apis/model_version.ts new file mode 100644 index 00000000..20de2070 --- /dev/null +++ b/public/apis/model_version.ts @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { MODEL_VERSION_STATE } from '../../common'; +import { + MODEL_VERSION_API_ENDPOINT, + MODEL_VERSION_LOAD_API_ENDPOINT, + MODEL_VERSION_UNLOAD_API_ENDPOINT, + MODEL_VERSION_UPLOAD_API_ENDPOINT, + MODEL_VERSION_PROFILE_API_ENDPOINT, +} from '../../server/routes/constants'; +import { InnerHttpProvider } from './inner_http_provider'; + +export interface ModelVersionSearchItem { + id: string; + name: string; + // TODO: the new version details API may not have this field, because model description is on model group level + // we should fix this when integrating the new API changes + description: string; + algorithm: string; + model_state: MODEL_VERSION_STATE; + model_version: string; + current_worker_node_count: number; + planning_worker_node_count: number; + planning_worker_nodes: string[]; + model_config?: { + all_config?: string; + embedding_dimension: number; + framework_type: string; + model_type: string; + }; + last_updated_time: number; + created_time: number; + last_registered_time?: number; + last_deployed_time?: number; + last_undeployed_time?: number; +} + +export interface ModelVersionDetail extends ModelVersionSearchItem { + content: string; + last_updated_time: number; + created_time: number; + model_format: string; +} + +export interface ModelVersionSearchResponse { + data: ModelVersionSearchItem[]; + total_model_versions: number; +} + +export interface ModelVersionLoadResponse { + task_id: string; + status: string; +} + +export interface ModelVersionUnloadResponse { + [nodeId: string]: { + stats: { + [id: string]: string; + }; + }; +} + +export interface ModelVersionProfileResponse { + nodes: { + [nodeId: string]: { + models: { + [id: string]: { + model_state: string; + predictor: string; + worker_nodes: string[]; + }; + }; + }; + }; +} + +interface UploadModelBase { + name: string; + version?: string; + description?: string; + modelFormat: string; + modelConfig: Record; + id: string; +} + +export interface UploadModelByURL extends UploadModelBase { + url: string; +} + +export interface UploadModelByChunk extends UploadModelBase { + modelContentHashValue: string; + totalChunks: number; +} + +export class ModelVersion { + public search({ + extraQuery, + dataSourceId, + ...restQuery + }: { + algorithms?: string[]; + ids?: string[]; + sort?: string[]; + name?: string; + from: number; + size: number; + states?: MODEL_VERSION_STATE[]; + nameOrId?: string; + versionOrKeyword?: string; + modelIds?: string[]; + extraQuery?: Record; + dataSourceId?: string; + }) { + return InnerHttpProvider.getHttp().get(MODEL_VERSION_API_ENDPOINT, { + query: extraQuery + ? { ...restQuery, extra_query: JSON.stringify(extraQuery), data_source_id: dataSourceId } + : { ...restQuery, data_source_id: dataSourceId }, + }); + } + + public delete(id: string) { + return InnerHttpProvider.getHttp().delete(`${MODEL_VERSION_API_ENDPOINT}/${id}`); + } + + public getOne(id: string) { + return InnerHttpProvider.getHttp().get( + `${MODEL_VERSION_API_ENDPOINT}/${id}` + ); + } + + public load(id: string) { + return InnerHttpProvider.getHttp().post( + `${MODEL_VERSION_LOAD_API_ENDPOINT}/${id}` + ); + } + + public unload(id: string) { + return InnerHttpProvider.getHttp().post( + `${MODEL_VERSION_UNLOAD_API_ENDPOINT}/${id}` + ); + } + + public profile(id: string) { + return InnerHttpProvider.getHttp().get( + `${MODEL_VERSION_PROFILE_API_ENDPOINT}/${id}` + ); + } + + public upload( + model: T + ): Promise< + T extends UploadModelByURL + ? { task_id: string } + : T extends UploadModelByChunk + ? { model_version_id: string } + : never + > { + return InnerHttpProvider.getHttp().post(MODEL_VERSION_UPLOAD_API_ENDPOINT, { + body: JSON.stringify(model), + }); + } + + public uploadChunk(id: string, chunkId: string, chunkContent: Blob) { + return InnerHttpProvider.getHttp().post( + `${MODEL_VERSION_API_ENDPOINT}/${id}/chunk/${chunkId}`, + { + body: chunkContent, + headers: { + 'Content-Type': 'application/octet-stream', + }, + } + ); + } +} diff --git a/public/components/__tests__/global_breadcrumbs.test.tsx b/public/components/__tests__/global_breadcrumbs.test.tsx index e2fe5cb8..181c2213 100644 --- a/public/components/__tests__/global_breadcrumbs.test.tsx +++ b/public/components/__tests__/global_breadcrumbs.test.tsx @@ -6,7 +6,7 @@ import React from 'react'; import { GlobalBreadcrumbs } from '../global_breadcrumbs'; import { history, render, waitFor, act } from '../../../test/test_utils'; -import { Model, ModelDetail } from '../../apis/model'; +import { ModelVersion, ModelVersionDetail } from '../../apis/model_version'; describe('', () => { it('should call onBreadcrumbsChange with overview title', () => { @@ -102,7 +102,7 @@ describe('', () => { it('should NOT call onBreadcrumbs with steal breadcrumbs after pathname changed', async () => { jest.useFakeTimers(); const onBreadcrumbsChange = jest.fn(); - const modelGetOneMock = jest.spyOn(Model.prototype, 'getOne').mockImplementation( + const modelGetOneMock = jest.spyOn(ModelVersion.prototype, 'getOne').mockImplementation( (id) => new Promise((resolve) => { setTimeout( @@ -111,7 +111,7 @@ describe('', () => { id, name: `model${id}`, model_version: `1.0.${id}`, - } as ModelDetail); + } as ModelVersionDetail); }, id === '2' ? 1000 : 0 ); diff --git a/public/components/common/forms/model_name_field.tsx b/public/components/common/forms/model_name_field.tsx index 514c42ab..658e4ed8 100644 --- a/public/components/common/forms/model_name_field.tsx +++ b/public/components/common/forms/model_name_field.tsx @@ -24,12 +24,12 @@ interface ModelNameFieldProps { } const isDuplicateModelName = async (name: string) => { - const searchResult = await APIProvider.getAPI('modelGroup').search({ + const searchResult = await APIProvider.getAPI('model').search({ name, from: 0, size: 1, }); - return searchResult.total_model_groups >= 1; + return searchResult.total_models >= 1; }; export const ModelNameField = ({ diff --git a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx index b4b3016a..0a0fb16f 100644 --- a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx +++ b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx @@ -9,7 +9,7 @@ import { EuiToast } from '@elastic/eui'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionDeploymentConfirmModal } from '../model_version_deployment_confirm_modal'; -import { Model } from '../../../../apis/model'; +import { ModelVersion } from '../../../../apis/model_version'; import * as PluginContext from '../../../../../../../src/plugins/opensearch_dashboards_react/public'; import { MountWrapper } from '../../../../../../../src/core/public/utils'; @@ -98,7 +98,7 @@ describe('', () => { it('should call model load after deploy button clicked', async () => { const modelLoadMock = jest - .spyOn(Model.prototype, 'load') + .spyOn(ModelVersion.prototype, 'load') .mockReturnValue(Promise.resolve({ task_id: 'foo', status: 'succeeded' })); render( ', () => { it('should show error toast if model load throw error', async () => { const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); const modelLoadMock = jest - .spyOn(Model.prototype, 'load') + .spyOn(ModelVersion.prototype, 'load') .mockRejectedValue(new Error('error')); render( ', () => { it('should show full error after "See full error" clicked', async () => { const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); const modelLoadMock = jest - .spyOn(Model.prototype, 'load') + .spyOn(ModelVersion.prototype, 'load') .mockRejectedValue(new Error('This is a full error message.')); render( ', () => { it('should hide full error after close button clicked', async () => { const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); const modelLoadMock = jest - .spyOn(Model.prototype, 'load') + .spyOn(ModelVersion.prototype, 'load') .mockRejectedValue(new Error('This is a full error message.')); render( ', () => { }); it('should call model unload after undeploy button clicked', async () => { - const modelLoadMock = jest.spyOn(Model.prototype, 'unload').mockImplementation(); + const modelLoadMock = jest.spyOn(ModelVersion.prototype, 'unload').mockImplementation(); render( ', () => { }, }, }); - const modelLoadMock = jest.spyOn(Model.prototype, 'unload').mockImplementation(); + const modelLoadMock = jest.spyOn(ModelVersion.prototype, 'unload').mockImplementation(); render( ', () => { it('should show error toast if model unload throw error', async () => { const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); const modelLoadMock = jest - .spyOn(Model.prototype, 'unload') + .spyOn(ModelVersion.prototype, 'unload') .mockRejectedValue(new Error('error')); render( ', () => { it('should show full error after "See full error" clicked', async () => { const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); const modelLoadMock = jest - .spyOn(Model.prototype, 'unload') + .spyOn(ModelVersion.prototype, 'unload') .mockRejectedValue(new Error('This is a full error message.')); render( ', () => { it('should hide full error after close button clicked', async () => { const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); const modelLoadMock = jest - .spyOn(Model.prototype, 'unload') + .spyOn(ModelVersion.prototype, 'unload') .mockRejectedValue(new Error('This is a full error message.')); render( { staticBreadcrumbs: baseModelRegistryBreadcrumbs, // TODO: Change to model group API asyncBreadcrumbsLoader: () => - APIProvider.getAPI('modelGroup') + APIProvider.getAPI('model') .getOne(modelId) .then( (model) => @@ -80,7 +80,7 @@ const getModelBreadcrumbs = (basename: string, matchedParams: {}) => { return { staticBreadcrumbs: baseModelRegistryBreadcrumbs, asyncBreadcrumbsLoader: () => { - return APIProvider.getAPI('modelGroup') + return APIProvider.getAPI('model') .getOne(modelId) .then( (model) => @@ -106,7 +106,7 @@ const getModelVersionBreadcrumbs = (basename: string, matchedParams: {}) => { staticBreadcrumbs: baseModelRegistryBreadcrumbs, // TODO: Change to model group API asyncBreadcrumbsLoader: () => - APIProvider.getAPI('model') + APIProvider.getAPI('modelVersion') .getOne(modelId) .then( (model) => diff --git a/public/components/model/model.tsx b/public/components/model/model.tsx index 41c4b34f..ca5acff3 100644 --- a/public/components/model/model.tsx +++ b/public/components/model/model.tsx @@ -26,7 +26,7 @@ import { ModelTagsPanel } from './model_tags_panel'; export const Model = () => { const { id: modelId } = useParams<{ id: string }>(); - const { data, loading, error } = useFetcher(APIProvider.getAPI('modelGroup').getOne, modelId); + const { data, loading, error } = useFetcher(APIProvider.getAPI('model').getOne, modelId); const tabs = useMemo( () => [ { diff --git a/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx index 4e5a24aa..ee6a8c61 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_cell.test.tsx @@ -7,7 +7,7 @@ import React from 'react'; import { render, screen } from '../../../../../test/test_utils'; import { ModelVersionCell } from '../model_version_cell'; -import { MODEL_STATE } from '../../../../../common'; +import { MODEL_VERSION_STATE } from '../../../../../common'; const setup = (options: { columnId: string; isDetails?: boolean }) => render( @@ -16,7 +16,7 @@ const setup = (options: { columnId: string; isDetails?: boolean }) => id: '1', name: 'model-1', version: '1.0.0', - state: MODEL_STATE.uploading, + state: MODEL_VERSION_STATE.registering, tags: {}, lastUpdatedTime: 1682604957236, createdTime: 1682604957236, @@ -48,7 +48,7 @@ describe('', () => { id: '1', name: 'model-1', version: '1.0.0', - state: MODEL_STATE.loaded, + state: MODEL_VERSION_STATE.deployed, tags: {}, lastUpdatedTime: 1682604957236, createdTime: 1682604957236, @@ -65,7 +65,7 @@ describe('', () => { id: '1', name: 'model-1', version: '1.0.0', - state: MODEL_STATE.partiallyLoaded, + state: MODEL_VERSION_STATE.partiallyDeployed, tags: {}, lastUpdatedTime: 1682604957236, createdTime: 1682604957236, diff --git a/public/components/model/model_versions_panel/__tests__/model_version_status_cell.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_status_cell.test.tsx index 735bd371..ce344fb7 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_status_cell.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_status_cell.test.tsx @@ -7,44 +7,46 @@ import React from 'react'; import { render, screen } from '../../../../../test/test_utils'; import { ModelVersionStatusCell } from '../model_version_status_cell'; -import { MODEL_STATE } from '../../../../../common'; +import { MODEL_VERSION_STATE } from '../../../../../common'; describe('', () => { it('should display "-" if unsupported state provided', async () => { - render(); + render(); expect(screen.getByText('-')).toBeInTheDocument(); }); it('should display "In progress..." when state is "uploading" or "loading"', async () => { - const { rerender } = render(); + const { rerender } = render(); expect(screen.getByText('In progress...')).toBeInTheDocument(); - rerender(); + rerender(); expect(screen.getByText('In progress...')).toBeInTheDocument(); }); it('should display "Success" when state is "uploaded" or "loaded"', async () => { - const { rerender } = render(); + const { rerender } = render(); expect(screen.getByText('Success')).toBeInTheDocument(); - rerender(); + rerender(); expect(screen.getByText('Success')).toBeInTheDocument(); }); it('should display "Error" when state is "registerFailed" or "loadedFailed"', async () => { - const { rerender } = render(); + const { rerender } = render( + + ); expect(screen.getByText('Error')).toBeInTheDocument(); - rerender(); + rerender(); expect(screen.getByText('Error')).toBeInTheDocument(); }); it('should display "Warning" when state is "partialLoaded"', async () => { - render(); + render(); expect(screen.getByText('Warning')).toBeInTheDocument(); }); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx index d620910b..87abf046 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_status_detail.test.tsx @@ -8,7 +8,7 @@ import userEvent from '@testing-library/user-event'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionStatusDetail } from '../model_version_status_detail'; -import { MODEL_STATE } from '../../../../../common'; +import { MODEL_VERSION_STATE } from '../../../../../common'; describe('', () => { it('should render "In progress..." and uploading tip', async () => { @@ -17,7 +17,7 @@ describe('', () => { id="1" name="model-1" version="1" - state={MODEL_STATE.uploading} + state={MODEL_VERSION_STATE.registering} createdTime={1683276773541} /> ); @@ -36,7 +36,7 @@ describe('', () => { id="1" name="model-1" version="1" - state={MODEL_STATE.loaded} + state={MODEL_VERSION_STATE.deployed} createdTime={1683276773541} lastDeployedTime={1683276773541} /> @@ -60,7 +60,7 @@ describe('', () => { name="model-1" version="1" createdTime={1683276773541} - state={MODEL_STATE.trained} + state={MODEL_VERSION_STATE.trained} /> ); @@ -73,7 +73,7 @@ describe('', () => { id="1" name="model-1" version="1.0.0" - state={MODEL_STATE.loadFailed} + state={MODEL_VERSION_STATE.deployFailed} createdTime={1683276773541} lastDeployedTime={1683276773541} /> @@ -92,7 +92,7 @@ describe('', () => { id="1" name="model-1" version="1.0.0" - state={MODEL_STATE.registerFailed} + state={MODEL_VERSION_STATE.registerFailed} createdTime={1683276773541} /> ); diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx index 0b2940e4..759e0338 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx @@ -9,14 +9,14 @@ import { within } from '@testing-library/dom'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionTable } from '../model_version_table'; -import { MODEL_STATE } from '../../../../../common'; +import { MODEL_VERSION_STATE } from '../../../../../common'; const versions = [ { id: '1', name: 'model-1', version: '1.0.0', - state: MODEL_STATE.uploading, + state: MODEL_VERSION_STATE.registering, tags: { 'Accuracy: test': 0.98, 'Accuracy: train': 0.99 }, lastUpdatedTime: 1682676759143, createdTime: 1682676759143, diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx index c479e069..39c49d96 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_table_row_actions.test.tsx @@ -8,16 +8,16 @@ import userEvent from '@testing-library/user-event'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionTableRowActions } from '../model_version_table_row_actions'; -import { MODEL_STATE } from '../../../../../common'; +import { MODEL_VERSION_STATE } from '../../../../../common'; -const setup = (state: MODEL_STATE) => { +const setup = (state: MODEL_VERSION_STATE) => { return render(); }; describe('', () => { it('should render "actions icon" and "Delete" button after clicked', async () => { const user = userEvent.setup(); - setup(MODEL_STATE.uploading); + setup(MODEL_VERSION_STATE.registering); expect(screen.getByLabelText('show actions')).toBeInTheDocument(); await user.click(screen.getByLabelText('show actions')); @@ -27,7 +27,7 @@ describe('', () => { it('should render "Upload new artifact" button for REGISTER_FAILED state', async () => { const user = userEvent.setup(); - setup(MODEL_STATE.registerFailed); + setup(MODEL_VERSION_STATE.registerFailed); await user.click(screen.getByLabelText('show actions')); expect(screen.getByText('Upload new artifact')).toBeInTheDocument(); @@ -35,19 +35,24 @@ describe('', () => { it('should render "Deploy" button for REGISTERED, DEPLOY_FAILED and UNDEPLOYED state', async () => { const user = userEvent.setup(); - const { rerender } = setup(MODEL_STATE.uploaded); + const { rerender } = setup(MODEL_VERSION_STATE.registered); await user.click(screen.getByLabelText('show actions')); expect(screen.getByText('Deploy')).toBeInTheDocument(); rerender( - + ); expect(screen.getByText('Deploy')).toBeInTheDocument(); rerender( ', () => { it('should render "Undeploy" button for DEPLOYED and PARTIALLY_DEPLOYED state', async () => { const user = userEvent.setup(); - const { rerender } = setup(MODEL_STATE.loaded); + const { rerender } = setup(MODEL_VERSION_STATE.deployed); await user.click(screen.getByLabelText('show actions')); expect(screen.getByText('Undeploy')).toBeInTheDocument(); rerender( ', () => { it('should call close popover after menuitem click', async () => { const user = userEvent.setup(); - setup(MODEL_STATE.loaded); + setup(MODEL_VERSION_STATE.deployed); await user.click(screen.getByLabelText('show actions')); await user.click(screen.getByText('Delete')); @@ -88,7 +93,7 @@ describe('', () => { it('should show deploy confirm modal after "Deploy" button clicked', async () => { const user = userEvent.setup(); - setup(MODEL_STATE.uploaded); + setup(MODEL_VERSION_STATE.registered); await user.click(screen.getByLabelText('show actions')); await user.click(screen.getByText('Deploy')); @@ -100,7 +105,7 @@ describe('', () => { it('should hide deploy confirm modal after "Cancel" button clicked', async () => { const user = userEvent.setup(); - setup(MODEL_STATE.uploaded); + setup(MODEL_VERSION_STATE.registered); await user.click(screen.getByLabelText('show actions')); await user.click(screen.getByText('Deploy')); await user.click(screen.getByText('Cancel')); @@ -110,7 +115,7 @@ describe('', () => { it('should show undeploy confirm modal after "Deploy" button clicked', async () => { const user = userEvent.setup(); - setup(MODEL_STATE.loaded); + setup(MODEL_VERSION_STATE.deployed); await user.click(screen.getByLabelText('show actions')); await user.click(screen.getByText('Undeploy')); @@ -124,7 +129,7 @@ describe('', () => { it('should hide undeploy confirm modal after "Cancel" button clicked', async () => { const user = userEvent.setup(); - setup(MODEL_STATE.loaded); + setup(MODEL_VERSION_STATE.deployed); await user.click(screen.getByLabelText('show actions')); await user.click(screen.getByText('Undeploy')); await user.click(screen.getByText('Cancel')); diff --git a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx index acb6a109..c77a9e2d 100644 --- a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx @@ -8,7 +8,7 @@ import userEvent from '@testing-library/user-event'; import * as euiExports from '@elastic/eui'; import { render, screen, waitFor, within } from '../../../../../test/test_utils'; -import { Model } from '../../../../apis/model'; +import { ModelVersion } from '../../../../apis/model_version'; import { ModelVersionsPanel } from '../model_versions_panel'; import * as PluginContext from '../../../../../../../src/plugins/opensearch_dashboards_react/public'; @@ -65,12 +65,14 @@ describe('', () => { 'should call model search API again after refresh button clicked', async () => { const euiDataGridMock = mockEuiDataGrid(); - const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { - return { - data: [], - total_models: 0, - }; - }); + const searchMock = jest + .spyOn(ModelVersion.prototype, 'search') + .mockImplementation(async () => { + return { + data: [], + total_model_versions: 0, + }; + }); render(); @@ -89,7 +91,7 @@ describe('', () => { 'should call model search with consistent state parameters after deployed state filter applied', async () => { const euiDataGridMock = mockEuiDataGrid(); - const searchMock = jest.spyOn(Model.prototype, 'search'); + const searchMock = jest.spyOn(ModelVersion.prototype, 'search'); render(); @@ -125,7 +127,7 @@ describe('', () => { it('should render loading screen when calling model search API', async () => { const euiDataGridMock = mockEuiDataGrid(); const searchMock = jest - .spyOn(Model.prototype, 'search') + .spyOn(ModelVersion.prototype, 'search') .mockImplementation(() => new Promise(() => {})); render(); @@ -137,7 +139,7 @@ describe('', () => { it('should render error screen and show error toast after call model search failed', async () => { const euiDataGridMock = mockEuiDataGrid(); - const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { + const searchMock = jest.spyOn(ModelVersion.prototype, 'search').mockImplementation(async () => { throw new Error(); }); const dangerMock = jest.fn(); @@ -166,10 +168,10 @@ describe('', () => { it('should render empty screen if model no versions', async () => { const euiDataGridMock = mockEuiDataGrid(); - const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { + const searchMock = jest.spyOn(ModelVersion.prototype, 'search').mockImplementation(async () => { return { data: [], - total_models: 0, + total_model_versions: 0, }; }); @@ -198,12 +200,14 @@ describe('', () => { expect(screen.getByTitle('Status')).toBeInTheDocument(); }); - const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { - return { - data: [], - total_models: 0, - }; - }); + const searchMock = jest + .spyOn(ModelVersion.prototype, 'search') + .mockImplementation(async () => { + return { + data: [], + total_model_versions: 0, + }; + }); await userEvent.click(screen.getByTitle('Status')); await userEvent.click(screen.getByRole('option', { name: 'In progress...' })); @@ -229,12 +233,14 @@ describe('', () => { expect(screen.getByTitle('Status')).toBeInTheDocument(); }); - const searchMock = jest.spyOn(Model.prototype, 'search').mockImplementation(async () => { - return { - data: [], - total_models: 0, - }; - }); + const searchMock = jest + .spyOn(ModelVersion.prototype, 'search') + .mockImplementation(async () => { + return { + data: [], + total_model_versions: 0, + }; + }); await userEvent.click(screen.getByTitle('Status')); await userEvent.click(screen.getByRole('option', { name: 'In progress...' })); @@ -244,7 +250,7 @@ describe('', () => { expect(searchMock).toHaveBeenCalledWith({ from: 0, size: 25, - modelGroupId: '1', + modelIds: ['1'], }); searchMock.mockRestore(); @@ -260,7 +266,7 @@ describe('', () => { await waitFor(() => { expect(screen.getByTestId('dataGridHeaderCell-version')).toBeInTheDocument(); }); - const searchMock = jest.spyOn(Model.prototype, 'search'); + const searchMock = jest.spyOn(ModelVersion.prototype, 'search'); await userEvent.click( within(screen.getByTestId('dataGridHeaderCell-version')).getByText('Version') diff --git a/public/components/model/model_versions_panel/model_version_cell.tsx b/public/components/model/model_versions_panel/model_version_cell.tsx index 6445baab..78a0d71e 100644 --- a/public/components/model/model_versions_panel/model_version_cell.tsx +++ b/public/components/model/model_versions_panel/model_version_cell.tsx @@ -5,9 +5,10 @@ import React from 'react'; import { get } from 'lodash'; -import { EuiBadge, EuiText } from '@elastic/eui'; +import { EuiBadge, EuiLink, EuiText } from '@elastic/eui'; +import { Link, generatePath } from 'react-router-dom'; -import { MODEL_STATE } from '../../../../common'; +import { MODEL_VERSION_STATE, routerPaths } from '../../../../common'; import { VersionTableDataItem } from '../types'; import { UiSettingDateFormatTime } from '../../common'; @@ -37,13 +38,18 @@ export const ModelVersionCell = ({ data, columnId, isDetails }: ModelVersionCell } switch (columnId) { case 'version': - return {data.version}; + return ( + + {data.version} + + ); case 'status': { return ; } case 'state': { const deployed = - data.state === MODEL_STATE.loaded || data.state === MODEL_STATE.partiallyLoaded; + data.state === MODEL_VERSION_STATE.deployed || + data.state === MODEL_VERSION_STATE.partiallyDeployed; return ( {deployed ? 'Deployed' : 'Not deployed'} diff --git a/public/components/model/model_versions_panel/model_version_status_cell.tsx b/public/components/model/model_versions_panel/model_version_status_cell.tsx index 92f3dafe..a96296d4 100644 --- a/public/components/model/model_versions_panel/model_version_status_cell.tsx +++ b/public/components/model/model_versions_panel/model_version_status_cell.tsx @@ -6,20 +6,20 @@ import React from 'react'; import { EuiHealth } from '@elastic/eui'; -import { MODEL_STATE } from '../../../../common'; +import { MODEL_VERSION_STATE } from '../../../../common'; -const state2StatusContentMap: { [key in MODEL_STATE]?: [string, string] } = { - [MODEL_STATE.uploading]: ['#AFB0B3', 'In progress...'], - [MODEL_STATE.loading]: ['#AFB0B3', 'In progress...'], - [MODEL_STATE.uploaded]: ['success', 'Success'], - [MODEL_STATE.loaded]: ['success', 'Success'], - [MODEL_STATE.unloaded]: ['success', 'Success'], - [MODEL_STATE.loadFailed]: ['danger', 'Error'], - [MODEL_STATE.registerFailed]: ['danger', 'Error'], - [MODEL_STATE.partiallyLoaded]: ['warning', 'Warning'], +const state2StatusContentMap: { [key in MODEL_VERSION_STATE]?: [string, string] } = { + [MODEL_VERSION_STATE.registering]: ['#AFB0B3', 'In progress...'], + [MODEL_VERSION_STATE.deploying]: ['#AFB0B3', 'In progress...'], + [MODEL_VERSION_STATE.registered]: ['success', 'Success'], + [MODEL_VERSION_STATE.deployed]: ['success', 'Success'], + [MODEL_VERSION_STATE.undeployed]: ['success', 'Success'], + [MODEL_VERSION_STATE.deployFailed]: ['danger', 'Error'], + [MODEL_VERSION_STATE.registerFailed]: ['danger', 'Error'], + [MODEL_VERSION_STATE.partiallyDeployed]: ['warning', 'Warning'], }; -export const ModelVersionStatusCell = ({ state }: { state: MODEL_STATE }) => { +export const ModelVersionStatusCell = ({ state }: { state: MODEL_VERSION_STATE }) => { const statusContent = state2StatusContentMap[state]; if (!statusContent) { return <>-; diff --git a/public/components/model/model_versions_panel/model_version_status_detail.tsx b/public/components/model/model_versions_panel/model_version_status_detail.tsx index 0508319b..c83a1491 100644 --- a/public/components/model/model_versions_panel/model_version_status_detail.tsx +++ b/public/components/model/model_versions_panel/model_version_status_detail.tsx @@ -14,7 +14,7 @@ import { } from '@elastic/eui'; import { Link, generatePath } from 'react-router-dom'; -import { MODEL_STATE, routerPaths } from '../../../../common'; +import { MODEL_VERSION_STATE, routerPaths } from '../../../../common'; import { UiSettingDateFormatTime } from '../../common'; import { APIProvider } from '../../../apis/api_provider'; @@ -22,26 +22,26 @@ import { ModelVersionErrorDetailsModal } from '../../common'; // TODO: Change to related time field after confirmed export const state2DetailContentMap: { - [key in MODEL_STATE]?: { + [key in MODEL_VERSION_STATE]?: { title: string; description: (versionLink: React.ReactNode) => React.ReactNode; timeTitle?: string; timeField?: 'createdTime' | 'lastRegisteredTime' | 'lastDeployedTime' | 'lastUndeployedTime'; }; } = { - [MODEL_STATE.uploading]: { + [MODEL_VERSION_STATE.registering]: { title: 'In progress...', description: (versionLink: React.ReactNode) => ( <>The model artifact for {versionLink} is uploading. ), }, - [MODEL_STATE.loading]: { + [MODEL_VERSION_STATE.deploying]: { title: 'In progress...', description: (versionLink: React.ReactNode) => ( <>The model artifact for {versionLink} is deploying. ), }, - [MODEL_STATE.uploaded]: { + [MODEL_VERSION_STATE.registered]: { title: 'Success', description: (versionLink: React.ReactNode) => ( <>The model artifact for {versionLink} uploaded. @@ -49,31 +49,31 @@ export const state2DetailContentMap: { timeTitle: 'Uploaded on', timeField: 'lastRegisteredTime', }, - [MODEL_STATE.loaded]: { + [MODEL_VERSION_STATE.deployed]: { title: 'Success', description: (versionLink: React.ReactNode) => <>{versionLink} deployed., timeTitle: 'Deployed on', timeField: 'lastDeployedTime', }, - [MODEL_STATE.unloaded]: { + [MODEL_VERSION_STATE.undeployed]: { title: 'Success', description: (versionLink: React.ReactNode) => <>{versionLink} undeployed., timeTitle: 'Undeployed on', timeField: 'lastUndeployedTime', }, - [MODEL_STATE.loadFailed]: { + [MODEL_VERSION_STATE.deployFailed]: { title: 'Error', description: (versionLink: React.ReactNode) => <>{versionLink} deployment failed., timeTitle: 'Deployment failed on', timeField: 'lastDeployedTime', }, - [MODEL_STATE.registerFailed]: { + [MODEL_VERSION_STATE.registerFailed]: { title: 'Error', description: (versionLink: React.ReactNode) => <>{versionLink} artifact upload failed., timeTitle: 'Upload failed on', timeField: 'createdTime', }, - [MODEL_STATE.partiallyLoaded]: { + [MODEL_VERSION_STATE.partiallyDeployed]: { title: 'Warning', description: (versionLink: React.ReactNode) => ( <>{versionLink} is deployed and partially responding. @@ -91,7 +91,7 @@ export const ModelVersionStatusDetail = ({ ...restProps }: { id: string; - state: MODEL_STATE; + state: MODEL_VERSION_STATE; name: string; version: string; createdTime: number; @@ -104,9 +104,9 @@ export const ModelVersionStatusDetail = ({ const [errorDetails, setErrorDetails] = useState(); const handleSeeFullErrorClick = useCallback(async () => { - const state2TaskTypeMap: { [key in MODEL_STATE]?: string } = { - [MODEL_STATE.loadFailed]: 'DEPLOY_MODEL', - [MODEL_STATE.registerFailed]: 'REGISTER_MODEL', + const state2TaskTypeMap: { [key in MODEL_VERSION_STATE]?: string } = { + [MODEL_VERSION_STATE.deployFailed]: 'DEPLOY_MODEL', + [MODEL_VERSION_STATE.registerFailed]: 'REGISTER_MODEL', }; if (!(state in state2TaskTypeMap)) { return; @@ -169,7 +169,8 @@ export const ModelVersionStatusDetail = ({ )} - {(state === MODEL_STATE.loadFailed || state === MODEL_STATE.registerFailed) && ( + {(state === MODEL_VERSION_STATE.deployFailed || + state === MODEL_VERSION_STATE.registerFailed) && ( <> @@ -193,7 +194,9 @@ export const ModelVersionStatusDetail = ({ version={version} errorDetails={errorDetails} errorType={ - state === MODEL_STATE.loadFailed ? 'deployment-failed' : 'artifact-upload-failed' + state === MODEL_VERSION_STATE.deployFailed + ? 'deployment-failed' + : 'artifact-upload-failed' } closeModal={handleCloseModal} /> diff --git a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx index 61b4bdc9..0ad57260 100644 --- a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx +++ b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx @@ -6,7 +6,7 @@ import React, { useState, useCallback } from 'react'; import { EuiPopover, EuiButtonIcon, EuiContextMenuPanel, EuiContextMenuItem } from '@elastic/eui'; -import { MODEL_STATE } from '../../../../common'; +import { MODEL_VERSION_STATE } from '../../../../common'; import { ModelVersionDeploymentConfirmModal } from '../../common'; export const ModelVersionTableRowActions = ({ @@ -15,7 +15,7 @@ export const ModelVersionTableRowActions = ({ name, version, }: { - state: MODEL_STATE; + state: MODEL_VERSION_STATE; id: string; name: string; version: string; @@ -69,7 +69,7 @@ export const ModelVersionTableRowActions = ({ , ] : []), - ...(state === MODEL_STATE.uploaded || - state === MODEL_STATE.unloaded || - state === MODEL_STATE.loadFailed + ...(state === MODEL_VERSION_STATE.registered || + state === MODEL_VERSION_STATE.undeployed || + state === MODEL_VERSION_STATE.deployFailed ? [ , ] : []), - ...(state === MODEL_STATE.loaded || state === MODEL_STATE.partiallyLoaded + ...(state === MODEL_VERSION_STATE.deployed || + state === MODEL_VERSION_STATE.partiallyDeployed ? [ { const stateRelatedStatus = modelState2StatusMap[modelState]; if (stateRelatedStatus && statuses.includes(stateRelatedStatus)) { return true; } - if (modelState === MODEL_STATE.loaded || modelState === MODEL_STATE.partiallyLoaded) { + if ( + modelState === MODEL_VERSION_STATE.deployed || + modelState === MODEL_VERSION_STATE.partiallyDeployed + ) { return states.includes('Deployed'); } return states.includes('Not deployed'); @@ -105,9 +108,9 @@ export const ModelVersionsPanel = ({ modelId }: ModelVersionsPanelProps) => { }, }); const { data: versionsData, reload, loading, error } = useFetcher( - APIProvider.getAPI('model').search, + APIProvider.getAPI('modelVersion').search, { - modelGroupId: modelId, + modelIds: [modelId], from: params.pageIndex * params.pageSize, size: params.pageSize, states: getStatesParam(params.filter), @@ -115,7 +118,7 @@ export const ModelVersionsPanel = ({ modelId }: ModelVersionsPanelProps) => { sort: getSortParam(params.sort), } ); - const totalVersionCount = versionsData?.total_models; + const totalVersionCount = versionsData?.total_model_versions; const { notifications } = useOpenSearchDashboards(); const versions = useMemo(() => { diff --git a/public/components/model/types.ts b/public/components/model/types.ts index 01fe3ea0..67f6f964 100644 --- a/public/components/model/types.ts +++ b/public/components/model/types.ts @@ -3,14 +3,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { MODEL_STATE } from '../../../common'; +import { MODEL_VERSION_STATE } from '../../../common'; import { TagKey } from '../common'; export interface VersionTableDataItem { id: string; name: string; version: string; - state: MODEL_STATE; + state: MODEL_VERSION_STATE; lastUpdatedTime: number; tags: { [key: string]: string | number }; createdTime: number; diff --git a/public/components/model_drawer/index.tsx b/public/components/model_drawer/index.tsx index 4383fe0b..e79d6d4e 100644 --- a/public/components/model_drawer/index.tsx +++ b/public/components/model_drawer/index.tsx @@ -31,8 +31,8 @@ interface Props { export const ModelDrawer = ({ onClose, name }: Props) => { const [sort, setSort] = useState('version-desc'); - const { data: model } = useFetcher(APIProvider.getAPI('model').search, { - nameOrId: name, + const { data: model } = useFetcher(APIProvider.getAPI('modelVersion').search, { + name, from: 0, size: 50, sort: [sort], diff --git a/public/components/model_drawer/version_table.tsx b/public/components/model_drawer/version_table.tsx index f3ed69ea..d6b46e32 100644 --- a/public/components/model_drawer/version_table.tsx +++ b/public/components/model_drawer/version_table.tsx @@ -6,7 +6,7 @@ import React, { useMemo, useCallback, useRef } from 'react'; import { EuiBasicTable, Direction, Criteria, EuiBasicTableColumn } from '@elastic/eui'; -import { ModelSearchItem } from '../../apis/model'; +import { ModelVersionSearchItem } from '../../apis/model_version'; import { renderTime } from '../../utils'; import type { VersionTableSort } from './'; @@ -15,7 +15,7 @@ export interface VersionTableCriteria { } export function VersionTable(props: { - models: ModelSearchItem[]; + models: ModelVersionSearchItem[]; sort: VersionTableSort; onChange: (criteria: VersionTableCriteria) => void; }) { @@ -23,7 +23,7 @@ export function VersionTable(props: { const onChangeRef = useRef(onChange); onChangeRef.current = onChange; - const columns: Array> = [ + const columns: Array> = [ { field: 'model_version', name: 'Version', @@ -60,13 +60,13 @@ export function VersionTable(props: { const [field, direction] = sort.split('-'); return { sort: { - field: field as keyof ModelSearchItem, + field: field as keyof ModelVersionSearchItem, direction: direction as Direction, }, }; }, [sort]); - const handleChange = useCallback(({ sort: newSort }: Criteria) => { + const handleChange = useCallback(({ sort: newSort }: Criteria) => { if (newSort) { onChangeRef.current({ sort: `${newSort.field}-${newSort.direction}` as VersionTableSort, @@ -75,7 +75,7 @@ export function VersionTable(props: { }, []); return ( - + columns={columns} items={models} rowProps={rowProps} diff --git a/public/components/model_list/__tests__/model_list.test.tsx b/public/components/model_list/__tests__/model_list.test.tsx index bb1858fe..67c983aa 100644 --- a/public/components/model_list/__tests__/model_list.test.tsx +++ b/public/components/model_list/__tests__/model_list.test.tsx @@ -4,6 +4,7 @@ */ import React from 'react'; +import userEvent from '@testing-library/user-event'; import { ModelAggregate } from '../../../apis/model_aggregate'; import { render, screen, waitFor, within } from '../../../../test/test_utils'; @@ -11,23 +12,7 @@ import { render, screen, waitFor, within } from '../../../../test/test_utils'; import { ModelList } from '../index'; const setup = () => { - const notificationsMock = { - toasts: { - get$: jest.fn(), - add: jest.fn(), - remove: jest.fn(), - addSuccess: jest.fn(), - addWarning: jest.fn(), - addDanger: jest.fn(), - addError: jest.fn(), - addInfo: jest.fn(), - }, - }; - const renderResult = render(); - return { - renderResult, - notificationsMock, - }; + render(); }; describe('', () => { @@ -37,12 +22,7 @@ describe('', () => { .mockImplementation(() => Promise.resolve({ data: [], - pagination: { - currentPage: 1, - pageSize: 15, - totalPages: 0, - totalRecords: 0, - }, + total_models: 0, }) ); @@ -63,7 +43,7 @@ describe('', () => { expect( screen.getByText('traced_small_model').closest('.euiTableRowCell') ).toBeInTheDocument(); - expect(screen.getByText('1.0.5').closest('.euiTableRowCell')).toBeInTheDocument(); + expect(screen.getByText('5').closest('.euiTableRowCell')).toBeInTheDocument(); }); }); @@ -77,4 +57,64 @@ describe('', () => { ) ).toBeInTheDocument(); }); + + it('should call model aggregate with filter parameters after filter applied', async () => { + setup(); + const modelAggregateSearchMock = jest.spyOn(ModelAggregate.prototype, 'search'); + + await userEvent.click(screen.getByText('Deployed')); + + expect(modelAggregateSearchMock).toHaveBeenLastCalledWith( + expect.objectContaining({ + states: ['DEPLOYED'], + }) + ); + + modelAggregateSearchMock.mockRestore(); + }); + + it('should call model aggregate with extraQuery after search text typed', async () => { + setup(); + const modelAggregateSearchMock = jest.spyOn(ModelAggregate.prototype, 'search'); + + await userEvent.type(screen.getByPlaceholderText('Search by name, person, or keyword'), 'foo'); + + await waitFor(() => { + expect(modelAggregateSearchMock).toHaveBeenLastCalledWith( + expect.objectContaining({ + extraQuery: JSON.stringify({ + bool: { + should: [ + { + match_phrase: { + name: 'foo', + }, + }, + { + match_phrase: { + description: 'foo', + }, + }, + { + nested: { + path: 'owner', + query: { + term: { + 'owner.name.keyword': { + value: 'foo', + boost: 1, + }, + }, + }, + }, + }, + ], + }, + }), + }) + ); + }); + + modelAggregateSearchMock.mockRestore(); + }); }); diff --git a/public/components/model_list/__tests__/model_table.test.tsx b/public/components/model_list/__tests__/model_table.test.tsx index f0afb550..f05c5a4a 100644 --- a/public/components/model_list/__tests__/model_table.test.tsx +++ b/public/components/model_list/__tests__/model_table.test.tsx @@ -4,31 +4,20 @@ */ import React from 'react'; -import moment from 'moment'; import userEvent from '@testing-library/user-event'; import { ModelTable, ModelTableProps } from '../model_table'; import { render, screen, within } from '../../../../test/test_utils'; -import { MODEL_STATE } from '../../../../common/model'; const tableData = [ { + id: '1', name: 'model1', - owner: 'foo', - latest_version: '5', + owner_name: 'foo', + latest_version: 5, description: 'model 1 description', - latest_version_state: MODEL_STATE.loaded, deployed_versions: ['1,2'], - created_time: Date.now(), - }, - { - name: 'model2', - owner: 'bar', - latest_version: '3', - description: 'model 2 description', - latest_version_state: MODEL_STATE.uploading, - deployed_versions: ['1,2'], - created_time: Date.now(), + last_updated_time: 1683699499637, }, ]; @@ -40,11 +29,10 @@ const setup = (options?: Partial) => { ', () => { expect(within(tableHeaders[2]).getByText('Description')).toBeInTheDocument(); expect(within(tableHeaders[3]).getByText('Owner')).toBeInTheDocument(); expect(within(tableHeaders[4]).getByText('Deployed versions')).toBeInTheDocument(); - expect(within(tableHeaders[5]).getByText('Created at')).toBeInTheDocument(); + expect(within(tableHeaders[5]).getByText('Last updated')).toBeInTheDocument(); }); it('should render consistent table body', () => { @@ -80,27 +68,13 @@ describe('', () => { expect(model1Cells).not.toBeUndefined(); expect(within(model1Cells!.item(1)).getByText(tableData[0].latest_version)).toBeInTheDocument(); expect(within(model1Cells!.item(2)).getByText(tableData[0].description)).toBeInTheDocument(); - expect( - within(model1Cells!.item(3)).getByText(tableData[0].owner.slice(0, 1)) - ).toBeInTheDocument(); + expect(within(model1Cells!.item(3)).getByText('f')).toBeInTheDocument(); expect( within(model1Cells!.item(4)).getByText(tableData[0].deployed_versions.join(', ')) ).toBeInTheDocument(); expect( - within(model1Cells!.item(5)).getByText( - moment(tableData[0].created_time).format('MMM D, YYYY') - ) + within(model1Cells!.item(5)).getByText('May 10, 2023 @ 06:18:19.637') ).toBeInTheDocument(); - - const model2FirstCellContent = renderResult.getByText('New model'); - expect(model2FirstCellContent).toBeInTheDocument(); - const model2Cells = model2FirstCellContent.closest('tr')?.querySelectorAll('td'); - expect(model2Cells).not.toBeUndefined(); - expect(within(model2Cells!.item(1)).getByRole('progressbar')).toBeInTheDocument(); - expect(within(model2Cells!.item(2)).getByText('...')).toBeInTheDocument(); - expect(within(model2Cells!.item(3)).getByRole('progressbar')).toBeInTheDocument(); - expect(within(model2Cells!.item(4)).getByText('updating')).toBeInTheDocument(); - expect(within(model2Cells!.item(5)).getByText('updating')).toBeInTheDocument(); }); it('should call onChange with consistent params after pageSize change', async () => { @@ -114,7 +88,7 @@ describe('', () => { pageSize: 50, }, sort: { - field: 'created_time', + field: 'last_updated_time', direction: 'desc', }, }); @@ -130,7 +104,7 @@ describe('', () => { pageSize: 15, }, sort: { - field: 'created_time', + field: 'last_updated_time', direction: 'desc', }, }); @@ -139,24 +113,23 @@ describe('', () => { it('should call onChange with consistent params after sort change', async () => { const { renderResult, onChangeMock } = setup(); expect(onChangeMock).not.toHaveBeenCalled(); - await userEvent.click(renderResult.getByTitle('Created at')); + await userEvent.click(renderResult.getByTitle('Last updated')); expect(onChangeMock).toHaveBeenCalledWith({ pagination: { currentPage: 1, pageSize: 15, }, sort: { - field: 'created_time', + field: 'last_updated_time', direction: 'asc', }, }); }); - it('should call onModelNameClick with consistent params after model name click', async () => { - const { renderResult, onModelNameClickMock } = setup(); - expect(onModelNameClickMock).not.toHaveBeenCalled(); + it('should redirect to model detail page after model name click', async () => { + const { renderResult } = setup(); await userEvent.click(renderResult.getByText('model1')); - expect(onModelNameClickMock).toHaveBeenCalledWith('model1'); + expect(location.href).toContain('model-registry/model/1'); }); it('should show loading screen if property loading equal true', () => { @@ -225,4 +198,14 @@ describe('', () => { await userEvent.click(screen.getByText('Reset search and filters')); expect(onResetClickMock).toHaveBeenCalled(); }); + + it('should navigate to model register version page after register version clicked', async () => { + setup(); + + await userEvent.click( + within(screen.getByText('model1').closest('tr')!).getByLabelText('Register new version') + ); + + expect(location.href).toContain('model-registry/register-model/1'); + }); }); diff --git a/public/components/model_list/__tests__/model_table_uploading_cell.test.tsx b/public/components/model_list/__tests__/model_table_uploading_cell.test.tsx index 45dc12a6..84d20909 100644 --- a/public/components/model_list/__tests__/model_table_uploading_cell.test.tsx +++ b/public/components/model_list/__tests__/model_table_uploading_cell.test.tsx @@ -7,7 +7,7 @@ import React from 'react'; import { ModelTableUploadingCell } from '../model_table_uploading_cell'; import { render, screen } from '../../../../test/test_utils'; -import { MODEL_STATE } from '../../../../common/model'; +import { MODEL_VERSION_STATE } from '../../../../common'; describe('', () => { it('should render "updating" if column is deployedVersions or createdAt', () => { @@ -15,7 +15,7 @@ describe('', () => { } - latestVersionState={MODEL_STATE.uploading} + latestVersionState={MODEL_VERSION_STATE.registering} /> ); expect(screen.getByText('updating')).toBeInTheDocument(); @@ -24,7 +24,7 @@ describe('', () => { } - latestVersionState={MODEL_STATE.uploading} + latestVersionState={MODEL_VERSION_STATE.registering} /> ); expect(screen.getByText('updating')).toBeInTheDocument(); @@ -35,7 +35,7 @@ describe('', () => { } - latestVersionState={MODEL_STATE.uploading} + latestVersionState={MODEL_VERSION_STATE.registering} /> ); expect(screen.getByRole('progressbar')).toBeInTheDocument(); @@ -44,7 +44,7 @@ describe('', () => { } - latestVersionState={MODEL_STATE.uploading} + latestVersionState={MODEL_VERSION_STATE.registering} /> ); expect(screen.getByRole('progressbar')).toBeInTheDocument(); @@ -55,7 +55,7 @@ describe('', () => { } - latestVersionState={MODEL_STATE.uploading} + latestVersionState={MODEL_VERSION_STATE.registering} /> ); expect(screen.getByText('New model')).toBeInTheDocument(); @@ -66,7 +66,7 @@ describe('', () => { } - latestVersionState={MODEL_STATE.uploading} + latestVersionState={MODEL_VERSION_STATE.registering} /> ); expect(screen.getByText('...')).toBeInTheDocument(); @@ -77,7 +77,7 @@ describe('', () => { Foo Bar} - latestVersionState={MODEL_STATE.loaded} + latestVersionState={MODEL_VERSION_STATE.deployed} /> ); expect(screen.getByText('Foo Bar')).toBeInTheDocument(); diff --git a/public/components/model_list/index.tsx b/public/components/model_list/index.tsx index 5ed97120..c96610d0 100644 --- a/public/components/model_list/index.tsx +++ b/public/components/model_list/index.tsx @@ -4,48 +4,102 @@ */ import React, { useState, useCallback, useMemo, useRef } from 'react'; import { EuiPageHeader, EuiSpacer, EuiPanel, EuiTextColor } from '@elastic/eui'; -import { CoreStart } from '../../../../../src/core/public'; + import { APIProvider } from '../../apis/api_provider'; import { useFetcher } from '../../hooks/use_fetcher'; -import { ModelDrawer } from '../model_drawer'; +import { MODEL_VERSION_STATE, ModelAggregateSort } from '../../../common'; + import { ModelTable, ModelTableCriteria, ModelTableSort } from './model_table'; import { ModelListFilter, ModelListFilterFilterValue } from './model_list_filter'; import { RegisterNewModelButton } from './register_new_model_button'; -import { - ModelConfirmDeleteModal, - ModelConfirmDeleteModalInstance, -} from './model_confirm_delete_modal'; import { ModelListEmpty } from './model_list_empty'; -export const ModelList = ({ notifications }: { notifications: CoreStart['notifications'] }) => { - const confirmModelDeleteRef = useRef(null); - const [params, setParams] = useState<{ - sort: ModelTableSort; - currentPage: number; - pageSize: number; - filterValue: ModelListFilterFilterValue; - }>({ +const getStatesParam = (deployed?: boolean) => { + if (deployed) { + return [MODEL_VERSION_STATE.deployed]; + } + if (deployed === false) { + return [ + MODEL_VERSION_STATE.deployFailed, + MODEL_VERSION_STATE.deploying, + MODEL_VERSION_STATE.partiallyDeployed, + MODEL_VERSION_STATE.registerFailed, + MODEL_VERSION_STATE.undeployed, + MODEL_VERSION_STATE.registered, + MODEL_VERSION_STATE.registering, + ]; + } + return undefined; +}; + +interface Params { + sort: ModelTableSort; + currentPage: number; + pageSize: number; + filterValue: ModelListFilterFilterValue; +} + +const getModelAggregateSearchParams = (params: Params) => { + return { + from: (params.currentPage - 1) * params.pageSize, + size: params.pageSize, + sort: params.sort + ? (`${params.sort.field}-${params.sort.direction}` as ModelAggregateSort) + : undefined, + states: getStatesParam(params.filterValue.deployed), + extraQuery: params.filterValue.search + ? JSON.stringify({ + bool: { + should: [ + { + match_phrase: { + name: params.filterValue.search, + }, + }, + { + match_phrase: { + description: params.filterValue.search, + }, + }, + { + nested: { + path: 'owner', + query: { + term: { + 'owner.name.keyword': { + value: params.filterValue.search, + boost: 1, + }, + }, + }, + }, + }, + ], + }, + }) + : undefined, + }; +}; + +export const ModelList = () => { + const [params, setParams] = useState({ currentPage: 1, pageSize: 15, filterValue: { tag: [], owner: [] }, - sort: { field: 'created_time', direction: 'desc' }, + sort: { field: 'last_updated_time', direction: 'desc' }, }); - const [drawerModelName, setDrawerModelName] = useState(''); const searchInputRef = useRef(); const setSearchInputRef = useCallback((node: HTMLInputElement | null) => { searchInputRef.current = node; }, []); - const { data, reload, loading, error } = useFetcher(APIProvider.getAPI('modelAggregate').search, { - from: Math.max(0, (params.currentPage - 1) * params.pageSize), - size: params.pageSize, - sort: params.sort?.field, - order: params.sort?.direction, - name: params.filterValue.search, - }); + const { data, loading, error } = useFetcher( + APIProvider.getAPI('modelAggregate').search, + getModelAggregateSearchParams(params) + ); const models = useMemo(() => data?.data || [], [data]); - const totalModelCounts = data?.total_models || 0; + const totalModelCounts = data?.total_models; const pagination = useMemo( () => ({ @@ -55,20 +109,13 @@ export const ModelList = ({ notifications }: { notifications: CoreStart['notific }), [totalModelCounts, params.currentPage, params.pageSize] ); - const showEmptyScreen = !loading && totalModelCounts === 0 && !params.filterValue.search; - - const handleModelDeleted = useCallback(async () => { - reload(); - notifications.toasts.addSuccess('Model has been deleted.'); - }, [reload, notifications.toasts]); - - const handleModelDelete = useCallback((modelId: string) => { - confirmModelDeleteRef.current?.show(modelId); - }, []); - - const handleViewModelDrawer = useCallback((name: string) => { - setDrawerModelName(name); - }, []); + const showEmptyScreen = + !loading && + totalModelCounts === 0 && + !params.filterValue.search && + params.filterValue.deployed === undefined && + params.filterValue.tag.length === 0 && + params.filterValue.owner.length === 0; const handleTableChange = useCallback((criteria: ModelTableCriteria) => { const { pagination: newPagination, sort } = criteria; @@ -136,14 +183,9 @@ export const ModelList = ({ notifications }: { notifications: CoreStart['notific models={models} pagination={pagination} onChange={handleTableChange} - onModelNameClick={handleViewModelDrawer} onResetClick={handleReset} error={!!error} /> - - {drawerModelName && ( - setDrawerModelName('')} name={drawerModelName} /> - )} )} {showEmptyScreen && } diff --git a/public/components/model_list/model_confirm_delete_modal.tsx b/public/components/model_list/model_confirm_delete_modal.tsx index a224699b..48e8042b 100644 --- a/public/components/model_list/model_confirm_delete_modal.tsx +++ b/public/components/model_list/model_confirm_delete_modal.tsx @@ -28,12 +28,12 @@ export const ModelConfirmDeleteModal = React.forwardRef< } return ( ( - await APIProvider.getAPI('model').search({ + await APIProvider.getAPI('modelVersion').search({ ids: [deleteIdRef.current], from: 0, size: 1, }) - ).total_models === 1 + ).total_model_versions === 1 ); }, onGiveUp: () => { @@ -54,7 +54,7 @@ export const ModelConfirmDeleteModal = React.forwardRef< } e.stopPropagation(); setIsDeleting(true); - await APIProvider.getAPI('model').delete(deleteIdRef.current); + await APIProvider.getAPI('modelVersion').delete(deleteIdRef.current); startPolling(); }, [startPolling] diff --git a/public/components/model_list/model_table.tsx b/public/components/model_list/model_table.tsx index 3023214e..478d49eb 100644 --- a/public/components/model_list/model_table.tsx +++ b/public/components/model_list/model_table.tsx @@ -15,17 +15,21 @@ import { EuiButton, EuiSpacer, EuiIcon, + EuiButtonIcon, + EuiLink, } from '@elastic/eui'; - import { Criteria } from '@elastic/eui'; -import { renderTime } from '../../utils'; +import { Link, generatePath } from 'react-router-dom'; + +import { ModelAggregateItem } from '../../../common'; +import { UiSettingDateFormatTime } from '../common'; +import { routerPaths } from '../../../common'; + import { ModelOwner } from './model_owner'; import { ModelDeployedVersions } from './model_deployed_versions'; -import { ModelTableUploadingCell } from './model_table_uploading_cell'; -import { ModelAggregateSearchItem } from '../../apis/model_aggregate'; export interface ModelTableSort { - field: 'created_time'; + field: 'name' | 'latest_version' | 'description' | 'owner_name' | 'last_updated_time'; direction: Direction; } @@ -35,7 +39,7 @@ export interface ModelTableCriteria { } export interface ModelTableProps { - models: ModelAggregateSearchItem[]; + models: ModelAggregateItem[]; pagination: { currentPage: number; pageSize: number; @@ -43,117 +47,78 @@ export interface ModelTableProps { }; sort: ModelTableSort; onChange: (criteria: ModelTableCriteria) => void; - onModelNameClick: (name: string) => void; loading: boolean; error: boolean; onResetClick: () => void; } export function ModelTable(props: ModelTableProps) { - const { models, sort, onChange, onModelNameClick, loading, onResetClick, error } = props; + const { models, sort, onChange, loading, onResetClick, error } = props; const onChangeRef = useRef(onChange); onChangeRef.current = onChange; - const columns = useMemo>>( + const columns = useMemo>>( () => [ { field: 'name', name: 'Model Name', width: '266px', render: (name: string, record) => ( - { - onModelNameClick(name); - }} - style={{ color: '#006BB4' }} - > - {name} - - } - latestVersionState={record.latest_version_state} - column="name" - /> + + {name} + ), + sortable: true, }, { field: 'latest_version', name: 'Latest version', width: '98px', align: 'center', - render: (latestVersion: string, record) => ( - {latestVersion}} - latestVersionState={record.latest_version_state} - column="latestVersion" - /> - ), + sortable: true, }, { field: 'description', name: 'Description', - render: (description: string, record) => ( - {description}} - latestVersionState={record.latest_version_state} - column="description" - /> - ), }, { - field: 'owner', + field: 'owner_name', name: 'Owner', width: '79px', - render: (owner: string, record) => ( - } - latestVersionState={record.latest_version_state} - column="owner" - /> - ), + render: (name: string) => , align: 'center', + sortable: true, }, { field: 'deployed_versions', name: 'Deployed versions', - render: (deployedVersions: string[], record) => ( - } - latestVersionState={record.latest_version_state} - column="deployedVersions" - /> + render: (deployedVersions: string[]) => ( + ), }, { - field: 'created_time', - name: 'Created at', - render: (createdTime: string, record) => ( - {renderTime(createdTime, 'MMM D, YYYY')}} - latestVersionState={record.latest_version_state} - column="createdAt" - /> - ), + field: 'last_updated_time', + name: 'Last updated', + render: (lastUpdatedTime: number) => , sortable: true, }, { name: 'Actions', actions: [ - // TODO: add a new task to update after design completed { - name: 'Prevew', - description: 'Preview model group', - type: 'icon', - icon: 'boxesHorizontal', - onClick: ({ name }) => { - onModelNameClick(name); - }, + render: ({ id }) => ( + + + + ), + }, + { + render: () => , }, ], }, ], - [onModelNameClick] + [] ); const pagination = useMemo( @@ -161,7 +126,7 @@ export function ModelTable(props: ModelTableProps) { pageIndex: props.pagination.currentPage - 1, pageSize: props.pagination.pageSize, totalItemCount: props.pagination.totalRecords || 0, - pageSizeOptions: [15, 30, 50, 100], + pageSizeOptions: [10, 20, 50], showPerPageOptions: true, }), [props.pagination] @@ -228,7 +193,7 @@ export function ModelTable(props: ModelTableProps) { [onResetClick, loading, error] ); - const handleChange = useCallback((criteria: Criteria) => { + const handleChange = useCallback((criteria: Criteria) => { onChangeRef.current({ ...(criteria.page ? { pagination: { currentPage: criteria.page.index + 1, pageSize: criteria.page.size } } @@ -238,7 +203,7 @@ export function ModelTable(props: ModelTableProps) { }, []); return ( - + columns={columns} items={loading || error ? [] : models} pagination={models.length > 0 ? pagination : undefined} diff --git a/public/components/model_list/model_table_uploading_cell.tsx b/public/components/model_list/model_table_uploading_cell.tsx index a62b015f..8180f418 100644 --- a/public/components/model_list/model_table_uploading_cell.tsx +++ b/public/components/model_list/model_table_uploading_cell.tsx @@ -5,7 +5,7 @@ import React from 'react'; import { EuiLoadingSpinner, EuiText } from '@elastic/eui'; -import { MODEL_STATE } from '../../../common/model'; +import { MODEL_VERSION_STATE } from '../../../common'; type ColumnType = | 'name' @@ -32,10 +32,10 @@ export const ModelTableUploadingCell = ({ latestVersionState, }: { column: ColumnType; - latestVersionState: MODEL_STATE; + latestVersionState: MODEL_VERSION_STATE; fallback: JSX.Element; }) => { - if (latestVersionState !== MODEL_STATE.uploading) { + if (latestVersionState !== MODEL_VERSION_STATE.registering) { return fallback; } if (column === 'latestVersion' || column === 'owner') { diff --git a/public/components/model_version/model_version.tsx b/public/components/model_version/model_version.tsx index 9562e819..ab8a2ce5 100644 --- a/public/components/model_version/model_version.tsx +++ b/public/components/model_version/model_version.tsx @@ -16,14 +16,14 @@ import { EuiTabbedContent, } from '@elastic/eui'; import { generatePath, useHistory, useParams } from 'react-router-dom'; - import { FormProvider, useForm } from 'react-hook-form'; + +import { MODEL_VERSION_STATE, routerPaths } from '../../../common'; import { useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; -import { routerPaths } from '../../../common/router_paths'; + import { VersionToggler } from './version_toggler'; import { ModelVersionCallout } from './version_callout'; -import { MODEL_STATE } from '../../../common/model'; import { ModelVersionDetails } from './version_details'; import { ModelVersionInformation } from './version_information'; import { ModelVersionArtifact } from './version_artifact'; @@ -32,7 +32,7 @@ import { ModelVersionFormData } from './types'; export const ModelVersion = () => { const { id: modelId } = useParams<{ id: string }>(); - const { data: model, loading } = useFetcher(APIProvider.getAPI('model').getOne, modelId); + const { data: model, loading } = useFetcher(APIProvider.getAPI('modelVersion').getOne, modelId); const [modelInfo, setModelInfo] = useState<{ version: string; name: string }>(); const history = useHistory(); const modelName = model?.name; @@ -153,8 +153,8 @@ export const ModelVersion = () => { ]} /> )} - - + + {loading ? ( diff --git a/public/components/model_version/version_callout.tsx b/public/components/model_version/version_callout.tsx index 8e2206db..27f53525 100644 --- a/public/components/model_version/version_callout.tsx +++ b/public/components/model_version/version_callout.tsx @@ -5,36 +5,36 @@ import React, { useEffect } from 'react'; import { EuiCallOut, EuiLoadingSpinner } from '@elastic/eui'; -import { MODEL_STATE } from '../../../common/model'; +import { MODEL_VERSION_STATE } from '../../../common'; interface ModelVersionCalloutProps { modelVersionId: string; - modelState: MODEL_STATE; + modelState: MODEL_VERSION_STATE; } -const MODEL_STATE_MAPPING: { - [K in MODEL_STATE]?: { +const MODEL_VERSION_STATE_MAPPING: { + [K in MODEL_VERSION_STATE]?: { title: React.ReactNode; color: 'danger' | 'warning' | 'primary'; iconType?: string; }; } = { - [MODEL_STATE.registerFailed]: { + [MODEL_VERSION_STATE.registerFailed]: { title: 'Artifact upload failed', color: 'danger' as const, iconType: 'alert', }, - [MODEL_STATE.loadFailed]: { + [MODEL_VERSION_STATE.deployFailed]: { title: 'Deployment failed', color: 'danger' as const, iconType: 'alert', }, - [MODEL_STATE.partiallyLoaded]: { + [MODEL_VERSION_STATE.partiallyDeployed]: { title: 'Model partially responding', color: 'warning' as const, iconType: 'alert', }, - [MODEL_STATE.uploading]: { + [MODEL_VERSION_STATE.registering]: { title: ( @@ -43,7 +43,7 @@ const MODEL_STATE_MAPPING: { ), color: 'primary' as const, }, - [MODEL_STATE.loading]: { + [MODEL_VERSION_STATE.deploying]: { title: ( @@ -55,13 +55,13 @@ const MODEL_STATE_MAPPING: { }; export const ModelVersionCallout = ({ modelState, modelVersionId }: ModelVersionCalloutProps) => { - const calloutProps = MODEL_STATE_MAPPING[modelState]; + const calloutProps = MODEL_VERSION_STATE_MAPPING[modelState]; useEffect(() => { if (calloutProps) { - if (modelState === MODEL_STATE.loadFailed) { + if (modelState === MODEL_VERSION_STATE.deployFailed) { // TODO: call task API to get the error details - } else if (modelState === MODEL_STATE.registerFailed) { + } else if (modelState === MODEL_VERSION_STATE.registerFailed) { // TODO: call task API to get the error details } } diff --git a/public/components/model_version/version_toggler.tsx b/public/components/model_version/version_toggler.tsx index 7bab0c9f..ab14702b 100644 --- a/public/components/model_version/version_toggler.tsx +++ b/public/components/model_version/version_toggler.tsx @@ -29,7 +29,7 @@ export const VersionToggler = ({ onVersionChange, }: VersionTogglerProps) => { const [isPopoverOpen, setIsPopoverOpen] = useState(false); - const { data: versions } = useFetcher(APIProvider.getAPI('model').search, { + const { data: versions } = useFetcher(APIProvider.getAPI('modelVersion').search, { name: modelName, from: 0, // TODO: Implement scroll bottom load more once version toggler UX confirmed diff --git a/public/components/monitoring/model_deployment_table.tsx b/public/components/monitoring/model_deployment_table.tsx index 408193ba..b42e74bc 100644 --- a/public/components/monitoring/model_deployment_table.tsx +++ b/public/components/monitoring/model_deployment_table.tsx @@ -21,7 +21,7 @@ import { EuiText, } from '@elastic/eui'; -import { MODEL_STATE } from '../../../common'; +import { MODEL_VERSION_STATE } from '../../../common'; export interface ModelDeploymentTableSort { field: 'name' | 'model_state' | 'id'; @@ -36,7 +36,7 @@ export interface ModelDeploymentTableCriteria { export interface ModelDeploymentItem { id: string; name: string; - model_state?: MODEL_STATE; + model_state?: MODEL_VERSION_STATE; respondingNodesCount: number | undefined; planningNodesCount: number | undefined; notRespondingNodesCount: number | undefined; diff --git a/public/components/monitoring/use_monitoring.ts b/public/components/monitoring/use_monitoring.ts index d11d3d91..462632e6 100644 --- a/public/components/monitoring/use_monitoring.ts +++ b/public/components/monitoring/use_monitoring.ts @@ -8,7 +8,7 @@ import { useMemo, useCallback, useState, useContext, useEffect } from 'react'; import { APIProvider } from '../../apis/api_provider'; import { GetAllConnectorResponse } from '../../apis/connector'; import { DO_NOT_FETCH, useFetcher } from '../../hooks/use_fetcher'; -import { MODEL_STATE } from '../../../common'; +import { MODEL_VERSION_STATE } from '../../../common'; import { DataSourceContext } from '../../contexts'; import { ModelDeployStatus } from './types'; import { DATA_SOURCE_FETCHING_ID, DataSourceId, getDataSourceId } from '../../utils/data_source'; @@ -90,11 +90,11 @@ const fetchDeployedModels = async ( const states = params.status?.map((status) => { switch (status) { case 'not-responding': - return MODEL_STATE.loadFailed; + return MODEL_VERSION_STATE.deployFailed; case 'responding': - return MODEL_STATE.loaded; + return MODEL_VERSION_STATE.deployed; case 'partial-responding': - return MODEL_STATE.partiallyLoaded; + return MODEL_VERSION_STATE.partiallyDeployed; } }); let externalConnectorsData: GetAllConnectorResponse; @@ -105,13 +105,17 @@ const fetchDeployedModels = async ( } catch (_e) { externalConnectorsData = { data: [], total_connectors: 0 }; } - const result = await APIProvider.getAPI('model').search({ + const result = await APIProvider.getAPI('modelVersion').search({ from: (params.currentPage - 1) * params.pageSize, size: params.pageSize, nameOrId: params.nameOrId, states: !states || states.length === 0 - ? [MODEL_STATE.loadFailed, MODEL_STATE.loaded, MODEL_STATE.partiallyLoaded] + ? [ + MODEL_VERSION_STATE.deployFailed, + MODEL_VERSION_STATE.deployed, + MODEL_VERSION_STATE.partiallyDeployed, + ] : states, sort: [`${params.sort.field}-${params.sort.direction}`], extraQuery: generateExtraQuery({ @@ -136,12 +140,12 @@ const fetchDeployedModels = async ( }; }>((previousValue, currentValue) => ({ ...previousValue, [currentValue.id]: currentValue }), {}); - const totalPages = Math.ceil(result.total_models / params.pageSize); + const totalPages = Math.ceil(result.total_model_versions / params.pageSize); return { pagination: { currentPage: params.currentPage, pageSize: params.pageSize, - totalRecords: result.total_models, + totalRecords: result.total_model_versions, totalPages, }, data: result.data.map( diff --git a/public/components/register_model/__tests__/model_file_uploader_manager.test.ts b/public/components/register_model/__tests__/model_file_uploader_manager.test.ts index b0f95186..de2d7d4b 100644 --- a/public/components/register_model/__tests__/model_file_uploader_manager.test.ts +++ b/public/components/register_model/__tests__/model_file_uploader_manager.test.ts @@ -4,14 +4,14 @@ */ import { waitFor } from '@testing-library/dom'; -import { Model } from '../../../../public/apis/model'; +import { ModelVersion } from '../../../../public/apis/model_version'; import { ModelFileUploadManager } from '../model_file_upload_manager'; describe('ModelFileUploadManager', () => { const uploadChunkMock = jest.fn(); beforeEach(() => { - jest.spyOn(Model.prototype, 'uploadChunk').mockImplementation(uploadChunkMock); + jest.spyOn(ModelVersion.prototype, 'uploadChunk').mockImplementation(uploadChunkMock); }); afterEach(() => { @@ -71,7 +71,7 @@ describe('ModelFileUploadManager', () => { }); it('should call onError', async () => { - jest.spyOn(Model.prototype, 'uploadChunk').mockRejectedValue(new Error()); + jest.spyOn(ModelVersion.prototype, 'uploadChunk').mockRejectedValue(new Error()); const onErrorMock = jest.fn(); const uploader = new ModelFileUploadManager(); const file = new File(['test model file'], 'model.zip', { type: 'application/zip' }); diff --git a/public/components/register_model/__tests__/register_model_api.test.ts b/public/components/register_model/__tests__/register_model_api.test.ts index f2dc7196..c906c232 100644 --- a/public/components/register_model/__tests__/register_model_api.test.ts +++ b/public/components/register_model/__tests__/register_model_api.test.ts @@ -3,27 +3,27 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ModelGroup } from '../../../apis/model_group'; import { Model } from '../../../apis/model'; +import { ModelVersion } from '../../../apis/model_version'; import { submitModelWithFile, submitModelWithURL } from '../register_model_api'; describe('register model api', () => { beforeEach(() => { + jest.spyOn(Model.prototype, 'register').mockResolvedValue({ model_id: '1', status: 'CREATED' }); + jest.spyOn(Model.prototype, 'delete').mockResolvedValue({ status: 'success' }); jest - .spyOn(ModelGroup.prototype, 'register') - .mockResolvedValue({ model_group_id: '1', status: 'CREATED' }); - jest.spyOn(ModelGroup.prototype, 'delete').mockResolvedValue({ status: 'success' }); - jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ task_id: 'foo', model_id: 'bar' }); + .spyOn(ModelVersion.prototype, 'upload') + .mockResolvedValue({ task_id: 'foo', model_version_id: 'bar' }); }); afterEach(() => { - jest.spyOn(ModelGroup.prototype, 'register').mockRestore(); - jest.spyOn(ModelGroup.prototype, 'delete').mockRestore(); - jest.spyOn(Model.prototype, 'upload').mockRestore(); + jest.spyOn(Model.prototype, 'register').mockRestore(); + jest.spyOn(Model.prototype, 'delete').mockRestore(); + jest.spyOn(ModelVersion.prototype, 'upload').mockRestore(); }); it('should not call register model group API if modelId provided', async () => { - expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); + expect(Model.prototype.register).not.toHaveBeenCalled(); await submitModelWithFile({ name: 'foo', @@ -34,12 +34,12 @@ describe('register model api', () => { modelFile: new File([], 'artifact.zip'), }); - expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); + expect(Model.prototype.register).not.toHaveBeenCalled(); }); it('should not call delete model group API if modelId provided and model upload failed', async () => { const uploadError = new Error(); - const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); + const uploadMock = jest.spyOn(ModelVersion.prototype, 'upload').mockRejectedValue(uploadError); try { await submitModelWithFile({ @@ -53,14 +53,14 @@ describe('register model api', () => { } catch (error) { expect(error).toBe(uploadError); } - expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); + expect(Model.prototype.delete).not.toHaveBeenCalled(); uploadMock.mockRestore(); }); describe('submitModelWithFile', () => { it('should call register model group API with name and description', async () => { - expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); + expect(Model.prototype.register).not.toHaveBeenCalled(); await submitModelWithFile({ name: 'foo', @@ -70,7 +70,7 @@ describe('register model api', () => { modelFile: new File([], 'artifact.zip'), }); - expect(ModelGroup.prototype.register).toHaveBeenCalledWith( + expect(Model.prototype.register).toHaveBeenCalledWith( expect.objectContaining({ name: 'foo', description: 'bar', @@ -80,9 +80,11 @@ describe('register model api', () => { it('should delete created model group API upload failed', async () => { const uploadError = new Error(); - const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); + const uploadMock = jest + .spyOn(ModelVersion.prototype, 'upload') + .mockRejectedValue(uploadError); - expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); + expect(Model.prototype.delete).not.toHaveBeenCalled(); try { await submitModelWithFile({ name: 'foo', @@ -94,7 +96,7 @@ describe('register model api', () => { } catch (error) { expect(uploadError).toBe(error); } - expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('1'); + expect(Model.prototype.delete).toHaveBeenCalledWith('1'); uploadMock.mockRestore(); }); @@ -117,7 +119,7 @@ describe('register model api', () => { describe('submitModelWithURL', () => { it('should call register model group API with name and description', async () => { - expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); + expect(Model.prototype.register).not.toHaveBeenCalled(); await submitModelWithURL({ name: 'foo', @@ -127,7 +129,7 @@ describe('register model api', () => { modelURL: 'https://address.to/artifact.zip', }); - expect(ModelGroup.prototype.register).toHaveBeenCalledWith( + expect(Model.prototype.register).toHaveBeenCalledWith( expect.objectContaining({ name: 'foo', description: 'bar', @@ -137,9 +139,11 @@ describe('register model api', () => { it('should delete created model group API upload failed', async () => { const uploadError = new Error(); - const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); + const uploadMock = jest + .spyOn(ModelVersion.prototype, 'upload') + .mockRejectedValue(uploadError); - expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); + expect(Model.prototype.delete).not.toHaveBeenCalled(); try { await submitModelWithURL({ name: 'foo', @@ -151,7 +155,7 @@ describe('register model api', () => { } catch (error) { expect(uploadError).toBe(error); } - expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('1'); + expect(Model.prototype.delete).toHaveBeenCalledWith('1'); uploadMock.mockRestore(); }); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 67496418..90607386 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -2,12 +2,9 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - -import React from 'react'; - -import { render, screen, waitFor } from '../../../../test/test_utils'; +import { screen, waitFor } from '../../../../test/test_utils'; import { setup } from './setup'; -import { Model } from '../../../../public/apis/model'; +import { ModelVersion } from '../../../../public/apis/model_version'; import * as PluginContext from '../../../../../../src/plugins/opensearch_dashboards_react/public'; import * as formAPI from '../register_model_api'; @@ -41,7 +38,7 @@ describe(' Form', () => { }, }); jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); - jest.spyOn(Model.prototype, 'uploadChunk').mockResolvedValue({}); + jest.spyOn(ModelVersion.prototype, 'uploadChunk').mockResolvedValue({}); }); afterEach(() => { diff --git a/public/components/register_model/model_file_upload_manager.ts b/public/components/register_model/model_file_upload_manager.ts index 26bdc15c..91aa7bc8 100644 --- a/public/components/register_model/model_file_upload_manager.ts +++ b/public/components/register_model/model_file_upload_manager.ts @@ -39,7 +39,7 @@ export class ModelFileUploadManager { chunkSize * (i - 1), Math.min(chunkSize * i, options.file.size) ); - await APIProvider.getAPI('model').uploadChunk(options.modelId, `${i - 1}`, chunk); + await APIProvider.getAPI('modelVersion').uploadChunk(options.modelId, `${i - 1}`, chunk); return { total: totalChunks, current: i }; }) ); diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 1b1ba5a9..a2acfca6 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -208,7 +208,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const initializeForm = async () => { form.setValue('modelId', registerToModelId); try { - const data = await APIProvider.getAPI('modelGroup').getOne(registerToModelId); + const data = await APIProvider.getAPI('model').getOne(registerToModelId); // TODO: clarify which fields to pre-populate const { name } = data; form.setValue('name', name); diff --git a/public/components/register_model/register_model_api.ts b/public/components/register_model/register_model_api.ts index d58e3f25..2d283948 100644 --- a/public/components/register_model/register_model_api.ts +++ b/public/components/register_model/register_model_api.ts @@ -38,13 +38,13 @@ const createModelIfNeedAndUploadVersion = async ({ }; } modelId = ( - await APIProvider.getAPI('modelGroup').register({ + await APIProvider.getAPI('model').register({ name, description, // TODO: This value should follow form data, need to be updated after UI design confirmed modelAccessMode: 'public', }) - ).model_group_id; + ).model_id; try { return { @@ -52,7 +52,7 @@ const createModelIfNeedAndUploadVersion = async ({ modelId, }; } catch (error) { - APIProvider.getAPI('modelGroup').delete(modelId); + APIProvider.getAPI('model').delete(modelId); throw error; } }; @@ -64,9 +64,9 @@ export async function submitModelWithFile(model: ModelFileFormData) { const result = await createModelIfNeedAndUploadVersion({ ...model, uploader: (modelId: string) => - APIProvider.getAPI('model').upload({ + APIProvider.getAPI('modelVersion').upload({ ...getModelUploadBase(model), - modelGroupId: modelId, + modelId, totalChunks, modelContentHashValue, }), @@ -74,7 +74,7 @@ export async function submitModelWithFile(model: ModelFileFormData) { return { modelId: result.modelId, - modelVersionId: result.uploadResult.model_id, + modelVersionId: result.uploadResult.model_version_id, }; } @@ -82,9 +82,9 @@ export async function submitModelWithURL(model: ModelUrlFormData) { const result = await createModelIfNeedAndUploadVersion({ ...model, uploader: (modelId: string) => - APIProvider.getAPI('model').upload({ + APIProvider.getAPI('modelVersion').upload({ ...getModelUploadBase(model), - modelGroupId: modelId, + modelId, url: model.modelURL, }), }); diff --git a/server/clusters/create_model_cluster.ts b/server/clusters/create_model_cluster.ts deleted file mode 100644 index 78b1cd2c..00000000 --- a/server/clusters/create_model_cluster.ts +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { CoreSetup } from '../../../../src/core/server'; - -import modelPlugin from './model_plugin'; -import { CLUSTER } from '../services/utils/constants'; - -export const createModelCluster = (core: CoreSetup) => { - return core.opensearch.legacy.createClient(CLUSTER.MODEL, { - plugins: [modelPlugin], - }); -}; diff --git a/server/clusters/model_plugin.ts b/server/clusters/model_plugin.ts deleted file mode 100644 index 40151a31..00000000 --- a/server/clusters/model_plugin.ts +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { MODEL_BASE_API, MODEL_PROFILE_API } from '../services/utils/constants'; - -// eslint-disable-next-line import/no-default-export -export default function (Client: any, config: any, components: any) { - const ca = components.clientAction.factory; - - if (!Client.prototype.mlCommonsModel) { - Client.prototype.mlCommonsModel = components.clientAction.namespaceFactory(); - } - - const mlCommonsModel = Client.prototype.mlCommonsModel.prototype; - - mlCommonsModel.search = ca({ - method: 'POST', - url: { - fmt: `${MODEL_BASE_API}/_search`, - }, - needBody: true, - }); - - mlCommonsModel.getOne = ca({ - method: 'GET', - url: { - fmt: `${MODEL_BASE_API}/<%=modelId%>`, - req: { - modelId: { - type: 'string', - required: true, - }, - }, - }, - }); - - mlCommonsModel.delete = ca({ - method: 'DELETE', - url: { - fmt: `${MODEL_BASE_API}/<%=modelId%>`, - req: { - modelId: { - type: 'string', - required: true, - }, - }, - }, - }); - - mlCommonsModel.load = ca({ - method: 'POST', - url: { - fmt: `${MODEL_BASE_API}/<%=modelId%>/_load`, - req: { - modelId: { - type: 'string', - required: true, - }, - }, - }, - }); - - mlCommonsModel.unload = ca({ - method: 'POST', - url: { - fmt: `${MODEL_BASE_API}/<%=modelId%>/_unload`, - req: { - modelId: { - type: 'string', - required: true, - }, - }, - }, - }); - - mlCommonsModel.profile = ca({ - method: 'GET', - url: { - fmt: `${MODEL_PROFILE_API}/<%=modelId%>`, - req: { - modelId: { - type: 'string', - required: true, - }, - }, - }, - }); -} diff --git a/server/plugin.ts b/server/plugin.ts index d5116fb9..a964ffb2 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -11,19 +11,17 @@ import { Logger, } from '../../../src/core/server'; -import { createModelCluster } from './clusters/create_model_cluster'; import { MlCommonsPluginSetup, MlCommonsPluginStart } from './types'; import { connectorRouter, modelRouter, + modelVersionRouter, modelAggregateRouter, profileRouter, securityRouter, taskRouter, modelRepositoryRouter, - modelGroupRouter, } from './routes'; -import { ModelService } from './services'; export class MlCommonsPlugin implements Plugin { private readonly logger: Logger; @@ -36,22 +34,14 @@ export class MlCommonsPlugin implements Plugin { router.get( @@ -16,21 +17,40 @@ export const modelAggregateRouter = (router: IRouter) => { query: schema.object({ from: schema.number(), size: schema.number(), - sort: schema.literal('created_time'), - order: schema.oneOf([schema.literal('asc'), schema.literal('desc')]), + sort: schema.maybe( + schema.oneOf([ + schema.literal('name-asc'), + schema.literal('name-desc'), + schema.literal('latest_version-asc'), + schema.literal('latest_version-desc'), + schema.literal('description-asc'), + schema.literal('description-desc'), + schema.literal('owner_name-asc'), + schema.literal('owner_name-desc'), + schema.literal('last_updated_time-asc'), + schema.literal('last_updated_time-desc'), + ]) + ), name: schema.maybe(schema.string()), + states: schema.maybe(schema.oneOf([modelStateSchema, schema.arrayOf(modelStateSchema)])), + extraQuery: schema.maybe(schema.recordOf(schema.string(), schema.any())), }), }, }, async (context, request) => { + const { states, extraQuery, ...restQuery } = request.query; try { const payload = await ModelAggregateService.search({ client: context.core.opensearch.client, - ...request.query, + states: typeof states === 'string' ? [states] : states, + extraQuery, + ...restQuery, }); return opensearchDashboardsResponseFactory.ok({ body: payload }); - } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), + }); } } ); diff --git a/server/routes/model_group_router.ts b/server/routes/model_group_router.ts deleted file mode 100644 index d65eb6b6..00000000 --- a/server/routes/model_group_router.ts +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { schema } from '@osd/config-schema'; - -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; -import { ModelGroupService } from '../services/model_group_service'; - -import { MODEL_GROUP_API_ENDPOINT } from './constants'; - -export const modelGroupRouter = (router: IRouter) => { - router.post( - { - path: MODEL_GROUP_API_ENDPOINT, - validate: { - body: schema.object({ - name: schema.string(), - description: schema.maybe(schema.string()), - modelAccessMode: schema.oneOf([ - schema.literal('public'), - schema.literal('private'), - schema.literal('restricted'), - ]), - backendRoles: schema.maybe(schema.arrayOf(schema.string())), - addAllBackendRoles: schema.maybe(schema.boolean()), - }), - }, - }, - async (context, request) => { - const { name, description, modelAccessMode, backendRoles, addAllBackendRoles } = request.body; - try { - const payload = await ModelGroupService.register({ - client: context.core.opensearch.client, - name, - description, - modelAccessMode, - backendRoles, - addAllBackendRoles, - }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); - } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ - body: error instanceof Error ? error.message : JSON.stringify(error), - }); - } - } - ); - - router.put( - { - path: `${MODEL_GROUP_API_ENDPOINT}/{groupId}`, - validate: { - params: schema.object({ - groupId: schema.string(), - }), - body: schema.object({ - name: schema.maybe(schema.string()), - description: schema.maybe(schema.string()), - }), - }, - }, - async (context, request) => { - const { - params: { groupId }, - body: { name, description }, - } = request; - try { - const payload = await ModelGroupService.update({ - client: context.core.opensearch.client, - id: groupId, - name, - description, - }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); - } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ - body: error instanceof Error ? error.message : JSON.stringify(error), - }); - } - } - ); - - router.delete( - { - path: `${MODEL_GROUP_API_ENDPOINT}/{groupId}`, - validate: { - params: schema.object({ - groupId: schema.string(), - }), - }, - }, - async (context, request) => { - try { - const payload = await ModelGroupService.delete({ - client: context.core.opensearch.client, - id: request.params.groupId, - }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); - } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ - body: error instanceof Error ? error.message : JSON.stringify(error), - }); - } - } - ); - - router.get( - { - path: MODEL_GROUP_API_ENDPOINT, - validate: { - query: schema.object({ - id: schema.maybe(schema.string()), - name: schema.maybe(schema.string()), - from: schema.number({ min: 0 }), - size: schema.number({ max: 100 }), - }), - }, - }, - async (context, request) => { - const { id, name, from, size } = request.query; - try { - const payload = await ModelGroupService.search({ - client: context.core.opensearch.client, - id, - name, - from, - size, - }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); - } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ - body: error instanceof Error ? error.message : JSON.stringify(error), - }); - } - } - ); -}; diff --git a/server/routes/model_router.ts b/server/routes/model_router.ts index 996f3c48..05ac2c5f 100644 --- a/server/routes/model_router.ts +++ b/server/routes/model_router.ts @@ -4,301 +4,137 @@ */ import { schema } from '@osd/config-schema'; -import { MAX_MODEL_CHUNK_SIZE, MODEL_STATE } from '../../common'; -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; -import { ModelService, RecordNotFoundError } from '../services'; -import { - MODEL_API_ENDPOINT, - MODEL_LOAD_API_ENDPOINT, - MODEL_UNLOAD_API_ENDPOINT, - MODEL_UPLOAD_API_ENDPOINT, - MODEL_PROFILE_API_ENDPOINT, -} from './constants'; -import { getOpenSearchClientTransport } from './utils'; - -const validateSortItem = (sort: string) => { - const [key, direction] = sort.split('-'); - if (key === undefined || direction === undefined) { - return 'Invalidate sort'; - } - if (direction !== 'asc' && direction !== 'desc') { - return 'Invalidate sort'; - } - const availableSortKeys = ['id', 'version', 'last_updated_time', 'name', 'model_state']; - - if (!availableSortKeys.includes(key) && !key.startsWith('tags.')) { - return 'Invalidate sort'; - } - return undefined; -}; - -const validateUniqueSort = (sort: string[]) => { - const uniqueSortKeys = new Set(sort.map((item) => item.split('-')[0])); - if (uniqueSortKeys.size < sort.length) { - return 'Invalidate sort'; - } - return undefined; -}; -const modelStateSchema = schema.oneOf([ - schema.literal(MODEL_STATE.loaded), - schema.literal(MODEL_STATE.trained), - schema.literal(MODEL_STATE.unloaded), - schema.literal(MODEL_STATE.uploaded), - schema.literal(MODEL_STATE.uploading), - schema.literal(MODEL_STATE.loading), - schema.literal(MODEL_STATE.partiallyLoaded), - schema.literal(MODEL_STATE.loadFailed), - schema.literal(MODEL_STATE.registerFailed), -]); - -const modelUploadBaseSchema = { - name: schema.string(), - version: schema.maybe(schema.string()), - description: schema.maybe(schema.string()), - modelFormat: schema.string(), - modelConfig: schema.object({}, { unknowns: 'allow' }), - modelGroupId: schema.string(), -}; - -const modelUploadByURLSchema = schema.object({ - ...modelUploadBaseSchema, - url: schema.string(), -}); - -const modelUploadByChunkSchema = schema.object({ - ...modelUploadBaseSchema, - modelContentHashValue: schema.string(), - totalChunks: schema.number(), -}); +import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { ModelService } from '../services'; -export const modelRouter = (services: { modelService: ModelService }, router: IRouter) => { - const { modelService } = services; +import { MODEL_API_ENDPOINT } from './constants'; - router.get( +export const modelRouter = (router: IRouter) => { + router.post( { path: MODEL_API_ENDPOINT, validate: { - query: schema.object({ - name: schema.maybe(schema.string()), - from: schema.number({ min: 0 }), - size: schema.number({ max: 50 }), - sort: schema.maybe( - schema.oneOf([ - schema.string({ validate: validateSortItem }), - schema.arrayOf(schema.string({ validate: validateSortItem }), { - validate: validateUniqueSort, - }), - ]) - ), - states: schema.maybe(schema.oneOf([schema.arrayOf(modelStateSchema), modelStateSchema])), - nameOrId: schema.maybe(schema.string()), - versionOrKeyword: schema.maybe(schema.string()), - extra_query: schema.maybe(schema.recordOf(schema.string(), schema.any())), - data_source_id: schema.maybe(schema.string()), - modelGroupId: schema.maybe(schema.string()), + body: schema.object({ + name: schema.string(), + description: schema.maybe(schema.string()), + modelAccessMode: schema.oneOf([ + schema.literal('public'), + schema.literal('private'), + schema.literal('restricted'), + ]), + backendRoles: schema.maybe(schema.arrayOf(schema.string())), + addAllBackendRoles: schema.maybe(schema.boolean()), }), }, }, - async (context, request, response) => { - const { - from, - size, - sort, - name, - states, - nameOrId, - extra_query: extraQuery, - data_source_id: dataSourceId, - modelGroupId, - versionOrKeyword, - } = request.query; + async (context, request) => { + const { name, description, modelAccessMode, backendRoles, addAllBackendRoles } = request.body; try { - const payload = await ModelService.search({ - transport: await getOpenSearchClientTransport({ - dataSourceId, - context, - }), - from, - size, - sort: typeof sort === 'string' ? [sort] : sort, + const payload = await ModelService.register({ + client: context.core.opensearch.client, name, - states: typeof states === 'string' ? [states] : states, - nameOrId, - extraQuery, - modelGroupId, - versionOrKeyword, + description, + modelAccessMode, + backendRoles, + addAllBackendRoles, }); - return response.ok({ body: payload }); - } catch (err) { - return response.badRequest({ body: err.message }); - } - } - ); - - router.get( - { - path: `${MODEL_API_ENDPOINT}/{modelId}`, - validate: { - params: schema.object({ - modelId: schema.string(), - }), - }, - }, - async (_context, request) => { - try { - const model = await modelService.getOne({ - request, - modelId: request.params.modelId, + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), }); - return opensearchDashboardsResponseFactory.ok({ body: model }); - } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); } } ); - router.delete( + router.put( { - path: `${MODEL_API_ENDPOINT}/{modelId}`, + path: `${MODEL_API_ENDPOINT}/{id}`, validate: { params: schema.object({ - modelId: schema.string(), + id: schema.string(), }), - }, - }, - async (_context, request) => { - try { - await modelService.delete({ - request, - modelId: request.params.modelId, - }); - return opensearchDashboardsResponseFactory.ok(); - } catch (err) { - if (err instanceof RecordNotFoundError) { - return opensearchDashboardsResponseFactory.notFound(); - } - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); - } - } - ); - - router.post( - { - path: `${MODEL_LOAD_API_ENDPOINT}/{modelId}`, - validate: { - params: schema.object({ - modelId: schema.string(), + body: schema.object({ + name: schema.maybe(schema.string()), + description: schema.maybe(schema.string()), }), }, }, - async (_context, request) => { + async (context, request, response) => { + const { + params: { id }, + body: { name, description }, + } = request; try { - const result = await modelService.load({ - request, - modelId: request.params.modelId, + const payload = await ModelService.update({ + client: context.core.opensearch.client, + id, + name, + description, }); - return opensearchDashboardsResponseFactory.ok({ body: result }); - } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); - } - } - ); - - router.post( - { - path: `${MODEL_UNLOAD_API_ENDPOINT}/{modelId}`, - validate: { - params: schema.object({ - modelId: schema.string(), - }), - }, - }, - async (_context, request) => { - try { - const result = await modelService.unload({ - request, - modelId: request.params.modelId, + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), }); - return opensearchDashboardsResponseFactory.ok({ body: result }); - } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); } } ); - router.get( + router.delete( { - path: `${MODEL_PROFILE_API_ENDPOINT}/{modelId}`, + path: `${MODEL_API_ENDPOINT}/{id}`, validate: { params: schema.object({ - modelId: schema.string(), + id: schema.string(), }), }, }, - async (_context, request) => { - try { - const result = await modelService.profile({ - request, - modelId: request.params.modelId, - }); - return opensearchDashboardsResponseFactory.ok({ body: result }); - } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); - } - } - ); - - router.post( - { - path: MODEL_UPLOAD_API_ENDPOINT, - validate: { - body: schema.oneOf([modelUploadByURLSchema, modelUploadByChunkSchema]), - }, - }, async (context, request) => { try { - const body = await ModelService.upload({ + const payload = await ModelService.delete({ client: context.core.opensearch.client, - model: request.body, + id: request.params.id, }); - - return opensearchDashboardsResponseFactory.ok({ - body, + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), }); - } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); } } ); - router.post( + router.get( { - path: `${MODEL_API_ENDPOINT}/{modelId}/chunk/{chunkId}`, + path: MODEL_API_ENDPOINT, validate: { - params: schema.object({ - modelId: schema.string(), - chunkId: schema.string(), + query: schema.object({ + ids: schema.maybe(schema.oneOf([schema.string(), schema.arrayOf(schema.string())])), + name: schema.maybe(schema.string()), + from: schema.number({ min: 0 }), + size: schema.number({ max: 100 }), + extraQuery: schema.maybe(schema.recordOf(schema.string(), schema.any())), }), - body: schema.buffer(), - }, - options: { - body: { - maxBytes: MAX_MODEL_CHUNK_SIZE, - }, }, }, async (context, request) => { + const { ids, name, from, size, extraQuery } = request.query; try { - await ModelService.uploadModelChunk({ + const payload = await ModelService.search({ client: context.core.opensearch.client, - modelId: request.params.modelId, - chunkId: request.params.chunkId, - chunk: request.body, + ids: typeof ids === 'string' ? [ids] : ids, + name, + from, + size, + extraQuery, + }); + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (error) { + return opensearchDashboardsResponseFactory.badRequest({ + body: error instanceof Error ? error.message : JSON.stringify(error), }); - return opensearchDashboardsResponseFactory.ok(); - } catch (err) { - return opensearchDashboardsResponseFactory.badRequest(err.message); } } ); diff --git a/server/routes/model_version_router.ts b/server/routes/model_version_router.ts new file mode 100644 index 00000000..e48d73d0 --- /dev/null +++ b/server/routes/model_version_router.ts @@ -0,0 +1,311 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { schema } from '@osd/config-schema'; +import { MAX_MODEL_CHUNK_SIZE, MODEL_VERSION_STATE } from '../../common'; +import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { ModelVersionService, RecordNotFoundError } from '../services'; +import { + MODEL_VERSION_API_ENDPOINT, + MODEL_VERSION_LOAD_API_ENDPOINT, + MODEL_VERSION_UNLOAD_API_ENDPOINT, + MODEL_VERSION_UPLOAD_API_ENDPOINT, + MODEL_VERSION_PROFILE_API_ENDPOINT, +} from './constants'; +import { getOpenSearchClientTransport } from './utils'; + +const validateSortItem = (sort: string) => { + const [key, direction] = sort.split('-'); + if (key === undefined || direction === undefined) { + return 'Invalidate sort'; + } + if (direction !== 'asc' && direction !== 'desc') { + return 'Invalidate sort'; + } + const availableSortKeys = ['id', 'version', 'last_updated_time', 'name', 'model_state']; + + if (!availableSortKeys.includes(key) && !key.startsWith('tags.')) { + return 'Invalidate sort'; + } + return undefined; +}; + +const validateUniqueSort = (sort: string[]) => { + const uniqueSortKeys = new Set(sort.map((item) => item.split('-')[0])); + if (uniqueSortKeys.size < sort.length) { + return 'Invalidate sort'; + } + return undefined; +}; + +export const modelStateSchema = schema.oneOf([ + schema.literal(MODEL_VERSION_STATE.deployed), + schema.literal(MODEL_VERSION_STATE.trained), + schema.literal(MODEL_VERSION_STATE.undeployed), + schema.literal(MODEL_VERSION_STATE.registered), + schema.literal(MODEL_VERSION_STATE.registering), + schema.literal(MODEL_VERSION_STATE.deploying), + schema.literal(MODEL_VERSION_STATE.partiallyDeployed), + schema.literal(MODEL_VERSION_STATE.deployFailed), + schema.literal(MODEL_VERSION_STATE.registerFailed), +]); + +const modelUploadBaseSchema = { + name: schema.string(), + version: schema.maybe(schema.string()), + description: schema.maybe(schema.string()), + modelFormat: schema.string(), + modelConfig: schema.object({}, { unknowns: 'allow' }), + modelId: schema.string(), +}; + +const modelUploadByURLSchema = schema.object({ + ...modelUploadBaseSchema, + url: schema.string(), +}); + +const modelUploadByChunkSchema = schema.object({ + ...modelUploadBaseSchema, + modelContentHashValue: schema.string(), + totalChunks: schema.number(), +}); + +export const modelVersionRouter = (router: IRouter) => { + router.get( + { + path: MODEL_VERSION_API_ENDPOINT, + validate: { + query: schema.object({ + name: schema.maybe(schema.string()), + algorithms: schema.maybe( + schema.oneOf([schema.string(), schema.arrayOf(schema.string())]) + ), + ids: schema.maybe(schema.oneOf([schema.string(), schema.arrayOf(schema.string())])), + from: schema.number({ min: 0 }), + size: schema.number({ max: 50 }), + sort: schema.maybe( + schema.oneOf([ + schema.string({ validate: validateSortItem }), + schema.arrayOf(schema.string({ validate: validateSortItem }), { + validate: validateUniqueSort, + }), + ]) + ), + states: schema.maybe(schema.oneOf([schema.arrayOf(modelStateSchema), modelStateSchema])), + nameOrId: schema.maybe(schema.string()), + versionOrKeyword: schema.maybe(schema.string()), + modelIds: schema.maybe(schema.oneOf([schema.string(), schema.arrayOf(schema.string())])), + extra_query: schema.maybe(schema.recordOf(schema.string(), schema.any())), + data_source_id: schema.maybe(schema.string()), + }), + }, + }, + async (context, request) => { + const { + algorithms, + ids, + from, + size, + sort, + name, + states, + nameOrId, + modelIds, + versionOrKeyword, + extra_query: extraQuery, + data_source_id: dataSourceId, + } = request.query; + try { + const payload = await ModelVersionService.search({ + transport: await getOpenSearchClientTransport({ + context, + dataSourceId, + }), + algorithms: typeof algorithms === 'string' ? [algorithms] : algorithms, + ids: typeof ids === 'string' ? [ids] : ids, + from, + size, + sort: typeof sort === 'string' ? [sort] : sort, + name, + states: typeof states === 'string' ? [states] : states, + nameOrId, + modelIds: typeof modelIds === 'string' ? [modelIds] : modelIds, + versionOrKeyword, + extraQuery, + }); + return opensearchDashboardsResponseFactory.ok({ body: payload }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.get( + { + path: `${MODEL_VERSION_API_ENDPOINT}/{id}`, + validate: { + params: schema.object({ + id: schema.string(), + }), + }, + }, + async (context, request) => { + try { + const model = await ModelVersionService.getOne({ + client: context.core.opensearch.client, + id: request.params.id, + }); + return opensearchDashboardsResponseFactory.ok({ body: model }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.delete( + { + path: `${MODEL_VERSION_API_ENDPOINT}/{id}`, + validate: { + params: schema.object({ + id: schema.string(), + }), + }, + }, + async (context, request) => { + try { + await ModelVersionService.delete({ + client: context.core.opensearch.client, + id: request.params.id, + }); + return opensearchDashboardsResponseFactory.ok(); + } catch (err) { + if (err instanceof RecordNotFoundError) { + return opensearchDashboardsResponseFactory.notFound(); + } + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.post( + { + path: `${MODEL_VERSION_LOAD_API_ENDPOINT}/{id}`, + validate: { + params: schema.object({ + id: schema.string(), + }), + }, + }, + async (context, request) => { + try { + const result = await ModelVersionService.load({ + client: context.core.opensearch.client, + id: request.params.id, + }); + return opensearchDashboardsResponseFactory.ok({ body: result }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.post( + { + path: `${MODEL_VERSION_UNLOAD_API_ENDPOINT}/{id}`, + validate: { + params: schema.object({ + id: schema.string(), + }), + }, + }, + async (context, request) => { + try { + const result = await ModelVersionService.unload({ + client: context.core.opensearch.client, + id: request.params.id, + }); + return opensearchDashboardsResponseFactory.ok({ body: result }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.get( + { + path: `${MODEL_VERSION_PROFILE_API_ENDPOINT}/{id}`, + validate: { + params: schema.object({ + id: schema.string(), + }), + }, + }, + async (context, request) => { + try { + const result = await ModelVersionService.profile({ + client: context.core.opensearch.client, + id: request.params.id, + }); + return opensearchDashboardsResponseFactory.ok({ body: result }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.post( + { + path: MODEL_VERSION_UPLOAD_API_ENDPOINT, + validate: { + body: schema.oneOf([modelUploadByURLSchema, modelUploadByChunkSchema]), + }, + }, + async (context, request) => { + try { + const body = await ModelVersionService.upload({ + client: context.core.opensearch.client, + model: request.body, + }); + + return opensearchDashboardsResponseFactory.ok({ + body, + }); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + } + } + ); + + router.post( + { + path: `${MODEL_VERSION_API_ENDPOINT}/{id}/chunk/{chunkId}`, + validate: { + params: schema.object({ + id: schema.string(), + chunkId: schema.string(), + }), + body: schema.buffer(), + }, + options: { + body: { + maxBytes: MAX_MODEL_CHUNK_SIZE, + }, + }, + }, + async (context, request) => { + try { + await ModelVersionService.uploadModelChunk({ + client: context.core.opensearch.client, + id: request.params.id, + chunkId: request.params.chunkId, + chunk: request.body, + }); + return opensearchDashboardsResponseFactory.ok(); + } catch (err) { + return opensearchDashboardsResponseFactory.badRequest(err.message); + } + } + ); +}; diff --git a/server/services/index.ts b/server/services/index.ts index 865e2b7e..e3773c7c 100644 --- a/server/services/index.ts +++ b/server/services/index.ts @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -export { ModelService } from './model_service'; +export { ModelVersionService } from './model_version_service'; export { TaskService } from './task_service'; export { RecordNotFoundError } from './errors'; +export { ModelService } from './model_service'; diff --git a/server/services/model_aggregate_service.ts b/server/services/model_aggregate_service.ts index 0c773c09..aa78442d 100644 --- a/server/services/model_aggregate_service.ts +++ b/server/services/model_aggregate_service.ts @@ -18,67 +18,60 @@ * permissions and limitations under the License. */ +import { groupBy } from 'lodash'; + import { IScopedClusterClient } from '../../../../src/core/server'; -import { MODEL_STATE, OpenSearchModelBase } from '../../common/model'; +import { + MODEL_VERSION_STATE, + ModelAggregateSort, + ModelAggregateItem, + ModelSort, +} from '../../common'; +import { ModelService } from './model_service'; +import { ModelVersionService } from './model_version_service'; import { MODEL_SEARCH_API } from './utils/constants'; +import { generateModelSearchQuery } from './utils/model'; const MAX_MODEL_BUCKET_NUM = 10000; +const getModelSort = (sort: ModelAggregateSort): ModelSort => { + switch (sort) { + case 'owner_name-asc': + return 'owner.name-asc'; + case 'owner_name-desc': + return 'owner.name-desc'; + default: + return sort; + } +}; interface GetAggregateModelsParams { + client: IScopedClusterClient; + states?: MODEL_VERSION_STATE[]; +} + +interface ModelAggregateSearchParams extends GetAggregateModelsParams { client: IScopedClusterClient; from: number; size: number; - name?: string; - sort: 'created_time'; - order: 'desc' | 'asc'; + sort?: ModelAggregateSort; + extraQuery?: Record; } export class ModelAggregateService { - public static async getAggregateModels({ - client, - from, - size, - sort, - name, - order, - }: GetAggregateModelsParams) { + public static async getModelIdsByVersion({ client, states }: GetAggregateModelsParams) { const aggregateResult = await client.asCurrentUser.transport.request({ method: 'GET', path: MODEL_SEARCH_API, body: { size: 0, - query: { - bool: { - must: [...(name ? [{ match: { name } }] : [])], - must_not: { - exists: { - field: 'chunk_number', - }, - }, - }, - }, + query: generateModelSearchQuery({ states }), aggs: { models: { terms: { - field: 'name.keyword', + field: 'model_group_id.keyword', size: MAX_MODEL_BUCKET_NUM, }, - aggs: { - latest_version_hits: { - top_hits: { - sort: [ - { - created_time: { - order: 'desc', - }, - }, - ], - size: 1, - _source: ['model_version', 'model_state', 'description', 'created_time'], - }, - }, - }, }, }, }, @@ -86,97 +79,49 @@ export class ModelAggregateService { const models = aggregateResult.body.aggregations.models.buckets as Array<{ key: string; doc_count: number; - latest_version_hits: { - hits: { - hits: [ - { - _source: Pick & { - created_time: number; - description?: string; - }; - } - ]; - }; - }; }>; - return { - models: models - .sort( - (a, b) => - ((a.latest_version_hits.hits.hits[0]._source.created_time ?? 0) - - (b.latest_version_hits.hits.hits[0]._source.created_time ?? 0)) * - (sort === 'created_time' && order === 'asc' ? 1 : -1) - ) - .slice(from, from + size), - total_models: models.length, - }; + return models.map(({ key }) => key); } - public static async search(params: GetAggregateModelsParams) { - const { client } = params; - const { models, total_models: totalModels } = await ModelAggregateService.getAggregateModels( - params - ); - const { names, count } = models.reduce<{ names: string[]; count: number }>( - (previous, { key, doc_count: docCount }: { key: string; doc_count: number }) => ({ - names: previous.names.concat(key), - count: docCount + previous.count, - }), - { names: [], count: 0 } - ); - const versionResult = await client.asCurrentUser.transport.request({ - method: 'GET', - path: MODEL_SEARCH_API, - body: { - size: count, - query: { - bool: { - should: names.map((name) => ({ term: { 'name.keyword': name } })), - must_not: { - exists: { - field: 'chunk_number', - }, - }, - }, - }, - _source: ['name', 'model_version', 'model_state', 'model_id'], - }, + public static async search({ + client, + from, + size, + sort, + states, + extraQuery, + }: ModelAggregateSearchParams) { + const sourceModelIds = states + ? await ModelAggregateService.getModelIdsByVersion({ client, states }) + : undefined; + const { data: models, total_models: totalModels } = await ModelService.search({ + client, + from, + size, + sort: sort ? getModelSort(sort) : sort, + ids: sourceModelIds, + extraQuery, + }); + const modelIds = models.map(({ id }) => id); + const { data: deployedModels } = await ModelVersionService.search({ + client, + from: 0, + size: MAX_MODEL_BUCKET_NUM, + modelIds, + states: [MODEL_VERSION_STATE.deployed], }); - const versionResultMap = (versionResult.body.hits.hits as Array<{ - _id: string; - _source: OpenSearchModelBase; - }>).reduce<{ - [key: string]: Array>; - }>( - (pValue, { _source: { name, ...resetProperties } }) => ({ - ...pValue, - [name]: (pValue[name] ?? []).concat(resetProperties), - }), - {} - ); + + const modelId2Version = groupBy(deployedModels, 'model_id'); + return { - data: models.map( - ({ - key, - latest_version_hits: { - hits: { hits }, - }, - }) => { - const latestVersion = hits[0]._source; - return { - name: key, - deployed_versions: (versionResultMap[key] ?? []) - .filter((item) => item.model_state === MODEL_STATE.loaded) - .map((item) => item.model_version), - // TODO: Change to the real model owner - owner: key, - latest_version: latestVersion.model_version, - latest_version_state: latestVersion.model_state, - created_time: latestVersion.created_time, - }; - } - ), + data: models.map((model) => ({ + ...model, + owner_name: model.owner.name, + deployed_versions: (modelId2Version[model.id] || []).map( + (deployedVersion) => deployedVersion.model_version + ), + })) as ModelAggregateItem[], total_models: totalModels, }; } diff --git a/server/services/model_group_service.ts b/server/services/model_group_service.ts deleted file mode 100644 index 7ad26cfd..00000000 --- a/server/services/model_group_service.ts +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -/* - * Copyright OpenSearch Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -import { IScopedClusterClient } from '../../../../src/core/server'; - -import { - MODEL_GROUP_BASE_API, - MODEL_GROUP_REGISTER_API, - MODEL_GROUP_SEARCH_API, - MODEL_GROUP_UPDATE_API, -} from './utils/constants'; -import { generateMustQueries, generateTermQuery } from './utils/query'; - -export class ModelGroupService { - public static async register(params: { - client: IScopedClusterClient; - name: string; - description?: string; - modelAccessMode: 'public' | 'restricted' | 'private'; - backendRoles?: string[]; - addAllBackendRoles?: boolean; - }) { - const { client, name, description, modelAccessMode, backendRoles, addAllBackendRoles } = params; - const result = ( - await client.asCurrentUser.transport.request({ - method: 'POST', - path: MODEL_GROUP_REGISTER_API, - body: { - name, - description, - model_access_mode: modelAccessMode, - backend_roles: backendRoles, - add_all_backend_roles: addAllBackendRoles, - }, - }) - ).body as { - model_group_id: string; - status: 'CREATED'; - }; - return result; - } - - public static async update({ - client, - id, - name, - description, - }: { - client: IScopedClusterClient; - id: string; - name?: string; - description?: string; - }) { - const result = ( - await client.asCurrentUser.transport.request({ - method: 'PUT', - path: MODEL_GROUP_UPDATE_API.replace('', id), - body: { - name, - description, - }, - }) - ).body as { - status: 'UPDATED'; - }; - return result; - } - - public static async delete({ client, id }: { client: IScopedClusterClient; id: string }) { - const result = ( - await client.asCurrentUser.transport.request({ - method: 'DELETE', - path: `${MODEL_GROUP_BASE_API}/${id}`, - }) - ).body; - return result; - } - - public static async search({ - client, - id, - name, - from, - size, - }: { - client: IScopedClusterClient; - id?: string; - name?: string; - from: number; - size: number; - }) { - const { - body: { hits }, - } = await client.asCurrentUser.transport.request({ - method: 'GET', - path: MODEL_GROUP_SEARCH_API, - body: { - query: generateMustQueries([ - ...(id ? [generateTermQuery('_id', id)] : []), - ...(name ? [generateTermQuery('name', name)] : []), - ]), - from, - size, - }, - }); - - return { - data: hits.hits.map(({ _id, _source }) => ({ - id: _id, - ..._source, - })), - total_model_groups: hits.total.value, - }; - } -} diff --git a/server/services/model_service.ts b/server/services/model_service.ts index 3793e1e0..671a23c8 100644 --- a/server/services/model_service.ts +++ b/server/services/model_service.ts @@ -18,98 +18,132 @@ * permissions and limitations under the License. */ -import { - IScopedClusterClient, - OpenSearchClient, - ScopeableRequest, - ILegacyClusterClient, -} from '../../../../src/core/server'; -import { MODEL_STATE } from '../../common'; - -import { generateModelSearchQuery } from './utils/model'; -import { RecordNotFoundError } from './errors'; -import { MODEL_BASE_API, MODEL_META_API, MODEL_UPLOAD_API } from './utils/constants'; +import { ModelSort, OpenSearchModel } from '../../common'; +import { IScopedClusterClient } from '../../../../src/core/server'; -const modelSortFieldMapping: { [key: string]: string } = { - version: 'model_version', - name: 'name.keyword', - id: '_id', +import { + MODEL_GROUP_BASE_API, + MODEL_GROUP_REGISTER_API, + MODEL_GROUP_SEARCH_API, + MODEL_GROUP_UPDATE_API, +} from './utils/constants'; +import { generateTermQuery } from './utils/query'; + +const getSortItem = (sort: ModelSort) => { + const [key, direction] = sort.split('-'); + const keyMapping: { [key: string]: string } = { + 'owner.name': 'owner.name.keyword', + name: 'name.keyword', + }; + + return { [keyMapping[key] || key]: direction }; }; -interface UploadModelBase { - name: string; - version?: string; - description?: string; - modelFormat: string; - modelConfig: Record; - modelGroupId: string; -} - -interface UploadModelByURL extends UploadModelBase { - url: string; -} - -interface UploadModelByChunk extends UploadModelBase { - modelContentHashValue: string; - totalChunks: number; -} - -type UploadResultInner< - T extends UploadModelByURL | UploadModelByChunk -> = T extends UploadModelByChunk - ? { model_id: string; status: string } - : T extends UploadModelByURL - ? { task_id: string; status: string } - : never; - -type UploadResult = Promise>; - -const isUploadModelByURL = ( - test: UploadModelByURL | UploadModelByChunk -): test is UploadModelByURL => (test as UploadModelByURL).url !== undefined; - export class ModelService { - private osClient: ILegacyClusterClient; + public static async register(params: { + client: IScopedClusterClient; + name: string; + description?: string; + modelAccessMode: 'public' | 'restricted' | 'private'; + backendRoles?: string[]; + addAllBackendRoles?: boolean; + }) { + const { client, name, description, modelAccessMode, backendRoles, addAllBackendRoles } = params; + const result = ( + await client.asCurrentUser.transport.request({ + method: 'POST', + path: MODEL_GROUP_REGISTER_API, + body: { + name, + description, + model_access_mode: modelAccessMode, + backend_roles: backendRoles, + add_all_backend_roles: addAllBackendRoles, + }, + }) + ).body as { + model_group_id: string; + status: 'CREATED'; + }; + return { + model_id: result.model_group_id, + status: result.status, + }; + } - constructor(osClient: ILegacyClusterClient) { - this.osClient = osClient; + public static async update({ + client, + id, + name, + description, + }: { + client: IScopedClusterClient; + id: string; + name?: string; + description?: string; + }) { + const result = ( + await client.asCurrentUser.transport.request({ + method: 'PUT', + path: MODEL_GROUP_UPDATE_API.replace('', id), + body: { + name, + description, + }, + }) + ).body as { + status: 'UPDATED'; + }; + return result; + } + + public static async delete({ client, id }: { client: IScopedClusterClient; id: string }) { + const result = ( + await client.asCurrentUser.transport.request({ + method: 'DELETE', + path: `${MODEL_GROUP_BASE_API}/${id}`, + }) + ).body; + return result; } public static async search({ + client, + ids, + name, from, size, sort, - transport, - ...restParams + extraQuery, }: { - transport: OpenSearchClient['transport']; + client: IScopedClusterClient; + ids?: string[]; + name?: string; from: number; size: number; - sort?: string[]; - name?: string; - states?: MODEL_STATE[]; + sort?: ModelSort; extraQuery?: Record; - nameOrId?: string; - versionOrKeyword?: string; - modelGroupId?: string; }) { const { body: { hits }, - } = await transport.request({ + } = await client.asCurrentUser.transport.request({ method: 'POST', - path: `${MODEL_BASE_API}/_search`, + path: MODEL_GROUP_SEARCH_API, body: { - query: generateModelSearchQuery(restParams), + query: { + bool: { + must: [ + ...(ids ? [generateTermQuery('_id', ids)] : []), + ...(name ? [generateTermQuery('name', name)] : []), + ...(extraQuery ? [extraQuery] : []), + ], + }, + }, from, size, ...(sort ? { - sort: sort.map((sorting) => { - const [field, direction] = sorting.split('-'); - return { - [modelSortFieldMapping[field] || field]: direction, - }; - }), + sort: [getSortItem(sort)], } : {}), }, @@ -119,119 +153,8 @@ export class ModelService { data: hits.hits.map(({ _id, _source }) => ({ id: _id, ..._source, - })), + })) as OpenSearchModel[], total_models: hits.total.value, }; } - - public async getOne({ request, modelId }: { request: ScopeableRequest; modelId: string }) { - const modelSource = await this.osClient - .asScoped(request) - .callAsCurrentUser('mlCommonsModel.getOne', { - modelId, - }); - return { - id: modelId, - ...modelSource, - }; - } - - public async delete({ request, modelId }: { request: ScopeableRequest; modelId: string }) { - const { result } = await this.osClient - .asScoped(request) - .callAsCurrentUser('mlCommonsModel.delete', { - modelId, - }); - if (result === 'not_found') { - throw new RecordNotFoundError(); - } - return true; - } - - public async load({ request, modelId }: { request: ScopeableRequest; modelId: string }) { - const result = await this.osClient.asScoped(request).callAsCurrentUser('mlCommonsModel.load', { - modelId, - }); - return result; - } - - public async unload({ request, modelId }: { request: ScopeableRequest; modelId: string }) { - const result = await this.osClient - .asScoped(request) - .callAsCurrentUser('mlCommonsModel.unload', { - modelId, - }); - return result; - } - - public async profile({ request, modelId }: { request: ScopeableRequest; modelId: string }) { - const result = await this.osClient - .asScoped(request) - .callAsCurrentUser('mlCommonsModel.profile', { - modelId, - }); - return result; - } - - public static async upload({ - client, - model, - }: { - client: IScopedClusterClient; - model: T; - }): UploadResult { - const { name, version, description, modelFormat, modelConfig, modelGroupId } = model; - const uploadModelBase = { - name, - version, - description, - model_format: modelFormat, - model_config: modelConfig, - model_group_id: modelGroupId, - }; - if (isUploadModelByURL(model)) { - const { task_id: taskId, status } = ( - await client.asCurrentUser.transport.request({ - method: 'POST', - path: MODEL_UPLOAD_API, - body: { - ...uploadModelBase, - url: model.url, - }, - }) - ).body; - return { task_id: taskId, status } as UploadResultInner; - } - - const { model_id: modelId, status } = ( - await client.asCurrentUser.transport.request({ - method: 'POST', - path: MODEL_META_API, - body: { - ...uploadModelBase, - model_content_hash_value: model.modelContentHashValue, - total_chunks: model.totalChunks, - }, - }) - ).body; - return { model_id: modelId, status } as UploadResultInner; - } - - public static async uploadModelChunk({ - client, - modelId, - chunkId, - chunk, - }: { - client: IScopedClusterClient; - modelId: string; - chunkId: string; - chunk: Buffer; - }) { - return client.asCurrentUser.transport.request({ - method: 'POST', - path: `${MODEL_BASE_API}/${modelId}/chunk/${chunkId}`, - body: chunk, - }); - } } diff --git a/server/services/model_version_service.ts b/server/services/model_version_service.ts new file mode 100644 index 00000000..1f109cc8 --- /dev/null +++ b/server/services/model_version_service.ts @@ -0,0 +1,240 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Copyright OpenSearch Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +import { IScopedClusterClient, OpenSearchClient } from '../../../../src/core/server'; +import { MODEL_VERSION_STATE } from '../../common'; + +import { generateModelSearchQuery } from './utils/model'; +import { RecordNotFoundError } from './errors'; +import { + MODEL_BASE_API, + MODEL_META_API, + MODEL_PROFILE_API, + MODEL_UPLOAD_API, +} from './utils/constants'; + +const modelSortFieldMapping: { [key: string]: string } = { + version: 'model_version', + name: 'name.keyword', + id: '_id', +}; + +interface UploadModelBase { + name: string; + version?: string; + description?: string; + modelFormat: string; + modelConfig: Record; + modelId: string; +} + +interface UploadModelByURL extends UploadModelBase { + url: string; +} + +interface UploadModelByChunk extends UploadModelBase { + modelContentHashValue: string; + totalChunks: number; +} + +type UploadResultInner< + T extends UploadModelByURL | UploadModelByChunk +> = T extends UploadModelByChunk + ? { model_version_id: string; status: string } + : T extends UploadModelByURL + ? { task_id: string; status: string } + : never; + +type UploadResult = Promise>; + +const isUploadModelByURL = ( + test: UploadModelByURL | UploadModelByChunk +): test is UploadModelByURL => (test as UploadModelByURL).url !== undefined; + +export class ModelVersionService { + constructor() {} + + public static async search({ + from, + size, + sort, + transport, + ...restParams + }: { + transport: OpenSearchClient['transport']; + algorithms?: string[]; + ids?: string[]; + from: number; + size: number; + sort?: string[]; + name?: string; + states?: MODEL_VERSION_STATE[]; + nameOrId?: string; + versionOrKeyword?: string; + modelIds?: string[]; + extraQuery?: Record; + }) { + const { + body: { hits }, + } = await transport.request({ + method: 'POST', + path: `${MODEL_BASE_API}/_search`, + body: { + query: generateModelSearchQuery(restParams), + from, + size, + ...(sort + ? { + sort: sort.map((sorting) => { + const [field, direction] = sorting.split('-'); + return { + [modelSortFieldMapping[field] || field]: direction, + }; + }), + } + : {}), + }, + }); + + return { + data: hits.hits.map(({ _id, _source: source }) => ({ + id: _id, + model_id: source.model_group_id, + ...source, + })), + total_model_versions: hits.total.value, + }; + } + + public static async getOne({ id, client }: { id: string; client: IScopedClusterClient }) { + const modelSource = ( + await client.asCurrentUser.transport.request({ + method: 'GET', + path: `${MODEL_BASE_API}/${id}`, + }) + ).body; + return { + id, + ...modelSource, + }; + } + + public static async delete({ id, client }: { id: string; client: IScopedClusterClient }) { + const { result } = ( + await client.asCurrentUser.transport.request({ + method: 'DELETE', + path: `${MODEL_BASE_API}/${id}`, + }) + ).body; + if (result === 'not_found') { + throw new RecordNotFoundError(); + } + return true; + } + + public static async load({ id, client }: { id: string; client: IScopedClusterClient }) { + return ( + await client.asCurrentUser.transport.request({ + method: 'POST', + path: `${MODEL_BASE_API}/${id}/_load`, + }) + ).body; + } + + public static async unload({ id, client }: { id: string; client: IScopedClusterClient }) { + return ( + await client.asCurrentUser.transport.request({ + method: 'POST', + path: `${MODEL_BASE_API}/${id}/_unload`, + }) + ).body; + } + + public static async profile({ client, id }: { client: IScopedClusterClient; id: string }) { + return ( + await client.asCurrentUser.transport.request({ + method: 'GET', + path: `${MODEL_PROFILE_API}/${id}`, + }) + ).body; + } + + public static async upload({ + client, + model, + }: { + client: IScopedClusterClient; + model: T; + }): UploadResult { + const { name, version, description, modelFormat, modelConfig, modelId } = model; + const uploadModelBase = { + name, + version, + description, + model_format: modelFormat, + model_config: modelConfig, + model_group_id: modelId, + }; + if (isUploadModelByURL(model)) { + const { task_id: taskId, status } = ( + await client.asCurrentUser.transport.request({ + method: 'POST', + path: MODEL_UPLOAD_API, + body: { + ...uploadModelBase, + url: model.url, + }, + }) + ).body; + return { task_id: taskId, status } as UploadResultInner; + } + + const { model_id: modelVersionId, status } = ( + await client.asCurrentUser.transport.request({ + method: 'POST', + path: MODEL_META_API, + body: { + ...uploadModelBase, + model_content_hash_value: model.modelContentHashValue, + total_chunks: model.totalChunks, + }, + }) + ).body; + return { model_version_id: modelVersionId, status } as UploadResultInner; + } + + public static async uploadModelChunk({ + client, + id, + chunkId, + chunk, + }: { + client: IScopedClusterClient; + id: string; + chunkId: string; + chunk: Buffer; + }) { + return client.asCurrentUser.transport.request({ + method: 'POST', + path: `${MODEL_BASE_API}/${id}/chunk/${chunkId}`, + body: chunk, + }); + } +} diff --git a/server/services/utils/model.ts b/server/services/utils/model.ts index ef280d7d..a1f44a6b 100644 --- a/server/services/utils/model.ts +++ b/server/services/utils/model.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { MODEL_STATE } from '../../../common'; +import { MODEL_VERSION_STATE } from '../../../common'; import { generateTermQuery } from './query'; export const convertModelSource = (source: { @@ -26,18 +26,18 @@ export const generateModelSearchQuery = ({ name, states, nameOrId, - extraQuery, - modelGroupId, + modelIds, versionOrKeyword, + extraQuery, }: { ids?: string[]; algorithms?: string[]; name?: string; - states?: MODEL_STATE[]; + states?: MODEL_VERSION_STATE[]; nameOrId?: string; extraQuery?: Record; versionOrKeyword?: string; - modelGroupId?: string; + modelIds?: string[]; }) => ({ bool: { must: [ @@ -88,7 +88,7 @@ export const generateModelSearchQuery = ({ }, ] : []), - ...(modelGroupId ? [generateTermQuery('model_group_id.keyword', modelGroupId)] : []), + ...(modelIds ? [generateTermQuery('model_group_id.keyword', modelIds)] : []), ], must_not: { exists: { diff --git a/test/mocks/data/model_aggregate.ts b/test/mocks/data/model_aggregate.ts index 74f617f3..16440173 100644 --- a/test/mocks/data/model_aggregate.ts +++ b/test/mocks/data/model_aggregate.ts @@ -6,13 +6,13 @@ export const modelAggregateResponse = { data: [ { + id: '1', name: 'traced_small_model', deployed_versions: ['1.0.1'], - owner: 'traced_small_model', - latest_version: '1.0.5', - latest_version_state: 'DEPLOYED', - created_time: 1681887678282, + owner_name: 'traced_small_model', + latest_version: 5, + last_updated_time: 1681887678282, }, ], - pagination: { currentPage: 1, pageSize: 15, totalRecords: 1, totalPages: 1 }, + total_models: 1, }; diff --git a/test/mocks/handlers.ts b/test/mocks/handlers.ts index 1f0af718..174b3cf1 100644 --- a/test/mocks/handlers.ts +++ b/test/mocks/handlers.ts @@ -9,10 +9,10 @@ import { MODEL_AGGREGATE_API_ENDPOINT } from '../../server/routes/constants'; import { modelConfig } from './data/model_config'; import { modelRepositoryResponse } from './data/model_repository'; -import { modelHandlers } from './model_handlers'; +import { modelVersionHandlers } from './model_version_handlers'; import { modelAggregateResponse } from './data/model_aggregate'; import { taskHandlers } from './task_handlers'; -import { modelGroupHandlers } from './model_group_handlers'; +import { modelHandlers } from './model_handlers'; export const handlers = [ rest.get('/api/ml-commons/model-repository', (req, res, ctx) => { @@ -21,10 +21,10 @@ export const handlers = [ rest.get('/api/ml-commons/model-repository/config-url/:config_url', (req, res, ctx) => { return res(ctx.status(200), ctx.json(modelConfig)); }), - ...modelHandlers, + ...modelVersionHandlers, rest.get(MODEL_AGGREGATE_API_ENDPOINT, (_req, res, ctx) => { return res(ctx.status(200), ctx.json(modelAggregateResponse)); }), ...taskHandlers, - ...modelGroupHandlers, + ...modelHandlers, ]; diff --git a/test/mocks/model_group_handlers.ts b/test/mocks/model_group_handlers.ts deleted file mode 100644 index 2b26dff2..00000000 --- a/test/mocks/model_group_handlers.ts +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { rest } from 'msw'; - -import { MODEL_GROUP_API_ENDPOINT } from '../../server/routes/constants'; - -const modelGroups = [ - { - name: 'model1', - id: '1', - latest_version: 1, - description: 'foo bar', - owner: { - backend_roles: ['admin'], - name: 'admin', - roles: ['admin'], - }, - created_time: 1683699499637, - last_updated_time: 1685073391256, - }, -]; - -export const modelGroupHandlers = [ - rest.get(MODEL_GROUP_API_ENDPOINT, (req, res, ctx) => { - const { searchParams } = req.url; - const name = searchParams.get('name'); - const id = searchParams.get('id'); - const from = parseInt(searchParams.get('from') || '0', 10); - const size = parseInt(searchParams.get('size') || `${modelGroups.length}`, 10); - const filteredData = modelGroups.filter((modelGroup) => { - if (id && id !== modelGroup.id) { - return false; - } - if (name && name !== modelGroup.name) { - return false; - } - return true; - }); - const end = size ? from + size : filteredData.length; - - return res( - ctx.status(200), - ctx.json({ - data: filteredData.slice(from, end), - total_model_groups: filteredData.length, - }) - ); - }), - - rest.post(MODEL_GROUP_API_ENDPOINT, (req, res, ctx) => { - return res( - ctx.status(200), - ctx.json({ - model_group_id: '1', - }) - ); - }), -]; diff --git a/test/mocks/model_handlers.ts b/test/mocks/model_handlers.ts index 244f6265..ca56da1f 100644 --- a/test/mocks/model_handlers.ts +++ b/test/mocks/model_handlers.ts @@ -9,80 +9,17 @@ import { MODEL_API_ENDPOINT } from '../../server/routes/constants'; const models = [ { - id: '1', name: 'model1', - model_version: '1.0.0', - description: 'model1 description', - created_time: 1683699467964, - last_registered_time: 1683699499632, - last_updated_time: 1683699499637, - model_config: { - all_config: '', - embedding_dimension: 768, - framework_type: 'SENTENCE_TRANSFORMERS', - model_type: 'roberta', - }, - model_format: 'TORCH_SCRIPT', - model_state: 'REGISTERED', - total_chunks: 34, - model_group_id: '1', - }, - { - id: '2', - name: 'model2', - model_version: '1.0.1', - description: 'model2 description', - created_time: 1683699467964, - last_registered_time: 1683699499632, - last_updated_time: 1683699499637, - model_config: { - all_config: '', - embedding_dimension: 768, - framework_type: 'SENTENCE_TRANSFORMERS', - model_type: 'roberta', - }, - model_format: 'TORCH_SCRIPT', - model_state: 'REGISTERED', - total_chunks: 34, - model_group_id: '2', - }, - { - id: '3', - name: 'model3', - model_version: '1.0.0', - description: 'model3 description', - created_time: 1683699467964, - last_registered_time: 1683699499632, - last_updated_time: 1683699499637, - model_config: { - all_config: '', - embedding_dimension: 768, - framework_type: 'SENTENCE_TRANSFORMERS', - model_type: 'roberta', - }, - model_format: 'TORCH_SCRIPT', - model_state: 'DEPLOYED', - total_chunks: 34, - model_group_id: '3', - }, - { - id: '4', - name: 'model1', - model_version: '1.0.1', - description: 'model1 version 1.0.1 description', - created_time: 1683699469964, - last_registered_time: 1683699599632, - last_updated_time: 1683699599637, - model_config: { - all_config: '', - embedding_dimension: 768, - framework_type: 'SENTENCE_TRANSFORMERS', - model_type: 'roberta', + id: '1', + latest_version: 1, + description: 'foo bar', + owner: { + backend_roles: ['admin'], + name: 'admin', + roles: ['admin'], }, - model_format: 'TORCH_SCRIPT', - model_state: 'DEPLOYED', - total_chunks: 34, - model_group_id: '1', + created_time: 1683699499637, + last_updated_time: 1685073391256, }, ]; @@ -91,30 +28,34 @@ export const modelHandlers = [ const { searchParams } = req.url; const name = searchParams.get('name'); const ids = searchParams.getAll('ids'); - const modelGroupId = searchParams.get('modelGroupId'); - const data = models.filter((model) => { - if (name) { - return model.name === name; - } - if (ids && ids.length > 0) { + const from = parseInt(searchParams.get('from') || '0', 10); + const size = parseInt(searchParams.get('size') || `${models.length}`, 10); + const filteredData = models.filter((model) => { + if (ids.length > 0) { return ids.includes(model.id); } - if (modelGroupId) { - return model.model_group_id === modelGroupId; + if (name && name !== model.name) { + return false; } return true; }); + const end = size ? from + size : filteredData.length; + return res( ctx.status(200), ctx.json({ - data, - total_models: data.length, + data: filteredData.slice(from, end), + total_models: filteredData.length, }) ); }), - rest.get(`${MODEL_API_ENDPOINT}/:modelId`, (req, res, ctx) => { - const [modelId, ..._restParts] = req.url.pathname.split('/').reverse(); - return res(ctx.status(200), ctx.json(models.find((model) => model.id === modelId))); + rest.post(MODEL_API_ENDPOINT, (req, res, ctx) => { + return res( + ctx.status(200), + ctx.json({ + model_id: '1', + }) + ); }), ]; diff --git a/test/mocks/model_version_handlers.ts b/test/mocks/model_version_handlers.ts new file mode 100644 index 00000000..d35e5263 --- /dev/null +++ b/test/mocks/model_version_handlers.ts @@ -0,0 +1,123 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { rest } from 'msw'; + +import { MODEL_VERSION_API_ENDPOINT } from '../../server/routes/constants'; + +const modelVersions = [ + { + id: '1', + name: 'model1', + model_version: '1.0.0', + description: 'model1 description', + created_time: 1683699467964, + last_registered_time: 1683699499632, + last_updated_time: 1683699499637, + model_config: { + all_config: '', + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + model_type: 'roberta', + }, + model_format: 'TORCH_SCRIPT', + model_state: 'PARTIALLY_DEPLOYED', + total_chunks: 34, + model_id: '1', + current_worker_node_count: 1, + planning_worker_node_count: 3, + planning_worker_nodes: ['node1', 'node2', 'node3'], + }, + { + id: '2', + name: 'model2', + model_version: '1.0.1', + description: 'model2 description', + created_time: 1683699467964, + last_registered_time: 1683699499632, + last_updated_time: 1683699499637, + model_config: { + all_config: '', + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + model_type: 'roberta', + }, + model_format: 'TORCH_SCRIPT', + model_state: 'REGISTERED', + total_chunks: 34, + model_id: '2', + }, + { + id: '3', + name: 'model3', + model_version: '1.0.0', + description: 'model3 description', + created_time: 1683699467964, + last_registered_time: 1683699499632, + last_updated_time: 1683699499637, + model_config: { + all_config: '', + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + model_type: 'roberta', + }, + model_format: 'TORCH_SCRIPT', + model_state: 'DEPLOYED', + total_chunks: 34, + model_id: '3', + }, + { + id: '4', + name: 'model1', + model_version: '1.0.1', + description: 'model1 version 1.0.1 description', + created_time: 1683699469964, + last_registered_time: 1683699599632, + last_updated_time: 1683699599637, + model_config: { + all_config: '', + embedding_dimension: 768, + framework_type: 'SENTENCE_TRANSFORMERS', + model_type: 'roberta', + }, + model_format: 'TORCH_SCRIPT', + model_state: 'DEPLOYED', + total_chunks: 34, + model_id: '1', + }, +]; + +export const modelVersionHandlers = [ + rest.get(MODEL_VERSION_API_ENDPOINT, (req, res, ctx) => { + const { searchParams } = req.url; + const name = searchParams.get('name'); + const ids = searchParams.getAll('ids'); + const modelIds = searchParams.getAll('modelIds'); + const data = modelVersions.filter((model) => { + if (name) { + return model.name === name; + } + if (ids.length > 0) { + return ids.includes(model.id); + } + if (modelIds.length > 0) { + return modelIds.includes(model.model_id); + } + return true; + }); + return res( + ctx.status(200), + ctx.json({ + data, + total_model_versions: data.length, + }) + ); + }), + + rest.get(`${MODEL_VERSION_API_ENDPOINT}/:modelId`, (req, res, ctx) => { + const [modelId, ..._restParts] = req.url.pathname.split('/').reverse(); + return res(ctx.status(200), ctx.json(modelVersions.find((model) => model.id === modelId))); + }), +]; From 59d0d048e619ae145acb3b30a91f55c0f4e98033 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Wed, 31 May 2023 09:30:49 +0800 Subject: [PATCH 56/75] feat: deploy and undeploy api integration (#198) On version details page, now user can deploy/undeploy the current model by click a button --------- Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- common/model_version.ts | 18 ++ public/apis/model_version.ts | 4 +- .../__tests__/global_breadcrumbs.test.tsx | 8 +- ..._version_deployment_confirm_modal.test.tsx | 271 ++---------------- ...model_version_deployment_confirm_modal.tsx | 120 +------- .../components/model/__tests__/model.test.tsx | 2 +- .../__tests__/toggle_deploy_button.test.tsx | 106 +++++++ .../model_version/model_version.tsx | 72 +++-- .../model_version/toggle_deploy_button.tsx | 118 ++++++++ .../model_version/version_details.tsx | 16 +- .../__tests__/register_model_form.test.tsx | 8 +- .../register_model/register_model.tsx | 1 + public/hooks/tests/use_deployment.test.ts | 131 +++++++++ public/hooks/use_deployment.tsx | 199 +++++++++++++ test/mocks/model_handlers.ts | 2 +- test/mocks/model_version_handlers.ts | 15 +- test/mocks/task_handlers.ts | 14 + test/test_utils.tsx | 11 +- 18 files changed, 712 insertions(+), 404 deletions(-) create mode 100644 public/components/model_version/__tests__/toggle_deploy_button.test.tsx create mode 100644 public/components/model_version/toggle_deploy_button.tsx create mode 100644 public/hooks/tests/use_deployment.test.ts create mode 100644 public/hooks/use_deployment.tsx diff --git a/common/model_version.ts b/common/model_version.ts index 2fedbe30..bf8d83e4 100644 --- a/common/model_version.ts +++ b/common/model_version.ts @@ -15,3 +15,21 @@ export enum MODEL_VERSION_STATE { deployFailed = 'DEPLOY_FAILED', registerFailed = 'REGISTER_FAILED', } + +export const isModelDeployable = (state: MODEL_VERSION_STATE) => { + if ( + state === MODEL_VERSION_STATE.undeployed || + state === MODEL_VERSION_STATE.registered || + state === MODEL_VERSION_STATE.deployFailed + ) { + return true; + } + return false; +}; + +export const isModelUndeployable = (state: MODEL_VERSION_STATE) => { + if (state === MODEL_VERSION_STATE.deployed || state === MODEL_VERSION_STATE.partiallyDeployed) { + return true; + } + return false; +}; diff --git a/public/apis/model_version.ts b/public/apis/model_version.ts index 20de2070..441a3b0e 100644 --- a/public/apis/model_version.ts +++ b/public/apis/model_version.ts @@ -16,9 +16,7 @@ import { InnerHttpProvider } from './inner_http_provider'; export interface ModelVersionSearchItem { id: string; name: string; - // TODO: the new version details API may not have this field, because model description is on model group level - // we should fix this when integrating the new API changes - description: string; + model_id: string; algorithm: string; model_state: MODEL_VERSION_STATE; model_version: string; diff --git a/public/components/__tests__/global_breadcrumbs.test.tsx b/public/components/__tests__/global_breadcrumbs.test.tsx index 181c2213..f4b46b0b 100644 --- a/public/components/__tests__/global_breadcrumbs.test.tsx +++ b/public/components/__tests__/global_breadcrumbs.test.tsx @@ -37,7 +37,7 @@ describe('', () => { it('should call onBreadcrumbsChange with register version breadcrumbs', async () => { const onBreadcrumbsChange = jest.fn(); render(, { - route: '/model-registry/register-model/1', + route: '/model-registry/register-model/model-id-1', }); expect(onBreadcrumbsChange).toHaveBeenCalledWith([ @@ -50,7 +50,7 @@ describe('', () => { expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ { text: 'Machine Learning', href: '/' }, { text: 'Model Registry', href: '/model-registry/model-list' }, - { text: 'model1', href: '/model-registry/model/1' }, + { text: 'model1', href: '/model-registry/model/model-id-1' }, { text: 'Register version' }, ]); }); @@ -59,7 +59,7 @@ describe('', () => { it('should call onBreadcrumbsChange with model breadcrumbs', async () => { const onBreadcrumbsChange = jest.fn(); render(, { - route: '/model-registry/model/1', + route: '/model-registry/model/model-id-1', }); expect(onBreadcrumbsChange).toHaveBeenCalledWith([ @@ -126,7 +126,7 @@ describe('', () => { { text: 'Model Registry', href: '/model-registry/model-list' }, ]); - history.current.push('/model-registry/model/1'); + history.current.push('/model-registry/model/model-id-1'); await act(async () => { jest.advanceTimersByTime(200); diff --git a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx index 0a0fb16f..c9962575 100644 --- a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx +++ b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx @@ -5,75 +5,26 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; -import { EuiToast } from '@elastic/eui'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionDeploymentConfirmModal } from '../model_version_deployment_confirm_modal'; import { ModelVersion } from '../../../../apis/model_version'; +import * as Hooks from '../../../../hooks/use_deployment'; -import * as PluginContext from '../../../../../../../src/plugins/opensearch_dashboards_react/public'; -import { MountWrapper } from '../../../../../../../src/core/public/utils'; -import { MountPoint } from 'opensearch-dashboards/public'; -import { OverlayModalOpenOptions } from 'src/core/public/overlays'; - -// Cannot spyOn(PluginContext, 'useOpenSearchDashboards') directly as it results in error: -// TypeError: Cannot redefine property: useOpenSearchDashboards -// So we have to mock the entire module first as a workaround -jest.mock('../../../../../../../src/plugins/opensearch_dashboards_react/public', () => { - return { - __esModule: true, - ...jest.requireActual('../../../../../../../src/plugins/opensearch_dashboards_react/public'), - }; -}); +describe('', () => { + const deployMock = jest.fn().mockResolvedValue(undefined); + const undeployMock = jest.fn().mockResolvedValue(undefined); -const generateToastMock = () => - jest.fn((toastInput) => { - render( - - ) - } - > - {typeof toastInput !== 'string' && - (typeof toastInput.text !== 'string' && toastInput.text ? ( - - ) : ( - toastInput.text - ))} - - ); + beforeEach(() => { + jest + .spyOn(Hooks, 'useDeployment') + .mockReturnValue({ deploy: deployMock, undeploy: undeployMock }); }); -const mockAddDangerAndOverlay = () => { - return jest.spyOn(PluginContext, 'useOpenSearchDashboards').mockReturnValue({ - services: { - notifications: { - toasts: { - addDanger: generateToastMock(), - }, - }, - overlays: { - openModal: jest.fn((modelMountPoint: MountPoint, options?: OverlayModalOpenOptions) => { - const { unmount } = render(); - return { - onClose: Promise.resolve(), - close: async () => { - unmount(); - }, - }; - }), - }, - }, + afterEach(() => { + jest.restoreAllMocks(); }); -}; -describe('', () => { describe('model=deploy', () => { it('should render deploy title and confirm message', () => { render( @@ -96,10 +47,7 @@ describe('', () => { ); }); - it('should call model load after deploy button clicked', async () => { - const modelLoadMock = jest - .spyOn(ModelVersion.prototype, 'load') - .mockReturnValue(Promise.resolve({ task_id: 'foo', status: 'succeeded' })); + it('should call deploy model after deploy button clicked', async () => { render( ', () => { /> ); - expect(modelLoadMock).not.toHaveBeenCalled(); + expect(deployMock).not.toHaveBeenCalled(); await userEvent.click(screen.getByRole('button', { name: 'Deploy' })); - expect(modelLoadMock).toHaveBeenCalledTimes(1); - - modelLoadMock.mockRestore(); - }); - - it('should show error toast if model load throw error', async () => { - const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); - const modelLoadMock = jest - .spyOn(ModelVersion.prototype, 'load') - .mockRejectedValue(new Error('error')); - render( - - ); - - await userEvent.click(screen.getByRole('button', { name: 'Deploy' })); - - expect(screen.getByText('deployment failed.')).toBeInTheDocument(); - expect(screen.getByText('See full error')).toBeInTheDocument(); - - modelLoadMock.mockRestore(); - useOpenSearchDashboardsMock.mockRestore(); - }); - - it('should show full error after "See full error" clicked', async () => { - const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); - const modelLoadMock = jest - .spyOn(ModelVersion.prototype, 'load') - .mockRejectedValue(new Error('This is a full error message.')); - render( - - ); - - await userEvent.click(screen.getByRole('button', { name: 'Deploy' })); - await userEvent.click(screen.getByText('See full error')); - - expect(screen.getByText('Error message:')).toBeInTheDocument(); - expect(screen.getByText('This is a full error message.')).toBeInTheDocument(); - - modelLoadMock.mockRestore(); - useOpenSearchDashboardsMock.mockRestore(); - }); - - it('should hide full error after close button clicked', async () => { - const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); - const modelLoadMock = jest - .spyOn(ModelVersion.prototype, 'load') - .mockRejectedValue(new Error('This is a full error message.')); - render( - - ); - - await userEvent.click(screen.getByRole('button', { name: 'Deploy' })); - await userEvent.click(screen.getByText('See full error')); - await userEvent.click(screen.getByText('Close')); - - expect(screen.queryByText('This is a full error message.')).not.toBeInTheDocument(); - - modelLoadMock.mockRestore(); - useOpenSearchDashboardsMock.mockRestore(); + await waitFor(() => { + expect(deployMock).toHaveBeenCalledTimes(1); + }); }); }); - describe('model=undeploy', () => { it('should render undeploy title and confirm message', () => { render( @@ -217,37 +90,6 @@ describe('', () => { }); it('should call model unload after undeploy button clicked', async () => { - const modelLoadMock = jest.spyOn(ModelVersion.prototype, 'unload').mockImplementation(); - render( - - ); - - expect(modelLoadMock).not.toHaveBeenCalled(); - await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); - expect(modelLoadMock).toHaveBeenCalledTimes(1); - - modelLoadMock.mockRestore(); - }); - - it('should show success toast after modal unload success', async () => { - const useOpenSearchDashboardsMock = jest - .spyOn(PluginContext, 'useOpenSearchDashboards') - .mockReturnValue({ - services: { - notifications: { - toasts: { - addSuccess: generateToastMock(), - }, - }, - }, - }); - const modelLoadMock = jest.spyOn(ModelVersion.prototype, 'unload').mockImplementation(); render( ', () => { /> ); + expect(undeployMock).not.toHaveBeenCalled(); await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); - await waitFor(() => { - expect(screen.getByTestId('euiToastHeader')).toHaveTextContent( - 'Undeployed model-1 version 1' - ); + expect(undeployMock).toHaveBeenCalledTimes(1); }); - - modelLoadMock.mockRestore(); - useOpenSearchDashboardsMock.mockRestore(); - }); - - it('should show error toast if model unload throw error', async () => { - const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); - const modelLoadMock = jest - .spyOn(ModelVersion.prototype, 'unload') - .mockRejectedValue(new Error('error')); - render( - - ); - - await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); - - expect(screen.getByText('undeployment failed.')).toBeInTheDocument(); - expect(screen.getByText('See full error')).toBeInTheDocument(); - - modelLoadMock.mockRestore(); - useOpenSearchDashboardsMock.mockRestore(); - }); - - it('should show full error after "See full error" clicked', async () => { - const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); - const modelLoadMock = jest - .spyOn(ModelVersion.prototype, 'unload') - .mockRejectedValue(new Error('This is a full error message.')); - render( - - ); - - await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); - await userEvent.click(screen.getByText('See full error')); - - expect(screen.getByText('Error message:')).toBeInTheDocument(); - expect(screen.getByText('This is a full error message.')).toBeInTheDocument(); - - modelLoadMock.mockRestore(); - useOpenSearchDashboardsMock.mockRestore(); - }); - - it('should hide full error after close button clicked', async () => { - const useOpenSearchDashboardsMock = mockAddDangerAndOverlay(); - const modelLoadMock = jest - .spyOn(ModelVersion.prototype, 'unload') - .mockRejectedValue(new Error('This is a full error message.')); - render( - - ); - - await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); - await userEvent.click(screen.getByText('See full error')); - await userEvent.click(screen.getByText('Close')); - - expect(screen.queryByText('This is a full error message.')).not.toBeInTheDocument(); - - modelLoadMock.mockRestore(); - useOpenSearchDashboardsMock.mockRestore(); }); }); }); diff --git a/public/components/common/modals/model_version_deployment_confirm_modal.tsx b/public/components/common/modals/model_version_deployment_confirm_modal.tsx index 46f1816b..54dc123c 100644 --- a/public/components/common/modals/model_version_deployment_confirm_modal.tsx +++ b/public/components/common/modals/model_version_deployment_confirm_modal.tsx @@ -4,23 +4,11 @@ */ import React, { useCallback, useState } from 'react'; -import { - EuiButton, - EuiConfirmModal, - EuiFlexGroup, - EuiFlexItem, - EuiLink, - EuiSpacer, - EuiText, -} from '@elastic/eui'; -import { Link, generatePath, useHistory } from 'react-router-dom'; +import { EuiConfirmModal, EuiLink } from '@elastic/eui'; +import { Link, generatePath } from 'react-router-dom'; -import { useOpenSearchDashboards } from '../../../../../../src/plugins/opensearch_dashboards_react/public'; -import { mountReactNode } from '../../../../../../src/core/public/utils'; import { routerPaths } from '../../../../common'; -import { APIProvider } from '../../../apis/api_provider'; - -import { ModelVersionErrorDetailsModal } from './model_version_error_details_modal'; +import { useDeployment } from '../../../hooks/use_deployment'; export const ModelVersionDeploymentConfirmModal = ({ id, @@ -35,114 +23,28 @@ export const ModelVersionDeploymentConfirmModal = ({ version: string; closeModal: () => void; }) => { + const { deploy, undeploy } = useDeployment(id); const [isSubmitting, setIsSubmitting] = useState(false); - const { - services: { notifications, overlays }, - } = useOpenSearchDashboards(); - const history = useHistory(); const mapping = { deploy: { title: 'Deploy', description: 'This version will begin deploying.', - errorMessage: 'deployment failed.', - errorType: 'deployment-failed' as const, - action: APIProvider.getAPI('modelVersion').load, + action: deploy, }, undeploy: { title: 'Undeploy', description: 'This version will be undeployed. You can deploy it again later.', - errorMessage: 'undeployment failed.', - errorType: 'undeployment-failed' as const, - action: APIProvider.getAPI('modelVersion').unload, + action: undeploy, }, }; - const { title, description, errorMessage, errorType, action } = mapping[mode]; + const { title, description, action } = mapping[mode]; const handleConfirm = useCallback(async () => { setIsSubmitting(true); - const modelVersionUrl = history.createHref({ - pathname: generatePath(routerPaths.modelVersion, { id }), - }); - try { - await action(id); - } catch (e) { - notifications?.toasts.addDanger({ - title: mountReactNode( - <> - - {name} version {version} - - . - - ), - text: mountReactNode( - <> - {errorMessage} - - - - - { - const overlayRef = overlays?.openModal( - mountReactNode( - { - overlayRef?.close(); - }} - errorDetails={e instanceof Error ? e.message : JSON.stringify(e)} - /> - ) - ); - }} - > - See full error - - - - - ), - }); - return; - } finally { - setIsSubmitting(false); - closeModal(); - } - // The undeploy API call is sync, we can show error message after immediately - if (mode === 'undeploy') { - notifications?.toasts.addSuccess({ - title: mountReactNode( - <> - Undeployed{' '} - - {name} version {version} - - . - - ), - }); - return; - } - // TODO: Implement model version table status updated after integrate model version table automatic refresh status column - }, [ - id, - notifications, - action, - closeModal, - overlays, - history, - name, - version, - errorType, - errorMessage, - mode, - ]); + await action(); + setIsSubmitting(false); + closeModal(); + }, [action, closeModal]); return ( { , - { route: generatePath(routerPaths.model, { id: '1' }) } + { route: generatePath(routerPaths.model, { id: 'model-id-1' }) } ); return { diff --git a/public/components/model_version/__tests__/toggle_deploy_button.test.tsx b/public/components/model_version/__tests__/toggle_deploy_button.test.tsx new file mode 100644 index 00000000..5953086d --- /dev/null +++ b/public/components/model_version/__tests__/toggle_deploy_button.test.tsx @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; + +import { render, screen } from '../../../../test/test_utils'; +import userEvent from '@testing-library/user-event'; +import { ToggleDeployButton, Props } from '../toggle_deploy_button'; +import { MODEL_VERSION_STATE } from '../../../../common'; +import * as Hooks from '../../../hooks/use_deployment'; + +function setup(props: Partial) { + render( + + ); +} + +describe('', () => { + const deployMock = jest.fn().mockResolvedValue(undefined); + const undeployMock = jest.fn().mockResolvedValue(undefined); + + beforeEach(() => { + jest + .spyOn(Hooks, 'useDeployment') + .mockReturnValue({ deploy: deployMock, undeploy: undeployMock }); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should render a "Deploy" button if model state is "REGISTERED"', () => { + setup({ modelState: MODEL_VERSION_STATE.registered }); + expect(screen.getByLabelText('deploy model')).toBeInTheDocument(); + }); + + it('should render a "Deploy" button if model state is "DEPLOY_FAILED"', () => { + setup({ modelState: MODEL_VERSION_STATE.deployFailed }); + expect(screen.getByLabelText('deploy model')).toBeInTheDocument(); + }); + + it('should render a "Deploy" button if model state is "UNDEPLOYED"', () => { + setup({ modelState: MODEL_VERSION_STATE.undeployed }); + expect(screen.getByLabelText('deploy model')).toBeInTheDocument(); + }); + + it('should render an "Undeploy" button if model state is "DEPLOYED"', () => { + setup({ modelState: MODEL_VERSION_STATE.deployed }); + expect(screen.getByLabelText('undeploy model')).toBeInTheDocument(); + }); + + it('should render an "Undeploy" button if model state is "PARTIALLY_DEPLOYED"', () => { + setup({ modelState: MODEL_VERSION_STATE.partiallyDeployed }); + expect(screen.getByLabelText('undeploy model')).toBeInTheDocument(); + }); + + it('should NOT render the button if model state is REGISTERING', () => { + setup({ modelState: MODEL_VERSION_STATE.registering }); + expect(screen.queryByLabelText('undeploy model')).toBeFalsy(); + expect(screen.queryByLabelText('deploy model')).toBeFalsy(); + }); + + it('should NOT render the button if model state is DEPLOYING', () => { + setup({ modelState: MODEL_VERSION_STATE.deploying }); + expect(screen.queryByLabelText('undeploy model')).toBeFalsy(); + expect(screen.queryByLabelText('deploy model')).toBeFalsy(); + }); + + it('should NOT render the button if model state is REGISTER_FAILED', () => { + setup({ modelState: MODEL_VERSION_STATE.registerFailed }); + expect(screen.queryByLabelText('undeploy model')).toBeFalsy(); + expect(screen.queryByLabelText('deploy model')).toBeFalsy(); + }); + + it('should display a confirmation dialog when click the "Deploy" button', async () => { + const user = userEvent.setup(); + setup({ + modelState: MODEL_VERSION_STATE.registered, + }); + await user.click(screen.getByLabelText('deploy model')); + expect(screen.queryByText('Deploy test model name version 1?')).toBeInTheDocument(); + + await user.click(screen.getByTestId('confirmModalConfirmButton')); + expect(deployMock).toHaveBeenCalled(); + }); + + it('should display a confirmation dialog when click the "Undeploy" button', async () => { + const user = userEvent.setup(); + setup({ + modelState: MODEL_VERSION_STATE.deployed, + }); + await user.click(screen.getByLabelText('undeploy model')); + expect(screen.queryByText('Undeploy test model name version 1?')).toBeInTheDocument(); + await user.click(screen.getByTestId('confirmModalConfirmButton')); + expect(undeployMock).toHaveBeenCalled(); + }); +}); diff --git a/public/components/model_version/model_version.tsx b/public/components/model_version/model_version.tsx index ab8a2ce5..420f21e1 100644 --- a/public/components/model_version/model_version.tsx +++ b/public/components/model_version/model_version.tsx @@ -14,11 +14,12 @@ import { EuiPanel, EuiLoadingContent, EuiTabbedContent, + EuiButtonIcon, } from '@elastic/eui'; import { generatePath, useHistory, useParams } from 'react-router-dom'; import { FormProvider, useForm } from 'react-hook-form'; -import { MODEL_VERSION_STATE, routerPaths } from '../../../common'; +import { OpenSearchModel, routerPaths } from '../../../common'; import { useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; @@ -29,14 +30,19 @@ import { ModelVersionInformation } from './version_information'; import { ModelVersionArtifact } from './version_artifact'; import { ModelVersionTags } from './version_tags'; import { ModelVersionFormData } from './types'; +import { ToggleDeployButton } from './toggle_deploy_button'; export const ModelVersion = () => { - const { id: modelId } = useParams<{ id: string }>(); - const { data: model, loading } = useFetcher(APIProvider.getAPI('modelVersion').getOne, modelId); + const [modelData, setModelData] = useState(); + const { id: modelVersionId } = useParams<{ id: string }>(); + const { data: modelVersionData, loading: modelVersionLoading, reload } = useFetcher( + APIProvider.getAPI('modelVersion').getOne, + modelVersionId + ); const [modelInfo, setModelInfo] = useState<{ version: string; name: string }>(); const history = useHistory(); - const modelName = model?.name; - const modelVersion = model?.model_version; + const modelName = modelVersionData?.name; + const modelVersion = modelVersionData?.model_version; const form = useForm(); const onVersionChange = useCallback( @@ -49,6 +55,16 @@ export const ModelVersion = () => { [history] ); + useEffect(() => { + if (modelVersionData?.model_id) { + APIProvider.getAPI('model') + .getOne(modelVersionData?.model_id) + .then((res) => { + setModelData(res); + }); + } + }, [modelVersionData?.model_id]); + useEffect(() => { if (!modelName || !modelVersion) { return; @@ -65,7 +81,7 @@ export const ModelVersion = () => { }, [modelName, modelVersion]); useEffect(() => { - if (model) { + if (modelVersionData) { form.reset({ versionNotes: 'TODO', // TODO: read from model.versionNotes tags: [ @@ -73,21 +89,21 @@ export const ModelVersion = () => { { key: 'Precision', value: '0.64', type: 'number' as const }, { key: 'Task', value: 'Image classification', type: 'string' as const }, ], // TODO: read from model.tags - configuration: JSON.stringify(model.model_config, undefined, 2), - modelFileFormat: model.model_format, + configuration: JSON.stringify(modelVersionData.model_config, undefined, 2), + modelFileFormat: modelVersionData.model_format, // TODO: read model url or model filename artifactSource: 'source_not_changed', // modelFile: new File([], 'artifact.zip'), modelURL: 'http://url.to/artifact.zip', }); } - }, [model, form]); + }, [modelVersionData, form]); const tabs = [ { id: 'version-information', name: 'Version information', - content: loading ? ( + content: modelVersionLoading ? ( <> @@ -106,7 +122,7 @@ export const ModelVersion = () => { { id: 'artifact-configuration', name: 'Artifact and configuration', - content: loading ? ( + content: modelVersionLoading ? ( <> @@ -145,27 +161,43 @@ export const ModelVersion = () => { } rightSideGroupProps={{ gutterSize: 'm', + alignItems: 'center', }} rightSideItems={[ Register version, - Deploy, - Delete, + modelVersionData && ( + + ), + + Delete + , ]} /> )} - - + {modelVersionData && ( + + )} - {loading ? ( + {modelVersionLoading ? ( ) : ( )} diff --git a/public/components/model_version/toggle_deploy_button.tsx b/public/components/model_version/toggle_deploy_button.tsx new file mode 100644 index 00000000..2f6f262c --- /dev/null +++ b/public/components/model_version/toggle_deploy_button.tsx @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useState, useCallback } from 'react'; +import { EuiButton, EuiConfirmModal } from '@elastic/eui'; + +import { isModelDeployable, isModelUndeployable, MODEL_VERSION_STATE } from '../../../common'; +import { useDeployment } from '../../hooks/use_deployment'; + +export interface Props { + modelState: MODEL_VERSION_STATE | undefined; + modelVersionId: string; + modelName: string; + modelVersion: string; + // called when deploy or undeploy completed + onComplete: () => void; +} + +export const ToggleDeployButton = ({ + modelState, + modelName, + modelVersion, + modelVersionId, + onComplete, +}: Props) => { + const [loading, setLoading] = useState(false); + const [isDeployModalVisible, setIsDeployModalVisible] = useState(false); + const [isUndeployModalVisible, setIsUndeployModalVisible] = useState(false); + const { deploy, undeploy } = useDeployment(modelVersionId); + + const onConfirmDeploy = useCallback(async () => { + setIsDeployModalVisible(false); + setLoading(true); + await deploy({ + onComplete: () => { + onComplete(); + setLoading(false); + }, + onError: () => { + setLoading(false); + }, + }); + }, [deploy, onComplete]); + + const deployModal = isDeployModalVisible && ( + setIsDeployModalVisible(false)} + onConfirm={onConfirmDeploy} + confirmButtonText="Deploy" + cancelButtonText="Cancel" + > +

      This version will begin deploying.

      +
      + ); + + const onConfirmUndeploy = useCallback(async () => { + setIsUndeployModalVisible(false); + setLoading(true); + await undeploy(); + setLoading(false); + onComplete(); + }, [undeploy, onComplete]); + + const undeployModal = isUndeployModalVisible && ( + setIsUndeployModalVisible(false)} + onConfirm={onConfirmUndeploy} + confirmButtonText="Undeploy" + cancelButtonText="Cancel" + > +

      This version will be undeployed. You can deploy it again later.

      +
      + ); + + const toggleButton = useMemo(() => { + if (!modelState) { + return undefined; + } + + if (isModelDeployable(modelState)) { + return ( + setIsDeployModalVisible(true)} + > + Deploy + + ); + } + + if (isModelUndeployable(modelState)) { + return ( + setIsUndeployModalVisible(true)} + > + Undeploy + + ); + } + + return undefined; + }, [modelState, loading]); + + return ( + <> + {toggleButton} + {deployModal} + {undeployModal} + + ); +}; diff --git a/public/components/model_version/version_details.tsx b/public/components/model_version/version_details.tsx index 24b02645..3ea504fa 100644 --- a/public/components/model_version/version_details.tsx +++ b/public/components/model_version/version_details.tsx @@ -20,14 +20,18 @@ interface Props { description?: string; createdTime?: number; lastUpdatedTime?: number; - modelId?: string; + modelVersionId?: string; + owner?: string; + versionNotes?: string; } export const ModelVersionDetails = ({ description, createdTime, lastUpdatedTime, - modelId, + modelVersionId, + owner, + versionNotes, }: Props) => { return ( @@ -42,7 +46,7 @@ export const ModelVersionDetails = ({

      Version notes

      - TODO + {versionNotes ?? '-'} @@ -50,7 +54,7 @@ export const ModelVersionDetails = ({

      Owner

      - TODO + {owner ?? '-'} @@ -72,10 +76,10 @@ export const ModelVersionDetails = ({

      ID

      - + {(copy) => ( - {modelId ?? '-'} + {modelVersionId ?? '-'} )} diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 90607386..40e9ea58 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -46,7 +46,7 @@ describe(' Form', () => { }); it('should init form when id param in url route', async () => { - await setup({ route: '/1', mode: 'version' }); + await setup({ route: '/model-id-1', mode: 'version' }); await waitFor(() => { expect(screen.getByText('model1')).toBeInTheDocument(); @@ -54,7 +54,7 @@ describe(' Form', () => { }); it('submit button label should be `Register version` when register new version', async () => { - await setup({ route: '/1', mode: 'version' }); + await setup({ route: '/model-id-1', mode: 'version' }); expect(screen.getByRole('button', { name: /register version/i })).toBeInTheDocument(); }); @@ -128,13 +128,13 @@ describe(' Form', () => { it('should call submit with file with provided model id and name', async () => { jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); - const { user } = await setup({ route: '/1', mode: 'version' }); + const { user } = await setup({ route: '/model-id-1', mode: 'version' }); await user.click(screen.getByRole('button', { name: /register version/i })); expect(onSubmitMock).toHaveBeenCalledWith( expect.objectContaining({ name: 'model1', - modelId: '1', + modelId: 'model-id-1', }) ); }); diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index a2acfca6..dbfd21b5 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -231,6 +231,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo // TODO: Fill model format here const { config, url } = preTrainedModel; form.setValue('modelURL', url); + form.setValue('modelFileFormat', 'TORCH_SCRIPT'); if (config.name) { form.setValue('name', config.name); } diff --git a/public/hooks/tests/use_deployment.test.ts b/public/hooks/tests/use_deployment.test.ts new file mode 100644 index 00000000..f3619016 --- /dev/null +++ b/public/hooks/tests/use_deployment.test.ts @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { rest } from 'msw'; + +import { useDeployment } from '../use_deployment'; +import { waitFor, renderHook } from '../../../test/test_utils'; +import * as PluginContext from '../../../../../src/plugins/opensearch_dashboards_react/public'; +import { ModelVersion } from '../../../public/apis/model_version'; +import { server } from '../../../test/mocks/server'; +import { + MODEL_VERSION_UNLOAD_API_ENDPOINT, + TASK_API_ENDPOINT, +} from '../../../server/routes/constants'; + +// Cannot spyOn(PluginContext, 'useOpenSearchDashboards') directly as it results in error: +// TypeError: Cannot redefine property: useOpenSearchDashboards +// So we have to mock the entire module first as a workaround +jest.mock('../../../../../src/plugins/opensearch_dashboards_react/public', () => { + return { + __esModule: true, + ...jest.requireActual('../../../../../src/plugins/opensearch_dashboards_react/public'), + }; +}); + +describe('useDeployment hook', () => { + const addDangerMock = jest.fn(); + const addSuccessMock = jest.fn(); + const openModalMock = jest.fn(); + + beforeEach(() => { + jest.spyOn(PluginContext, 'useOpenSearchDashboards').mockReturnValue({ + services: { + notifications: { + toasts: { + addDanger: addDangerMock, + addSuccess: addSuccessMock, + }, + }, + }, + overlays: { + openModal: openModalMock, + }, + }); + jest + .spyOn(ModelVersion.prototype, 'load') + .mockResolvedValue({ task_id: 'mock_task_id', status: 'deployed' }); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should display success toast if model deployed successfully', async () => { + // '2': REGISTERED, defined in msw handler + const { result } = renderHook(() => useDeployment('2')); + await result.current.deploy(); + await waitFor(() => { + expect(addSuccessMock).toHaveBeenCalled(); + }); + }); + + it('should display success toast if model undeployed successfully', async () => { + // '3': DEPLOYED, defined in msw handler + const { result } = renderHook(() => useDeployment('3')); + await result.current.undeploy(); + await waitFor(() => { + expect(addSuccessMock).toHaveBeenCalled(); + }); + }); + + it('should display error toast if model deploy failed', async () => { + server.use( + rest.get(`${TASK_API_ENDPOINT}/:taskId`, (req, res, ctx) => { + return res( + ctx.json({ + model_id: '1', + task_type: 'DEPLOY_MODEL', + state: 'FAILED', + create_time: 1685360406270, + last_update_time: 1685360406471, + worker_node: ['node-1'], + error: 'model config error', + }) + ); + }) + ); + // '2': REGISTERED, defined in msw handler + const { result } = renderHook(() => useDeployment('2')); + await result.current.deploy(); + await waitFor(() => { + expect(addDangerMock).toHaveBeenCalled(); + }); + }); + + it('should display error toast if undeploy failed', async () => { + server.use( + rest.post(`${MODEL_VERSION_UNLOAD_API_ENDPOINT}/:modelId`, (req, res, ctx) => { + // Send invalid HTTP status code + return res(ctx.status(500)); + }) + ); + + // '3': DEPLOYED, defined in msw handler + const { result } = renderHook(() => useDeployment('3')); + await result.current.undeploy(); + + await waitFor(() => { + expect(addDangerMock).toHaveBeenCalled(); + }); + }); + + it('should display error toast if trying to deploy a model which is already deployed', async () => { + // '3': DEPLOYED, defined in msw handler + const { result } = renderHook(() => useDeployment('3')); + await result.current.deploy(); + await waitFor(() => { + expect(addDangerMock).toHaveBeenCalled(); + }); + }); + + it('should display error toast if trying to undeploy a model which is not deployed', async () => { + const { result } = renderHook(() => useDeployment('2')); + await result.current.undeploy(); + await waitFor(() => { + expect(addDangerMock).toHaveBeenCalled(); + }); + }); +}); diff --git a/public/hooks/use_deployment.tsx b/public/hooks/use_deployment.tsx new file mode 100644 index 00000000..4c0a2e25 --- /dev/null +++ b/public/hooks/use_deployment.tsx @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback } from 'react'; +import { takeWhile, switchMap } from 'rxjs/operators'; +import { timer } from 'rxjs'; +import { generatePath, useHistory } from 'react-router-dom'; + +import { EuiButton, EuiFlexGroup, EuiFlexItem, EuiLink, EuiText } from '@elastic/eui'; +import { isModelDeployable, isModelUndeployable, routerPaths } from '../../common'; +import { APIProvider } from '../apis/api_provider'; +import { useOpenSearchDashboards } from '../../../../src/plugins/opensearch_dashboards_react/public'; +import { mountReactNode } from '../../../../src/core/public/utils'; +import { ModelVersionErrorDetailsModal } from '../components/common/modals'; + +export const useDeployment = (modelVersionId: string) => { + const { + services: { notifications, overlays }, + } = useOpenSearchDashboards(); + const history = useHistory(); + + const deploy = useCallback( + // model deploy is async, so we pass onComplete and onError callback to handle the deployment status + async (options?: { onComplete?: () => void; onError?: () => void }) => { + const modelVersionData = await APIProvider.getAPI('modelVersion').getOne(modelVersionId); + const modelVersionUrl = history.createHref({ + pathname: generatePath(routerPaths.modelVersion, { id: modelVersionId }), + }); + + if (!isModelDeployable(modelVersionData.model_state)) { + notifications?.toasts.addDanger( + `Cannot deploy a model which is ${modelVersionData.model_state}` + ); + return; + } + + const taskData = await APIProvider.getAPI('modelVersion').load(modelVersionId); + + // Poll task api every 2s for the deployment status + timer(0, 2000) + // task may have state: CREATED, RUNNING, COMPLETED, FAILED, CANCELLED and COMPLETED_WITH_ERROR + .pipe(switchMap((_) => APIProvider.getAPI('task').getOne(taskData.task_id))) + // continue polling when task state is CREATED or RUNNING + .pipe(takeWhile((res) => res.state === 'CREATED' || res.state === 'RUNNING', true)) + .subscribe({ + error: () => { + options?.onError?.(); + notifications?.toasts.addDanger( + { + title: mountReactNode( + + + {modelVersionData.name} version {modelVersionData.model_version} + {' '} + deployment failed. + + ), + text: 'Network error', + }, + { toastLifeTimeMs: 60000 } + ); + }, + next: (res) => { + if (res.state === 'COMPLETED') { + options?.onComplete?.(); + notifications?.toasts.addSuccess({ + title: mountReactNode( + + + {modelVersionData.name} version {modelVersionData.model_version} + {' '} + has been deployed. + + ), + }); + } else if (res.state === 'FAILED') { + options?.onError?.(); + notifications?.toasts.addDanger( + { + title: mountReactNode( + + + {modelVersionData.name} version {modelVersionData.model_version} + {' '} + deployment failed. + + ), + text: mountReactNode( + + + { + const overlayRef = overlays?.openModal( + mountReactNode( + { + overlayRef?.close(); + }} + errorDetails={res.error ? res.error : 'Unknown error'} + /> + ) + ); + }} + > + See full error + + + + ), + }, + { toastLifeTimeMs: 60000 } + ); + } + }, + }); + }, + [modelVersionId, history, notifications?.toasts, overlays] + ); + + const undeploy = useCallback(async () => { + const modelVersionData = await APIProvider.getAPI('modelVersion').getOne(modelVersionId); + + if (!isModelUndeployable(modelVersionData.model_state)) { + notifications?.toasts.addDanger( + `Cannot undeploy a model which is ${modelVersionData.model_state}` + ); + return; + } + + const modelVersionUrl = history.createHref({ + pathname: generatePath(routerPaths.modelVersion, { id: modelVersionId }), + }); + try { + await APIProvider.getAPI('modelVersion').unload(modelVersionId); + notifications?.toasts.addSuccess({ + title: mountReactNode( + + + {modelVersionData.name} version {modelVersionData.model_version} + {' '} + has been undeployed + + ), + }); + } catch (e) { + notifications?.toasts.addDanger( + { + title: mountReactNode( + + + {modelVersionData.name} version {modelVersionData.model_version} + {' '} + undeployment failed + + ), + text: mountReactNode( + + + { + const overlayRef = overlays?.openModal( + mountReactNode( + { + overlayRef?.close(); + }} + errorDetails={e instanceof Error ? e.message : JSON.stringify(e)} + /> + ) + ); + }} + > + See full error + + + + ), + }, + { toastLifeTimeMs: 60000 } + ); + } + }, [modelVersionId, notifications?.toasts, history, overlays]); + + return { deploy, undeploy }; +}; diff --git a/test/mocks/model_handlers.ts b/test/mocks/model_handlers.ts index ca56da1f..93ec2da8 100644 --- a/test/mocks/model_handlers.ts +++ b/test/mocks/model_handlers.ts @@ -10,7 +10,7 @@ import { MODEL_API_ENDPOINT } from '../../server/routes/constants'; const models = [ { name: 'model1', - id: '1', + id: 'model-id-1', latest_version: 1, description: 'foo bar', owner: { diff --git a/test/mocks/model_version_handlers.ts b/test/mocks/model_version_handlers.ts index d35e5263..88a96e1e 100644 --- a/test/mocks/model_version_handlers.ts +++ b/test/mocks/model_version_handlers.ts @@ -5,7 +5,11 @@ import { rest } from 'msw'; -import { MODEL_VERSION_API_ENDPOINT } from '../../server/routes/constants'; +import { + MODEL_VERSION_API_ENDPOINT, + MODEL_VERSION_LOAD_API_ENDPOINT, + MODEL_VERSION_UNLOAD_API_ENDPOINT, +} from '../../server/routes/constants'; const modelVersions = [ { @@ -120,4 +124,13 @@ export const modelVersionHandlers = [ const [modelId, ..._restParts] = req.url.pathname.split('/').reverse(); return res(ctx.status(200), ctx.json(modelVersions.find((model) => model.id === modelId))); }), + + rest.post(`${MODEL_VERSION_LOAD_API_ENDPOINT}/:modelId`, (req, res, ctx) => { + return res(ctx.json({ task_id: 'task-id-1', status: 'CREATED' })); + }), + + rest.post(`${MODEL_VERSION_UNLOAD_API_ENDPOINT}/:modelId`, (req, res, ctx) => { + const { modelId } = req.params; + return res(ctx.json({ node_1: { stats: { [modelId as string]: 'undeployed' } } })); + }), ]; diff --git a/test/mocks/task_handlers.ts b/test/mocks/task_handlers.ts index 6e887f9e..ce87a6c3 100644 --- a/test/mocks/task_handlers.ts +++ b/test/mocks/task_handlers.ts @@ -41,4 +41,18 @@ export const taskHandlers = [ }) ); }), + + rest.get(`${TASK_API_ENDPOINT}/:taskId`, (req, res, ctx) => { + const { taskId } = req.params; + return res( + ctx.json({ + model_id: '1', + task_type: 'DEPLOY_MODEL', + state: 'COMPLETED', + create_time: 1685360406270, + last_update_time: 1685360406471, + worker_node: ['node-1'], + }) + ); + }), ]; diff --git a/test/test_utils.tsx b/test/test_utils.tsx index 7950ae60..b73bec4b 100644 --- a/test/test_utils.tsx +++ b/test/test_utils.tsx @@ -6,6 +6,7 @@ import React, { FC, ReactElement } from 'react'; import { I18nProvider } from '@osd/i18n/react'; import { render, RenderOptions } from '@testing-library/react'; +import { renderHook, RenderHookOptions } from '@testing-library/react-hooks'; import { createBrowserHistory } from 'history'; import { Router } from 'react-router-dom'; import { DataSourceContextProvider } from '../public/contexts'; @@ -18,7 +19,7 @@ export const history = { current: createBrowserHistory(), }; -const AllTheProviders: FC<{ children: React.ReactNode }> = ({ children }) => { +const AllTheProviders: FC = ({ children }) => { return ( @@ -47,8 +48,16 @@ const customRender = ( return render(ui, { wrapper: AllTheProviders, ...options }); }; +const customRenderHook = ( + callback: (props: TProps) => TResult, + options?: RenderHookOptions +) => { + return renderHook(callback, { wrapper: AllTheProviders, ...options }); +}; + export * from '@testing-library/react'; export { customRender as render }; +export { customRenderHook as renderHook }; export const mockOffsetMethods = () => { const originalOffsetHeight = Object.getOwnPropertyDescriptor( From 52d1e4ca3736509bb31257eb204e8d9ca28d7303 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Mon, 12 Jun 2023 10:27:26 +0800 Subject: [PATCH 57/75] fix: map model_group_id to model_id (#207) Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- server/services/model_version_service.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/server/services/model_version_service.ts b/server/services/model_version_service.ts index 1f109cc8..def751f6 100644 --- a/server/services/model_version_service.ts +++ b/server/services/model_version_service.ts @@ -133,6 +133,7 @@ export class ModelVersionService { return { id, ...modelSource, + model_id: modelSource.model_group_id, }; } From 7bf1a7aa39ed9027204f2af17ace4acde60330fb Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 12 Jun 2023 14:03:59 +0800 Subject: [PATCH 58/75] Feature/add model version delete modal (#199) * test: separate mockOpenSearchDashboards Signed-off-by: Lin Wang * fix: update id to modelId in model version upload API Signed-off-by: Lin Wang * feat: add model version delete confirm and unable do action modal Signed-off-by: Lin Wang * refactor: refactor delete polling with rxjs Signed-off-by: Lin Wang * fix: update model id and owner Signed-off-by: Lin Wang * fix: request should wait previous response Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- common/model.ts | 2 +- public/apis/model_version.ts | 2 +- ...odel_version_delete_confirm_modal.test.tsx | 209 ++++++++++++++++++ ..._version_deployment_confirm_modal.test.tsx | 1 - ...el_version_unable_do_action_modal.test.tsx | 114 ++++++++++ .../type_text_confirm_modal.test.tsx | 49 ++++ public/components/common/modals/index.ts | 3 + .../model_version_delete_confirm_modal.tsx | 155 +++++++++++++ .../model_version_unable_do_action_modal.tsx | 108 +++++++++ .../common/modals/type_text_confirm_modal.tsx | 43 ++++ public/components/model/model.tsx | 2 +- .../model_version_table_row_actions.test.tsx | 33 ++- .../__tests__/model_versions_panel.test.tsx | 26 +++ .../model_version_table.tsx | 16 +- .../model_version_table_row_actions.tsx | 73 +++++- .../model_versions_panel.tsx | 1 + public/components/model_list/model_table.tsx | 2 +- server/services/model_aggregate_service.ts | 4 +- server/services/model_version_service.ts | 1 + server/services/utils/model.ts | 2 +- test/mock_opensearch_dashboards_react.tsx | 74 +++++++ test/mocks/model_version_handlers.ts | 13 +- 22 files changed, 909 insertions(+), 24 deletions(-) create mode 100644 public/components/common/modals/__tests__/model_version_delete_confirm_modal.test.tsx create mode 100644 public/components/common/modals/__tests__/model_version_unable_do_action_modal.test.tsx create mode 100644 public/components/common/modals/__tests__/type_text_confirm_modal.test.tsx create mode 100644 public/components/common/modals/model_version_delete_confirm_modal.tsx create mode 100644 public/components/common/modals/model_version_unable_do_action_modal.tsx create mode 100644 public/components/common/modals/type_text_confirm_modal.tsx create mode 100644 test/mock_opensearch_dashboards_react.tsx diff --git a/common/model.ts b/common/model.ts index be717653..583d16a4 100644 --- a/common/model.ts +++ b/common/model.ts @@ -5,7 +5,7 @@ export interface OpenSearchModel { id: string; - owner: { + owner?: { backend_roles: string[]; roles: string[]; name: string; diff --git a/public/apis/model_version.ts b/public/apis/model_version.ts index 441a3b0e..1828808c 100644 --- a/public/apis/model_version.ts +++ b/public/apis/model_version.ts @@ -81,7 +81,7 @@ interface UploadModelBase { description?: string; modelFormat: string; modelConfig: Record; - id: string; + modelId: string; } export interface UploadModelByURL extends UploadModelBase { diff --git a/public/components/common/modals/__tests__/model_version_delete_confirm_modal.test.tsx b/public/components/common/modals/__tests__/model_version_delete_confirm_modal.test.tsx new file mode 100644 index 00000000..80d98f5a --- /dev/null +++ b/public/components/common/modals/__tests__/model_version_delete_confirm_modal.test.tsx @@ -0,0 +1,209 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen, waitFor, act } from '../../../../../test/test_utils'; +import { mockUseOpenSearchDashboards } from '../../../../../test/mock_opensearch_dashboards_react'; +import { ModelVersion } from '../../../../apis/model_version'; +import { ModelVersionDeleteConfirmModal } from '../model_version_delete_confirm_modal'; +const setup = () => { + const closeModalMock = jest.fn(); + const renderResult = render( + + ); + return { + renderResult, + closeModalMock, + }; +}; + +describe('', () => { + it('should render title, confirm tip, cancel and delete button by default', () => { + setup(); + + expect(screen.getByTestId('confirmModalTitleText')).toHaveTextContent( + 'Delete model1 version 1?' + ); + expect(screen.getByLabelText('Type model1 version 1 to confirm.')).toBeInTheDocument(); + expect(screen.getByText('Cancel').closest('button')).toBeInTheDocument(); + expect(screen.getByText('Delete version').closest('button')).toBeInTheDocument(); + }); + + it('should call closeModal with false when cancel button is clicked', async () => { + const { closeModalMock } = setup(); + + expect(closeModalMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Cancel')); + expect(closeModalMock).toHaveBeenCalledWith(false); + }); + + it('should call model version delete API after confirm text typed and delete button clicked', async () => { + const modelDeleteMock = jest.spyOn(ModelVersion.prototype, 'delete').mockResolvedValue({}); + setup(); + + await userEvent.type(screen.getByLabelText('confirm text input'), 'model1 version 1'); + + expect(modelDeleteMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Delete version')); + expect(modelDeleteMock).toHaveBeenCalledWith('1'); + + modelDeleteMock.mockRestore(); + }); + + it('should show delete success toast and call closeModal with true after delete success and can not be searched', async () => { + const modelDeleteMock = jest.spyOn(ModelVersion.prototype, 'delete').mockResolvedValue({}); + const modelSearchMock = jest + .spyOn(ModelVersion.prototype, 'search') + .mockResolvedValue({ data: [], total_model_versions: 0 }); + const openSearchDashboardsMock = mockUseOpenSearchDashboards(); + const { closeModalMock } = setup(); + + await userEvent.type(screen.getByLabelText('confirm text input'), 'model1 version 1'); + + expect(closeModalMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Delete version')); + await waitFor(() => { + expect(screen.getByTestId('euiToastHeader')).toHaveTextContent( + 'model1 version 1 has been deleted' + ); + }); + expect(closeModalMock).toHaveBeenCalledWith(true); + + modelDeleteMock.mockRestore(); + modelSearchMock.mockRestore(); + openSearchDashboardsMock.mockRestore(); + }); + + it('should show unable to delete toast and call closeModal with false after delete failed', async () => { + const modelDeleteMock = jest + .spyOn(ModelVersion.prototype, 'delete') + .mockRejectedValue(new Error()); + const openSearchDashboardsMock = mockUseOpenSearchDashboards(); + const { closeModalMock } = setup(); + + await userEvent.type(screen.getByLabelText('confirm text input'), 'model1 version 1'); + + expect(closeModalMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Delete version')); + expect(screen.getByTestId('euiToastHeader')).toHaveTextContent( + 'Unable to delete model1 version 1' + ); + expect(closeModalMock).toHaveBeenCalledWith(false); + + modelDeleteMock.mockRestore(); + openSearchDashboardsMock.mockRestore(); + }); + + it('should call closeModal with false after delete success and still can be searched', async () => { + jest.useFakeTimers(); + + const modelDeleteMock = jest.spyOn(ModelVersion.prototype, 'delete').mockResolvedValue({}); + const modelSearchMock = jest + .spyOn(ModelVersion.prototype, 'search') + .mockResolvedValue({ data: [], total_model_versions: 1 }); + + const { closeModalMock } = setup(); + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + + await user.type(screen.getByLabelText('confirm text input'), 'model1 version 1'); + await user.click(screen.getByText('Delete version')); + + expect(closeModalMock).not.toHaveBeenCalled(); + + for (let i = 0; i < 200; i++) { + await act(async () => { + jest.advanceTimersByTime(300); + }); + } + + expect(closeModalMock).toHaveBeenCalledWith(false); + + modelDeleteMock.mockRestore(); + modelSearchMock.mockRestore(); + + jest.useRealTimers(); + }); + + it('should not call search API anymore after modal unmount', async () => { + jest.useFakeTimers(); + + const modelDeleteMock = jest.spyOn(ModelVersion.prototype, 'delete').mockResolvedValue({}); + const modelSearchMock = jest + .spyOn(ModelVersion.prototype, 'search') + .mockResolvedValue({ data: [], total_model_versions: 1 }); + + const { renderResult } = setup(); + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + + await user.type(screen.getByLabelText('confirm text input'), 'model1 version 1'); + await user.click(screen.getByText('Delete version')); + + for (let i = 0; i < 10; i++) { + await act(async () => { + jest.advanceTimersByTime(300); + }); + } + + modelSearchMock.mockClear(); + renderResult.unmount(); + for (let i = 0; i < 10; i++) { + await act(async () => { + jest.advanceTimersByTime(300); + }); + } + expect(modelSearchMock).not.toHaveBeenCalled(); + + modelDeleteMock.mockRestore(); + modelSearchMock.mockRestore(); + + jest.useRealTimers(); + }); + + it('should not call model search if previous call not response', async () => { + jest.useFakeTimers(); + + const modelDeleteMock = jest.spyOn(ModelVersion.prototype, 'delete').mockResolvedValue({}); + const modelSearchMock = jest + .spyOn(ModelVersion.prototype, 'search') + .mockImplementation(async () => { + await new Promise((resolve) => { + setTimeout(resolve, 2000); + }); + return { data: [], total_model_versions: 1 }; + }); + + setup(); + const user = userEvent.setup({ advanceTimers: jest.advanceTimersByTime }); + + await user.type(screen.getByLabelText('confirm text input'), 'model1 version 1'); + await user.click(screen.getByText('Delete version')); + + expect(modelSearchMock).toHaveBeenCalledTimes(1); + + await act(async () => { + jest.advanceTimersByTime(2000); + }); + expect(modelSearchMock).toHaveBeenCalledTimes(1); + + // Should not call model search immediately after API response + await act(async () => { + jest.advanceTimersByTime(100); + }); + expect(modelSearchMock).toHaveBeenCalledTimes(1); + + // Should call model search again after 100+200ms delay + await act(async () => { + jest.advanceTimersByTime(200); + }); + expect(modelSearchMock).toHaveBeenCalledTimes(2); + + modelDeleteMock.mockRestore(); + modelSearchMock.mockRestore(); + + jest.useRealTimers(); + }); +}); diff --git a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx index c9962575..47f42416 100644 --- a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx +++ b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx @@ -8,7 +8,6 @@ import userEvent from '@testing-library/user-event'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionDeploymentConfirmModal } from '../model_version_deployment_confirm_modal'; -import { ModelVersion } from '../../../../apis/model_version'; import * as Hooks from '../../../../hooks/use_deployment'; describe('', () => { diff --git a/public/components/common/modals/__tests__/model_version_unable_do_action_modal.test.tsx b/public/components/common/modals/__tests__/model_version_unable_do_action_modal.test.tsx new file mode 100644 index 00000000..e4ed3658 --- /dev/null +++ b/public/components/common/modals/__tests__/model_version_unable_do_action_modal.test.tsx @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../test/test_utils'; +import { MODEL_VERSION_STATE } from '../../../../../common'; +import { ModelVersionUnableDoActionModal } from '../model_version_unable_do_action_modal'; + +const setup = ( + options: { actionType: 'delete' | 'edit'; state: MODEL_VERSION_STATE } = { + actionType: 'delete', + state: MODEL_VERSION_STATE.registering, + } +) => { + const closeModalMock = jest.fn(); + render( + + ); + return { + closeModalMock, + }; +}; + +describe('', () => { + it('should call closeModal after close button clicked', async () => { + const { closeModalMock } = setup(); + + expect(closeModalMock).not.toHaveBeenCalled(); + await userEvent.click(screen.getByText('Close')); + expect(closeModalMock).toHaveBeenCalled(); + }); + describe('actionType=delete', () => { + it('should display unable delete title', async () => { + setup(); + + expect(screen.getByText('Unable to delete')).toBeInTheDocument(); + }); + + it('should display unable delete message for model uploading state', async () => { + setup({ actionType: 'delete', state: MODEL_VERSION_STATE.registering }); + + expect( + screen.getByText( + 'This version is uploading. Wait for this version to complete uploading and then try again.' + ) + ).toBeInTheDocument(); + }); + + it('should display unable delete message for model deploying state', async () => { + setup({ actionType: 'delete', state: MODEL_VERSION_STATE.deploying }); + + expect( + screen.getByText( + /To delete this version, wait for it to complete deploying and then undeploy it on the.+page./ + ) + ).toBeInTheDocument(); + }); + + it('should display unable delete message for model deployed state', async () => { + setup({ actionType: 'delete', state: MODEL_VERSION_STATE.deployed }); + + expect( + screen.getByText( + /This version is currently deployed. To delete this version, undeploy it on the.+page./ + ) + ).toBeInTheDocument(); + }); + }); + describe('actionType=edit', () => { + it('should display unable edit title', async () => { + setup({ actionType: 'edit', state: MODEL_VERSION_STATE.registering }); + + expect(screen.getByText('Unable to edit')).toBeInTheDocument(); + }); + + it('should display unable edit message for model uploading state', async () => { + setup({ actionType: 'edit', state: MODEL_VERSION_STATE.registering }); + + expect( + screen.getByText('Wait for this version to complete uploading and then try again.') + ).toBeInTheDocument(); + }); + + it('should display unable edit message for model deploying state', async () => { + setup({ actionType: 'edit', state: MODEL_VERSION_STATE.deploying }); + + expect( + screen.getByText( + /To edit this version, wait for it to complete deploying and then undeploy it on the.+page./ + ) + ).toBeInTheDocument(); + }); + + it('should display unable edit message for model deployed state', async () => { + setup({ actionType: 'edit', state: MODEL_VERSION_STATE.deployed }); + + expect( + screen.getByText( + /This version is currently deployed. To edit this version, undeploy it on the.+page./ + ) + ).toBeInTheDocument(); + }); + }); +}); diff --git a/public/components/common/modals/__tests__/type_text_confirm_modal.test.tsx b/public/components/common/modals/__tests__/type_text_confirm_modal.test.tsx new file mode 100644 index 00000000..cbfe1999 --- /dev/null +++ b/public/components/common/modals/__tests__/type_text_confirm_modal.test.tsx @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import userEvent from '@testing-library/user-event'; + +import { render, screen } from '../../../../../test/test_utils'; +import { TypeTextConfirmModal } from '../type_text_confirm_modal'; + +const setup = () => { + render( + + foo + + ); +}; + +describe('', () => { + it('should render title, children, type to confirm tip, text input and disable delete button by default', () => { + setup(); + + expect(screen.getByText('Delete model version 3?')).toBeInTheDocument(); + expect(screen.getByText('foo')).toBeInTheDocument(); + expect(screen.getByLabelText('Type model version 3 to confirm.')).toBeInTheDocument(); + expect(screen.getByLabelText('confirm text input')).toBeInTheDocument(); + expect(screen.getByText('Delete').closest('button')).toBeDisabled(); + }); + + it('should disable delete button if typed text not match', async () => { + setup(); + + await userEvent.type(screen.getByLabelText('confirm text input'), 'foobar'); + expect(screen.getByText('Delete').closest('button')).toBeDisabled(); + }); + + it('should enable delete button if typed text matched', async () => { + setup(); + + await userEvent.type(screen.getByLabelText('confirm text input'), 'model version 3'); + expect(screen.getByText('Delete').closest('button')).toBeEnabled(); + }); +}); diff --git a/public/components/common/modals/index.ts b/public/components/common/modals/index.ts index fce83901..085bdfe7 100644 --- a/public/components/common/modals/index.ts +++ b/public/components/common/modals/index.ts @@ -5,3 +5,6 @@ export { ModelVersionErrorDetailsModal } from './model_version_error_details_modal'; export { ModelVersionDeploymentConfirmModal } from './model_version_deployment_confirm_modal'; +export { TypeTextConfirmModal } from './type_text_confirm_modal'; +export { ModelVersionDeleteConfirmModal } from './model_version_delete_confirm_modal'; +export { ModelVersionUnableDoActionModal } from './model_version_unable_do_action_modal'; diff --git a/public/components/common/modals/model_version_delete_confirm_modal.tsx b/public/components/common/modals/model_version_delete_confirm_modal.tsx new file mode 100644 index 00000000..bfe13409 --- /dev/null +++ b/public/components/common/modals/model_version_delete_confirm_modal.tsx @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useState, useRef, useEffect } from 'react'; +import { EuiLink, EuiSpacer, EuiText } from '@elastic/eui'; +import { Link, generatePath, useHistory } from 'react-router-dom'; +import { of } from 'rxjs'; +import { delay, scan, switchMap, takeWhile, retryWhen, map } from 'rxjs/operators'; + +import { useOpenSearchDashboards } from '../../../../../../src/plugins/opensearch_dashboards_react/public'; +import { mountReactNode } from '../../../../../../src/core/public/utils'; +import { routerPaths } from '../../../../common'; +import { APIProvider } from '../../../apis/api_provider'; + +import { TypeTextConfirmModal } from './type_text_confirm_modal'; + +const typeTextConfirmModalStyle = { width: 500 }; + +interface ModelVersionDeleteConfirmModalProps { + id: string; + name: string; + version: string; + closeModal: (versionDeleted: boolean) => void; +} + +export const ModelVersionDeleteConfirmModal = ({ + id, + name, + version, + closeModal, +}: ModelVersionDeleteConfirmModalProps) => { + const { + services: { notifications }, + } = useOpenSearchDashboards(); + const history = useHistory(); + const [isDeleting, setIsDeleting] = useState(false); + const mountedRef = useRef(false); + + const handleConfirm = useCallback(async () => { + setIsDeleting(true); + const modelVersionAddress = history.createHref({ + pathname: generatePath(routerPaths.modelVersion, { id }), + }); + try { + await APIProvider.getAPI('modelVersion').delete(id); + } catch { + notifications?.toasts.addDanger({ + title: mountReactNode( + <> + Unable to delete{' '} + + {name} version {version} + + + ), + }); + closeModal(false); + return; + } + /** + * + * Delete a model version is a sync operation, but the deleted model + * still can be searched by model search API after model version deleted. + * Add this polling here to make sure version can't searchable. + * + **/ + of(null) + .pipe(takeWhile(() => mountedRef.current)) + .pipe( + switchMap(async () => { + const result = await APIProvider.getAPI('modelVersion').search({ + ids: [id], + from: 0, + size: 1, + }); + if (result.total_model_versions > 0) { + throw new Error('Model version searchable.'); + } + return result; + }) + ) + .pipe( + retryWhen((errors) => + errors.pipe( + delay(300), + scan((acc) => acc + 1, 0), + map((times) => { + if (times >= 200) { + throw new Error('Exceed max retries.'); + } + }) + ) + ) + ) + .subscribe({ + next: () => { + notifications?.toasts.addSuccess({ + title: mountReactNode( + <> + + {name} version {version} + {' '} + has been deleted + + ), + }); + closeModal(true); + }, + error: () => { + closeModal(false); + }, + }); + }, [id, name, version, history, notifications, closeModal]); + + const handleModalCancel = useCallback(() => { + closeModal(false); + }, [closeModal]); + + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); + + return ( + + Delete{' '} + + {name} version {version} + + ? +

      + } + textToType={`${name} version ${version}`} + confirmButtonText="Delete version" + buttonColor="danger" + cancelButtonText="Cancel" + onConfirm={handleConfirm} + onCancel={handleModalCancel} + confirmButtonDisabled={isDeleting} + isLoading={isDeleting} + > + + This action is irreversible. + + + + ); +}; diff --git a/public/components/common/modals/model_version_unable_do_action_modal.tsx b/public/components/common/modals/model_version_unable_do_action_modal.tsx new file mode 100644 index 00000000..9bbaa9c3 --- /dev/null +++ b/public/components/common/modals/model_version_unable_do_action_modal.tsx @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { Link, generatePath } from 'react-router-dom'; +import { + EuiButtonEmpty, + EuiLink, + EuiModal, + EuiModalBody, + EuiModalFooter, + EuiModalHeader, + EuiText, + EuiTitle, +} from '@elastic/eui'; +import { MODEL_VERSION_STATE, routerPaths } from '../../../../common'; + +const mapping: { + ['delete']: { + [key in MODEL_VERSION_STATE]?: (modelVersionNode: React.ReactNode) => React.ReactNode; + }; + ['edit']: { + [key in MODEL_VERSION_STATE]?: (modelVersionNode: React.ReactNode) => React.ReactNode; + }; +} = { + delete: { + [MODEL_VERSION_STATE.deployed]: (modelVersionNode: React.ReactNode) => ( + <> + This version is currently deployed. To delete this version, undeploy it on the + {modelVersionNode} page. + + ), + [MODEL_VERSION_STATE.registering]: (_modelVersionNode: React.ReactNode) => ( + <> + This version is uploading. Wait for this version to complete uploading and then try again. + + ), + [MODEL_VERSION_STATE.deploying]: (modelVersionNode: React.ReactNode) => ( + <> + To delete this version, wait for it to complete deploying and then undeploy it on the + {modelVersionNode} page. + + ), + }, + edit: { + [MODEL_VERSION_STATE.deployed]: (modelVersionNode: React.ReactNode) => ( + <> + This version is currently deployed. To edit this version, undeploy it on the{' '} + {modelVersionNode} page. + + ), + [MODEL_VERSION_STATE.registering]: (_modelVersionNode: React.ReactNode) => ( + <>Wait for this version to complete uploading and then try again. + ), + [MODEL_VERSION_STATE.deploying]: (modelVersionNode: React.ReactNode) => ( + <> + To edit this version, wait for it to complete deploying and then undeploy it on the + {modelVersionNode} page. + + ), + }, +}; + +interface ModelVersionUnableDoActionModalProps { + id: string; + name: string; + version: string; + state: MODEL_VERSION_STATE; + actionType: 'edit' | 'delete'; + closeModal: () => void; +} + +export const ModelVersionUnableDoActionModal = ({ + id, + name, + state, + version, + closeModal, + actionType, +}: ModelVersionUnableDoActionModalProps) => { + const modeVersionLinkNode = ( + + + {name} version {version} + + + ); + + return ( + + + +

      + Unable to {actionType} {modeVersionLinkNode} +

      +
      +
      + + {mapping[actionType][state]?.(modeVersionLinkNode)} + + + Close + +
      + ); +}; diff --git a/public/components/common/modals/type_text_confirm_modal.tsx b/public/components/common/modals/type_text_confirm_modal.tsx new file mode 100644 index 00000000..382ebc19 --- /dev/null +++ b/public/components/common/modals/type_text_confirm_modal.tsx @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState, useCallback } from 'react'; +import { + EuiConfirmModal, + EuiConfirmModalProps, + EuiFieldText, + EuiSpacer, + EuiText, +} from '@elastic/eui'; +import { EuiFieldTextProps } from '@opensearch-project/oui'; + +export interface TypeTextConfirmModalProps extends EuiConfirmModalProps { + textToType: string; +} + +export const TypeTextConfirmModal = ({ + textToType, + children, + ...restProps +}: TypeTextConfirmModalProps) => { + const [typedText, setTypedText] = useState(); + const handleTextChange = useCallback['onChange']>((e) => { + setTypedText(e.target.value); + }, []); + + return ( + + {children} + + Type {textToType} to confirm. + + + + + ); +}; diff --git a/public/components/model/model.tsx b/public/components/model/model.tsx index ca5acff3..02c66599 100644 --- a/public/components/model/model.tsx +++ b/public/components/model/model.tsx @@ -97,7 +97,7 @@ export const Model = () => { /> { - return render(); + return render( + + ); }; describe('', () => { @@ -46,6 +54,7 @@ describe('', () => { id="1" name="model-1" version="1" + onDeleted={jest.fn()} /> ); expect(screen.getByText('Deploy')).toBeInTheDocument(); @@ -56,6 +65,7 @@ describe('', () => { id="1" name="model-1" version="1" + onDeleted={jest.fn()} /> ); expect(screen.getByText('Deploy')).toBeInTheDocument(); @@ -74,6 +84,7 @@ describe('', () => { id="1" name="model-1" version="1" + onDeleted={jest.fn()} /> ); expect(screen.getByText('Undeploy')).toBeInTheDocument(); @@ -138,4 +149,24 @@ describe('', () => { screen.queryByText('This version will be undeployed. You can deploy it again later.') ).not.toBeInTheDocument(); }); + + it('should show delete confirm modal after "Delete" button clicked', async () => { + const user = userEvent.setup(); + setup(MODEL_VERSION_STATE.registered); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Delete')); + + expect(screen.getByTestId('confirmModalTitleText')).toHaveTextContent( + 'Delete model-1 version 1?' + ); + }); + + it('should show unable to delete modal after "Delete" button clicked if state was registering', async () => { + const user = userEvent.setup(); + setup(MODEL_VERSION_STATE.registering); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Delete')); + + expect(screen.getByText('Unable to delete')).toBeInTheDocument(); + }); }); diff --git a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx index c77a9e2d..009f3a7d 100644 --- a/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_versions_panel.test.tsx @@ -292,4 +292,30 @@ describe('', () => { }, 20 * 1000 ); + + it( + 'should reload model version data after version delete successfully', + async () => { + render(); + await waitFor(() => { + expect(screen.getAllByLabelText('show actions').length).toBeGreaterThanOrEqual(1); + }); + await userEvent.click(screen.getAllByLabelText('show actions')[0]); + await userEvent.click(within(screen.getByRole('dialog')).getByText('Delete')); + await userEvent.type(screen.getByLabelText('confirm text input'), 'model2 version 1.0.1'); + expect(screen.getByText('Delete version')).toBeEnabled(); + + const searchMock = jest + .spyOn(ModelVersion.prototype, 'search') + .mockResolvedValue({ data: [], total_model_versions: 0 }); + await userEvent.click(screen.getByText('Delete version')); + + expect(searchMock).toHaveBeenLastCalledWith( + expect.objectContaining({ + modelIds: ['2'], + }) + ); + }, + 20 * 1000 + ); }); diff --git a/public/components/model/model_versions_panel/model_version_table.tsx b/public/components/model/model_versions_panel/model_version_table.tsx index 3d3347e3..7e2af93c 100644 --- a/public/components/model/model_versions_panel/model_version_table.tsx +++ b/public/components/model/model_versions_panel/model_version_table.tsx @@ -38,10 +38,11 @@ const ExpandCopyIDButton = ({ textToCopy }: { textToCopy: string }) => { ); }; -interface VersionTableProps extends Pick { +interface ModelVersionTableProps extends Pick { tags: string[]; versions: VersionTableDataItem[]; totalVersionCount?: number; + onVersionDeleted: (id: string) => void; } export const ModelVersionTable = ({ @@ -50,7 +51,8 @@ export const ModelVersionTable = ({ versions, pagination, totalVersionCount, -}: VersionTableProps) => { + onVersionDeleted, +}: ModelVersionTableProps) => { const columns = useMemo( () => [ { @@ -114,12 +116,18 @@ export const ModelVersionTable = ({ rowCellRender: ({ rowIndex }: EuiDataGridCellValueElementProps) => { const { id, name, version, state } = versions[rowIndex]; return ( - + ); }, }, ], - [versions] + [versions, onVersionDeleted] ); const [visibleColumns, setVisibleColumns] = useState(() => { const tagHiddenByDefaultColumns = tags.slice(3); diff --git a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx index 0ad57260..231b8d7c 100644 --- a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx +++ b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx @@ -7,22 +7,32 @@ import React, { useState, useCallback } from 'react'; import { EuiPopover, EuiButtonIcon, EuiContextMenuPanel, EuiContextMenuItem } from '@elastic/eui'; import { MODEL_VERSION_STATE } from '../../../../common'; -import { ModelVersionDeploymentConfirmModal } from '../../common'; +import { + ModelVersionDeleteConfirmModal, + ModelVersionDeploymentConfirmModal, + ModelVersionUnableDoActionModal, +} from '../../common'; + +interface ModelVersionTableRowActionsProps { + id: string; + name: string; + version: string; + state: MODEL_VERSION_STATE; + onDeleted: (id: string) => void; +} export const ModelVersionTableRowActions = ({ - state, id, name, + state, version, -}: { - state: MODEL_VERSION_STATE; - id: string; - name: string; - version: string; -}) => { + onDeleted, +}: ModelVersionTableRowActionsProps) => { const [isPopoverOpen, setIsPopoverOpen] = useState(false); const [isDeployConfirmModalShow, setIsDeployConfirmModalShow] = useState(false); const [isUndeployConfirmModalShow, setIsUndeployConfirmModalShow] = useState(false); + const [isDeleteConfirmModalShow, setIsDeleteConfirmModalShow] = useState(false); + const [isUnableDeleteModalShow, setIsUnableDeleteModalShow] = useState(false); const handleShowActionsClick = useCallback(() => { setIsPopoverOpen((flag) => !flag); @@ -48,6 +58,34 @@ export const ModelVersionTableRowActions = ({ setIsUndeployConfirmModalShow(false); }, []); + const handleDeleteClick = useCallback(() => { + if ( + [ + MODEL_VERSION_STATE.deployed, + MODEL_VERSION_STATE.deploying, + MODEL_VERSION_STATE.registering, + ].includes(state) + ) { + setIsUnableDeleteModalShow(true); + return; + } + setIsDeleteConfirmModalShow(true); + }, [state]); + + const handleDeleteConfirmModalClose = useCallback( + (versionDeleted: boolean) => { + setIsDeleteConfirmModalShow(false); + if (versionDeleted) { + onDeleted(id); + } + }, + [id, onDeleted] + ); + + const closeUnableDeleteModal = useCallback(() => { + setIsUnableDeleteModalShow(false); + }, []); + return ( <> Delete , @@ -137,6 +176,24 @@ export const ModelVersionTableRowActions = ({ closeModal={closeUndeployConfirmModal} /> )} + {isDeleteConfirmModalShow && ( + + )} + {isUnableDeleteModalShow && ( + + )} ); }; diff --git a/public/components/model/model_versions_panel/model_versions_panel.tsx b/public/components/model/model_versions_panel/model_versions_panel.tsx index 90e7bf25..20056f3e 100644 --- a/public/components/model/model_versions_panel/model_versions_panel.tsx +++ b/public/components/model/model_versions_panel/model_versions_panel.tsx @@ -253,6 +253,7 @@ export const ModelVersionsPanel = ({ modelId }: ModelVersionsPanelProps) => { pagination={pagination} totalVersionCount={totalVersionCount} sorting={versionsSorting} + onVersionDeleted={reload} /> )} {panelStatus === 'loading' && ( diff --git a/public/components/model_list/model_table.tsx b/public/components/model_list/model_table.tsx index 478d49eb..7a19a391 100644 --- a/public/components/model_list/model_table.tsx +++ b/public/components/model_list/model_table.tsx @@ -85,7 +85,7 @@ export function ModelTable(props: ModelTableProps) { field: 'owner_name', name: 'Owner', width: '79px', - render: (name: string) => , + render: (name: string) => (name ? : '-'), align: 'center', sortable: true, }, diff --git a/server/services/model_aggregate_service.ts b/server/services/model_aggregate_service.ts index aa78442d..df6ba923 100644 --- a/server/services/model_aggregate_service.ts +++ b/server/services/model_aggregate_service.ts @@ -69,7 +69,7 @@ export class ModelAggregateService { aggs: { models: { terms: { - field: 'model_group_id.keyword', + field: 'model_group_id', size: MAX_MODEL_BUCKET_NUM, }, }, @@ -117,7 +117,7 @@ export class ModelAggregateService { return { data: models.map((model) => ({ ...model, - owner_name: model.owner.name, + owner_name: model.owner?.name, deployed_versions: (modelId2Version[model.id] || []).map( (deployedVersion) => deployedVersion.model_version ), diff --git a/server/services/model_version_service.ts b/server/services/model_version_service.ts index def751f6..04934128 100644 --- a/server/services/model_version_service.ts +++ b/server/services/model_version_service.ts @@ -132,6 +132,7 @@ export class ModelVersionService { ).body; return { id, + model_id: modelSource.model_group_id, ...modelSource, model_id: modelSource.model_group_id, }; diff --git a/server/services/utils/model.ts b/server/services/utils/model.ts index a1f44a6b..4392e849 100644 --- a/server/services/utils/model.ts +++ b/server/services/utils/model.ts @@ -88,7 +88,7 @@ export const generateModelSearchQuery = ({ }, ] : []), - ...(modelIds ? [generateTermQuery('model_group_id.keyword', modelIds)] : []), + ...(modelIds ? [generateTermQuery('model_group_id', modelIds)] : []), ], must_not: { exists: { diff --git a/test/mock_opensearch_dashboards_react.tsx b/test/mock_opensearch_dashboards_react.tsx new file mode 100644 index 00000000..81fb7da7 --- /dev/null +++ b/test/mock_opensearch_dashboards_react.tsx @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiToast } from '@elastic/eui'; + +import * as PluginContext from '../../../src/plugins/opensearch_dashboards_react/public'; +import { MountWrapper } from '../../../src/core/public/utils'; +import { MountPoint } from '../../../src/core/public'; +import { OverlayModalOpenOptions } from '../../../src/core/public/overlays'; + +import { render } from './test_utils'; + +// Cannot spyOn(PluginContext, 'useOpenSearchDashboards') directly as it results in error: +// TypeError: Cannot redefine property: useOpenSearchDashboards +// So we have to mock the entire module first as a workaround +jest.mock('../../../src/plugins/opensearch_dashboards_react/public', () => { + return { + __esModule: true, + ...jest.requireActual('../../../src/plugins/opensearch_dashboards_react/public'), + }; +}); + +const generateToastMock = () => + jest.fn((toastInput) => { + render( + + ) + } + > + {typeof toastInput !== 'string' && + (typeof toastInput.text !== 'string' && toastInput.text ? ( + + ) : ( + toastInput.text + ))} + + ); + return { + id: '', + }; + }); + +export const mockUseOpenSearchDashboards = () => + jest.spyOn(PluginContext, 'useOpenSearchDashboards').mockReturnValue({ + services: { + notifications: { + toasts: { + addDanger: generateToastMock(), + addSuccess: generateToastMock(), + }, + }, + overlays: { + openModal: jest.fn((modelMountPoint: MountPoint, options?: OverlayModalOpenOptions) => { + const { unmount } = render(); + return { + onClose: Promise.resolve(), + close: async () => { + unmount(); + }, + }; + }), + }, + }, + }); diff --git a/test/mocks/model_version_handlers.ts b/test/mocks/model_version_handlers.ts index 88a96e1e..0b6b522c 100644 --- a/test/mocks/model_version_handlers.ts +++ b/test/mocks/model_version_handlers.ts @@ -120,9 +120,16 @@ export const modelVersionHandlers = [ ); }), - rest.get(`${MODEL_VERSION_API_ENDPOINT}/:modelId`, (req, res, ctx) => { - const [modelId, ..._restParts] = req.url.pathname.split('/').reverse(); - return res(ctx.status(200), ctx.json(modelVersions.find((model) => model.id === modelId))); + rest.get(`${MODEL_VERSION_API_ENDPOINT}/:id`, (req, res, ctx) => { + const [id, ..._restParts] = req.url.pathname.split('/').reverse(); + return res( + ctx.status(200), + ctx.json(modelVersions.find((modelVersion) => modelVersion.id === id)) + ); + }), + + rest.delete(`${MODEL_VERSION_API_ENDPOINT}/:id`, (req, res, ctx) => { + return res(ctx.status(200), ctx.json({})); }), rest.post(`${MODEL_VERSION_LOAD_API_ENDPOINT}/:modelId`, (req, res, ctx) => { From 6bd7184859ab307594dff49aec167a49c3213b3b Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 19 Jun 2023 21:32:47 +0800 Subject: [PATCH 59/75] Feature/import model by name (#218) feat: import pre-trained model by name + update description field to optional fix type error --------- Signed-off-by: Lin Wang --- public/apis/model_version.ts | 13 ++---- .../common/forms/model_description_field.tsx | 2 +- .../__tests__/register_model_api.test.ts | 19 ++++++++ .../__tests__/register_model_form.test.tsx | 8 ++-- .../register_model/model_details.tsx | 5 ++- .../register_model/register_model.tsx | 19 ++++---- .../register_model/register_model.types.ts | 7 ++- .../register_model/register_model_api.ts | 12 ++--- server/routes/model_version_router.ts | 19 ++++---- server/services/model_version_service.ts | 45 ++++++++----------- 10 files changed, 82 insertions(+), 67 deletions(-) diff --git a/public/apis/model_version.ts b/public/apis/model_version.ts index 1828808c..275db379 100644 --- a/public/apis/model_version.ts +++ b/public/apis/model_version.ts @@ -80,17 +80,18 @@ interface UploadModelBase { version?: string; description?: string; modelFormat: string; - modelConfig: Record; modelId: string; } export interface UploadModelByURL extends UploadModelBase { url: string; + modelConfig: Record; } export interface UploadModelByChunk extends UploadModelBase { modelContentHashValue: string; totalChunks: number; + modelConfig: Record; } export class ModelVersion { @@ -147,15 +148,9 @@ export class ModelVersion { ); } - public upload( + public upload( model: T - ): Promise< - T extends UploadModelByURL - ? { task_id: string } - : T extends UploadModelByChunk - ? { model_version_id: string } - : never - > { + ): Promise { return InnerHttpProvider.getHttp().post(MODEL_VERSION_UPLOAD_API_ENDPOINT, { body: JSON.stringify(model), }); diff --git a/public/components/common/forms/model_description_field.tsx b/public/components/common/forms/model_description_field.tsx index 13c44f44..29469a2b 100644 --- a/public/components/common/forms/model_description_field.tsx +++ b/public/components/common/forms/model_description_field.tsx @@ -9,7 +9,7 @@ import { EuiFormRow, EuiTextArea } from '@elastic/eui'; import { Control, FieldPathByValue, useController } from 'react-hook-form'; interface ModeDescriptionFormData { - description: string; + description?: string; } const DESCRIPTION_MAX_LENGTH = 200; diff --git a/public/components/register_model/__tests__/register_model_api.test.ts b/public/components/register_model/__tests__/register_model_api.test.ts index c906c232..c3d1748f 100644 --- a/public/components/register_model/__tests__/register_model_api.test.ts +++ b/public/components/register_model/__tests__/register_model_api.test.ts @@ -174,5 +174,24 @@ describe('register model api', () => { taskId: 'foo', }); }); + + it('should call register model group API without URL and configuration', async () => { + expect(ModelVersion.prototype.upload).not.toHaveBeenCalled(); + + await submitModelWithURL({ + name: 'foo', + description: 'bar', + configuration: '{}', + modelFileFormat: '', + }); + + expect(ModelVersion.prototype.upload).toHaveBeenCalled(); + expect(ModelVersion.prototype.upload).not.toHaveBeenCalledWith( + expect.objectContaining({ + url: expect.any(String), + modelConfig: expect.anything(), + }) + ); + }); }); }); diff --git a/public/components/register_model/__tests__/register_model_form.test.tsx b/public/components/register_model/__tests__/register_model_form.test.tsx index 40e9ea58..22669eb0 100644 --- a/public/components/register_model/__tests__/register_model_form.test.tsx +++ b/public/components/register_model/__tests__/register_model_form.test.tsx @@ -75,19 +75,17 @@ describe(' Form', () => { }); await waitFor(() => expect(screen.getByLabelText(/^name$/i).value).toEqual( - 'sentence-transformers/all-distilroberta-v1' + 'huggingface/sentence-transformers/all-distilroberta-v1' ) ); expect(onSubmitMock).not.toHaveBeenCalled(); await user.click(screen.getByRole('button', { name: /register model/i })); expect(onSubmitMock).toHaveBeenCalledWith( expect.objectContaining({ - name: 'sentence-transformers/all-distilroberta-v1', + name: 'huggingface/sentence-transformers/all-distilroberta-v1', + version: '1.0.1', description: 'This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.', - modelURL: - 'https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-distilroberta-v1/1.0.1/torch_script/sentence-transformers_all-distilroberta-v1-1.0.1-torch_script.zip', - configuration: expect.stringContaining('sentence_transformers'), }) ); }); diff --git a/public/components/register_model/model_details.tsx b/public/components/register_model/model_details.tsx index 7d1d60d4..c7849c0d 100644 --- a/public/components/register_model/model_details.tsx +++ b/public/components/register_model/model_details.tsx @@ -12,14 +12,15 @@ import { ModelNameField, ModelDescriptionField } from '../../components/common'; import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; export const ModelDetailsPanel = () => { - const { control, trigger } = useFormContext(); + const { control, trigger, watch } = useFormContext(); + const type = watch('type'); return (

      Details

      - +
      ); diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index dbfd21b5..daf09248 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -38,7 +38,7 @@ import { ModelVersionNotesPanel } from './model_version_notes'; import { modelFileUploadManager } from './model_file_upload_manager'; import { MAX_CHUNK_SIZE, FORM_ERRORS } from '../common/forms/form_constants'; import { ModelDetailsPanel } from './model_details'; -import type { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import type { ModelFileFormData, ModelFormBase, ModelUrlFormData } from './register_model.types'; import { ArtifactPanel } from './artifact'; import { ConfigurationPanel } from './model_configuration'; import { ModelTagsPanel } from './model_tags'; @@ -115,7 +115,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const formErrors = useMemo(() => ({ ...form.formState.errors }), [form.formState]); const onSubmit = useCallback( - async (data: ModelFileFormData | ModelUrlFormData) => { + async (data: ModelFileFormData | ModelUrlFormData | ModelFormBase) => { try { const onComplete = () => { notifications?.toasts.addSuccess({ @@ -229,18 +229,17 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo .subscribe( (preTrainedModel) => { // TODO: Fill model format here - const { config, url } = preTrainedModel; - form.setValue('modelURL', url); + const { config } = preTrainedModel; form.setValue('modelFileFormat', 'TORCH_SCRIPT'); if (config.name) { - form.setValue('name', config.name); + form.setValue('name', `huggingface/${config.name}`); + } + if (config.version) { + form.setValue('version', config.version); } if (config.description) { form.setValue('description', config.description); } - if (config.model_config) { - form.setValue('configuration', JSON.stringify(config.model_config)); - } setPreTrainedModelLoading(false); }, (error) => { @@ -254,6 +253,10 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo }; }, [nameParams, form]); + useEffect(() => { + form.setValue('type', formType); + }, [formType, form]); + const onError = useCallback((errors: FieldErrors) => { // TODO // eslint-disable-next-line no-console diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index d8f57f0a..9a2d909d 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -5,14 +5,15 @@ import type { Tag } from '../model/types'; -interface ModelFormBase { +export interface ModelFormBase { name: string; + version?: string; modelId?: string; description?: string; - configuration: string; modelFileFormat: string; tags?: Tag[]; versionNotes?: string; + type?: 'import' | 'upload'; } /** @@ -20,6 +21,7 @@ interface ModelFormBase { */ export interface ModelFileFormData extends ModelFormBase { modelFile: File; + configuration: string; } /** @@ -27,4 +29,5 @@ export interface ModelFileFormData extends ModelFormBase { */ export interface ModelUrlFormData extends ModelFormBase { modelURL: string; + configuration: string; } diff --git a/public/components/register_model/register_model_api.ts b/public/components/register_model/register_model_api.ts index 2d283948..e090e3cb 100644 --- a/public/components/register_model/register_model_api.ts +++ b/public/components/register_model/register_model_api.ts @@ -6,18 +6,20 @@ import { APIProvider } from '../../apis/api_provider'; import { MAX_CHUNK_SIZE } from '../common/forms/form_constants'; import { getModelContentHashValue } from './get_model_content_hash_value'; -import { ModelFileFormData, ModelUrlFormData } from './register_model.types'; +import { ModelFileFormData, ModelUrlFormData, ModelFormBase } from './register_model.types'; const getModelUploadBase = ({ name, + version, versionNotes, modelFileFormat, configuration, -}: ModelFileFormData | ModelUrlFormData) => ({ +}: ModelFormBase & { configuration?: string }) => ({ name, + version, description: versionNotes, modelFormat: modelFileFormat, - modelConfig: JSON.parse(configuration), + modelConfig: configuration ? JSON.parse(configuration) : undefined, }); const createModelIfNeedAndUploadVersion = async ({ @@ -78,14 +80,14 @@ export async function submitModelWithFile(model: ModelFileFormData) { }; } -export async function submitModelWithURL(model: ModelUrlFormData) { +export async function submitModelWithURL(model: ModelUrlFormData | ModelFormBase) { const result = await createModelIfNeedAndUploadVersion({ ...model, uploader: (modelId: string) => APIProvider.getAPI('modelVersion').upload({ ...getModelUploadBase(model), modelId, - url: model.modelURL, + url: 'modelURL' in model ? model.modelURL : undefined, }), }); diff --git a/server/routes/model_version_router.ts b/server/routes/model_version_router.ts index e48d73d0..bd9b9dd7 100644 --- a/server/routes/model_version_router.ts +++ b/server/routes/model_version_router.ts @@ -52,24 +52,23 @@ export const modelStateSchema = schema.oneOf([ schema.literal(MODEL_VERSION_STATE.registerFailed), ]); -const modelUploadBaseSchema = { +const modelUploadBaseSchema = schema.object({ name: schema.string(), version: schema.maybe(schema.string()), description: schema.maybe(schema.string()), modelFormat: schema.string(), - modelConfig: schema.object({}, { unknowns: 'allow' }), modelId: schema.string(), -}; +}); -const modelUploadByURLSchema = schema.object({ - ...modelUploadBaseSchema, +const modelUploadByURLSchema = modelUploadBaseSchema.extends({ url: schema.string(), + modelConfig: schema.object({}, { unknowns: 'allow' }), }); -const modelUploadByChunkSchema = schema.object({ - ...modelUploadBaseSchema, +const modelUploadByChunkSchema = modelUploadBaseSchema.extends({ modelContentHashValue: schema.string(), totalChunks: schema.number(), + modelConfig: schema.object({}, { unknowns: 'allow' }), }); export const modelVersionRouter = (router: IRouter) => { @@ -259,7 +258,11 @@ export const modelVersionRouter = (router: IRouter) => { { path: MODEL_VERSION_UPLOAD_API_ENDPOINT, validate: { - body: schema.oneOf([modelUploadByURLSchema, modelUploadByChunkSchema]), + body: schema.oneOf([ + modelUploadByURLSchema, + modelUploadByChunkSchema, + modelUploadBaseSchema, + ]), }, }, async (context, request) => { diff --git a/server/services/model_version_service.ts b/server/services/model_version_service.ts index 04934128..970cdb1f 100644 --- a/server/services/model_version_service.ts +++ b/server/services/model_version_service.ts @@ -41,32 +41,25 @@ interface UploadModelBase { version?: string; description?: string; modelFormat: string; - modelConfig: Record; modelId: string; } interface UploadModelByURL extends UploadModelBase { url: string; + modelConfig: Record; } interface UploadModelByChunk extends UploadModelBase { modelContentHashValue: string; totalChunks: number; + modelConfig: Record; } -type UploadResultInner< - T extends UploadModelByURL | UploadModelByChunk -> = T extends UploadModelByChunk +type UploadResultInner = T extends UploadModelByChunk ? { model_version_id: string; status: string } - : T extends UploadModelByURL - ? { task_id: string; status: string } - : never; + : { task_id: string; status: string }; -type UploadResult = Promise>; - -const isUploadModelByURL = ( - test: UploadModelByURL | UploadModelByChunk -): test is UploadModelByURL => (test as UploadModelByURL).url !== undefined; +type UploadResult = Promise>; export class ModelVersionService { constructor() {} @@ -134,7 +127,6 @@ export class ModelVersionService { id, model_id: modelSource.model_group_id, ...modelSource, - model_id: modelSource.model_group_id, }; } @@ -178,48 +170,47 @@ export class ModelVersionService { ).body; } - public static async upload({ + public static async upload({ client, model, }: { client: IScopedClusterClient; model: T; }): UploadResult { - const { name, version, description, modelFormat, modelConfig, modelId } = model; + const { name, version, description, modelFormat, modelId } = model; const uploadModelBase = { name, version, description, model_format: modelFormat, - model_config: modelConfig, + model_config: 'modelConfig' in model ? model.modelConfig : undefined, model_group_id: modelId, }; - if (isUploadModelByURL(model)) { - const { task_id: taskId, status } = ( + if ('totalChunks' in model) { + const { model_id: modelVersionId, status } = ( await client.asCurrentUser.transport.request({ method: 'POST', - path: MODEL_UPLOAD_API, + path: MODEL_META_API, body: { ...uploadModelBase, - url: model.url, + model_content_hash_value: model.modelContentHashValue, + total_chunks: model.totalChunks, }, }) ).body; - return { task_id: taskId, status } as UploadResultInner; + return { model_version_id: modelVersionId, status } as UploadResultInner; } - - const { model_id: modelVersionId, status } = ( + const { task_id: taskId, status } = ( await client.asCurrentUser.transport.request({ method: 'POST', - path: MODEL_META_API, + path: MODEL_UPLOAD_API, body: { ...uploadModelBase, - model_content_hash_value: model.modelContentHashValue, - total_chunks: model.totalChunks, + url: 'url' in model ? model.url : undefined, }, }) ).body; - return { model_version_id: modelVersionId, status } as UploadResultInner; + return { task_id: taskId, status } as UploadResultInner; } public static async uploadModelChunk({ From 9361ee6e57621f5fd11a90b6d640e330d2f93c1e Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 19 Jun 2023 21:42:57 +0800 Subject: [PATCH 60/75] Feat add basic model delete (#219) feat: add model delete confirm modal --------- Signed-off-by: Lin Wang --- public/components/common/modals/index.ts | 1 + .../modals/model_delete_confirm_modal.tsx | 145 ++++++++++++++++++ .../components/model/__tests__/model.test.tsx | 2 +- public/components/model/model.tsx | 4 +- .../components/model/model_delete_button.tsx | 51 ++++++ public/components/model_list/index.tsx | 3 +- public/components/model_list/model_table.tsx | 10 +- .../model_table_row_delete_button.tsx | 51 ++++++ 8 files changed, 261 insertions(+), 6 deletions(-) create mode 100644 public/components/common/modals/model_delete_confirm_modal.tsx create mode 100644 public/components/model/model_delete_button.tsx create mode 100644 public/components/model_list/model_table_row_delete_button.tsx diff --git a/public/components/common/modals/index.ts b/public/components/common/modals/index.ts index 085bdfe7..50e8679b 100644 --- a/public/components/common/modals/index.ts +++ b/public/components/common/modals/index.ts @@ -8,3 +8,4 @@ export { ModelVersionDeploymentConfirmModal } from './model_version_deployment_c export { TypeTextConfirmModal } from './type_text_confirm_modal'; export { ModelVersionDeleteConfirmModal } from './model_version_delete_confirm_modal'; export { ModelVersionUnableDoActionModal } from './model_version_unable_do_action_modal'; +export { ModelDeleteConfirmModal } from './model_delete_confirm_modal'; diff --git a/public/components/common/modals/model_delete_confirm_modal.tsx b/public/components/common/modals/model_delete_confirm_modal.tsx new file mode 100644 index 00000000..9331abcd --- /dev/null +++ b/public/components/common/modals/model_delete_confirm_modal.tsx @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useState } from 'react'; +import { Link, generatePath, useHistory } from 'react-router-dom'; +import { EuiLink, EuiSpacer, EuiText } from '@elastic/eui'; + +import { useOpenSearchDashboards } from '../../../../../../src/plugins/opensearch_dashboards_react/public'; +import { mountReactNode } from '../../../../../../src/core/public/utils'; +import { routerPaths } from '../../../../common'; +import { APIProvider } from '../../../../public/apis/api_provider'; + +import { TypeTextConfirmModal } from './type_text_confirm_modal'; + +const typeTextConfirmModalStyle = { width: 500 }; + +interface ModelDeleteConfirmModalProps { + id: string; + name: string; + closeModal: (data: { succeed: boolean; canceled: boolean }) => void; +} + +export const ModelDeleteConfirmModal = ({ id, name, closeModal }: ModelDeleteConfirmModalProps) => { + const { + services: { notifications }, + } = useOpenSearchDashboards(); + const [isDeleting, setIsDeleting] = useState(false); + const history = useHistory(); + + const handleCancel = useCallback(() => { + if (isDeleting) { + return; + } + closeModal({ canceled: true, succeed: false }); + }, [isDeleting, closeModal]); + + const handleConfirm = useCallback(async () => { + setIsDeleting(true); + try { + // TODO: move to delete in background and hide confirm modal + const modelVersionIds: string[] = []; + while (true) { + const searchResult = await APIProvider.getAPI('modelVersion').search({ + from: modelVersionIds.length, + size: 50, + modelIds: [id], + }); + searchResult.data.forEach((modelVersion) => { + modelVersionIds.push(modelVersion.id); + }); + if (modelVersionIds.length >= searchResult.total_model_versions) { + break; + } + } + for (let i = 0; i < modelVersionIds.length; i++) { + await APIProvider.getAPI('modelVersion').delete(modelVersionIds[i]); + } + /** + * Model group can't be deleted if there are model versions in it + * We need to wait for all versions can't be searchable + **/ + while (true) { + if ( + ( + await APIProvider.getAPI('modelVersion').search({ + from: 0, + size: 1, + modelIds: [id], + }) + ).total_model_versions === 0 + ) { + break; + } + await new Promise((resolve) => setTimeout(resolve, 300)); + } + await APIProvider.getAPI('model').delete(id); + /** + * Model group still can be searchable after delete + * We need to wait for model can't be searchable + **/ + while (true) { + if ( + (await APIProvider.getAPI('model').search({ ids: [id], from: 0, size: 0 })) + .total_models === 0 + ) { + break; + } + await new Promise((resolve) => setTimeout(resolve, 300)); + } + } catch (e) { + closeModal({ succeed: false, canceled: false }); + const modelLinkAddress = history.createHref({ + pathname: generatePath(routerPaths.model, { id }), + }); + notifications?.toasts.addDanger({ + title: mountReactNode( + <> + Unable to delete {name} + + ), + }); + return; + } finally { + setIsDeleting(false); + } + notifications?.toasts.addSuccess({ + title: mountReactNode( + <> + {name} has been deleted + + ), + }); + closeModal({ succeed: true, canceled: false }); + }, [id, name, closeModal, notifications, history]); + + return ( + + Delete{' '} + + {name} version + + ? + + } + textToType={name} + confirmButtonText="Delete model" + buttonColor="danger" + cancelButtonText="Cancel" + onConfirm={handleConfirm} + onCancel={handleCancel} + confirmButtonDisabled={isDeleting} + isLoading={isDeleting} + > + + This will delete all versions of this model. This action is irreversible. + + + + ); +}; diff --git a/public/components/model/__tests__/model.test.tsx b/public/components/model/__tests__/model.test.tsx index 49596598..c635014d 100644 --- a/public/components/model/__tests__/model.test.tsx +++ b/public/components/model/__tests__/model.test.tsx @@ -44,7 +44,7 @@ describe('', () => { expect(screen.queryByTestId('model-group-loading-indicator')).toBeNull(); }); expect(screen.getByText('model1')).toBeInTheDocument(); - expect(screen.getByText('Delete')).toBeInTheDocument(); + expect(screen.getByLabelText('Delete model')).toBeInTheDocument(); expect(screen.getByText('Register version')).toBeInTheDocument(); expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); expect(screen.queryByTestId('model-group-overview-card')).toBeInTheDocument(); diff --git a/public/components/model/model.tsx b/public/components/model/model.tsx index 02c66599..410dae6f 100644 --- a/public/components/model/model.tsx +++ b/public/components/model/model.tsx @@ -23,6 +23,7 @@ import { ModelOverviewCard } from './model_overview_card'; import { ModelVersionsPanel } from './model_versions_panel'; import { ModelDetailsPanel } from './model_details_panel'; import { ModelTagsPanel } from './model_tags_panel'; +import { ModelDeleteButton } from './model_delete_button'; export const Model = () => { const { id: modelId } = useParams<{ id: string }>(); @@ -88,11 +89,12 @@ export const Model = () => {

      {data.name}

      } + rightSideGroupProps={{ alignItems: 'center' }} rightSideItems={[ Register version , - Delete, + , ]} /> { + const [isModalVisible, setIsModalVisible] = useState(false); + const history = useHistory(); + + const handleDeleteButtonClick = useCallback(() => { + setIsModalVisible(true); + }, []); + const handleModalClose = useCallback( + ({ succeed }: { succeed: boolean; canceled: boolean }) => { + if (succeed) { + history.push(routerPaths.modelList); + return; + } + setIsModalVisible(false); + }, + [history] + ); + + return ( + <> + + Delete + + {isModalVisible && ( + + )} + + ); +}; diff --git a/public/components/model_list/index.tsx b/public/components/model_list/index.tsx index c96610d0..6fbd1e9b 100644 --- a/public/components/model_list/index.tsx +++ b/public/components/model_list/index.tsx @@ -94,7 +94,7 @@ export const ModelList = () => { searchInputRef.current = node; }, []); - const { data, loading, error } = useFetcher( + const { data, loading, error, reload } = useFetcher( APIProvider.getAPI('modelAggregate').search, getModelAggregateSearchParams(params) ); @@ -185,6 +185,7 @@ export const ModelList = () => { onChange={handleTableChange} onResetClick={handleReset} error={!!error} + onModelDeleted={reload} /> )} diff --git a/public/components/model_list/model_table.tsx b/public/components/model_list/model_table.tsx index 7a19a391..f52338d9 100644 --- a/public/components/model_list/model_table.tsx +++ b/public/components/model_list/model_table.tsx @@ -27,6 +27,7 @@ import { routerPaths } from '../../../common'; import { ModelOwner } from './model_owner'; import { ModelDeployedVersions } from './model_deployed_versions'; +import { ModelTableRowDeleteButton } from './model_table_row_delete_button'; export interface ModelTableSort { field: 'name' | 'latest_version' | 'description' | 'owner_name' | 'last_updated_time'; @@ -50,10 +51,11 @@ export interface ModelTableProps { loading: boolean; error: boolean; onResetClick: () => void; + onModelDeleted: () => void; } export function ModelTable(props: ModelTableProps) { - const { models, sort, onChange, loading, onResetClick, error } = props; + const { models, sort, onChange, loading, onResetClick, error, onModelDeleted } = props; const onChangeRef = useRef(onChange); onChangeRef.current = onChange; @@ -113,12 +115,14 @@ export function ModelTable(props: ModelTableProps) { ), }, { - render: () => , + render: ({ id, name }) => ( + + ), }, ], }, ], - [] + [onModelDeleted] ); const pagination = useMemo( diff --git a/public/components/model_list/model_table_row_delete_button.tsx b/public/components/model_list/model_table_row_delete_button.tsx new file mode 100644 index 00000000..69e1f3a9 --- /dev/null +++ b/public/components/model_list/model_table_row_delete_button.tsx @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EuiButtonIcon } from '@elastic/eui'; +import React, { useState, useCallback } from 'react'; + +import { ModelDeleteConfirmModal } from '../common/modals/model_delete_confirm_modal'; + +interface ModelTableRowDeleteButtonProps { + id: string; + name: string; + onDeleted: () => void; +} + +export const ModelTableRowDeleteButton = ({ + id, + name, + onDeleted, +}: ModelTableRowDeleteButtonProps) => { + const [isModalVisible, setIsModalVisible] = useState(false); + + const handleDeleteButtonClick = useCallback(() => { + setIsModalVisible(true); + }, []); + const handleModalClose = useCallback( + ({ succeed }: { succeed: boolean; canceled: boolean }) => { + if (succeed) { + onDeleted(); + } + }, + [onDeleted] + ); + + return ( + <> + + Delete + + {isModalVisible && ( + + )} + + ); +}; From 9d7851eef431a044f563e2b1af9a22a87e28265b Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Tue, 20 Jun 2023 09:33:48 +0800 Subject: [PATCH 61/75] feat: refresh model version data after deploy or undeploy complete (#216) * feat: refresh model version data after deploy or undeploy complete Signed-off-by: Lin Wang * fix: undeploy no onComplete and onError Signed-off-by: Lin Wang --------- Signed-off-by: Lin Wang --- ..._version_deployment_confirm_modal.test.tsx | 66 ++++++++++ ...model_version_deployment_confirm_modal.tsx | 34 ++++- .../model_version_table_row_actions.test.tsx | 116 +++++++++++++++++- .../model_version_table.tsx | 21 +++- .../model_version_table_row_actions.tsx | 42 +++++-- .../model_versions_panel.tsx | 4 + 6 files changed, 264 insertions(+), 19 deletions(-) diff --git a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx index 47f42416..dcf72a6c 100644 --- a/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx +++ b/public/components/common/modals/__tests__/model_version_deployment_confirm_modal.test.tsx @@ -106,4 +106,70 @@ describe('', () => { }); }); }); + + it('should call closeModal with canceled after cancel button clicked', async () => { + const closeModalMock = jest.fn(() => {}); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Cancel' })); + expect(closeModalMock).toHaveBeenCalledWith({ canceled: true, succeed: false, id: '1' }); + }); + + it('should call closeModal with succeed equal true after undeploy button clicked and deploy complete', async () => { + const closeModalMock = jest.fn(() => {}); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); + + await waitFor(() => { + expect(undeployMock.mock.calls.length).toBeGreaterThan(0); + expect( + undeployMock.mock.calls[undeployMock.mock.calls.length - 1][0].onComplete + ).not.toBeUndefined(); + }); + + undeployMock.mock.calls[undeployMock.mock.calls.length - 1][0].onComplete(); + expect(closeModalMock).toHaveBeenCalledWith({ canceled: false, succeed: true, id: '1' }); + }); + + it('should call closeModal with succeed equal false after undeploy button clicked and deploy complete', async () => { + const closeModalMock = jest.fn(() => {}); + render( + + ); + + await userEvent.click(screen.getByRole('button', { name: 'Undeploy' })); + + await waitFor(() => { + expect(undeployMock.mock.calls.length).toBeGreaterThan(0); + expect( + undeployMock.mock.calls[undeployMock.mock.calls.length - 1][0].onError + ).not.toBeUndefined(); + }); + + undeployMock.mock.calls[undeployMock.mock.calls.length - 1][0].onError(); + expect(closeModalMock).toHaveBeenCalledWith({ canceled: false, succeed: false, id: '1' }); + }); }); diff --git a/public/components/common/modals/model_version_deployment_confirm_modal.tsx b/public/components/common/modals/model_version_deployment_confirm_modal.tsx index 54dc123c..81a86180 100644 --- a/public/components/common/modals/model_version_deployment_confirm_modal.tsx +++ b/public/components/common/modals/model_version_deployment_confirm_modal.tsx @@ -21,7 +21,7 @@ export const ModelVersionDeploymentConfirmModal = ({ mode: 'deploy' | 'undeploy'; name: string; version: string; - closeModal: () => void; + closeModal: (data: { canceled: boolean; succeed: boolean; id: string }) => void; }) => { const { deploy, undeploy } = useDeployment(id); const [isSubmitting, setIsSubmitting] = useState(false); @@ -39,12 +39,34 @@ export const ModelVersionDeploymentConfirmModal = ({ }; const { title, description, action } = mapping[mode]; + const handleCancel = useCallback(() => { + closeModal({ canceled: true, succeed: false, id }); + }, [closeModal, id]); + const handleConfirm = useCallback(async () => { setIsSubmitting(true); - await action(); - setIsSubmitting(false); - closeModal(); - }, [action, closeModal]); + try { + await action({ + onComplete: () => { + setIsSubmitting(false); + closeModal({ canceled: false, succeed: true, id }); + }, + onError: () => { + setIsSubmitting(false); + closeModal({ canceled: false, succeed: false, id }); + }, + }); + } catch (e) { + setIsSubmitting(false); + closeModal({ canceled: false, succeed: false, id }); + return; + } + // undeploy action won't call onComplete after success, so need to handle undeploy success here + if (mode === 'undeploy') { + setIsSubmitting(false); + closeModal({ canceled: false, succeed: true, id }); + } + }, [id, mode, action, closeModal]); return ( { - return render( + const onDeployedFailedMock = jest.fn(); + const onDeployedMock = jest.fn(); + const onUndeployedFailedMock = jest.fn(); + const onUndeployedMock = jest.fn(); + const result = render( ); + return { + renderResult: result, + onDeployedMock, + onDeployedFailedMock, + onUndeployedMock, + onUndeployedFailedMock, + }; }; describe('', () => { @@ -43,7 +59,9 @@ describe('', () => { it('should render "Deploy" button for REGISTERED, DEPLOY_FAILED and UNDEPLOYED state', async () => { const user = userEvent.setup(); - const { rerender } = setup(MODEL_VERSION_STATE.registered); + const { + renderResult: { rerender }, + } = setup(MODEL_VERSION_STATE.registered); await user.click(screen.getByLabelText('show actions')); expect(screen.getByText('Deploy')).toBeInTheDocument(); @@ -55,6 +73,10 @@ describe('', () => { name="model-1" version="1" onDeleted={jest.fn()} + onDeployed={jest.fn()} + onDeployFailed={jest.fn()} + onUndeployed={jest.fn()} + onUndeployFailed={jest.fn()} /> ); expect(screen.getByText('Deploy')).toBeInTheDocument(); @@ -66,6 +88,10 @@ describe('', () => { name="model-1" version="1" onDeleted={jest.fn()} + onDeployed={jest.fn()} + onDeployFailed={jest.fn()} + onUndeployed={jest.fn()} + onUndeployFailed={jest.fn()} /> ); expect(screen.getByText('Deploy')).toBeInTheDocument(); @@ -73,7 +99,9 @@ describe('', () => { it('should render "Undeploy" button for DEPLOYED and PARTIALLY_DEPLOYED state', async () => { const user = userEvent.setup(); - const { rerender } = setup(MODEL_VERSION_STATE.deployed); + const { + renderResult: { rerender }, + } = setup(MODEL_VERSION_STATE.deployed); await user.click(screen.getByLabelText('show actions')); expect(screen.getByText('Undeploy')).toBeInTheDocument(); @@ -85,6 +113,10 @@ describe('', () => { name="model-1" version="1" onDeleted={jest.fn()} + onDeployed={jest.fn()} + onDeployFailed={jest.fn()} + onUndeployed={jest.fn()} + onUndeployFailed={jest.fn()} /> ); expect(screen.getByText('Undeploy')).toBeInTheDocument(); @@ -169,4 +201,80 @@ describe('', () => { expect(screen.getByText('Unable to delete')).toBeInTheDocument(); }); + + it('should call onDeployed after deployed', async () => { + const user = userEvent.setup(); + const useDeploymentMock = jest + .spyOn(useDeploymentExports, 'useDeployment') + .mockImplementation(() => ({ + deploy: async (options?: { onComplete?: () => void; onError?: () => void }) => { + options?.onComplete?.(); + }, + undeploy: jest.fn(), + })); + + const { onDeployedMock } = setup(MODEL_VERSION_STATE.deployFailed); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Deploy')); + await user.click(screen.getByRole('button', { name: 'Deploy' })); + + expect(onDeployedMock).toHaveBeenCalled(); + useDeploymentMock.mockRestore(); + }); + + it('should call onDeployedFailed after deploy failed', async () => { + const user = userEvent.setup(); + const useDeploymentMock = jest + .spyOn(useDeploymentExports, 'useDeployment') + .mockImplementation(() => ({ + deploy: async (options?: { onComplete?: () => void; onError?: () => void }) => { + options?.onError?.(); + }, + undeploy: jest.fn(), + })); + + const { onDeployedFailedMock } = setup(MODEL_VERSION_STATE.deployFailed); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Deploy')); + await user.click(screen.getByRole('button', { name: 'Deploy' })); + + expect(onDeployedFailedMock).toHaveBeenCalled(); + useDeploymentMock.mockRestore(); + }); + + it('should call onUndeployed after undeploy failed', async () => { + const user = userEvent.setup(); + const useDeploymentMock = jest + .spyOn(useDeploymentExports, 'useDeployment') + .mockImplementation(() => ({ + deploy: jest.fn(), + undeploy: jest.fn().mockResolvedValue({}), + })); + + const { onUndeployedMock } = setup(MODEL_VERSION_STATE.deployed); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Undeploy')); + await user.click(screen.getByRole('button', { name: 'Undeploy' })); + + expect(onUndeployedMock).toHaveBeenCalled(); + useDeploymentMock.mockRestore(); + }); + + it('should call onUndeployedFailed after undeploy failed', async () => { + const user = userEvent.setup(); + const useDeploymentMock = jest + .spyOn(useDeploymentExports, 'useDeployment') + .mockImplementation(() => ({ + deploy: jest.fn(), + undeploy: jest.fn().mockRejectedValue(new Error('Undeploy failed')), + })); + + const { onUndeployedFailedMock } = setup(MODEL_VERSION_STATE.deployed); + await user.click(screen.getByLabelText('show actions')); + await user.click(screen.getByText('Undeploy')); + await user.click(screen.getByRole('button', { name: 'Undeploy' })); + + expect(onUndeployedFailedMock).toHaveBeenCalled(); + useDeploymentMock.mockRestore(); + }); }); diff --git a/public/components/model/model_versions_panel/model_version_table.tsx b/public/components/model/model_versions_panel/model_version_table.tsx index 7e2af93c..5122b262 100644 --- a/public/components/model/model_versions_panel/model_version_table.tsx +++ b/public/components/model/model_versions_panel/model_version_table.tsx @@ -43,6 +43,10 @@ interface ModelVersionTableProps extends Pick void; + onVersionDeployed: (id: string) => void; + onVersionDeployFailed: (id: string) => void; + onVersionUndeployed: (id: string) => void; + onVersionUndeployFailed: (id: string) => void; } export const ModelVersionTable = ({ @@ -52,6 +56,10 @@ export const ModelVersionTable = ({ pagination, totalVersionCount, onVersionDeleted, + onVersionDeployed, + onVersionDeployFailed, + onVersionUndeployed, + onVersionUndeployFailed, }: ModelVersionTableProps) => { const columns = useMemo( () => [ @@ -122,12 +130,23 @@ export const ModelVersionTable = ({ state={state} version={version} onDeleted={onVersionDeleted} + onDeployed={onVersionDeployed} + onDeployFailed={onVersionDeployFailed} + onUndeployed={onVersionUndeployed} + onUndeployFailed={onVersionUndeployFailed} /> ); }, }, ], - [versions, onVersionDeleted] + [ + versions, + onVersionDeleted, + onVersionDeployed, + onVersionUndeployed, + onVersionDeployFailed, + onVersionUndeployFailed, + ] ); const [visibleColumns, setVisibleColumns] = useState(() => { const tagHiddenByDefaultColumns = tags.slice(3); diff --git a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx index 231b8d7c..260156d2 100644 --- a/public/components/model/model_versions_panel/model_version_table_row_actions.tsx +++ b/public/components/model/model_versions_panel/model_version_table_row_actions.tsx @@ -19,6 +19,10 @@ interface ModelVersionTableRowActionsProps { version: string; state: MODEL_VERSION_STATE; onDeleted: (id: string) => void; + onDeployed: (id: string) => void; + onUndeployed: (id: string) => void; + onDeployFailed: (id: string) => void; + onUndeployFailed: (id: string) => void; } export const ModelVersionTableRowActions = ({ @@ -27,6 +31,10 @@ export const ModelVersionTableRowActions = ({ state, version, onDeleted, + onDeployed, + onDeployFailed, + onUndeployed, + onUndeployFailed, }: ModelVersionTableRowActionsProps) => { const [isPopoverOpen, setIsPopoverOpen] = useState(false); const [isDeployConfirmModalShow, setIsDeployConfirmModalShow] = useState(false); @@ -50,13 +58,31 @@ export const ModelVersionTableRowActions = ({ setIsUndeployConfirmModalShow(true); }, []); - const closeDeployConfirmModal = useCallback(() => { - setIsDeployConfirmModalShow(false); - }, []); + const handleDeployConfirmModalClose = useCallback( + ({ id: versionId, succeed, canceled }: { succeed: boolean; id: string; canceled: boolean }) => { + if (succeed) { + onDeployed(versionId); + } + if (!succeed && !canceled) { + onDeployFailed(versionId); + } + setIsDeployConfirmModalShow(false); + }, + [onDeployed, onDeployFailed] + ); - const closeUndeployConfirmModal = useCallback(() => { - setIsUndeployConfirmModalShow(false); - }, []); + const handleUndeployConfirmModalClose = useCallback( + ({ id: versionId, succeed, canceled }: { canceled: boolean; succeed: boolean; id: string }) => { + if (succeed) { + onUndeployed(versionId); + } + if (!succeed && !canceled) { + onUndeployFailed(versionId); + } + setIsUndeployConfirmModalShow(false); + }, + [onUndeployed, onUndeployFailed] + ); const handleDeleteClick = useCallback(() => { if ( @@ -164,7 +190,7 @@ export const ModelVersionTableRowActions = ({ id={id} name={name} version={version} - closeModal={closeDeployConfirmModal} + closeModal={handleDeployConfirmModalClose} /> )} {isUndeployConfirmModalShow && ( @@ -173,7 +199,7 @@ export const ModelVersionTableRowActions = ({ id={id} name={name} version={version} - closeModal={closeUndeployConfirmModal} + closeModal={handleUndeployConfirmModalClose} /> )} {isDeleteConfirmModalShow && ( diff --git a/public/components/model/model_versions_panel/model_versions_panel.tsx b/public/components/model/model_versions_panel/model_versions_panel.tsx index 20056f3e..f310a39a 100644 --- a/public/components/model/model_versions_panel/model_versions_panel.tsx +++ b/public/components/model/model_versions_panel/model_versions_panel.tsx @@ -254,6 +254,10 @@ export const ModelVersionsPanel = ({ modelId }: ModelVersionsPanelProps) => { totalVersionCount={totalVersionCount} sorting={versionsSorting} onVersionDeleted={reload} + onVersionDeployed={reload} + onVersionUndeployed={reload} + onVersionDeployFailed={reload} + onVersionUndeployFailed={reload} /> )} {panelStatus === 'loading' && ( From 3ba16ab29d72a1c2dfc86fb7dfd4945f1e213075 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Tue, 20 Jun 2023 10:26:34 +0800 Subject: [PATCH 62/75] fix: fix error when search for model name when index hasn't created (#220) Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .../components/common/forms/model_name_field.tsx | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/public/components/common/forms/model_name_field.tsx b/public/components/common/forms/model_name_field.tsx index 658e4ed8..b7f12921 100644 --- a/public/components/common/forms/model_name_field.tsx +++ b/public/components/common/forms/model_name_field.tsx @@ -24,12 +24,16 @@ interface ModelNameFieldProps { } const isDuplicateModelName = async (name: string) => { - const searchResult = await APIProvider.getAPI('model').search({ - name, - from: 0, - size: 1, - }); - return searchResult.total_models >= 1; + try { + const searchResult = await APIProvider.getAPI('model').search({ + name, + from: 0, + size: 1, + }); + return searchResult.total_models >= 1; + } catch (e) { + return false; + } }; export const ModelNameField = ({ From 42d76cd3861f88f8d2f1fa2dba0ffea6710bdef2 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Tue, 20 Jun 2023 10:26:56 +0800 Subject: [PATCH 63/75] build: add experimental release action (#221) Signed-off-by: Yulong Ruan Signed-off-by: Lin Wang --- .github/workflows/experimental-release.yml | 65 ++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 .github/workflows/experimental-release.yml diff --git a/.github/workflows/experimental-release.yml b/.github/workflows/experimental-release.yml new file mode 100644 index 00000000..de6f891c --- /dev/null +++ b/.github/workflows/experimental-release.yml @@ -0,0 +1,65 @@ +name: experimental-release + +on: + workflow_dispatch: + inputs: + opensearch_dashboards_version: + description: 'Which version of OpenSearch-Dashboards that this build is targeting' + required: false + default: '2.x' + +jobs: + build: + name: Build + runs-on: ubuntu-latest + env: + RELEASE_TAG: ${{ github.event.inputs.opensearch_dashboards_version }}-experimental.${{ github.run_id }} + + steps: + - run: echo Target OpenSearch-Dashboards version ${{ github.event.inputs.opensearch_dashboards_version }} + - name: Checkout OpenSearch Dashboards + uses: actions/checkout@v2 + with: + repository: opensearch-project/OpenSearch-Dashboards + ref: ${{ github.event.inputs.opensearch_dashboards_version }} + path: OpenSearch-Dashboards + - name: Setup Node + uses: actions/setup-node@v3 + with: + node-version-file: './OpenSearch-Dashboards/.nvmrc' + registry-url: 'https://registry.npmjs.org' + - name: Install Yarn + # Need to use bash to avoid having a windows/linux specific step + shell: bash + run: | + YARN_VERSION=$(node -p "require('./OpenSearch-Dashboards/package.json').engines.yarn") + echo "Installing yarn@$YARN_VERSION" + npm i -g yarn@$YARN_VERSION + - run: node -v + - run: yarn -v + - name: Checkout ML Commons OpenSearch Dashboards plugin + uses: actions/checkout@v2 + with: + path: OpenSearch-Dashboards/plugins/ml-commons-dashboards + - name: Bootstrap plugin/opensearch-dashboards + run: | + cd OpenSearch-Dashboards/plugins/ml-commons-dashboards + yarn osd bootstrap + - name: Run build + run: | + cd OpenSearch-Dashboards/plugins/ml-commons-dashboards + yarn run build + - name: Create Release Tag + run: | + cd OpenSearch-Dashboards/plugins/ml-commons-dashboards + git tag ${{ env.RELEASE_TAG }} + git push origin ${{ env.RELEASE_TAG }} + - name: Release + id: release_step + uses: softprops/action-gh-release@v1 + with: + files: OpenSearch-Dashboards/plugins/ml-commons-dashboards/build/*.zip + tag_name: ${{ env.RELEASE_TAG }} + - name: Update Artifact URL + run: | + echo ${{ steps.release_step.outputs.assets }} From a2531e8935f65b0043f38eabf679db56255feb63 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 3 Jan 2025 11:25:26 +0800 Subject: [PATCH 64/75] Fix owner and transport missing Signed-off-by: Lin Wang --- public/components/model_version/model_version.tsx | 2 +- server/services/model_aggregate_service.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/public/components/model_version/model_version.tsx b/public/components/model_version/model_version.tsx index 420f21e1..3b544ff1 100644 --- a/public/components/model_version/model_version.tsx +++ b/public/components/model_version/model_version.tsx @@ -197,7 +197,7 @@ export const ModelVersion = () => { modelVersionId={modelVersionData?.id} createdTime={modelVersionData?.created_time} lastUpdatedTime={modelVersionData?.last_updated_time} - owner={modelData?.owner.name} + owner={modelData?.owner?.name} /> )} diff --git a/server/services/model_aggregate_service.ts b/server/services/model_aggregate_service.ts index df6ba923..247ef48e 100644 --- a/server/services/model_aggregate_service.ts +++ b/server/services/model_aggregate_service.ts @@ -105,7 +105,7 @@ export class ModelAggregateService { }); const modelIds = models.map(({ id }) => id); const { data: deployedModels } = await ModelVersionService.search({ - client, + transport: client.asCurrentUser.transport, from: 0, size: MAX_MODEL_BUCKET_NUM, modelIds, From e767c0ce0412255dbb3f30bd34d4d1e9d2bd469f Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 3 Jan 2025 11:29:06 +0800 Subject: [PATCH 65/75] Hide model registry entrances Signed-off-by: Lin Wang --- common/router.ts | 43 --------------------------------------- public/components/app.tsx | 8 +------- 2 files changed, 1 insertion(+), 50 deletions(-) diff --git a/common/router.ts b/common/router.ts index e3bc2c8f..c077134c 100644 --- a/common/router.ts +++ b/common/router.ts @@ -3,11 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { Model } from '../public/components/model'; -import { ModelList } from '../public/components/model_list'; import { Monitoring } from '../public/components/monitoring'; -import { RegisterModelForm } from '../public/components/register_model/register_model'; -import { ModelVersion } from '../public/components/model_version'; import { routerPaths } from './router_paths'; interface RouteConfig { @@ -28,43 +24,4 @@ export const ROUTES: RouteConfig[] = [ label: 'Overview', nav: true, }, - { - path: routerPaths.registerModel, - label: 'Register Model', - Component: RegisterModelForm, - nav: true, - }, - { - path: routerPaths.modelList, - label: 'Model Registry', - Component: ModelList, - nav: true, - }, - { - path: routerPaths.model, - // TODO: refactor label to be dynamic so that we can display group name in breadcrumb - label: 'Model', - Component: Model, - nav: false, - }, - { - path: routerPaths.modelVersion, - label: 'Model Version', - Component: ModelVersion, - nav: false, - }, ]; - -/* export const ROUTES1 = [ - { - path: routerPaths.modelList, - Component: ModelList, - label: 'Model List', - icon: 'createSingleMetricJob', - }, - { - path: routerPaths.registerModel, - label: 'Register Model', - Component: RegisterModelForm, - }, -];*/ diff --git a/public/components/app.tsx b/public/components/app.tsx index 5e8e4f76..22cb5ed2 100644 --- a/public/components/app.tsx +++ b/public/components/app.tsx @@ -6,7 +6,7 @@ import React from 'react'; import { I18nProvider } from '@osd/i18n/react'; import { Redirect, Route, Switch } from 'react-router-dom'; -import { EuiPage, EuiPageBody, EuiPageSideBar } from '@elastic/eui'; +import { EuiPage, EuiPageBody } from '@elastic/eui'; import { useObservable } from 'react-use'; import { ROUTES } from '../../common/router'; import { routerPaths } from '../../common/router_paths'; @@ -25,7 +25,6 @@ import { DataSourceContextProvider } from '../contexts/data_source_context'; import { GlobalBreadcrumbs } from './global_breadcrumbs'; import { DataSourceTopNavMenu } from './data_source_top_nav_menu'; -import { NavPanel } from './nav_panel'; interface MlCommonsPluginAppDeps { basename: string; @@ -73,11 +72,6 @@ export const MlCommonsPluginApp = ({ > <> - {!useNewPageHeader && ( - - - - )} {ROUTES.map(({ path, Component, exact }) => ( From d74398a7e6bc41cf402937b88207cef2a8c9a77c Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 3 Jan 2025 16:31:06 +0800 Subject: [PATCH 66/75] Replace response factory with response Signed-off-by: Lin Wang --- server/__tests__/plugin.test.ts | 8 ++-- server/plugin.ts | 2 +- server/routes/model_aggregate_router.ts | 8 ++-- server/routes/model_repository_router.ts | 25 +++++++----- server/routes/model_router.ts | 24 +++++------ server/routes/model_version_router.ts | 52 ++++++++++++------------ server/routes/security_router.ts | 8 ++-- server/routes/task_router.ts | 14 +++---- 8 files changed, 72 insertions(+), 69 deletions(-) diff --git a/server/__tests__/plugin.test.ts b/server/__tests__/plugin.test.ts index b27eb587..1966a327 100644 --- a/server/__tests__/plugin.test.ts +++ b/server/__tests__/plugin.test.ts @@ -5,7 +5,7 @@ import { MlCommonsPlugin } from '../plugin'; import { coreMock, httpServiceMock } from '../../../../src/core/server/mocks'; -import * as modelRouterExports from '../routes/model_router'; +import * as modelVersionRouterExports from '../routes/model_version_router'; import * as connectorRouterExports from '../routes/connector_router'; import * as profileRouterExports from '../routes/profile_router'; @@ -22,10 +22,10 @@ describe('MlCommonsPlugin', () => { initContext = coreMock.createPluginInitializerContext(); }); - it('should register model routers', () => { - jest.spyOn(modelRouterExports, 'modelRouter'); + it('should register model version routers', () => { + jest.spyOn(modelVersionRouterExports, 'modelVersionRouter'); new MlCommonsPlugin(initContext).setup(mockCoreSetup); - expect(modelRouterExports.modelRouter).toHaveBeenCalledWith(routerMock); + expect(modelVersionRouterExports.modelVersionRouter).toHaveBeenCalledWith(routerMock); }); it('should register connector routers', () => { diff --git a/server/plugin.ts b/server/plugin.ts index a964ffb2..371ae036 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -13,9 +13,9 @@ import { import { MlCommonsPluginSetup, MlCommonsPluginStart } from './types'; import { + modelVersionRouter, connectorRouter, modelRouter, - modelVersionRouter, modelAggregateRouter, profileRouter, securityRouter, diff --git a/server/routes/model_aggregate_router.ts b/server/routes/model_aggregate_router.ts index ea2f0277..766a5968 100644 --- a/server/routes/model_aggregate_router.ts +++ b/server/routes/model_aggregate_router.ts @@ -4,7 +4,7 @@ */ import { schema } from '@osd/config-schema'; -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { IRouter } from '../../../../src/core/server'; import { ModelAggregateService } from '../services/model_aggregate_service'; import { MODEL_AGGREGATE_API_ENDPOINT } from './constants'; import { modelStateSchema } from './model_version_router'; @@ -37,7 +37,7 @@ export const modelAggregateRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { const { states, extraQuery, ...restQuery } = request.query; try { const payload = await ModelAggregateService.search({ @@ -46,9 +46,9 @@ export const modelAggregateRouter = (router: IRouter) => { extraQuery, ...restQuery, }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); + return response.ok({ body: payload }); } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ + return response.badRequest({ body: error instanceof Error ? error.message : JSON.stringify(error), }); } diff --git a/server/routes/model_repository_router.ts b/server/routes/model_repository_router.ts index 8c8a286d..e74935e0 100644 --- a/server/routes/model_repository_router.ts +++ b/server/routes/model_repository_router.ts @@ -7,7 +7,7 @@ import { schema } from '@osd/config-schema'; // @ts-ignore import fetch from 'node-fetch'; -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { IRouter } from '../../../../src/core/server'; import { MODEL_REPOSITORY_API_ENDPOINT, MODEL_REPOSITORY_CONFIG_URL_API_ENDPOINT, @@ -19,14 +19,17 @@ const PRE_TRAINED_MODELS_URL = const fetchURLAsJSONData = (url: string) => fetch(url).then((response: any) => response.json()); export const modelRepositoryRouter = (router: IRouter) => { - router.get({ path: MODEL_REPOSITORY_API_ENDPOINT, validate: false }, async () => { - try { - const data = await fetchURLAsJSONData(PRE_TRAINED_MODELS_URL); - return opensearchDashboardsResponseFactory.ok({ body: data }); - } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ body: error.message }); + router.get( + { path: MODEL_REPOSITORY_API_ENDPOINT, validate: false }, + async (_context, _request, response) => { + try { + const data = await fetchURLAsJSONData(PRE_TRAINED_MODELS_URL); + return response.ok({ body: data }); + } catch (error) { + return response.badRequest({ body: error.message }); + } } - }); + ); router.get( { @@ -37,12 +40,12 @@ export const modelRepositoryRouter = (router: IRouter) => { }), }, }, - async (_context, request) => { + async (_context, request, response) => { try { const data = await fetchURLAsJSONData(request.params.configURL); - return opensearchDashboardsResponseFactory.ok({ body: data }); + return response.ok({ body: data }); } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ body: error.message }); + return response.badRequest({ body: error.message }); } } ); diff --git a/server/routes/model_router.ts b/server/routes/model_router.ts index 05ac2c5f..f8de4c9c 100644 --- a/server/routes/model_router.ts +++ b/server/routes/model_router.ts @@ -5,7 +5,7 @@ import { schema } from '@osd/config-schema'; -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { IRouter } from '../../../../src/core/server'; import { ModelService } from '../services'; import { MODEL_API_ENDPOINT } from './constants'; @@ -28,7 +28,7 @@ export const modelRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { const { name, description, modelAccessMode, backendRoles, addAllBackendRoles } = request.body; try { const payload = await ModelService.register({ @@ -39,9 +39,9 @@ export const modelRouter = (router: IRouter) => { backendRoles, addAllBackendRoles, }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); + return response.ok({ body: payload }); } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ + return response.badRequest({ body: error instanceof Error ? error.message : JSON.stringify(error), }); } @@ -73,9 +73,9 @@ export const modelRouter = (router: IRouter) => { name, description, }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); + return response.ok({ body: payload }); } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ + return response.badRequest({ body: error instanceof Error ? error.message : JSON.stringify(error), }); } @@ -91,15 +91,15 @@ export const modelRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { try { const payload = await ModelService.delete({ client: context.core.opensearch.client, id: request.params.id, }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); + return response.ok({ body: payload }); } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ + return response.badRequest({ body: error instanceof Error ? error.message : JSON.stringify(error), }); } @@ -119,7 +119,7 @@ export const modelRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { const { ids, name, from, size, extraQuery } = request.query; try { const payload = await ModelService.search({ @@ -130,9 +130,9 @@ export const modelRouter = (router: IRouter) => { size, extraQuery, }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); + return response.ok({ body: payload }); } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ + return response.badRequest({ body: error instanceof Error ? error.message : JSON.stringify(error), }); } diff --git a/server/routes/model_version_router.ts b/server/routes/model_version_router.ts index bd9b9dd7..87e33c5e 100644 --- a/server/routes/model_version_router.ts +++ b/server/routes/model_version_router.ts @@ -5,7 +5,7 @@ import { schema } from '@osd/config-schema'; import { MAX_MODEL_CHUNK_SIZE, MODEL_VERSION_STATE } from '../../common'; -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { IRouter } from '../../../../src/core/server'; import { ModelVersionService, RecordNotFoundError } from '../services'; import { MODEL_VERSION_API_ENDPOINT, @@ -101,7 +101,7 @@ export const modelVersionRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { const { algorithms, ids, @@ -134,9 +134,9 @@ export const modelVersionRouter = (router: IRouter) => { versionOrKeyword, extraQuery, }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); + return response.ok({ body: payload }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); @@ -150,15 +150,15 @@ export const modelVersionRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { try { const model = await ModelVersionService.getOne({ client: context.core.opensearch.client, id: request.params.id, }); - return opensearchDashboardsResponseFactory.ok({ body: model }); + return response.ok({ body: model }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); @@ -172,18 +172,18 @@ export const modelVersionRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { try { await ModelVersionService.delete({ client: context.core.opensearch.client, id: request.params.id, }); - return opensearchDashboardsResponseFactory.ok(); + return response.ok(); } catch (err) { if (err instanceof RecordNotFoundError) { - return opensearchDashboardsResponseFactory.notFound(); + return response.notFound(); } - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); @@ -197,15 +197,15 @@ export const modelVersionRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { try { const result = await ModelVersionService.load({ client: context.core.opensearch.client, id: request.params.id, }); - return opensearchDashboardsResponseFactory.ok({ body: result }); + return response.ok({ body: result }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); @@ -219,15 +219,15 @@ export const modelVersionRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { try { const result = await ModelVersionService.unload({ client: context.core.opensearch.client, id: request.params.id, }); - return opensearchDashboardsResponseFactory.ok({ body: result }); + return response.ok({ body: result }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); @@ -241,15 +241,15 @@ export const modelVersionRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { try { const result = await ModelVersionService.profile({ client: context.core.opensearch.client, id: request.params.id, }); - return opensearchDashboardsResponseFactory.ok({ body: result }); + return response.ok({ body: result }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); @@ -265,18 +265,18 @@ export const modelVersionRouter = (router: IRouter) => { ]), }, }, - async (context, request) => { + async (context, request, response) => { try { const body = await ModelVersionService.upload({ client: context.core.opensearch.client, model: request.body, }); - return opensearchDashboardsResponseFactory.ok({ + return response.ok({ body, }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); @@ -297,7 +297,7 @@ export const modelVersionRouter = (router: IRouter) => { }, }, }, - async (context, request) => { + async (context, request, response) => { try { await ModelVersionService.uploadModelChunk({ client: context.core.opensearch.client, @@ -305,9 +305,9 @@ export const modelVersionRouter = (router: IRouter) => { chunkId: request.params.chunkId, chunk: request.body, }); - return opensearchDashboardsResponseFactory.ok(); + return response.ok(); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest(err.message); + return response.badRequest(err.message); } } ); diff --git a/server/routes/security_router.ts b/server/routes/security_router.ts index 87c0bff0..ec241880 100644 --- a/server/routes/security_router.ts +++ b/server/routes/security_router.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { IRouter } from '../../../../src/core/server'; import { SecurityService } from '../services/security_service'; import { SECURITY_ACCOUNT_API_ENDPOINT } from './constants'; @@ -13,14 +13,14 @@ export const securityRouter = (router: IRouter) => { path: SECURITY_ACCOUNT_API_ENDPOINT, validate: false, }, - async (context) => { + async (context, _request, response) => { try { const body = await SecurityService.getAccount({ client: context.core.opensearch.client, }); - return opensearchDashboardsResponseFactory.ok({ body }); + return response.ok({ body }); } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ body: error as Error }); + return response.badRequest({ body: error as Error }); } } ); diff --git a/server/routes/task_router.ts b/server/routes/task_router.ts index ec1ddea6..f54cc5e9 100644 --- a/server/routes/task_router.ts +++ b/server/routes/task_router.ts @@ -4,7 +4,7 @@ */ import { schema } from '@osd/config-schema'; -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { IRouter } from '../../../../src/core/server'; import { TaskService } from '../services'; import { TASK_API_ENDPOINT } from './constants'; @@ -23,15 +23,15 @@ export const taskRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { try { const body = await TaskService.getOne({ client: context.core.opensearch.client, taskId: request.params.taskId, }); - return opensearchDashboardsResponseFactory.ok({ body }); + return response.ok({ body }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); @@ -54,7 +54,7 @@ export const taskRouter = (router: IRouter) => { }), }, }, - async (context, request) => { + async (context, request, response) => { const { model_id: modelId, task_type: taskType, sort, ...restQuery } = request.query; try { const body = await TaskService.search({ @@ -67,9 +67,9 @@ export const taskRouter = (router: IRouter) => { : (sort as ['last_update_time-desc' | 'last_update_time-asc']), ...restQuery, }); - return opensearchDashboardsResponseFactory.ok({ body }); + return response.ok({ body }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); From 93f17d6d5e5a97ffda7e545a3fa8e19dd1ec4ed1 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 3 Jan 2025 16:39:54 +0800 Subject: [PATCH 67/75] Renaming model version service test cases Signed-off-by: Lin Wang --- ...ervice.test.ts => model_version_service.test.ts} | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) rename server/services/__tests__/{model_service.test.ts => model_version_service.test.ts} (86%) diff --git a/server/services/__tests__/model_service.test.ts b/server/services/__tests__/model_version_service.test.ts similarity index 86% rename from server/services/__tests__/model_service.test.ts rename to server/services/__tests__/model_version_service.test.ts index 24381324..6989540b 100644 --- a/server/services/__tests__/model_service.test.ts +++ b/server/services/__tests__/model_version_service.test.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ModelService } from '../model_service'; +import { ModelVersionService } from '../model_version_service'; const createTransportMock = () => ({ request: jest.fn().mockResolvedValue({ @@ -18,10 +18,10 @@ const createTransportMock = () => ({ }), }); -describe('ModelService', () => { +describe('ModelVersionService', () => { it('should call transport request with consistent params', () => { const mockTransport = createTransportMock(); - ModelService.search({ + ModelVersionService.search({ from: 0, size: 1, transport: mockTransport, @@ -54,7 +54,7 @@ describe('ModelService', () => { it('should call transport request with sort params', () => { const mockTransport = createTransportMock(); - ModelService.search({ + ModelVersionService.search({ from: 0, size: 1, transport: mockTransport, @@ -71,7 +71,7 @@ describe('ModelService', () => { }); it('should return consistent results', async () => { - const result = await ModelService.search({ + const result = await ModelVersionService.search({ from: 0, size: 1, transport: createTransportMock(), @@ -82,10 +82,11 @@ describe('ModelService', () => { "data": Array [ Object { "id": "model-1", + "model_id": undefined, "name": "Model 1", }, ], - "total_models": 1, + "total_model_versions": 1, } `); }); From 0e85397dd15a8b04d30dc0219acbba54be82348e Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 3 Jan 2025 16:44:35 +0800 Subject: [PATCH 68/75] Rename model version utils Signed-off-by: Lin Wang --- server/services/model_aggregate_service.ts | 4 ++-- server/services/model_version_service.ts | 4 ++-- .../{model.test.ts => model_version.test.ts} | 17 ++++++++++------- .../utils/{model.ts => model_version.ts} | 16 +--------------- 4 files changed, 15 insertions(+), 26 deletions(-) rename server/services/utils/__tests__/{model.test.ts => model_version.test.ts} (83%) rename server/services/utils/{model.ts => model_version.ts} (84%) diff --git a/server/services/model_aggregate_service.ts b/server/services/model_aggregate_service.ts index 247ef48e..4c17e5f1 100644 --- a/server/services/model_aggregate_service.ts +++ b/server/services/model_aggregate_service.ts @@ -31,7 +31,7 @@ import { import { ModelService } from './model_service'; import { ModelVersionService } from './model_version_service'; import { MODEL_SEARCH_API } from './utils/constants'; -import { generateModelSearchQuery } from './utils/model'; +import { generateModelVersionSearchQuery } from './utils/model'; const MAX_MODEL_BUCKET_NUM = 10000; const getModelSort = (sort: ModelAggregateSort): ModelSort => { @@ -65,7 +65,7 @@ export class ModelAggregateService { path: MODEL_SEARCH_API, body: { size: 0, - query: generateModelSearchQuery({ states }), + query: generateModelVersionSearchQuery({ states }), aggs: { models: { terms: { diff --git a/server/services/model_version_service.ts b/server/services/model_version_service.ts index 970cdb1f..cf269099 100644 --- a/server/services/model_version_service.ts +++ b/server/services/model_version_service.ts @@ -21,7 +21,7 @@ import { IScopedClusterClient, OpenSearchClient } from '../../../../src/core/server'; import { MODEL_VERSION_STATE } from '../../common'; -import { generateModelSearchQuery } from './utils/model'; +import { generateModelVersionSearchQuery } from './utils/model_version'; import { RecordNotFoundError } from './errors'; import { MODEL_BASE_API, @@ -90,7 +90,7 @@ export class ModelVersionService { method: 'POST', path: `${MODEL_BASE_API}/_search`, body: { - query: generateModelSearchQuery(restParams), + query: generateModelVersionSearchQuery(restParams), from, size, ...(sort diff --git a/server/services/utils/__tests__/model.test.ts b/server/services/utils/__tests__/model_version.test.ts similarity index 83% rename from server/services/utils/__tests__/model.test.ts rename to server/services/utils/__tests__/model_version.test.ts index a49b62df..4c3fe594 100644 --- a/server/services/utils/__tests__/model.test.ts +++ b/server/services/utils/__tests__/model_version.test.ts @@ -3,13 +3,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { MODEL_STATE } from '../../../../common'; -import { generateModelSearchQuery } from '../model'; +import { MODEL_VERSION_STATE } from '../../../../common'; +import { generateModelVersionSearchQuery } from '../model_version'; -describe('generateModelSearchQuery', () => { +describe('generateModelVersionSearchQuery', () => { it('should generate consistent query when states provided', () => { - expect(generateModelSearchQuery({ states: [MODEL_STATE.loaded, MODEL_STATE.partiallyLoaded] })) - .toMatchInlineSnapshot(` + expect( + generateModelVersionSearchQuery({ + states: [MODEL_VERSION_STATE.deployed, MODEL_VERSION_STATE.partiallyDeployed], + }) + ).toMatchInlineSnapshot(` Object { "bool": Object { "must": Array [ @@ -32,7 +35,7 @@ describe('generateModelSearchQuery', () => { `); }); it('should generate consistent query when nameOrId provided', () => { - expect(generateModelSearchQuery({ nameOrId: 'foo' })).toMatchInlineSnapshot(` + expect(generateModelVersionSearchQuery({ nameOrId: 'foo' })).toMatchInlineSnapshot(` Object { "bool": Object { "must": Array [ @@ -69,7 +72,7 @@ describe('generateModelSearchQuery', () => { }); it('should generate consistent query when extraQuery provided', () => { expect( - generateModelSearchQuery({ + generateModelVersionSearchQuery({ extraQuery: { bool: { must: [ diff --git a/server/services/utils/model.ts b/server/services/utils/model_version.ts similarity index 84% rename from server/services/utils/model.ts rename to server/services/utils/model_version.ts index 4392e849..9933872a 100644 --- a/server/services/utils/model.ts +++ b/server/services/utils/model_version.ts @@ -6,21 +6,7 @@ import { MODEL_VERSION_STATE } from '../../../common'; import { generateTermQuery } from './query'; -export const convertModelSource = (source: { - model_content: string; - name: string; - algorithm: string; - model_state: string; - model_version: string; -}) => ({ - content: source.model_content, - name: source.name, - algorithm: source.algorithm, - state: source.model_state, - version: source.model_version, -}); - -export const generateModelSearchQuery = ({ +export const generateModelVersionSearchQuery = ({ ids, algorithms, name, From 1409e0c919844a5bb9dab41674a4d33f76c33575 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 3 Jan 2025 16:52:18 +0800 Subject: [PATCH 69/75] Renaming model version router Signed-off-by: Lin Wang --- ...r.test.ts => model_version_router.test.ts} | 30 +++++++++---------- ...el_service.ts => model_version_service.ts} | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) rename server/routes/__tests__/{model_router.test.ts => model_version_router.test.ts} (77%) rename server/services/__mocks__/{model_service.ts => model_version_service.ts} (85%) diff --git a/server/routes/__tests__/model_router.test.ts b/server/routes/__tests__/model_version_router.test.ts similarity index 77% rename from server/routes/__tests__/model_router.test.ts rename to server/routes/__tests__/model_version_router.test.ts index 7126e2f6..d701e588 100644 --- a/server/routes/__tests__/model_router.test.ts +++ b/server/routes/__tests__/model_version_router.test.ts @@ -10,9 +10,9 @@ import { Router } from '../../../../../src/core/server/http/router'; import { triggerHandler, createDataSourceEnhancedRouter } from '../router.mock'; import { httpServerMock } from '../../../../../src/core/server/http/http_server.mocks'; import { loggerMock } from '../../../../../src/core/server/logging/logger.mock'; -import { MODEL_API_ENDPOINT } from '../constants'; -import { modelRouter } from '../model_router'; -import { ModelService } from '../../services'; +import { MODEL_VERSION_API_ENDPOINT } from '../constants'; +import { modelVersionRouter } from '../model_version_router'; +import { ModelVersionService } from '../../services'; const setupRouter = () => { const mockedLogger = loggerMock.create(); @@ -22,7 +22,7 @@ const setupRouter = () => { getLatestCurrentUserTransport, } = createDataSourceEnhancedRouter(mockedLogger); - modelRouter(router); + modelVersionRouter(router); return { router, dataSourceTransportMock, @@ -36,29 +36,29 @@ const triggerModelSearch = ( ) => triggerHandler(router, { method: 'GET', - path: MODEL_API_ENDPOINT, + path: MODEL_VERSION_API_ENDPOINT, req: httpServerMock.createRawRequest({ query: { data_source_id: dataSourceId, from, size }, }), }); -jest.mock('../../services/model_service'); +jest.mock('../../services/model_version_service'); -describe('model routers', () => { +describe('model version routers', () => { beforeEach(() => { - jest.spyOn(ModelService, 'search'); + jest.spyOn(ModelVersionService, 'search'); }); afterEach(() => { jest.resetAllMocks(); }); - describe('model search', () => { + describe('model version search', () => { it('should call connector search and return consistent result', async () => { - expect(ModelService.search).not.toHaveBeenCalled(); + expect(ModelVersionService.search).not.toHaveBeenCalled(); const { router, getLatestCurrentUserTransport } = setupRouter(); const result = (await triggerModelSearch(router, { from: 0, size: 50 })) as ResponseObject; - expect(ModelService.search).toHaveBeenCalledWith( + expect(ModelVersionService.search).toHaveBeenCalledWith( expect.objectContaining({ transport: getLatestCurrentUserTransport(), from: 0, @@ -78,11 +78,11 @@ describe('model routers', () => { }); it('should call model search with data source transport', async () => { - expect(ModelService.search).not.toHaveBeenCalled(); + expect(ModelVersionService.search).not.toHaveBeenCalled(); const { router, dataSourceTransportMock } = setupRouter(); await triggerModelSearch(router, { dataSourceId: 'foo', from: 0, size: 50 }); - expect(ModelService.search).toHaveBeenCalledWith({ + expect(ModelVersionService.search).toHaveBeenCalledWith({ transport: dataSourceTransportMock, from: 0, size: 50, @@ -90,13 +90,13 @@ describe('model routers', () => { }); it('should response error message after model search throw error', async () => { - jest.spyOn(ModelService, 'search').mockImplementationOnce(() => { + jest.spyOn(ModelVersionService, 'search').mockImplementationOnce(() => { throw new Error('foo'); }); const { router, getLatestCurrentUserTransport } = setupRouter(); const result = (await triggerModelSearch(router, { from: 0, size: 50 })) as Boom; - expect(ModelService.search).toHaveBeenCalledWith( + expect(ModelVersionService.search).toHaveBeenCalledWith( expect.objectContaining({ transport: getLatestCurrentUserTransport(), from: 0, diff --git a/server/services/__mocks__/model_service.ts b/server/services/__mocks__/model_version_service.ts similarity index 85% rename from server/services/__mocks__/model_service.ts rename to server/services/__mocks__/model_version_service.ts index faedeb58..f995bf77 100644 --- a/server/services/__mocks__/model_service.ts +++ b/server/services/__mocks__/model_version_service.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -export class ModelService { +export class ModelVersionService { public static async search() { return { data: [{ name: 'Model 1' }], From eca42341479e40d8452f1c012e3c58226db729e3 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 3 Jan 2025 16:56:06 +0800 Subject: [PATCH 70/75] Hide global breadcrumbs model registry unit tests Signed-off-by: Lin Wang --- .../__tests__/global_breadcrumbs.test.tsx | 135 +----------------- 1 file changed, 1 insertion(+), 134 deletions(-) diff --git a/public/components/__tests__/global_breadcrumbs.test.tsx b/public/components/__tests__/global_breadcrumbs.test.tsx index f4b46b0b..b13dd264 100644 --- a/public/components/__tests__/global_breadcrumbs.test.tsx +++ b/public/components/__tests__/global_breadcrumbs.test.tsx @@ -5,8 +5,7 @@ import React from 'react'; import { GlobalBreadcrumbs } from '../global_breadcrumbs'; -import { history, render, waitFor, act } from '../../../test/test_utils'; -import { ModelVersion, ModelVersionDetail } from '../../apis/model_version'; +import { render } from '../../../test/test_utils'; describe('', () => { it('should call onBreadcrumbsChange with overview title', () => { @@ -20,136 +19,4 @@ describe('', () => { { text: 'Overview' }, ]); }); - - it('should call onBreadcrumbsChange with register model breadcrumbs', () => { - const onBreadcrumbsChange = jest.fn(); - render(, { - route: '/model-registry/register-model', - }); - - expect(onBreadcrumbsChange).toHaveBeenCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - { text: 'Register model' }, - ]); - }); - - it('should call onBreadcrumbsChange with register version breadcrumbs', async () => { - const onBreadcrumbsChange = jest.fn(); - render(, { - route: '/model-registry/register-model/model-id-1', - }); - - expect(onBreadcrumbsChange).toHaveBeenCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - ]); - - await waitFor(() => { - expect(onBreadcrumbsChange).toBeCalledTimes(2); - expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - { text: 'model1', href: '/model-registry/model/model-id-1' }, - { text: 'Register version' }, - ]); - }); - }); - - it('should call onBreadcrumbsChange with model breadcrumbs', async () => { - const onBreadcrumbsChange = jest.fn(); - render(, { - route: '/model-registry/model/model-id-1', - }); - - expect(onBreadcrumbsChange).toHaveBeenCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - ]); - - await waitFor(() => { - expect(onBreadcrumbsChange).toBeCalledTimes(2); - expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - { text: 'model1' }, - ]); - }); - }); - - it('should call onBreadcrumbsChange with model version breadcrumbs', async () => { - const onBreadcrumbsChange = jest.fn(); - render(, { - route: '/model-registry/model-version/1', - }); - - expect(onBreadcrumbsChange).toHaveBeenCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - ]); - - await waitFor(() => { - expect(onBreadcrumbsChange).toBeCalledTimes(2); - expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - { text: 'model1', href: '/model-registry/model/1' }, - { text: 'Version 1.0.0' }, - ]); - }); - }); - - it('should NOT call onBreadcrumbs with steal breadcrumbs after pathname changed', async () => { - jest.useFakeTimers(); - const onBreadcrumbsChange = jest.fn(); - const modelGetOneMock = jest.spyOn(ModelVersion.prototype, 'getOne').mockImplementation( - (id) => - new Promise((resolve) => { - setTimeout( - () => { - resolve({ - id, - name: `model${id}`, - model_version: `1.0.${id}`, - } as ModelVersionDetail); - }, - id === '2' ? 1000 : 0 - ); - }) - ); - render(, { - route: '/model-registry/model-version/2', - }); - - expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - ]); - - history.current.push('/model-registry/model/model-id-1'); - - await act(async () => { - jest.advanceTimersByTime(200); - }); - - expect(onBreadcrumbsChange).toHaveBeenLastCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - { text: 'model1' }, - ]); - - await act(async () => { - jest.advanceTimersByTime(1000); - }); - - expect(onBreadcrumbsChange).not.toHaveBeenLastCalledWith([ - { text: 'Machine Learning', href: '/' }, - { text: 'Model Registry', href: '/model-registry/model-list' }, - { text: 'model2', href: '/model-registry/model/2' }, - { text: 'Version 1.0.2' }, - ]); - - modelGetOneMock.mockRestore(); - jest.useRealTimers(); - }); }); From e1e8670e6bea2cce25d005fcaead25fe2c5b82df Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 3 Jan 2025 17:31:45 +0800 Subject: [PATCH 71/75] Fix model version table unit tests Signed-off-by: Lin Wang --- .../__tests__/model_version_table.test.tsx | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx index 759e0338..ab829539 100644 --- a/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx +++ b/public/components/model/model_versions_panel/__tests__/model_version_table.test.tsx @@ -5,7 +5,7 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; -import { within } from '@testing-library/dom'; +import { fireEvent, within } from '@testing-library/dom'; import { render, screen, waitFor } from '../../../../../test/test_utils'; import { ModelVersionTable } from '../model_version_table'; @@ -66,7 +66,7 @@ describe('', () => { timeout: 2000, } ); - await user.click(screen.getByText('Version')); + await fireEvent.click(screen.getByText('Version')); await waitFor(async () => { expect(screen.getByText('Sort A-Z').closest('li')).toHaveClass( 'euiDataGridHeader__action--selected' @@ -83,13 +83,12 @@ describe('', () => { it( 'should NOT render sort button for state and status column', async () => { - const user = userEvent.setup(); render(); - await user.click(screen.getByText('State')); + await fireEvent.click(screen.getByText('State')); expect(screen.queryByTitle('Sort A-Z')).toBeNull(); - await user.click(screen.getByText('Status')); + await fireEvent.click(screen.getByText('Status')); expect(screen.queryByTitle('Sort A-Z')).toBeNull(); }, 20 * 1000 From d9faa0da489a8b9e392728780ba446029b34c118 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 6 Jan 2025 11:25:55 +0800 Subject: [PATCH 72/75] Fix invalid import path for model util Signed-off-by: Lin Wang --- server/services/model_aggregate_service.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/services/model_aggregate_service.ts b/server/services/model_aggregate_service.ts index 4c17e5f1..47b62708 100644 --- a/server/services/model_aggregate_service.ts +++ b/server/services/model_aggregate_service.ts @@ -31,7 +31,7 @@ import { import { ModelService } from './model_service'; import { ModelVersionService } from './model_version_service'; import { MODEL_SEARCH_API } from './utils/constants'; -import { generateModelVersionSearchQuery } from './utils/model'; +import { generateModelVersionSearchQuery } from './utils/model_version'; const MAX_MODEL_BUCKET_NUM = 10000; const getModelSort = (sort: ModelAggregateSort): ModelSort => { From 717cbad1aadd8caf6b6e188820cafd2b5c53ef68 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 6 Jan 2025 13:55:24 +0800 Subject: [PATCH 73/75] Fix failed UT in monitoring page Signed-off-by: Lin Wang --- public/components/monitoring/index.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/public/components/monitoring/index.tsx b/public/components/monitoring/index.tsx index 42995b11..a32882c1 100644 --- a/public/components/monitoring/index.tsx +++ b/public/components/monitoring/index.tsx @@ -131,8 +131,8 @@ export const Monitoring = (props: MonitoringProps) => { From f392d31e928e4fad57f5a4b234afcebb695aa7d1 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 6 Jan 2025 14:12:05 +0800 Subject: [PATCH 74/75] Fix failed cases in useMonitoring Signed-off-by: Lin Wang --- .../__tests__/use_monitoring.test.tsx | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/public/components/monitoring/__tests__/use_monitoring.test.tsx b/public/components/monitoring/__tests__/use_monitoring.test.tsx index 272b85bb..d128ac05 100644 --- a/public/components/monitoring/__tests__/use_monitoring.test.tsx +++ b/public/components/monitoring/__tests__/use_monitoring.test.tsx @@ -5,7 +5,7 @@ import React, { useContext } from 'react'; import { act, renderHook } from '@testing-library/react-hooks'; -import { Model, ModelSearchResponse } from '../../../apis/model'; +import { ModelVersion, ModelVersionSearchResponse } from '../../../apis/model_version'; import { Connector } from '../../../apis/connector'; import { useMonitoring } from '../use_monitoring'; import { @@ -52,14 +52,14 @@ const setup = ({ jest.mock('../../../apis/connector'); const mockEmptyRecords = () => - jest.spyOn(Model.prototype, 'search').mockResolvedValueOnce({ + jest.spyOn(ModelVersion.prototype, 'search').mockResolvedValueOnce({ data: [], - total_models: 0, + total_model_versions: 0, }); describe('useMonitoring', () => { beforeEach(() => { - jest.spyOn(Model.prototype, 'search').mockResolvedValue({ + jest.spyOn(ModelVersion.prototype, 'search').mockResolvedValue({ data: [ { id: 'model-1-id', @@ -72,7 +72,7 @@ describe('useMonitoring', () => { planning_worker_nodes: ['node1', 'node2', 'node3'], }, ], - total_models: 500, + total_model_versions: 500, }); }); @@ -89,7 +89,7 @@ describe('useMonitoring', () => { result.current.searchByNameOrId('foo'); }); await waitFor(() => - expect(Model.prototype.search).toHaveBeenCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenCalledWith( expect.objectContaining({ nameOrId: 'foo', states: ['DEPLOY_FAILED', 'DEPLOYED', 'PARTIALLY_DEPLOYED'], @@ -101,7 +101,7 @@ describe('useMonitoring', () => { result.current.searchByStatus(['responding']); }); await waitFor(() => - expect(Model.prototype.search).toHaveBeenCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenCalledWith( expect.objectContaining({ nameOrId: 'foo', states: ['DEPLOYED'], @@ -118,7 +118,7 @@ describe('useMonitoring', () => { result.current.searchByStatus(['partial-responding']); }); await waitFor(() => - expect(Model.prototype.search).toHaveBeenCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenCalledWith( expect.objectContaining({ states: ['PARTIALLY_DEPLOYED'], }) @@ -130,7 +130,7 @@ describe('useMonitoring', () => { const { result, waitFor } = renderHook(() => useMonitoring()); await waitFor(() => - expect(Model.prototype.search).toHaveBeenCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenCalledWith( expect.objectContaining({ sort: ['model_state-asc'], from: 0, @@ -146,7 +146,7 @@ describe('useMonitoring', () => { }); }); await waitFor(() => - expect(Model.prototype.search).toHaveBeenCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenCalledWith( expect.objectContaining({ sort: ['name-desc'], from: 10, @@ -159,17 +159,17 @@ describe('useMonitoring', () => { it('should call search API twice after reload called', async () => { const { result, waitFor } = renderHook(() => useMonitoring()); - await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledTimes(1)); + await waitFor(() => expect(ModelVersion.prototype.search).toHaveBeenCalledTimes(1)); act(() => { result.current.reload(); }); - await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledTimes(2)); + await waitFor(() => expect(ModelVersion.prototype.search).toHaveBeenCalledTimes(2)); }); it('should return consistent deployedModels', async () => { - jest.spyOn(Model.prototype, 'search').mockRestore(); - const searchMock = jest.spyOn(Model.prototype, 'search').mockResolvedValue({ + jest.spyOn(ModelVersion.prototype, 'search').mockRestore(); + const searchMock = jest.spyOn(ModelVersion.prototype, 'search').mockResolvedValue({ data: [ { id: 'model-1-id', @@ -206,7 +206,7 @@ describe('useMonitoring', () => { }, }, ], - total_models: 3, + total_model_versions: 3, }); const { result, waitFor } = renderHook(() => useMonitoring()); @@ -239,8 +239,8 @@ describe('useMonitoring', () => { }); it('should return empty connector if connector id not exists in all connectors', async () => { - jest.spyOn(Model.prototype, 'search').mockRestore(); - const searchMock = jest.spyOn(Model.prototype, 'search').mockResolvedValue({ + jest.spyOn(ModelVersion.prototype, 'search').mockRestore(); + const searchMock = jest.spyOn(ModelVersion.prototype, 'search').mockResolvedValue({ data: [ { id: 'model-1-id', @@ -254,7 +254,7 @@ describe('useMonitoring', () => { connector_id: 'not-exists-external-connector-id', }, ], - total_models: 1, + total_model_versions: 1, }); const { result, waitFor } = renderHook(() => useMonitoring()); @@ -272,13 +272,13 @@ describe('useMonitoring', () => { }); it('should return empty connector if failed to load all external connectors', async () => { - jest.spyOn(Model.prototype, 'search').mockRestore(); + jest.spyOn(ModelVersion.prototype, 'search').mockRestore(); const getAllExternalConnectorsMock = jest .spyOn(Connector.prototype, 'getAll') .mockImplementation(async () => { throw new Error(); }); - const searchMock = jest.spyOn(Model.prototype, 'search').mockResolvedValue({ + const searchMock = jest.spyOn(ModelVersion.prototype, 'search').mockResolvedValue({ data: [ { id: 'model-1-id', @@ -292,7 +292,7 @@ describe('useMonitoring', () => { connector_id: 'not-exists-external-connector-id', }, ], - total_models: 1, + total_model_versions: 1, }); const { result, waitFor } = renderHook(() => useMonitoring()); @@ -331,8 +331,8 @@ describe('useMonitoring', () => { }); await waitFor(() => { - expect(Model.prototype.search).toHaveBeenCalledTimes(3); - expect(Model.prototype.search).toHaveBeenLastCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenCalledTimes(3); + expect(ModelVersion.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ from: 0, }) @@ -361,8 +361,8 @@ describe('useMonitoring', () => { }); await waitFor(() => { - expect(Model.prototype.search).toHaveBeenCalledTimes(3); - expect(Model.prototype.search).toHaveBeenLastCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenCalledTimes(3); + expect(ModelVersion.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ from: 0, }) @@ -379,7 +379,7 @@ describe('useMonitoring', () => { result.current.searchBySource(['local']); }); await waitFor(() => - expect(Model.prototype.search).toHaveBeenLastCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ extraQuery: { bool: { @@ -402,7 +402,7 @@ describe('useMonitoring', () => { result.current.searchBySource(['external']); }); await waitFor(() => - expect(Model.prototype.search).toHaveBeenLastCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ extraQuery: { bool: { @@ -425,7 +425,7 @@ describe('useMonitoring', () => { result.current.searchBySource(['external', 'local']); }); await waitFor(() => - expect(Model.prototype.search).toHaveBeenLastCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ extraQuery: undefined, }) @@ -442,7 +442,7 @@ describe('useMonitoring', () => { }); await waitFor(() => - expect(Model.prototype.search).toHaveBeenLastCalledWith( + expect(ModelVersion.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ extraQuery: { bool: { @@ -486,7 +486,7 @@ describe('useMonitoring', () => { }, }); await waitFor(() => { - expect(Model.prototype.search).not.toHaveBeenCalled(); + expect(ModelVersion.prototype.search).not.toHaveBeenCalled(); }); }); @@ -504,7 +504,7 @@ describe('useMonitoring', () => { dataSourceId: 'foo', }); await waitFor(() => { - expect(Model.prototype.search).toHaveBeenCalledWith(dataSourceIdExpect); + expect(ModelVersion.prototype.search).toHaveBeenCalledWith(dataSourceIdExpect); expect(getAllConnectorMock).toHaveBeenCalledWith(dataSourceIdExpect); }); }); @@ -529,7 +529,7 @@ describe('useMonitoring', () => { }); }); await waitFor(() => { - expect(Model.prototype.search).toHaveBeenCalledTimes(2); + expect(ModelVersion.prototype.search).toHaveBeenCalledTimes(2); expect(result.current.params).toEqual( expect.objectContaining({ currentPage: 2, @@ -541,7 +541,7 @@ describe('useMonitoring', () => { setSelectedDataSourceOption({ id: 'bar' }); }); await waitFor(() => { - expect(Model.prototype.search).toHaveBeenCalledTimes(3); + expect(ModelVersion.prototype.search).toHaveBeenCalledTimes(3); expect(result.current.params.connector).toEqual([]); expect(result.current.params.currentPage).toEqual(1); }); @@ -577,7 +577,7 @@ describe('useMonitoring', () => { status: undefined, }) ); - expect(Model.prototype.search).toHaveBeenCalledTimes(3); + expect(ModelVersion.prototype.search).toHaveBeenCalledTimes(3); }); }); @@ -586,9 +586,9 @@ describe('useMonitoring', () => { renderHookResult: { result, waitFor }, } = setup(); await waitFor(() => { - expect(Model.prototype.search).toHaveBeenCalled(); + expect(ModelVersion.prototype.search).toHaveBeenCalled(); }); - jest.spyOn(Model.prototype, 'search').mockImplementationOnce( + jest.spyOn(ModelVersion.prototype, 'search').mockImplementationOnce( () => new Promise((resolve) => { setTimeout(() => { @@ -605,7 +605,7 @@ describe('useMonitoring', () => { planning_worker_nodes: ['node1', 'node2', 'node3'], }, ], - total_models: 1, + total_model_versions: 1, }); }, 300); }) @@ -628,7 +628,7 @@ describe('useMonitoring', () => { describe('useMonitoring.pageStatus', () => { it("should return 'loading' if data loading and will back to 'normal' after data loaded", async () => { let resolveFn: Function; - const promise = new Promise((resolve) => { + const promise = new Promise((resolve) => { resolveFn = () => { resolve({ data: [ @@ -643,11 +643,11 @@ describe('useMonitoring.pageStatus', () => { planning_worker_nodes: ['node1', 'node2', 'node3'], }, ], - total_models: 1, + total_model_versions: 1, }); }; }); - jest.spyOn(Model.prototype, 'search').mockReturnValueOnce(promise); + jest.spyOn(ModelVersion.prototype, 'search').mockReturnValueOnce(promise); const { result } = renderHook(() => useMonitoring()); expect(result.current.pageStatus).toBe('loading'); From 64bc340c987d857709df9f0ea2331fdd3105f377 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 6 Jan 2025 17:46:45 +0800 Subject: [PATCH 75/75] Change back waitFor Signed-off-by: Lin Wang --- public/hooks/tests/use_fetcher.test.tsx | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/public/hooks/tests/use_fetcher.test.tsx b/public/hooks/tests/use_fetcher.test.tsx index 32c5a7a1..5213cc20 100644 --- a/public/hooks/tests/use_fetcher.test.tsx +++ b/public/hooks/tests/use_fetcher.test.tsx @@ -6,13 +6,13 @@ import React from 'react'; import { act, renderHook } from '@testing-library/react-hooks'; import { DO_NOT_FETCH, useFetcher } from '../use_fetcher'; -import { render, waitFor } from '../../../test/test_utils'; +import { render, waitFor as testUtilWaitFor } from '../../../test/test_utils'; describe('useFetcher', () => { it('should call fetcher with consistent params and return consistent result', async () => { const data = { foo: 'bar' }; const fetcher = jest.fn((_arg1: string) => Promise.resolve(data)); - const { result } = renderHook(() => useFetcher(fetcher, 'foo')); + const { result, waitFor } = renderHook(() => useFetcher(fetcher, 'foo')); await waitFor(() => result.current.data !== null); expect(result.current.data).toBe(data); @@ -22,7 +22,7 @@ describe('useFetcher', () => { it('should call fetcher only once if params content not change', async () => { const fetcher = jest.fn((_arg1: any) => Promise.resolve()); - const { result, rerender } = renderHook(({ params }) => useFetcher(fetcher, params), { + const { result, waitFor, rerender } = renderHook(({ params }) => useFetcher(fetcher, params), { initialProps: { params: { foo: 'bar' } }, }); @@ -109,7 +109,7 @@ describe('useFetcher', () => { it('should return consistent updated data', async () => { const fetcher = () => Promise.resolve('foo'); - const { result } = renderHook(() => useFetcher(fetcher)); + const { result, waitFor } = renderHook(() => useFetcher(fetcher)); await waitFor(() => result.current.data === 'foo'); await act(async () => { @@ -122,7 +122,7 @@ describe('useFetcher', () => { it('should return consistent mutated data', async () => { const fetcher = () => Promise.resolve('foo'); - const { result } = renderHook(() => useFetcher(fetcher)); + const { result, waitFor } = renderHook(() => useFetcher(fetcher)); await waitFor(() => result.current.data === 'foo'); @@ -152,7 +152,7 @@ describe('useFetcher', () => { it('should call fetcher after first parameter changed from DO_NOT_FETCH', async () => { const fetcher = jest.fn(async (...params) => params); - const { result, rerender, waitFor: hookWaitFor } = renderHook( + const { result, rerender, waitFor } = renderHook( ({ params }) => useFetcher(fetcher, ...params), { initialProps: { @@ -165,7 +165,7 @@ describe('useFetcher', () => { expect(result.current.loading).toBe(true); expect(fetcher).toHaveBeenCalled(); - await hookWaitFor(() => { + await waitFor(() => { expect(result.current.loading).toBe(false); expect(result.current.data).toEqual([]); }); @@ -193,11 +193,11 @@ describe('useFetcher', () => { const { getByText, rerender } = render( ); - await waitFor(() => { + await testUtilWaitFor(() => { expect(getByText('false')).toBeInTheDocument(); }); rerender(); - await waitFor(() => { + await testUtilWaitFor(() => { expect(getByText('false')).toBeInTheDocument(); });