Skip to content

Commit 2c81219

Browse files
authored
Fix gptq device_map = "cpu" (#1662)
* fix gptq cpu device_map * fix test * remove default dict
1 parent 3b4f5ac commit 2c81219

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

optimum/gptq/quantizer.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -332,20 +332,23 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
332332
use_cache = model.config.use_cache
333333
model.config.use_cache = False
334334

335+
# If the model has a device_map, we don't move to model. We have already dispatched the hook that will do the work
335336
if hasattr(model, "hf_device_map"):
336337
devices = list(model.hf_device_map.values())
338+
has_device_map = True
337339
if "disk" in devices:
338340
raise ValueError("disk offload is not supported with GPTQ quantization")
339-
if "cpu" in devices and len(model.hf_device_map) > 1:
340-
logger.info("Cpu offload is not recommended. There might be some issues with the memory")
341-
hook = None
342-
for name, device in model.hf_device_map.items():
343-
if device == "cpu":
344-
module = recurse_getattr(model, name)
345-
remove_hook_from_module(module, recurse=True)
346-
module, hook = cpu_offload_with_hook(module, prev_module_hook=hook)
347-
# If the model has a device_map, we don't move to model. We have already dispatched the hook that will do the work
348-
has_device_map = True
341+
if "cpu" in devices or torch.device("cpu") in devices:
342+
if len(model.hf_device_map) > 1:
343+
logger.info("Cpu offload is not recommended. There might be some issues with the memory")
344+
hook = None
345+
for name, device in model.hf_device_map.items():
346+
if device == "cpu":
347+
module = recurse_getattr(model, name)
348+
remove_hook_from_module(module, recurse=True)
349+
module, hook = cpu_offload_with_hook(module, prev_module_hook=hook)
350+
else:
351+
has_device_map = False
349352

350353
if hasattr(model, "dtype"):
351354
self.use_cuda_fp16 = model.dtype == torch.float16

tests/gptq/test_quantization.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class GPTQTest(unittest.TestCase):
5454
exllama_config = None
5555
cache_block_outputs = True
5656
modules_to_quantize_inside_block = None
57-
57+
device_map_for_quantization = "cuda"
5858
dataset = [
5959
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
6060
]
@@ -66,7 +66,7 @@ def setUpClass(cls):
6666
Setup quantized model
6767
"""
6868
cls.model_fp16 = AutoModelForCausalLM.from_pretrained(
69-
cls.model_name, torch_dtype=torch.float16, device_map={"": 0}
69+
cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map_for_quantization
7070
)
7171
cls.mem_fp16 = cls.model_fp16.get_memory_footprint()
7272

@@ -168,6 +168,13 @@ def test_serialization(self):
168168
self.check_inference_correctness(quantized_model_from_saved)
169169

170170

171+
class GPTQTestCPUInit(GPTQTest):
172+
device_map_for_quantization = "cpu"
173+
174+
def test_generate_quality(self):
175+
self.check_inference_correctness(self.quantized_model.to(0))
176+
177+
171178
class GPTQTestExllama(GPTQTest):
172179
disable_exllama = False
173180
exllama_config = {"version": 1}

0 commit comments

Comments
 (0)