@@ -151,12 +151,12 @@ class ModelArguments:
151
151
metadata = {"help" : "The specific model version to use (can be a branch name, tag name or commit id)." },
152
152
)
153
153
feature_extractor_name : str = field (default = None , metadata = {"help" : "Name or path of preprocessor config." })
154
- use_auth_token : bool = field (
155
- default = False ,
154
+ token : str = field (
155
+ default = None ,
156
156
metadata = {
157
157
"help" : (
158
- "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
159
- "with private models )."
158
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
159
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface` )."
160
160
)
161
161
},
162
162
)
@@ -239,8 +239,7 @@ def main():
239
239
data_args .dataset_name ,
240
240
data_args .dataset_config_name ,
241
241
cache_dir = model_args .cache_dir ,
242
- task = "image-classification" ,
243
- use_auth_token = True if model_args .use_auth_token else None ,
242
+ token = model_args .token ,
244
243
)
245
244
else :
246
245
data_files = {}
@@ -252,7 +251,6 @@ def main():
252
251
"imagefolder" ,
253
252
data_files = data_files ,
254
253
cache_dir = model_args .cache_dir ,
255
- task = "image-classification" ,
256
254
)
257
255
258
256
# If we don't have a validation split, split off a percentage of train as validation.
@@ -287,15 +285,15 @@ def compute_metrics(p):
287
285
finetuning_task = "image-classification" ,
288
286
cache_dir = model_args .cache_dir ,
289
287
revision = model_args .model_revision ,
290
- use_auth_token = True if model_args .use_auth_token else None ,
288
+ token = model_args .token ,
291
289
)
292
290
model = AutoModelForImageClassification .from_pretrained (
293
291
model_args .model_name_or_path ,
294
292
from_tf = bool (".ckpt" in model_args .model_name_or_path ),
295
293
config = config ,
296
294
cache_dir = model_args .cache_dir ,
297
295
revision = model_args .model_revision ,
298
- use_auth_token = True if model_args .use_auth_token else None ,
296
+ token = model_args .token ,
299
297
ignore_mismatched_sizes = model_args .ignore_mismatched_sizes ,
300
298
)
301
299
@@ -311,7 +309,7 @@ def compute_metrics(p):
311
309
model_args .feature_extractor_name or model_args .model_name_or_path ,
312
310
cache_dir = model_args .cache_dir ,
313
311
revision = model_args .model_revision ,
314
- use_auth_token = True if model_args .use_auth_token else None ,
312
+ token = model_args .token ,
315
313
)
316
314
317
315
# Define torchvision transforms to be applied to each image.
0 commit comments