Skip to content

Commit 8f23817

Browse files
authored
[OPENVINO-CODE] add-device-options (#895)
1 parent 8a7bf32 commit 8f23817

File tree

11 files changed

+210
-27
lines changed

11 files changed

+210
-27
lines changed

modules/openvino_code/package.json

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"vsce:publish": "vsce publish",
5858
"ovsx:publish": "ovsx publish",
5959
"clear-out": "rimraf ./out"
60+
6061
},
6162
"devDependencies": {
6263
"@types/glob": "8.1.0",
@@ -200,6 +201,17 @@
200201
],
201202
"description": "Which model to use for code generation."
202203
},
204+
"openvinoCode.device": {
205+
"order": 1,
206+
"type": "string",
207+
"default": "CPU",
208+
"enum":[
209+
"CPU",
210+
"GPU",
211+
"NPU"
212+
],
213+
"description": "Which device to use for code generation"
214+
},
203215
"openvinoCode.serverUrl": {
204216
"order": 1,
205217
"type": "string",

modules/openvino_code/server/pyproject.toml

+11-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies = [
1111
'torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl ; sys_platform=="linux" and python_version == "3.10"',
1212
'torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl ; sys_platform=="linux" and python_version == "3.11"',
1313
'torch ; sys_platform != "linux"',
14-
'openvino==2023.3.0',
14+
'openvino==2024.0.0',
1515
'transformers==4.36.0',
1616
'optimum==1.17.1',
1717
'optimum-intel[openvino]==1.15.0',
@@ -27,13 +27,18 @@ build-backend = "setuptools.build_meta"
2727

2828
[tool.black]
2929
line-length = 119
30-
target-versions = ["py38", "py39", "py310", "py311"]
31-
30+
target-version = ['py38', 'py39', 'py310', 'py311']
31+
unstable = true
32+
preview = true
3233

3334
[tool.ruff]
34-
ignore = ["C901", "E501", "E741", "W605"]
35-
select = ["C", "E", "F", "I", "W"]
35+
lint.ignore = ["C901", "E501", "E741", "W605", "F401", "W292"]
36+
lint.select = ["C", "E", "F", "I", "W"]
37+
lint.extend-safe-fixes = ["F601"]
38+
lint.extend-unsafe-fixes = ["UP034"]
39+
lint.fixable = ["F401"]
3640
line-length = 119
3741

38-
[tool.ruff.isort]
42+
43+
[tool.ruff.lint.isort]
3944
lines-after-imports = 2

modules/openvino_code/server/src/app.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ async def generate_stream(
114114
generation_request = TypeAdapter(GenerationRequest).validate_python(await request.json())
115115
logger.info(generation_request)
116116
return StreamingResponse(
117-
generator.generate_stream(generation_request.inputs, generation_request.parameters.model_dump(), request)
117+
generator.generate_stream(
118+
generation_request.inputs,
119+
generation_request.parameters.model_dump(),
120+
request,
121+
)
118122
)
119123

120124

@@ -127,7 +131,11 @@ async def summarize(
127131

128132
start = perf_counter()
129133
generated_text: str = generator.summarize(
130-
request.inputs, request.template, request.definition, request.format, request.parameters.model_dump()
134+
request.inputs,
135+
request.template,
136+
request.definition,
137+
request.format,
138+
request.parameters.model_dump(),
131139
)
132140
stop = perf_counter()
133141

@@ -148,6 +156,10 @@ async def summarize_stream(
148156
logger.info(request)
149157
return StreamingResponse(
150158
generator.summarize_stream(
151-
request.inputs, request.template, request.definition, request.format, request.parameters.model_dump()
159+
request.inputs,
160+
request.template,
161+
request.definition,
162+
request.format,
163+
request.parameters.model_dump(),
152164
)
153165
)

modules/openvino_code/server/src/generators.py

+76-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@
55
from pathlib import Path
66
from threading import Thread
77
from time import time
8-
from typing import Any, Callable, Container, Dict, Generator, List, Optional, Type, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
Container,
12+
Dict,
13+
Generator,
14+
List,
15+
Optional,
16+
Type,
17+
Union,
18+
)
919

1020
import torch
1121
from fastapi import Request
@@ -53,11 +63,20 @@ def get_model(checkpoint: str, device: str = "CPU") -> OVModel:
5363
model_class = get_model_class(checkpoint)
5464
try:
5565
model = model_class.from_pretrained(
56-
checkpoint, ov_config=ov_config, compile=False, device=device, trust_remote_code=True
66+
checkpoint,
67+
ov_config=ov_config,
68+
compile=False,
69+
device=device,
70+
trust_remote_code=True,
5771
)
5872
except EntryNotFoundError:
5973
model = model_class.from_pretrained(
60-
checkpoint, ov_config=ov_config, export=True, compile=False, device=device, trust_remote_code=True
74+
checkpoint,
75+
ov_config=ov_config,
76+
export=True,
77+
compile=False,
78+
device=device,
79+
trust_remote_code=True,
6180
)
6281
model.save_pretrained(model_path)
6382
model.compile()
@@ -75,10 +94,24 @@ def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
7594
async def generate_stream(self, input_text: str, parameters: Dict[str, Any], request: Request):
7695
raise NotImplementedError
7796

78-
def summarize(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]):
97+
def summarize(
98+
self,
99+
input_text: str,
100+
template: str,
101+
signature: str,
102+
style: str,
103+
parameters: Dict[str, Any],
104+
):
79105
raise NotImplementedError
80106

81-
def summarize_stream(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]):
107+
def summarize_stream(
108+
self,
109+
input_text: str,
110+
template: str,
111+
signature: str,
112+
style: str,
113+
parameters: Dict[str, Any],
114+
):
82115
raise NotImplementedError
83116

84117

@@ -128,13 +161,19 @@ def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
128161
prompt_len = input_ids.shape[-1]
129162
config = GenerationConfig.from_dict({**self.generation_config.to_dict(), **parameters})
130163
output_ids = self.model.generate(
131-
input_ids, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
164+
input_ids,
165+
generation_config=config,
166+
stopping_criteria=stopping_criteria,
167+
**self.assistant_model_config,
132168
)[0][prompt_len:]
133169
logger.info(f"Number of input tokens: {prompt_len}; generated {len(output_ids)} tokens")
134170
return self.tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
135171

136172
async def generate_stream(
137-
self, input_text: str, parameters: Dict[str, Any], request: Optional[Request] = None
173+
self,
174+
input_text: str,
175+
parameters: Dict[str, Any],
176+
request: Optional[Request] = None,
138177
) -> Generator[str, None, None]:
139178
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
140179
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
@@ -192,7 +231,10 @@ def generate_between(
192231
prev_len = prompt.shape[-1]
193232

194233
prompt = self.model.generate(
195-
prompt, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
234+
prompt,
235+
generation_config=config,
236+
stopping_criteria=stopping_criteria,
237+
**self.assistant_model_config,
196238
)[
197239
:, :-1
198240
] # skip the last token - stop token
@@ -219,7 +261,10 @@ async def generate_between_stream(
219261
prev_len = prompt.shape[-1]
220262

221263
prompt = self.model.generate(
222-
prompt, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
264+
prompt,
265+
generation_config=config,
266+
stopping_criteria=stopping_criteria,
267+
**self.assistant_model_config,
223268
)[
224269
:, :-1
225270
] # skip the last token - stop token
@@ -237,24 +282,40 @@ def summarization_input(function: str, signature: str, style: str) -> str:
237282
signature=signature,
238283
)
239284

240-
def summarize(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]) -> str:
285+
def summarize(
286+
self,
287+
input_text: str,
288+
template: str,
289+
signature: str,
290+
style: str,
291+
parameters: Dict[str, Any],
292+
) -> str:
241293
prompt = self.summarization_input(input_text, signature, style)
242294
splited_template = re.split(r"\$\{.*\}", template)
243295
splited_template[0] = prompt + splited_template[0]
244296

245-
return self.generate_between(splited_template, parameters, stopping_criteria=self.summarize_stopping_criteria)[
246-
len(prompt) :
247-
]
297+
return self.generate_between(
298+
splited_template,
299+
parameters,
300+
stopping_criteria=self.summarize_stopping_criteria,
301+
)[len(prompt) :]
248302

249303
async def summarize_stream(
250-
self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]
304+
self,
305+
input_text: str,
306+
template: str,
307+
signature: str,
308+
style: str,
309+
parameters: Dict[str, Any],
251310
):
252311
prompt = self.summarization_input(input_text, signature, style)
253312
splited_template = re.split(r"\$\{.*\}", template)
254313
splited_template = [prompt] + splited_template
255314

256315
async for token in self.generate_between_stream(
257-
splited_template, parameters, stopping_criteria=self.summarize_stopping_criteria
316+
splited_template,
317+
parameters,
318+
stopping_criteria=self.summarize_stopping_criteria,
258319
):
259320
yield token
260321

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { Features } from './features';
2+
3+
enum DeviceId {
4+
CPU = 'CPU',
5+
GPU = 'GPU',
6+
NPU = 'NPU',
7+
}
8+
9+
export enum DeviceName {
10+
CPU = 'CPU',
11+
GPU = 'GPU',
12+
NPU = 'NPU',
13+
}
14+
15+
export const DEVICE_NAME_TO_ID_MAP: Record<DeviceName, DeviceId> = {
16+
[DeviceName.CPU]: DeviceId.CPU,
17+
[DeviceName.GPU]: DeviceId.GPU,
18+
[DeviceName.NPU]: DeviceId.NPU,
19+
};
20+
21+
export const DEVICE_SUPPORTED_FEATURES: Record<DeviceName, Features[]> = {
22+
[DeviceName.CPU]: [Features.CODE_COMPLETION, Features.SUMMARIZATION, Features.FIM],
23+
[DeviceName.GPU]: [Features.CODE_COMPLETION, Features.SUMMARIZATION, Features.FIM],
24+
[DeviceName.NPU]: [Features.CODE_COMPLETION, Features.SUMMARIZATION, Features.FIM],
25+
};

modules/openvino_code/shared/side-panel-message.ts

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export enum SidePanelMessageTypes {
1010
GENERATE_COMPLETION_CLICK = `${sidePanelMessagePrefix}.generateCompletionClick`,
1111
SETTINGS_CLICK = `${sidePanelMessagePrefix}.settingsClick`,
1212
MODEL_CHANGE = `${sidePanelMessagePrefix}.modelChange`,
13+
DEVICE_CHANGE = `${sidePanelMessagePrefix}.deviceChange`,
1314
}
1415

1516
export interface ISidePanelMessage<P = unknown> {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//import { ModelName } from '@shared/model';
2+
import { DeviceName } from '@shared/device';
3+
import { Select, SelectOptionProps } from '../../../shared/Select/Select';
4+
import { ServerStatus } from '@shared/server-state';
5+
import { Features } from '@shared/features';
6+
7+
const options: SelectOptionProps<DeviceName>[] = [
8+
{ value: DeviceName.CPU },
9+
{ value: DeviceName.GPU },
10+
{ value: DeviceName.NPU },
11+
];
12+
13+
interface DeviceSelectProps {
14+
disabled: boolean;
15+
selectedDeviceName: DeviceName;
16+
onChange: (deviceName: DeviceName) => void;
17+
supportedFeatures: Features[];
18+
serverStatus: ServerStatus;
19+
}
20+
21+
export const DeviceSelect = ({
22+
disabled,
23+
selectedDeviceName,
24+
onChange,
25+
supportedFeatures,
26+
serverStatus,
27+
}: DeviceSelectProps): JSX.Element => {
28+
const isServerStopped = serverStatus === ServerStatus.STOPPED;
29+
return (
30+
<>
31+
<Select
32+
label="Device"
33+
options={options}
34+
selectedValue={selectedDeviceName}
35+
disabled={disabled}
36+
onChange={(value) => onChange(value)}
37+
></Select>
38+
{isServerStopped && <span>Supported Features: {supportedFeatures.join(', ')}</span>}
39+
</>
40+
);
41+
};

modules/openvino_code/side-panel-ui/src/components/sections/ServerSection/ServerSection.tsx

+18
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import { ServerStatus } from './ServerStatus/ServerStatus';
77
import './ServerSection.css';
88
import { ModelSelect } from './ModelSelect/ModelSelect';
99
import { ModelName } from '@shared/model';
10+
import { DeviceSelect } from './DeviceSelect/DeviceSelect';
11+
import { DeviceName } from '@shared/device';
1012

1113
interface ServerSectionProps {
1214
state: IExtensionState | null;
@@ -46,6 +48,15 @@ export function ServerSection({ state }: ServerSectionProps): JSX.Element {
4648
});
4749
};
4850

51+
const handleDeviceChange = (deviceName: DeviceName) => {
52+
vscode.postMessage({
53+
type: SidePanelMessageTypes.DEVICE_CHANGE,
54+
payload: {
55+
deviceName,
56+
},
57+
});
58+
};
59+
4960
if (!state) {
5061
return <>Extension state is not available</>;
5162
}
@@ -64,6 +75,13 @@ export function ServerSection({ state }: ServerSectionProps): JSX.Element {
6475
supportedFeatures={state.features.supportedList}
6576
serverStatus={state.server.status}
6677
></ModelSelect>
78+
<DeviceSelect
79+
disabled={!isServerStopped}
80+
onChange={handleDeviceChange}
81+
selectedDeviceName={state.config.device}
82+
supportedFeatures={state.features.supportedList}
83+
serverStatus={state.server.status}
84+
></DeviceSelect>
6785
{isServerStarting && <StartingStages currentStage={state.server.stage}></StartingStages>}
6886
<div className="button-group">
6987
{isServerStopped && <button onClick={handleStartServerClick}>Start Server</button>}

modules/openvino_code/src/configuration.ts

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { ModelName } from '@shared/model';
2+
import { DeviceName } from '@shared/device';
23
import { WorkspaceConfiguration, workspace } from 'vscode';
34
import { CONFIG_KEY } from './constants';
45

@@ -7,6 +8,7 @@ import { CONFIG_KEY } from './constants';
78
*/
89
export type CustomConfiguration = {
910
model: ModelName;
11+
device: DeviceName;
1012
serverUrl: string;
1113
serverRequestTimeout: number;
1214
streamInlineCompletion: boolean;

0 commit comments

Comments
 (0)