From 93a2e29376319ef037004bbcfc615b54c6c8af1c Mon Sep 17 00:00:00 2001
From: kumarijy <jyoti.kumari@intel.com>
Date: Mon, 4 Mar 2024 21:59:37 -0800
Subject: [PATCH] add device options on Frontend and Backend

---
 modules/openvino_code/package-lock.json       |  4 ++--
 modules/openvino_code/package.json            | 13 ++++++++++++-
 modules/openvino_code/server/main.py          |  7 +++++++
 .../openvino_code/server/src/generators.py    |  2 +-
 .../shared/side-panel-message.ts              |  1 +
 .../sections/ServerSection/ServerSection.tsx  | 19 +++++++++++++++++++
 modules/openvino_code/src/configuration.ts    |  2 ++
 .../src/python-server/python-server-runner.ts |  9 ++++++---
 .../side-panel/side-panel-message-handler.ts  |  4 ++++
 9 files changed, 54 insertions(+), 7 deletions(-)

diff --git a/modules/openvino_code/package-lock.json b/modules/openvino_code/package-lock.json
index 33d55b8d2..433dacc74 100644
--- a/modules/openvino_code/package-lock.json
+++ b/modules/openvino_code/package-lock.json
@@ -1,12 +1,12 @@
 {
   "name": "openvino-code-completion",
-  "version": "0.0.11",
+  "version": "0.0.12",
   "lockfileVersion": 3,
   "requires": true,
   "packages": {
     "": {
       "name": "openvino-code-completion",
-      "version": "0.0.11",
+      "version": "0.0.12",
       "license": "https://github.com/openvinotoolkit/openvino_contrib/blob/master/LICENSE",
       "workspaces": [
         "side-panel-ui"
diff --git a/modules/openvino_code/package.json b/modules/openvino_code/package.json
index 078a54263..95f75b363 100644
--- a/modules/openvino_code/package.json
+++ b/modules/openvino_code/package.json
@@ -1,7 +1,7 @@
 {
   "publisher": "OpenVINO",
   "name": "openvino-code-completion",
-  "version": "0.0.11",
+  "version": "0.0.12",
   "displayName": "OpenVINO Code Completion",
   "description": "VSCode extension for AI code completion with OpenVINO",
   "icon": "media/logo.png",
@@ -199,6 +199,17 @@
             ],
             "description": "Which model to use for code generation."
           },
+          "openvinoCode.device": {
+            "order": 0,
+            "type": "string",
+            "default": "GPU",
+            "enum": [
+              "CPU",
+              "GPU",
+              "NPU"
+            ],
+            "description": "Which device to use for code generation."
+          },
           "openvinoCode.serverUrl": {
             "order": 1,
             "type": "string",
diff --git a/modules/openvino_code/server/main.py b/modules/openvino_code/server/main.py
index 3d86194e6..861fc3302 100644
--- a/modules/openvino_code/server/main.py
+++ b/modules/openvino_code/server/main.py
@@ -1,8 +1,12 @@
 from src.utils import get_parser, setup_logger
+import logging
+
 
 
 # Logger should be set up before other imports to propagate logging config to other packages
 setup_logger()
+logger = logging.getLogger("")
+
 
 import uvicorn  # noqa: E402
 
@@ -14,6 +18,9 @@ def main():
     args = get_parser().parse_args()
 
     # temporary solution for cli args passing
+
+    logger.error(args.model)
+    logger.error(args.device)
     generator_dependency = get_generator_dependency(args.model, args.device, args.tokenizer_checkpoint, args.assistant)
     app.dependency_overrides[get_generator_dummy] = generator_dependency
 
diff --git a/modules/openvino_code/server/src/generators.py b/modules/openvino_code/server/src/generators.py
index 761d652e2..14b9f957a 100644
--- a/modules/openvino_code/server/src/generators.py
+++ b/modules/openvino_code/server/src/generators.py
@@ -86,7 +86,7 @@ class OVGenerator(GeneratorFunctor):
     def __init__(
         self,
         checkpoint: str,
-        device: str = "CPU",
+        device: str = "GPU",
         tokenizer_checkpoint: Optional[str] = None,
         assistant_checkpoint: Optional[str] = None,
         summarize_stop_tokens: Optional[Container[str]] = SUMMARIZE_STOP_TOKENS,
diff --git a/modules/openvino_code/shared/side-panel-message.ts b/modules/openvino_code/shared/side-panel-message.ts
index 0c0720d6d..2ebdf20b6 100644
--- a/modules/openvino_code/shared/side-panel-message.ts
+++ b/modules/openvino_code/shared/side-panel-message.ts
@@ -10,6 +10,7 @@ export enum SidePanelMessageTypes {
   GENERATE_COMPLETION_CLICK = `${sidePanelMessagePrefix}.generateCompletionClick`,
   SETTINGS_CLICK = `${sidePanelMessagePrefix}.settingsClick`,
   MODEL_CHANGE = `${sidePanelMessagePrefix}.modelChange`,
+  DEVICE_CHANGE = `${sidePanelMessagePrefix}.deviceChange`,
 }
 
 export interface ISidePanelMessage<P = unknown> {
diff --git a/modules/openvino_code/side-panel-ui/src/components/sections/ServerSection/ServerSection.tsx b/modules/openvino_code/side-panel-ui/src/components/sections/ServerSection/ServerSection.tsx
index 68ed0ea33..5be1983a1 100644
--- a/modules/openvino_code/side-panel-ui/src/components/sections/ServerSection/ServerSection.tsx
+++ b/modules/openvino_code/side-panel-ui/src/components/sections/ServerSection/ServerSection.tsx
@@ -7,6 +7,9 @@ import { ServerStatus } from './ServerStatus/ServerStatus';
 import './ServerSection.css';
 import { ModelSelect } from './ModelSelect/ModelSelect';
 import { ModelName } from '@shared/model';
+import { DeviceSelect } from './DeviceSelect/DeviceSelect';
+import { DeviceName } from '@shared/device';
+
 
 interface ServerSectionProps {
   state: IExtensionState | null;
@@ -46,6 +49,15 @@ export function ServerSection({ state }: ServerSectionProps): JSX.Element {
     });
   };
 
+  const handleDeviceChange = (deviceName: DeviceName) => {
+    vscode.postMessage({
+      type: SidePanelMessageTypes.DEVICE_CHANGE,
+      payload: {
+        deviceName,
+      },
+    });
+  };
+
   if (!state) {
     return <>Extension state is not available</>;
   }
@@ -64,6 +76,13 @@ export function ServerSection({ state }: ServerSectionProps): JSX.Element {
         supportedFeatures={state.features.supportedList}
         serverStatus={state.server.status}
       ></ModelSelect>
+      <DeviceSelect
+        disabled={!isServerStopped}
+        onChange={handleDeviceChange}
+        selectedDeviceName={state.config.device}
+        supportedFeatures={state.features.supportedList}
+        serverStatus={state.server.status}
+      ></DeviceSelect>
       {isServerStarting && <StartingStages currentStage={state.server.stage}></StartingStages>}
       <div className="button-group">
         {isServerStopped && <button onClick={handleStartServerClick}>Start Server</button>}
diff --git a/modules/openvino_code/src/configuration.ts b/modules/openvino_code/src/configuration.ts
index bd76a32eb..e41e4bc5f 100644
--- a/modules/openvino_code/src/configuration.ts
+++ b/modules/openvino_code/src/configuration.ts
@@ -1,4 +1,5 @@
 import { ModelName } from '@shared/model';
+import { DeviceName } from '@shared/device';
 import { WorkspaceConfiguration, workspace } from 'vscode';
 import { CONFIG_KEY } from './constants';
 
@@ -7,6 +8,7 @@ import { CONFIG_KEY } from './constants';
  */
 export type CustomConfiguration = {
   model: ModelName;
+  device: DeviceName;
   serverUrl: string;
   serverRequestTimeout: number;
   streamInlineCompletion: boolean;
diff --git a/modules/openvino_code/src/python-server/python-server-runner.ts b/modules/openvino_code/src/python-server/python-server-runner.ts
index a65afa10b..a13f82dc7 100644
--- a/modules/openvino_code/src/python-server/python-server-runner.ts
+++ b/modules/openvino_code/src/python-server/python-server-runner.ts
@@ -13,6 +13,7 @@ import { join } from 'path';
 import { MODEL_NAME_TO_ID_MAP, ModelName } from '@shared/model';
 import { extensionState } from '../state';
 import { clearLruCache } from '../lru-cache.decorator';
+import { DEVICE_NAME_TO_ID_MAP, DeviceName } from '@shared/device';
 
 const SERVER_STARTED_STDOUT_ANCHOR = 'OpenVINO Code Server started';
 
@@ -20,7 +21,7 @@ interface ServerHooks {
   onStarted: () => void;
 }
 
-async function runServer(modelName: ModelName, config: PythonServerConfiguration, hooks?: ServerHooks) {
+async function runServer(modelName: ModelName, deviceName: DeviceName, config: PythonServerConfiguration, hooks?: ServerHooks) {
   const { serverDir, proxyEnv, abortSignal, logger } = config;
   logger.info('Starting server...');
 
@@ -40,8 +41,9 @@ async function runServer(modelName: ModelName, config: PythonServerConfiguration
   }
 
   const model = MODEL_NAME_TO_ID_MAP[modelName];
+  const device = DEVICE_NAME_TO_ID_MAP[deviceName];
 
-  await spawnCommand(venvPython, ['main.py', '--model', model], {
+  await spawnCommand(venvPython, ['main.py', '--model', model, '--device', device], {
     logger,
     cwd: serverDir,
     abortSignal,
@@ -149,8 +151,9 @@ export class NativePythonServerRunner {
     this._stateController.setStage(ServerStartingStage.START_SERVER);
 
     const modelName = extensionState.config.model;
+    const deviceName = extensionState.config.device;
 
-    await runServer(modelName, config, {
+    await runServer(modelName, deviceName, config, {
       onStarted: () => {
         this._stateController.setStatus(ServerStatus.STARTED);
         this._stateController.setStage(null);
diff --git a/modules/openvino_code/src/side-panel/side-panel-message-handler.ts b/modules/openvino_code/src/side-panel/side-panel-message-handler.ts
index 35f42032b..b9716c5ac 100644
--- a/modules/openvino_code/src/side-panel/side-panel-message-handler.ts
+++ b/modules/openvino_code/src/side-panel/side-panel-message-handler.ts
@@ -4,6 +4,8 @@ import { Webview, commands } from 'vscode';
 import { settingsService } from '../settings/settings.service';
 import { COMMANDS } from '../constants';
 import { ModelName } from '@shared/model';
+import { DeviceName } from '@shared/device';
+
 
 type SidePanelMessageHandlerType = (webview: Webview, payload?: ISidePanelMessage['payload']) => void;
 
@@ -12,6 +14,8 @@ const sidePanelMessageHandlers: Record<SidePanelMessageTypes, SidePanelMessageHa
   [SidePanelMessageTypes.SETTINGS_CLICK]: () => settingsService.openSettings(),
   [SidePanelMessageTypes.MODEL_CHANGE]: (_, payload) =>
     settingsService.updateSetting('model', (payload as { modelName: ModelName }).modelName),
+  [SidePanelMessageTypes.DEVICE_CHANGE]: (_, payload) =>
+    settingsService.updateSetting('device', (payload as { deviceName: DeviceName }).deviceName),  
   [SidePanelMessageTypes.START_SERVER_CLICK]: () => void commands.executeCommand(COMMANDS.START_SERVER_NATIVE),
   [SidePanelMessageTypes.STOP_SERVER_CLICK]: () => void commands.executeCommand(COMMANDS.STOP_SERVER_NATIVE),
   [SidePanelMessageTypes.SHOW_SERVER_LOG_CLICK]: () => void commands.executeCommand(COMMANDS.SHOW_SERVER_LOG),