Skip to content
This repository was archived by the owner on Aug 28, 2023. It is now read-only.

Commit 440b786

Browse files
authored
[83105] Hugging Face models are marked as Generic instead of Task Classification (#68)
1 parent 753b163 commit 440b786

File tree

12 files changed

+60
-15
lines changed

12 files changed

+60
-15
lines changed

.github/workflows/check-pr-name.yml

+11-5
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,26 @@ jobs:
1414
uses: actions/github-script@v5
1515
with:
1616
script: |
17+
const prName = context.payload.pull_request.title;
18+
const prCreator = context.payload.pull_request.user.login;
1719
const prNameRegExp = /^(?:\[\d+\]\s)+\w+.*/;
1820
const skipLabel = '[skip-name]';
19-
const prName = context.payload.pull_request.title;
21+
2022
console.log(`Pull Request Name is ${prName}`);
2123
22-
if (prName.includes(skipLabel)) {
23-
console.log('Skipping PR name checks');
24+
// Skip the PRs with dependency updates
25+
const dependencyBotsNames = ['dependabot', 'snyk-bot'];
26+
27+
const shouldSkip = prName.includes(skipLabel) || dependencyBotsNames.includes(prCreator);
28+
if (shouldSkip) {
29+
console.log('Skipping PR name checks.');
2430
return;
2531
}
2632
2733
if (!prNameRegExp.test(prName)) {
2834
console.log('Template: [issue_number] ([another_issue_number] ...) Short description');
29-
core.setFailed('Your Pull Request title does not confirm to the template');
35+
core.setFailed('Your Pull Request title does not confirm to the template.');
3036
return;
3137
}
3238
33-
console.log('Your Pull Request name confirm to the provided template');
39+
console.log('Your Pull Request name confirm to the provided template.');

client/src/app/modules/model-manager/components/card/card.component.ts

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
import { ChangeDetectionStrategy, Component, HostBinding, Input } from '@angular/core';
1+
import {
2+
ChangeDetectionStrategy,
3+
Component,
4+
EventEmitter,
5+
HostBinding,
6+
HostListener,
7+
Input,
8+
Output,
9+
} from '@angular/core';
210

311
@Component({
412
selector: 'wb-card',
@@ -8,6 +16,14 @@ import { ChangeDetectionStrategy, Component, HostBinding, Input } from '@angular
816
})
917
export class CardComponent {
1018
@HostBinding('class.disabled') @Input() disabled = false;
19+
20+
@Output() selected = new EventEmitter<void>();
21+
22+
@HostListener('click') onClick(): void {
23+
if (!this.disabled) {
24+
this.selected.emit();
25+
}
26+
}
1127
}
1228

1329
@Component({

client/src/app/modules/model-manager/components/hugging-face-import-ribbon-content/hugging-face-import-ribbon-content.component.html

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
[disabled]="model.validation.disabled"
4545
[matTooltip]="model.validation.message"
4646
[matTooltipDisabled]="!model.validation.disabled"
47-
(click)="selectedModel = model"
47+
(selected)="selectedModel = model"
4848
data-test-id="model-card"
4949
>
5050
<wb-card-title-row>

client/src/app/modules/model-manager/components/hugging-face-import-ribbon-content/huggingface-model-details/huggingface-model-details.component.html

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
<wb-model-zoo-details-parameters>
88
<wb-parameter-details class="parameter" *ngFor="let param of parameters" [parameter]="param"></wb-parameter-details>
9+
<wb-markdown-text class="hf-card-url" [text]="huggingfaceCardUrl"></wb-markdown-text>
910
</wb-model-zoo-details-parameters>
1011

1112
<wb-model-zoo-details-description>

client/src/app/modules/model-manager/components/hugging-face-import-ribbon-content/huggingface-model-details/huggingface-model-details.component.scss

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
}
1818
}
1919

20+
.hf-card-url {
21+
@include wb-text-8();
22+
margin-top: 30px;
23+
}
24+
2025
.readme-not-found {
2126
color: $brownish-grey;
2227
margin: auto;

client/src/app/modules/model-manager/components/hugging-face-import-ribbon-content/huggingface-model-details/huggingface-model-details.component.ts

+7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import { DatePipe } from '@angular/common';
1212

1313
import { Store } from '@ngrx/store';
1414

15+
import { MessagesService } from '@core/services/common/messages.service';
16+
1517
import { ModelDomain, modelDomainNames } from '@store/model-store/model.model';
1618
import { RootStoreState } from '@store';
1719
import { HuggingfaceModelStoreActions, HuggingfaceModelStoreSelectors } from '@store/huggingface-model-store';
@@ -35,6 +37,9 @@ export class HuggingfaceModelDetailsComponent {
3537
this._model = value;
3638
if (this._model) {
3739
this.parameters = this._extractHfModelParameters(this._model);
40+
this.huggingfaceCardUrl = this._messages.getHint('importHuggingFaceTips', 'huggingfaceModelCard', {
41+
id: this._model.id,
42+
});
3843
this._store$.dispatch(HuggingfaceModelStoreActions.loadModelReadme({ huggingfaceModelId: this._model.id }));
3944
} else {
4045
this.parameters = null;
@@ -56,12 +61,14 @@ export class HuggingfaceModelDetailsComponent {
5661
readonly error$ = this._store$.select(HuggingfaceModelStoreSelectors.selectModelReadmeError);
5762

5863
parameters: IParameter[] = [];
64+
huggingfaceCardUrl: string;
5965

6066
isImportStarted = false;
6167

6268
constructor(
6369
private readonly _cdr: ChangeDetectorRef,
6470
private readonly _store$: Store<RootStoreState.State>,
71+
private readonly _messages: MessagesService,
6572
@Inject(LOCALE_ID) private readonly _localeId: string
6673
) {}
6774

client/src/app/shared/models/huggingface/huggingface-model.ts

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ export interface IHuggingfaceModel {
1414
lastModified: string;
1515
tags: string[];
1616
validation: IHuggingfaceModelValidationResult;
17-
siblings?: string[];
1817
config?: IHuggingfaceModelConfig;
1918
downloads?: number;
2019
}

client/src/assets/data/hint-messages.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@
106106
},
107107
"importHuggingFaceTips": {
108108
"externalResourceNotification": "Hugging Face is an external source, there may be connection problems. Contains unverified models, we do not guarantee their performance.",
109-
"shownSubsetNotification": "We show only a subset of the models from Hugging Face (Text Classification and PyTorch)."
109+
"shownSubsetNotification": "We show only a subset of the models from Hugging Face (Text Classification and PyTorch).",
110+
"huggingfaceModelCard": "[Model Card on Hugging Face](https://huggingface.co/${id})"
110111
},
111112
"login": {
112113
"loginTip": "To enter the DL Workbench, use a token that is generated once you start the application. Copy it from the console.\n [Read more](https://docs.openvino.ai/latest/workbench_docs_Workbench_DG_Troubleshooting.html#omz)"

wb/main/huggingface_api/huggingface_api.py

-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def json(self):
8686
'lastModified': self.last_modified,
8787
'tags': self.tags,
8888
'validation': self.validation.json(),
89-
'siblings': self.siblings,
9089
'config': self.config.json() if self.config else None,
9190
'downloads': self.downloads,
9291
}

wb/main/jobs/create_setup_bundle/create_setup_bundle_job.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import tempfile
2020
from contextlib import closing
2121

22+
from werkzeug.utils import secure_filename
23+
2224
from wb.extensions_factories.database import get_db_session_for_celery
2325
from wb.main.enumerates import JobTypesEnum, StatusEnum
2426
from wb.main.jobs.interfaces.ijob import IJob
@@ -63,14 +65,15 @@ def run(self):
6365
topology_temporary_path = None
6466

6567
if self.include_model:
66-
topology_temporary_path = os.path.join(tmp_scripts_folder, self.topology_name)
68+
topology_name = secure_filename(self.topology_name)
69+
topology_temporary_path = os.path.join(tmp_scripts_folder, topology_name)
6770
os.makedirs(topology_temporary_path)
6871
xml_file = find_by_ext(self.topology_path, 'xml')
69-
tmp_xml_file = os.path.join(topology_temporary_path, f'{self.topology_name}.xml')
72+
tmp_xml_file = os.path.join(topology_temporary_path, f'{topology_name}.xml')
7073
shutil.copy(xml_file, tmp_xml_file)
7174

7275
bin_file = find_by_ext(self.topology_path, 'bin')
73-
tmp_bin_file = os.path.join(topology_temporary_path, f'{self.topology_name}.bin')
76+
tmp_bin_file = os.path.join(topology_temporary_path, f'{topology_name}.bin')
7477
shutil.copy(bin_file, tmp_bin_file)
7578

7679
setup_bundle_creator = SetupBundleCreator(

wb/main/jobs/interfaces/job_observers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
FileMetaData)
4949
from wb.main.models.wait_model_upload_job_model import WaitModelUploadJobModel
5050
from wb.main.utils.observer_pattern import Observer
51-
from wb.main.utils.utils import get_size_of_files
51+
from wb.main.utils.utils import get_size_of_files, FileSizeConverter
5252

5353

5454
def check_existing_job_model_decorator(func: Callable[['Observer', JobState], None]):
@@ -565,7 +565,7 @@ def update(self, subject_state: JobState):
565565
file_record.status = StatusEnum.ready
566566
file_record.write_record(session)
567567

568-
topology.size = file_record.size
568+
topology.size = FileSizeConverter.bytes_to_mb(file_record.size)
569569
topology.status = StatusEnum.ready
570570
topology.progress = 100
571571
topology.write_record(session)

wb/main/pipeline_creators/model_creation/import_huggingface_model_pipeline_creator.py

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717
import os
1818

19+
import json
1920
from sqlalchemy.orm import Session
2021

2122
from config.constants import UPLOAD_FOLDER_MODELS, ORIGINAL_FOLDER
@@ -28,9 +29,11 @@
2829
from wb.main.models.huggingface.import_huggingface_model_job_model import ImportHuggingfaceJobModel, \
2930
ImportHuggingfaceJobData
3031
from wb.main.models.model_optimizer_job_model import ModelOptimizerJobData
32+
from wb.main.models.topologies_metadata_model import DEFAULT_ACCURACY_CONFIGURATION
3133
from wb.main.models.topologies_model import ModelJobData
3234
from wb.main.pipeline_creators.model_creation.base_model_creation_pipeline_creator import \
3335
BaseModelCreationPipelineCreator
36+
from wb.main.shared.enumerates import TaskEnum
3437
from wb.main.utils.utils import create_empty_dir
3538

3639

@@ -59,7 +62,12 @@ def __init__(self, model_id: int, huggingface_model_id: str, model_name: str):
5962
def create_result_model(self, session: Session):
6063
uploaded_model: TopologiesModel = session.query(TopologiesModel).get(self.model_id)
6164

65+
# consider all huggingface models as text classification
6266
metadata = TopologiesMetaDataModel()
67+
metadata.task_type = TaskEnum.text_classification.value
68+
accuracy_config = json.loads(DEFAULT_ACCURACY_CONFIGURATION)
69+
accuracy_config['taskType'] = TaskEnum.text_classification.value
70+
metadata.advanced_configuration = json.dumps(accuracy_config)
6371
metadata.write_record(session)
6472

6573
self._result_model = TopologiesModel(self.model_name, SupportedFrameworksEnum.openvino, metadata.id)

0 commit comments

Comments
 (0)