Skip to content

Commit 5d7667a

Browse files
committed
fix a bug after transformer upgrate
Signed-off-by: Xin He <xinhe3@habana.ai>
1 parent 2b220bb commit 5d7667a

File tree

1 file changed

+2
-1
lines changed
  • neural_compressor/torch/algorithms/weight_only

1 file changed

+2
-1
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ def forward(layer, *args, **kwargs):
391391
for arg in kwargs:
392392
# TODO: investigate include parameters
393393
# each outputs can be different shape, hence also use list to store
394-
if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi":
394+
# position_embeddings is a list of two tensors in llama, handle it specially.
395+
if isinstance(kwargs[arg], torch.Tensor) or arg in ["alibi", "position_embeddings"]:
395396
if self.cache_key_arguments.get(arg, None) is None:
396397
self.cache_key_arguments[arg] = []
397398
self.cache_key_arguments[arg].append(kwargs[arg])

0 commit comments

Comments
 (0)