15
15
import importlib
16
16
import os
17
17
from functools import partial
18
- from typing import List , Union
18
+ from typing import List
19
19
20
20
import torch
21
21
from torch .fx import GraphModule
22
+ from transformers import AutoConfig
22
23
23
24
from .core import Config , ParallelExecutionCtx
24
25
from .passes import build_parallel_pass_pipeline
@@ -43,7 +44,7 @@ def parallelize_backend(
43
44
44
45
45
46
def parallelize_model (
46
- model : Union [ torch . nn . Module , str ] ,
47
+ model : str ,
47
48
parallel_ctx : ParallelExecutionCtx ,
48
49
* model_args ,
49
50
** kwargs ,
@@ -52,8 +53,8 @@ def parallelize_model(
52
53
API for automatic model parallelism through Pytorch FX.
53
54
54
55
Args:
55
- model (Union[torch.nn.Module, str] ):
56
- Model to parallelize, could either be a module or a model id on the Huggingface Hub.
56
+ model (str):
57
+ Model to parallelize, a model id on the Huggingface Hub.
57
58
parallel_ctx (ParallelExecutionCtx):
58
59
Parallel execution context containing process groups the current process belongs to.
59
60
*model_args (Any):
@@ -80,44 +81,41 @@ def parallelize_model(
80
81
setattr (parallel_config , k , v )
81
82
kwargs .pop (k )
82
83
83
- if isinstance (model , str ):
84
- from transformers import AutoConfig
85
-
86
- is_local = os .path .isdir (model )
87
- if not is_local :
88
- hf_folder = download_model_from_hf (
89
- model_name_or_path = model ,
90
- cache_dir = cache_dir ,
91
- revision = revision ,
92
- local_files_only = local_files_only ,
93
- skip_download_weights = skip_load_weights ,
94
- )
95
- else :
96
- hf_folder = model
97
-
98
- # should be able to load config using only local files
99
- model_config , kwargs = AutoConfig .from_pretrained (
100
- hf_folder , revision = revision , local_files_only = True , return_unused_kwargs = True , ** kwargs
84
+ is_local = os .path .isdir (model )
85
+ if not is_local :
86
+ hf_folder = download_model_from_hf (
87
+ model_name_or_path = model ,
88
+ cache_dir = cache_dir ,
89
+ revision = revision ,
90
+ local_files_only = local_files_only ,
91
+ skip_download_weights = skip_load_weights ,
101
92
)
93
+ else :
94
+ hf_folder = model
102
95
103
- # try getting model class info from config
104
- model_arch = model_config .architectures
105
- model_cls = getattr (importlib .import_module ("transformers" ), model_arch [0 ])
96
+ # should be able to load config using only local files
97
+ model_config , kwargs = AutoConfig .from_pretrained (
98
+ hf_folder , revision = revision , local_files_only = True , return_unused_kwargs = True , ** kwargs
99
+ )
106
100
107
- if not skip_load_weights :
108
- parallel_ctx .weight_map = try_collect_weight_map (model , cache_dir , hf_folder )
101
+ # try getting model class info from config
102
+ model_arch = model_config .architectures
103
+ model_cls = getattr (importlib .import_module ("transformers" ), model_arch [0 ])
109
104
110
- torch_dtype , dtype_orig = kwargs .pop ("torch_dtype" , None ), None
111
- if torch_dtype is not None :
112
- dtype_orig = model_cls ._set_default_torch_dtype (torch_dtype )
105
+ if not skip_load_weights :
106
+ parallel_ctx .weight_map = try_collect_weight_map (model , cache_dir , hf_folder )
113
107
114
- with MetaAwareMethodsPatcher ():
115
- model = model_cls (model_config , * model_args , ** kwargs )
116
- # TODO: remove this once support training-time trace
117
- model .eval ()
108
+ torch_dtype , dtype_orig = kwargs .pop ("torch_dtype" , None ), None
109
+ if torch_dtype is not None :
110
+ dtype_orig = model_cls ._set_default_torch_dtype (torch_dtype )
118
111
119
- if dtype_orig is not None :
120
- torch .set_default_dtype (dtype_orig )
112
+ with MetaAwareMethodsPatcher ():
113
+ model = model_cls (model_config , * model_args , ** kwargs )
114
+ # TODO: remove this once support training-time trace
115
+ model .eval ()
116
+
117
+ if dtype_orig is not None :
118
+ torch .set_default_dtype (dtype_orig )
121
119
122
120
move_model_to_device (model , device = parallel_ctx .current_device )
123
121
initialize_parameter_meta (model )
0 commit comments