Skip to content

Commit 0978e15

Browse files
committed
update with extra_repr
Signed-off-by: Xin He <xinhe3@habana.ai>
1 parent e1eb22d commit 0978e15

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

+22
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,13 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
591591
setattr(self, "w2_weight", None)
592592
self.forward = self.forward_orig
593593

594+
def extra_repr(self) -> str:
595+
return extra_representation(
596+
self.extra_repr_org(),
597+
self.class_name_org,
598+
get_current_repr(self),
599+
)
600+
594601

595602
# This patched module is called by the vllm-mixtral FusedMoE layer
596603
# we wrap each expert weight with this module since FusedMoE has a single tensor for all experts weights
@@ -853,6 +860,13 @@ def update_measure(self, prev, cur, dim, idx, inp_seq_len):
853860
measure_output((output,), self._mod_extra_config.outputs)
854861
return output
855862

863+
def extra_repr(self) -> str:
864+
return extra_representation(
865+
self.extra_repr_org(),
866+
self.class_name_org,
867+
get_current_repr(self),
868+
)
869+
856870

857871
class PatchedVLLMKVCache(PatchedModuleBase):
858872
# Module to patch VLLMKVCache module from llama model
@@ -891,6 +905,14 @@ def fetch_from_cache(self, cache, blocks, permutations=None):
891905
output_cache = self.orig_fetch_from_cache(quant_cache, blocks)
892906
return self.dequant_output(output_cache)
893907

908+
def extra_repr(self) -> str:
909+
return extra_representation(
910+
self.extra_repr_org(),
911+
self.class_name_org,
912+
get_current_repr(self),
913+
)
914+
915+
894916
def init_conv(instance, mod_extra_config):
895917
if instance.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
896918
instance.quant_input = instance._mod_extra_config.inputs[0]

neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -175,36 +175,37 @@ def forward_quant(self, *args, **kwargs):
175175

176176
@classmethod
177177
def get_module_info(cls) -> ModuleInfo:
178-
"""Return the module info for the module.
178+
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
179+
Return the module info for the module, which is used to determine the scaling methods for the module.
179180
180181
For example, for linear module, the module info is: ModuleInfo(type="linear", patched_module=cls).
181182
"""
182183
return ModuleInfo(type=cls.get_type(), patched_module=cls)
183184

184185
@classmethod
185186
def get_type(cls) -> str:
186-
"""Return the type of the patched module.
187+
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
188+
Return the type of the patched module, which is used to determine the scaling methods for the module.
187189
188190
Multiple patched modules can have the same type, and share the same scaling methods.
189191
"""
190192
raise NotImplementedError("`get_type` is not implemented")
191193

192194
@classmethod
193195
def get_module_type(cls) -> ModuleType:
194-
"""Return the module type for the module.
196+
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
197+
Return the module type for the module, which is used to determine the number of inputs, outputs, and parameters of the module.
195198
196199
The module type is used to determine the number of inputs, outputs, and parameters of the module.
197200
For example, for linear module, the module type is: ModuleType(1, ["weight"], 1, False).
198201
"""
199202
raise NotImplementedError("`get_module_type` is not implemented")
200203

201204
def extra_repr(self):
202-
try:
203-
return f"quantization_mode={self.quantization_mode}, " + \
204-
f"module_info={self.get_module_info()}, " + \
205-
f"module_type={self.get_module_type()}"
206-
except NotImplementedError:
207-
return ""
205+
"""This extra_repr is only for the newly registered patched module that doesn't in _mod_default_dict."""
206+
return f"quantization_mode={self.quantization_mode}, " + \
207+
f"module_info={self.get_module_info()}, " + \
208+
f"module_type={self.get_module_type()}"
208209

209210

210211
def _create_and_register_helper_module_class(name):

0 commit comments

Comments
 (0)