Skip to content

Commit 9bddd52

Browse files
fix device mapping issue of llama gptq (#2101)
Signed-off-by: Xin He <xinhe3@habana.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b2d019f commit 9bddd52

File tree

1 file changed

+2
-1
lines changed
  • neural_compressor/torch/algorithms/weight_only

1 file changed

+2
-1
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn
116116
gptq_related_blocks["transformers"] = m
117117
find_transformers = True
118118
# return gptq_related_blocks
119-
elif is_leaf(m) and not find_transformers:
119+
elif (is_leaf(m) and not find_transformers) or "Embedding" in type(m).__name__:
120+
# "Embedding" in type(m).__name__ to resolve 'LlamaRotaryEmbedding'
120121
gptq_related_blocks["embeddings"][n] = m
121122
elif n.find(gptq_related_blocks["transformers_name"]) == -1 and find_transformers:
122123
# no longer belong to transformers

0 commit comments

Comments
 (0)