diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 981e8baa..7dd2d7c8 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -278,13 +278,13 @@ def is_awq_quantized(self) -> bool: def is_exllamav2(self) -> bool: return (self.is_gptq_quantized or self.is_awq_quantized) and ( ( - getattr(self.pretrained_config, "quantization_config", None) is not None - and getattr(self.pretrained_config.quantization_config, "exllama_config", None) is not None + hasattr(self.pretrained_config, "quantization_config") + and hasattr(self.pretrained_config.quantization_config, "exllama_config") and self.pretrained_config.quantization_config.exllama_config.get("exllama_version", None) == 2 ) or ( - self.config.quantization_config.get("exllama_config", None) is not None - and self.config.quantization_config.exllama_config.get("exllama_version", None) == 2 + "exllama_config" in self.config.quantization_config + and self.config.quantization_config["exllama_config"].get("exllama_version", None) == 2 ) )