From 6e83384225b5b015bb6c8c23fad0c98dc8bedfe4 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Wed, 13 Mar 2024 13:45:20 +0100 Subject: [PATCH] Fix gptq exllamav2 check (#157) --- optimum_benchmark/backends/pytorch/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 981e8baa..7dd2d7c8 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -278,13 +278,13 @@ def is_awq_quantized(self) -> bool: def is_exllamav2(self) -> bool: return (self.is_gptq_quantized or self.is_awq_quantized) and ( ( - getattr(self.pretrained_config, "quantization_config", None) is not None - and getattr(self.pretrained_config.quantization_config, "exllama_config", None) is not None + hasattr(self.pretrained_config, "quantization_config") + and hasattr(self.pretrained_config.quantization_config, "exllama_config") and self.pretrained_config.quantization_config.exllama_config.get("exllama_version", None) == 2 ) or ( - self.config.quantization_config.get("exllama_config", None) is not None - and self.config.quantization_config.exllama_config.get("exllama_version", None) == 2 + "exllama_config" in self.config.quantization_config + and self.config.quantization_config["exllama_config"].get("exllama_version", None) == 2 ) )