Skip to content

Commit 600436e

Browse files
authored
fix device check (#2136)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 53240c3 commit 600436e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

optimum/gptq/quantizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
520520
blocks = recurse_getattr(model, self.block_name_to_quantize)
521521

522522
cur_layer_device = get_device(blocks[0])
523-
if not is_gptqmodel_available():
523+
if not is_gptqmodel_available() and cur_layer_device.type == "cpu":
524524
cur_layer_device = 0
525525

526526
if not has_device_map:
@@ -592,7 +592,7 @@ def store_input_hook(_, input, *args):
592592
block = block.to(0)
593593
layers = get_layers(block)
594594
block_device = get_device(block)
595-
if not is_gptqmodel_available():
595+
if not is_gptqmodel_available() and block_device.type == "cpu":
596596
block_device = 0
597597
if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0:
598598
if self.true_sequential:

0 commit comments

Comments
 (0)