Skip to content

Commit cb2d548

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 49ec5a2 commit cb2d548

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

ds/infer_bf16.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# ==--------------------------------------------------------------------------==
22
# Patch for loading DS models
3-
from typing import Union, Optional
4-
import torch
53
import os
6-
from packaging import version
4+
from typing import Optional, Union
75
from zipfile import is_zipfile
8-
from transformers.utils import is_safetensors_available, strtobool
6+
7+
import torch
8+
from packaging import version
99
from transformers.integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
10+
from transformers.utils import is_safetensors_available, strtobool
1011

1112
if is_safetensors_available():
1213
from safetensors import safe_open
@@ -37,9 +38,7 @@ def load_state_dict(
3738
map_location: Optional[Union[str, torch.device]] = None,
3839
weights_only: bool = True,
3940
):
40-
"""
41-
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
42-
"""
41+
"""Reads a PyTorch checkpoint file, returning properly formatted errors if they arise."""
4342

4443
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
4544
# Check format of the archive
@@ -103,10 +102,8 @@ def load_state_dict(
103102

104103

105104
def set_initialized_submodules(model, state_dict_keys):
106-
"""
107-
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
108-
dict.
109-
"""
105+
"""Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
106+
dict."""
110107
state_dict_keys = set(state_dict_keys)
111108
not_initialized_submodules = {}
112109
for module_name, module in model.named_modules():
@@ -137,16 +134,17 @@ def patch_transformers():
137134
logger = logging.getLogger(__name__)
138135
logger.setLevel(logging.INFO)
139136

137+
140138
def eval(model_path):
141139
import transformers
142140
from transformers.modeling_utils import no_init_weights
143-
# from patch_for_ds import patch_transformers
144141

142+
# from patch_for_ds import patch_transformers
145143
# if not not_patch_lin:
146144
# patch_lin()
147145

148146
def _patch__initialize_weights(self, module):
149-
print(f"Skipping init_weights ")
147+
print("Skipping init_weights ")
150148
module._is_hf_initialized = True
151149

152150
transformers.modeling_utils.PreTrainedModel._initialize_weights = _patch__initialize_weights

0 commit comments

Comments
 (0)