Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance 3.x torch WOQ load #1877

Closed
wants to merge 21 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
enhance code
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
yuwenzho committed Jun 28, 2024
commit 9c41f57cd566955a1358fb91f1c17e560727a597
8 changes: 4 additions & 4 deletions neural_compressor/torch/algorithms/weight_only/modules.py
Original file line number Diff line number Diff line change
@@ -55,14 +55,15 @@ def __init__(
dtype,
bits,
group_size,
bias,
device,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.dtype = dtype
self.bits = bits
self.group_size = group_size if group_size != -1 else in_features
self.device = device

@abstractmethod
def pack(self, *args, **kwargs):
@@ -110,7 +111,7 @@ def __init__(
dtype,
bits,
group_size,
bias,
device,
)
self.use_optimum_format = use_optimum_format
if "int" not in self.dtype: # for nf4, fp4
@@ -122,7 +123,6 @@ def __init__(
self.int2float_mapping = {}
for k, v in zip(int_list, float_list):
self.int2float_mapping[k] = v
self.device = device
self.compression_dim = compression_dim
assert compression_dtype in [
torch.int8,
@@ -272,7 +272,7 @@ def unpack(self):
device = scales.device
if self.g_idx is None:
# used for recovering fp32_weight
self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32)
self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32).to(device)
# unpack weight
if not self.use_optimum_format and self.compression_dim == 0:
qweight = qweight.T.contiguous()
2 changes: 1 addition & 1 deletion neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module

from .utility import cast_fp8, quant_tensor, search_clip
from .modules import INCWeightOnlyLinear

if is_transformers_imported():
import transformers
@@ -188,7 +189,6 @@ def convert(
int_weight = int_weight.t_().contiguous()
scale = scale.t_().contiguous()
zp = zp.t_().contiguous() if zp is not None else zp
from .modules import INCWeightOnlyLinear

new_module = INCWeightOnlyLinear(
in_features,
488 changes: 274 additions & 214 deletions neural_compressor/torch/algorithms/weight_only/save_load.py
Original file line number Diff line number Diff line change
@@ -18,19 +18,22 @@
import json
import os
import re
import tempfile

import torch

from neural_compressor.common.utils import load_config_mapping, save_config_mapping
from neural_compressor.common.utils import save_config_mapping
from neural_compressor.torch.utils import (
HPU_SAFE_WEIGHTS_NAME,
HPU_WEIGHT_NAME,
QCONFIG_NAME,
WEIGHT_NAME,
LoadFormat,
set_module,
logger,
)

)
from .utility import convert_dtype_str2torch
from .modules import HPUWeightOnlyLinear, INCWeightOnlyLinear, MulLinear

format_woqlinear_mapping = {LoadFormat.HUGGINGFACE: INCWeightOnlyLinear, LoadFormat.DEFAULT: INCWeightOnlyLinear}
@@ -53,16 +56,21 @@ def save(model, output_dir="./saved_results"):
del model.save
torch.save(model.state_dict(), qmodel_weight_file_path)

logger.info("Save quantized model to {}.".format(qmodel_weight_file_path))
logger.info("Save quantized model weight to {}.".format(qmodel_weight_file_path))
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))


def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device="cpu", **kwargs):
"""Load quantized weight-only quantization model.
1. Load INC weight-only quantized model in local.
2. Load HuggingFace weight-only quantized model,
including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub.
from neural_compressor.torch.quantization import load
load(model_name_or_path="saved_results", original_model=fp32_model, format="default", device="cpu")
2. Load HuggingFace weight-only quantized model, including GPTQ models and
upstreamed INC quantized models in HF model hub.
from neural_compressor.torch.quantization import load
load(model_name_or_path=model_name_or_path, format="huggingface", device="cpu")
Args:
model_name_or_path (str): torch checkpoint directory or hugginface model_name_or_path.
@@ -71,7 +79,7 @@ def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, dev
Parameter should not be None. it coworks with 'original_model' parameter to load INC
weight-only quantized model in local.
original_model (torch.nn.module, optional): original model before quantization.
Needed if 'format' is set to 'default' and not TorchScript model.Defaults to None.
Needed if 'format' is set to 'default'. Defaults to None.
format (str, optional): 'defult' for loading INC weight-only quantized model.
'huggingface' for loading huggingface WOQ causal language model. Defaults to "default".
kwargs (remaining dictionary of keyword arguments, optional):
@@ -146,7 +154,7 @@ def load_inc_format_woq_model(self):

# get loaded state_dict
self.loaded_state_dict = torch.load(qmodel_weight_file_path)
self.loaded_state_dict_keys = list(self.loaded_state_dict.keys())
self.loaded_state_dict_keys = list(set(self.loaded_state_dict.keys()))

# get qconfig
with open(qconfig_file_path, "r") as file:
@@ -155,8 +163,8 @@ def load_inc_format_woq_model(self):
# build weight-only quantization model with WeightOnlyLinear module
model = self._build_woq_model()

# load pretrained weight to weight-only quantization model
model.load_state_dict(self.loaded_state_dict, assign=True)
# load remianing pretrained weight to weight-only quantization model
model.load_state_dict(self.loaded_state_dict, assign=True, strict=False)

# save hpu format tensor to local directory
if self._should_save_hpu_format_tensor:
@@ -180,16 +188,20 @@ def load_hf_format_woq_model(self):

# get loaded state_dict
self.loaded_state_dict = self._get_loaded_state_dict(config)
self.loaded_state_dict_keys = list(self.loaded_state_dict.keys())
self.loaded_state_dict_keys = list(set(self.loaded_state_dict.keys()))

# initiate the huggingface model (FP32 model)
# initiate the huggingface model (FP32 empty model)
self.original_model = self._init_hf_model(model_class, config)

# build weight-only quantization model with WeightOnlyLinear module
# and load quantized weight to WeightOnlyLinear modules
model = self._build_woq_model()

# load pretrained weight to weight-only quantization model
model = self._load_pretrained_weight(model, model_class)
# clear loaded_state_dict
self.loaded_state_dict = {}

# load remaining pretrained weight to weight-only quantization model
model = self._load_remaining_pretrained_weight(model)

# save hpu format tensor to local directory
if self._should_save_hpu_format_tensor:
@@ -200,9 +212,26 @@ def load_hf_format_woq_model(self):

def _build_woq_model(self):
"""Build weight-only quantization model."""
from neural_compressor.torch.utils import set_module
self._update_format_woqlinear_mapping()

# if hpu woq linear module can be used directly, then update format mapping module to HPUWeightOnlyLinear
for name, module in self.original_model.named_modules():
# replace `torch.nn.Linear` to `WeightOnlyLinear` in self.original_model and load its quantized data
if isinstance(module, torch.nn.Linear):
# module without qweight means it is not quantized, then skip it
if (
name + ".qweight" not in self.loaded_state_dict_keys
and name + ".linear.qweight" not in self.loaded_state_dict_keys
):
continue

module_quantization_config, _is_autoround = self._get_module_quantization_config(name, module)
self._replace_woqlinear_modules(name, module, module_quantization_config, _is_autoround)

woq_model = self.original_model
return woq_model

def _update_format_woqlinear_mapping(self):
"""Update format mapping module to HPUWeightOnlyLinear if tensor is hpu format."""
if self._use_hpu_module():
format_woqlinear_mapping.update({self.format: HPUWeightOnlyLinear})

@@ -212,115 +241,123 @@ def _build_woq_model(self):
f"Device mapping is {device_woqlinear_mapping}."
)

for name, module in self.original_model.named_modules():
_is_autoround = False
# get quantization config of module
module_quantization_config = self.quantization_config
# pattern will map (module_name, moduele_type)
pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))"
for q_config_key, q_config_value in self.quantization_config.items():
if re.search(pattern, q_config_key):
if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "autoround":
_is_autoround = True
module_quantization_config = [config for config in q_config_value.values()][0]

# replace `torch.nn.Linear` with `WeightOnlyLinear`
if isinstance(module, torch.nn.Linear):
# module without qweight means it is not quantized, then skip it
loaded_state_dict_keys_set = set(self.loaded_state_dict_keys)
if (
name + ".qweight" not in loaded_state_dict_keys_set
and name + ".linear.qweight" not in loaded_state_dict_keys_set
):
continue
def _get_module_quantization_config(self, module_name, module):
"""Gt quantization config of current module.
# insert MulLinear module
if name + ".linear.qweight" in loaded_state_dict_keys_set:
new_module = MulLinear(module)
set_module(self.original_model, name, new_module)
name += ".linear"

zp = True if name + ".qzeros" in loaded_state_dict_keys_set else False
g_idx = True if name + ".g_idx" in loaded_state_dict_keys_set else False

WeightOnlyLinearClass = format_woqlinear_mapping[self.format]
kwargs = {}
if WeightOnlyLinearClass == INCWeightOnlyLinear:
kwargs["group_size"] = module_quantization_config.get("group_size", 32)
kwargs["g_idx"] = g_idx
if _is_autoround:
from .utility import convert_dtype_str2torch

kwargs["scale_dtype"] = convert_dtype_str2torch(
module_quantization_config.get("scale_dtype", "fp16")
)
elif WeightOnlyLinearClass == HPUWeightOnlyLinear:
# TODO: update kwargs specific to HPUWeightOnlyLinear
kwargs["group_size"] = module_quantization_config.get("group_size", 32)
kwargs["g_idx"] = g_idx

new_module = WeightOnlyLinearClass(
module.in_features,
module.out_features,
dtype=module_quantization_config.get("dtype", "int"),
bits=module_quantization_config.get("bits", 4),
zp=zp,
bias=module.bias is not None,
use_optimum_format=True,
**kwargs,
1. INC weight-only quantization model, quantization_config will be structured in module level like:
{(module1_name, module1_type): {"rtn": {"bits": 4, ...}}, ...}
2. HF weight-only quantization model, quantization_config will be structured in model level like:
{'bits': 4, ...}
"""
module_quantization_config = self.quantization_config
pattern = rf"(\(.*{re.escape(module_name)}.*{re.escape(type(module).__name__)}.*\))"
_is_autoround = False
# for loop is used to find quantization config of the target module in INC weight-only quantization model
for q_config_key, q_config_value in self.quantization_config.items():
if re.search(pattern, q_config_key):
# pattern will map (module_name, moduele_type)
if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "autoround":
_is_autoround = True
module_quantization_config = [config for config in q_config_value.values()][0]
return module_quantization_config, _is_autoround

def _replace_woqlinear_modules(self, name, linear_module, module_quantization_config, _is_autoround):
"""Replace torch.nn.Linear modules with WeightOnlyLinear and load its quantized data."""
# insert MulLinear module for AWQ/TEQ algorithm
if name + ".linear.qweight" in self.loaded_state_dict_keys:
new_module = MulLinear(linear_module)
set_module(self.original_model, name, new_module)
name += ".linear"

# get format mapping module class
WeightOnlyLinearClass = format_woqlinear_mapping[self.format]

# base initialization kwargs
base_kwargs = {}
base_kwargs["in_features"] = linear_module.in_features
base_kwargs["out_features"] = linear_module.out_features
base_kwargs["dtype"] = module_quantization_config.get("dtype", "int")
base_kwargs["bits"] = module_quantization_config.get("bits", 4)
base_kwargs["group_size"] = module_quantization_config.get("group_size", 32)

# class spceific initialization kwargs
class_specific_kwargs = {}
if WeightOnlyLinearClass == INCWeightOnlyLinear:
class_specific_kwargs["g_idx"] = True if name + ".g_idx" in self.loaded_state_dict_keys else False
class_specific_kwargs["zp"] = True if name + ".qzeros" in self.loaded_state_dict_keys else False
class_specific_kwargs["use_optimum_format"]=True
class_specific_kwargs["bias"] = linear_module.bias is not None
if _is_autoround:
class_specific_kwargs["scale_dtype"] = convert_dtype_str2torch(
module_quantization_config.get("scale_dtype", "fp16")
)
elif WeightOnlyLinearClass == HPUWeightOnlyLinear:
# TODO: update kwargs specific to HPUWeightOnlyLinear
class_specific_kwargs["g_idx"] = True if name + ".g_idx" in self.loaded_state_dict_keys else False
class_specific_kwargs["zp"] = True if name + ".qzeros" in self.loaded_state_dict_keys else False
class_specific_kwargs["use_optimum_format"]=True
class_specific_kwargs["bias"] = linear_module.bias is not None

# load quantized data of current module
new_module_state_dict = {}
keys = [".qweight", ".scales", ".qzeros", ".bias", ".g_idx"]
for key in keys:
if name + key in self.loaded_state_dict:
new_module_state_dict[key[1:]] = self.loaded_state_dict[name + key]
self.load_data_to_new_module_from_state_dict(new_module, new_module_state_dict)

# if format mapping module doesn't match device mapping module, then replace to device mapping module
if format_woqlinear_mapping[self.format] != device_woqlinear_mapping[self.device]:
logger.debug(
f"Replacing {name}'s type from "
f"'{format_woqlinear_mapping[self.format].__name__}' "
f"to '{device_woqlinear_mapping[self.device].__name__}'"
)
WeightOnlyLinearClass = device_woqlinear_mapping[self.device]

# update kwargs for the device mapping WeightOnlyLinear module
kwargs = {}
if WeightOnlyLinearClass == INCWeightOnlyLinear:
kwargs["group_size"] = module_quantization_config.get("group_size", 32)
kwargs["g_idx"] = g_idx
elif WeightOnlyLinearClass == HPUWeightOnlyLinear:
# TODO: update kwargs specific to HPUWeightOnlyLinear
kwargs["group_size"] = module_quantization_config.get("group_size", 32)
kwargs["g_idx"] = g_idx

int_weight, scale, zp = new_module.unpack()
bias = new_module.bias
g_idx = new_module.g_idx
new_module = WeightOnlyLinearClass(
module.in_features,
module.out_features,
dtype=module_quantization_config.get("dtype", "int"),
bits=module_quantization_config.get("bits", 4),
zp=zp,
bias=module.bias is not None,
use_optimum_format=True,
**kwargs,
)
new_module.pack(int_weight.to(self.device), scale.to(self.device), zp.to(self.device))
new_module.bias = bias.to(self.device)
new_module.g_idx = g_idx.to(self.device)
new_module = WeightOnlyLinearClass(**base_kwargs, **class_specific_kwargs)

# if the new module is HPUWeightOnlyLinear, save hpu_model.safetensors for next loading
if not self._should_save_hpu_format_tensor and WeightOnlyLinearClass == HPUWeightOnlyLinear:
self._should_save_hpu_format_tensor = True
# load quantized data of current module
self._load_data_to_new_module(new_module, name)

set_module(self.original_model, name, new_module)
woq_model = self.original_model
return woq_model
# if format mapping module doesn't match device mapping module, then replace to device mapping module
if format_woqlinear_mapping[self.format] != device_woqlinear_mapping[self.device]:
new_module = self._update_mapped_woqlinear_modules(name, new_module, base_kwargs)

set_module(self.original_model, name, new_module)

def _load_data_to_new_module(self, new_module, module_name):
new_module_state_dict = {}
for key in [".qweight", ".scales", ".qzeros", ".bias", ".g_idx"]:
full_name = module_name + key
if full_name in self.loaded_state_dict:
new_module_state_dict[key[1:]] = self.loaded_state_dict.pop(full_name)
self.loaded_state_dict_keys.remove(full_name)
new_module.load_state_dict(new_module_state_dict)

def _update_mapped_woqlinear_modules(self, name, format_woqlinear_module, base_kwargs):
"""Update the format mapping module to device mapping module."""
OldWeightOnlyLinearClass = format_woqlinear_mapping[self.format]

# get deivice mapping module class
NewWeightOnlyLinearClass = device_woqlinear_mapping[self.device]

logger.debug(
f"Replacing {name}'s type from "
f"'{OldWeightOnlyLinearClass.__name__}' "
f"to '{NewWeightOnlyLinearClass.__name__}'"
)

if OldWeightOnlyLinearClass == INCWeightOnlyLinear and NewWeightOnlyLinearClass == HPUWeightOnlyLinear:
# INCWeightOnlyLinear --> HPUWeightOnlyLinear
# TODO: need to rewrite once HPUWeightOnlyLinear is implemented
class_specific_kwargs = {} # class spceific initialization kwargs
class_specific_kwargs["g_idx"] = getattr(format_woqlinear_module, "g_idx", None) is not None
class_specific_kwargs["zp"] = getattr(format_woqlinear_module, "qzeros", None) is not None
class_specific_kwargs["use_optimum_format"] = True
class_specific_kwargs["bias"] = getattr(format_woqlinear_module, "bias", None) is not None

int_weight, scale, zp = format_woqlinear_module.unpack()
bias = format_woqlinear_module.bias
g_idx = format_woqlinear_module.g_idx

# initialize the new WeightOnlyLinearClass
device_mapping_module = NewWeightOnlyLinearClass(**base_kwargs, **class_specific_kwargs,)

device_mapping_module.pack(int_weight.to(self.device), scale.to(self.device), zp.to(self.device))
device_mapping_module.bias = bias.to(self.device)
device_mapping_module.g_idx = g_idx.to(self.device)

# if the new module is HPUWeightOnlyLinear, save hpu format tensor for next loading
if not self._should_save_hpu_format_tensor:
self._should_save_hpu_format_tensor = True
else:
raise RuntimeError(f"Not support update {OldWeightOnlyLinearClass.__name__} to {NewWeightOnlyLinearClass.__name__}.")

return device_mapping_module

def _get_model_class_and_config(self):
from transformers import AutoConfig, AutoModelForCausalLM
@@ -362,17 +399,10 @@ def _get_model_class_and_config(self):

def _get_loaded_state_dict(self, config):
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import _add_variant, get_checkpoint_shard_files, load_state_dict
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
cached_file,
download_url,
extract_commit_hash,
has_file,
is_remote_url,
is_safetensors_available,
)

@@ -446,9 +476,90 @@ def _get_loaded_state_dict(self, config):

self.model_name_or_path = str(self.model_name_or_path)

# get resolved weight archive file
kwargs = {
"use_safetensors": use_safetensors,
"variant": variant,
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"token": token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
resolved_archive_file = self._get_resolved_archive_file(**kwargs)

self._model_local_dir = os.path.abspath(os.path.expanduser(os.path.dirname(resolved_archive_file)))
# if hpu format tensor can be used directly, then update resolved_archive_file to the hpu format tensor file
if self._use_hpu_module():
resolved_archive_file = os.path.join(self._model_local_dir, HPU_SAFE_WEIGHTS_NAME)

logger.info(f"Find weight file {resolved_archive_file}")

if is_sharded: # pragma: no cover
# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
self.model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
self.kwargs["sharded_metadata"] = sharded_metadata

# Time to load the checkpoint
state_dict = None
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
for shard_file in resolved_archive_file:
if state_dict is None:
state_dict = load_state_dict(shard_file)
else:
state_dict.update(load_state_dict(shard_file))

# set kwargs for next functions to use
self.kwargs["is_sharded"] = is_sharded
self.kwargs["offload_folder"] = offload_folder
self.kwargs["offload_state_dict"] = offload_state_dict
self.kwargs["resolved_archive_file"] = resolved_archive_file

return state_dict

def _get_resolved_archive_file(self, **kwargs):
"""Get weight archive file of model."""
from transformers.utils import is_remote_url
from transformers.modeling_utils import _add_variant
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
cached_file,
download_url,
has_file,
)

use_safetensors = kwargs.pop("use_safetensors")
variant = kwargs.pop("variant")
subfolder = kwargs.get("subfolder")

resolved_archive_file = None
is_local = os.path.isdir(self.model_name_or_path)
if is_local: # pragma: no cover
# self.model_name_or_path is a local directory
if os.path.isfile(
os.path.join(
self.model_name_or_path,
@@ -507,29 +618,18 @@ def _get_loaded_state_dict(self, config):
archive_file = self.model_name_or_path
is_local = True
elif is_remote_url(self.model_name_or_path): # pragma: no cover
# self.model_name_or_path is a url
filename = self.model_name_or_path
resolved_archive_file = download_url(self.model_name_or_path)
else:
# self.model_name_or_path is a model_id in huggingface
if use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
filename = _add_variant(WEIGHTS_NAME, variant)
try:
# Load from URL or cache if already cached
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"token": token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
cached_file_kwargs = kwargs
resolved_archive_file = cached_file(self.model_name_or_path, filename, **cached_file_kwargs)

# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
@@ -573,9 +673,9 @@ def _get_loaded_state_dict(self, config):
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"token": token,
"revision": cached_file_kwargs.get("revision"),
"proxies": cached_file_kwargs.get("proxies"),
"token": cached_file_kwargs.get("token"),
}
if variant is not None and has_file(self.model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
@@ -604,45 +704,8 @@ def _get_loaded_state_dict(self, config):
if is_local:
resolved_archive_file = archive_file

self._model_local_dir = os.path.abspath(os.path.expanduser(os.path.dirname(resolved_archive_file)))
# if hpu format tensor can be used directly, then update resolved_archive_file to the hpu format tensor file
if self._use_hpu_module():
resolved_archive_file = os.path.join(self._model_local_dir, HPU_SAFE_WEIGHTS_NAME)
return resolved_archive_file

logger.info(f"Find weight file {resolved_archive_file}")

if is_sharded: # pragma: no cover
# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
self.model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
self.kwargs["sharded_metadata"] = sharded_metadata

if is_sharded: # pragma: no cover
state_dict = sharded_metadata["weight_map"]
else:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)

# set kwargs for next functions to use
self.kwargs["is_sharded"] = is_sharded
self.kwargs["offload_folder"] = offload_folder
self.kwargs["offload_state_dict"] = offload_state_dict
self.kwargs["resolved_archive_file"] = resolved_archive_file

# return loaded_state_dict_keys
return state_dict

def _init_hf_model(self, model_class, config):
from accelerate.big_modeling import init_empty_weights
@@ -690,50 +753,50 @@ def _init_hf_model(self, model_class, config):

# set kwargs for next functions to use
self.kwargs["resolved_archive_file"] = resolved_archive_file
self.kwargs["sharded_metadata"] = sharded_metadata
self.kwargs["torch_dtype"] = torch_dtype
self.kwargs["dtype_orig"] = dtype_orig
self.kwargs["_fast_init"] = _fast_init
self.kwargs["offload_folder"] = offload_folder
self.kwargs["offload_state_dict"] = offload_state_dict

return model

def _load_pretrained_weight(self, model, model_class):
def _load_remaining_pretrained_weight(self, model):
"""Load remaing pretrained weight.
In _build_woq_model function, linear will be replaced to weight-only quantization linear
and its quantized weight will be loaded. Remaining pretrained weight (like layernorm weight,
embedding weight or other unquantized linear weight) will be loaded in this function.
"""
from transformers.modeling_utils import load_state_dict
from transformers.modeling_utils import _load_state_dict_into_meta_model

resolved_archive_file = self.kwargs.pop("resolved_archive_file", None)
sharded_metadata = self.kwargs.pop("sharded_metadata", None)
torch_dtype = self.kwargs.pop("torch_dtype", torch.float32)
dtype_orig = self.kwargs.pop("dtype_orig", None)
_fast_init = self.kwargs.pop("_fast_init", True)
offload_folder = self.kwargs.pop("offload_folder", None)
offload_state_dict = self.kwargs.pop("offload_state_dict", False)

# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
self.loaded_state_dict_keys,
resolved_archive_file,
self.model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
device_map={"": self.device},
)
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
_load_state_dict_into_meta_model(
model=model,
state_dict=state_dict,
loaded_state_dict_keys=self.loaded_state_dict_keys,
start_prefix="",
expected_keys=list(state_dict.keys()),
device_map={"": self.device},
offload_folder=offload_folder,
state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None,
state_dict_index={} if offload_state_dict else None,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
@@ -743,9 +806,6 @@ def _load_pretrained_weight(self, model, model_class):

return model

def load_data_to_new_module_from_state_dict(self, new_module, new_module_weight):
new_module.load_state_dict(new_module_weight)

def _save_hpu_format_tensor(self, model): # pragma: no cover
from safetensors.torch import save_file

@@ -754,11 +814,11 @@ def _save_hpu_format_tensor(self, model): # pragma: no cover

if self.format == LoadFormat.HUGGINGFACE:
filename = os.path.join(self._model_local_dir, HPU_SAFE_WEIGHTS_NAME)
save_file(model.state_dict(), filename=filename, metadata={"format": "pt"})
save_file({k: v.cpu() for k, v in model.state_dict().items()}, filename=filename, metadata={"format": "pt"})
logger.debug(f"Save hpu format tensor to {filename}")
elif self.format == LoadFormat.DEFAULT:
qmodel_weight_file_path = os.path.join(self._model_local_dir, HPU_WEIGHT_NAME)
torch.save(model.state_dict(), qmodel_weight_file_path)
torch.save({k: v.cpu() for k, v in model.state_dict().items()}, qmodel_weight_file_path)
logger.debug(f"Save hpu format tensor to {qmodel_weight_file_path}")

def _use_hpu_module(self): # pragma: no cover