Skip to content

Commit 7674e33

Browse files
echarlaixPenghuiCheng
authored andcommitted
Fix OpenVINO image classification examples (huggingface#598)
1 parent b751766 commit 7674e33

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

examples/openvino/image-classification/run_image_classification.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ class ModelArguments:
151151
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
152152
)
153153
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
154-
use_auth_token: bool = field(
155-
default=False,
154+
token: str = field(
155+
default=None,
156156
metadata={
157157
"help": (
158-
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
159-
"with private models)."
158+
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
159+
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
160160
)
161161
},
162162
)
@@ -239,8 +239,7 @@ def main():
239239
data_args.dataset_name,
240240
data_args.dataset_config_name,
241241
cache_dir=model_args.cache_dir,
242-
task="image-classification",
243-
use_auth_token=True if model_args.use_auth_token else None,
242+
token=model_args.token,
244243
)
245244
else:
246245
data_files = {}
@@ -252,7 +251,6 @@ def main():
252251
"imagefolder",
253252
data_files=data_files,
254253
cache_dir=model_args.cache_dir,
255-
task="image-classification",
256254
)
257255

258256
# If we don't have a validation split, split off a percentage of train as validation.
@@ -287,15 +285,15 @@ def compute_metrics(p):
287285
finetuning_task="image-classification",
288286
cache_dir=model_args.cache_dir,
289287
revision=model_args.model_revision,
290-
use_auth_token=True if model_args.use_auth_token else None,
288+
token=model_args.token,
291289
)
292290
model = AutoModelForImageClassification.from_pretrained(
293291
model_args.model_name_or_path,
294292
from_tf=bool(".ckpt" in model_args.model_name_or_path),
295293
config=config,
296294
cache_dir=model_args.cache_dir,
297295
revision=model_args.model_revision,
298-
use_auth_token=True if model_args.use_auth_token else None,
296+
token=model_args.token,
299297
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
300298
)
301299

@@ -311,7 +309,7 @@ def compute_metrics(p):
311309
model_args.feature_extractor_name or model_args.model_name_or_path,
312310
cache_dir=model_args.cache_dir,
313311
revision=model_args.model_revision,
314-
use_auth_token=True if model_args.use_auth_token else None,
312+
token=model_args.token,
315313
)
316314

317315
# Define torchvision transforms to be applied to each image.

0 commit comments

Comments
 (0)