diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 07e45980..bb7212f7 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -308,28 +308,28 @@ def process_quantization_config(self) -> None: self.logger.info("\t+ Processing AutoQuantization config") self.quantization_config = AutoQuantizationConfig.from_dict( - (getattr(self.pretrained_config, "quantization_config") or {}).update(self.config.quantization_config) + dict(**getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) ) @property def is_quantized(self) -> bool: return self.config.quantization_scheme is not None or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) is not None + and self.pretrained_config.quantization_config.get("quant_method") is not None ) @property def is_gptq_quantized(self) -> bool: return self.config.quantization_scheme == "gptq" or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq" + and self.pretrained_config.quantization_config.get("quant_method") == "gptq" ) @property def is_awq_quantized(self) -> bool: return self.config.quantization_scheme == "awq" or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "awq" + and self.pretrained_config.quantization_config.get("quant_method") == "awq" ) @property @@ -341,11 +341,11 @@ def is_exllamav2(self) -> bool: ( hasattr(self.pretrained_config, "quantization_config") and hasattr(self.pretrained_config.quantization_config, "exllama_config") - and self.pretrained_config.quantization_config.exllama_config.get("version", None) == 2 + and self.pretrained_config.quantization_config.exllama_config.get("version") == 2 ) or ( "exllama_config" in self.config.quantization_config - and self.config.quantization_config["exllama_config"].get("version", None) == 2 + and self.config.quantization_config["exllama_config"].get("version") == 2 ) ) )