From 27d4efb59618d715a1e443f129ce75e2ea3d6a49 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 31 Jan 2025 11:55:10 +0100 Subject: [PATCH] fix --- optimum_benchmark/backends/pytorch/backend.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 07e45980..bb7212f7 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -308,28 +308,28 @@ def process_quantization_config(self) -> None: self.logger.info("\t+ Processing AutoQuantization config") self.quantization_config = AutoQuantizationConfig.from_dict( - (getattr(self.pretrained_config, "quantization_config") or {}).update(self.config.quantization_config) + dict(**getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) ) @property def is_quantized(self) -> bool: return self.config.quantization_scheme is not None or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) is not None + and self.pretrained_config.quantization_config.get("quant_method") is not None ) @property def is_gptq_quantized(self) -> bool: return self.config.quantization_scheme == "gptq" or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq" + and self.pretrained_config.quantization_config.get("quant_method") == "gptq" ) @property def is_awq_quantized(self) -> bool: return self.config.quantization_scheme == "awq" or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "awq" + and self.pretrained_config.quantization_config.get("quant_method") == "awq" ) @property @@ -341,11 +341,11 @@ def is_exllamav2(self) -> bool: ( hasattr(self.pretrained_config, "quantization_config") and hasattr(self.pretrained_config.quantization_config, "exllama_config") - and self.pretrained_config.quantization_config.exllama_config.get("version", None) == 2 + and self.pretrained_config.quantization_config.exllama_config.get("version") == 2 ) or ( "exllama_config" in self.config.quantization_config - and self.config.quantization_config["exllama_config"].get("version", None) == 2 + and self.config.quantization_config["exllama_config"].get("version") == 2 ) ) )