Skip to content

Commit 4114d3b

Browse files
only support model id in api now
1 parent 50fcfc0 commit 4114d3b

File tree

3 files changed

+36
-38
lines changed

3 files changed

+36
-38
lines changed

optimum/fx/parallelization/api.py

+34-36
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import importlib
1616
import os
1717
from functools import partial
18-
from typing import List, Union
18+
from typing import List
1919

2020
import torch
2121
from torch.fx import GraphModule
22+
from transformers import AutoConfig
2223

2324
from .core import Config, ParallelExecutionCtx
2425
from .passes import build_parallel_pass_pipeline
@@ -43,7 +44,7 @@ def parallelize_backend(
4344

4445

4546
def parallelize_model(
46-
model: Union[torch.nn.Module, str],
47+
model: str,
4748
parallel_ctx: ParallelExecutionCtx,
4849
*model_args,
4950
**kwargs,
@@ -52,8 +53,8 @@ def parallelize_model(
5253
API for automatic model parallelism through Pytorch FX.
5354
5455
Args:
55-
model (Union[torch.nn.Module, str]):
56-
Model to parallelize, could either be a module or a model id on the Huggingface Hub.
56+
model (str):
57+
Model to parallelize, a model id on the Huggingface Hub.
5758
parallel_ctx (ParallelExecutionCtx):
5859
Parallel execution context containing process groups the current process belongs to.
5960
*model_args (Any):
@@ -80,44 +81,41 @@ def parallelize_model(
8081
setattr(parallel_config, k, v)
8182
kwargs.pop(k)
8283

83-
if isinstance(model, str):
84-
from transformers import AutoConfig
85-
86-
is_local = os.path.isdir(model)
87-
if not is_local:
88-
hf_folder = download_model_from_hf(
89-
model_name_or_path=model,
90-
cache_dir=cache_dir,
91-
revision=revision,
92-
local_files_only=local_files_only,
93-
skip_download_weights=skip_load_weights,
94-
)
95-
else:
96-
hf_folder = model
97-
98-
# should be able to load config using only local files
99-
model_config, kwargs = AutoConfig.from_pretrained(
100-
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
84+
is_local = os.path.isdir(model)
85+
if not is_local:
86+
hf_folder = download_model_from_hf(
87+
model_name_or_path=model,
88+
cache_dir=cache_dir,
89+
revision=revision,
90+
local_files_only=local_files_only,
91+
skip_download_weights=skip_load_weights,
10192
)
93+
else:
94+
hf_folder = model
10295

103-
# try getting model class info from config
104-
model_arch = model_config.architectures
105-
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])
96+
# should be able to load config using only local files
97+
model_config, kwargs = AutoConfig.from_pretrained(
98+
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
99+
)
106100

107-
if not skip_load_weights:
108-
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)
101+
# try getting model class info from config
102+
model_arch = model_config.architectures
103+
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])
109104

110-
torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
111-
if torch_dtype is not None:
112-
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
105+
if not skip_load_weights:
106+
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)
113107

114-
with MetaAwareMethodsPatcher():
115-
model = model_cls(model_config, *model_args, **kwargs)
116-
# TODO: remove this once support training-time trace
117-
model.eval()
108+
torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
109+
if torch_dtype is not None:
110+
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
118111

119-
if dtype_orig is not None:
120-
torch.set_default_dtype(dtype_orig)
112+
with MetaAwareMethodsPatcher():
113+
model = model_cls(model_config, *model_args, **kwargs)
114+
# TODO: remove this once support training-time trace
115+
model.eval()
116+
117+
if dtype_orig is not None:
118+
torch.set_default_dtype(dtype_orig)
121119

122120
move_model_to_device(model, device=parallel_ctx.current_device)
123121
initialize_parameter_meta(model)

optimum/fx/parallelization/op_registry/op_handlers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def propagate(self) -> List[int]:
439439
# last resort, if no input is being parallelized, then we make output also not parallelized,
440440
# this will give us relief on writing policies for strange ops which don't actually need
441441
# parallelization in most cases
442-
if all([self.extract_axis(arg) is None for arg in self.node.all_input_nodes]):
442+
if all(self.extract_axis(arg) is None for arg in self.node.all_input_nodes):
443443
return [None]
444444

445445
raise NotImplementedError(f"don't know how to propagate axis for {self.node.target}")

optimum/fx/parallelization/passes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
194194
graph: Graph = decompose_and_functionalize(graph_module)(*ctx.example_inputs)
195195
stable_topological_sort(graph)
196196

197-
nodes = [node for node in graph.nodes]
197+
nodes = list(graph.nodes)
198198

199199
def search(idx: int):
200200
if idx == len(nodes):

0 commit comments

Comments
 (0)