|
1 | 1 | # ==--------------------------------------------------------------------------==
|
2 | 2 | # Patch for loading DS models
|
3 |
| -from typing import Union, Optional |
4 |
| -import torch |
5 | 3 | import os
|
6 |
| -from packaging import version |
| 4 | +from typing import Optional, Union |
7 | 5 | from zipfile import is_zipfile
|
8 |
| -from transformers.utils import is_safetensors_available, strtobool |
| 6 | + |
| 7 | +import torch |
| 8 | +from packaging import version |
9 | 9 | from transformers.integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
| 10 | +from transformers.utils import is_safetensors_available, strtobool |
10 | 11 |
|
11 | 12 | if is_safetensors_available():
|
12 | 13 | from safetensors import safe_open
|
@@ -37,9 +38,7 @@ def load_state_dict(
|
37 | 38 | map_location: Optional[Union[str, torch.device]] = None,
|
38 | 39 | weights_only: bool = True,
|
39 | 40 | ):
|
40 |
| - """ |
41 |
| - Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. |
42 |
| - """ |
| 41 | + """Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.""" |
43 | 42 |
|
44 | 43 | if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
|
45 | 44 | # Check format of the archive
|
@@ -103,10 +102,8 @@ def load_state_dict(
|
103 | 102 |
|
104 | 103 |
|
105 | 104 | def set_initialized_submodules(model, state_dict_keys):
|
106 |
| - """ |
107 |
| - Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state |
108 |
| - dict. |
109 |
| - """ |
| 105 | + """Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state |
| 106 | + dict.""" |
110 | 107 | state_dict_keys = set(state_dict_keys)
|
111 | 108 | not_initialized_submodules = {}
|
112 | 109 | for module_name, module in model.named_modules():
|
@@ -137,16 +134,17 @@ def patch_transformers():
|
137 | 134 | logger = logging.getLogger(__name__)
|
138 | 135 | logger.setLevel(logging.INFO)
|
139 | 136 |
|
| 137 | + |
140 | 138 | def eval(model_path):
|
141 | 139 | import transformers
|
142 | 140 | from transformers.modeling_utils import no_init_weights
|
143 |
| - # from patch_for_ds import patch_transformers |
144 | 141 |
|
| 142 | + # from patch_for_ds import patch_transformers |
145 | 143 | # if not not_patch_lin:
|
146 | 144 | # patch_lin()
|
147 | 145 |
|
148 | 146 | def _patch__initialize_weights(self, module):
|
149 |
| - print(f"Skipping init_weights ") |
| 147 | + print("Skipping init_weights ") |
150 | 148 | module._is_hf_initialized = True
|
151 | 149 |
|
152 | 150 | transformers.modeling_utils.PreTrainedModel._initialize_weights = _patch__initialize_weights
|
|
0 commit comments