@@ -175,36 +175,37 @@ def forward_quant(self, *args, **kwargs):
175
175
176
176
@classmethod
177
177
def get_module_info (cls ) -> ModuleInfo :
178
- """Return the module info for the module.
178
+ """Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
179
+ Return the module info for the module, which is used to determine the scaling methods for the module.
179
180
180
181
For example, for linear module, the module info is: ModuleInfo(type="linear", patched_module=cls).
181
182
"""
182
183
return ModuleInfo (type = cls .get_type (), patched_module = cls )
183
184
184
185
@classmethod
185
186
def get_type (cls ) -> str :
186
- """Return the type of the patched module.
187
+ """Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
188
+ Return the type of the patched module, which is used to determine the scaling methods for the module.
187
189
188
190
Multiple patched modules can have the same type, and share the same scaling methods.
189
191
"""
190
192
raise NotImplementedError ("`get_type` is not implemented" )
191
193
192
194
@classmethod
193
195
def get_module_type (cls ) -> ModuleType :
194
- """Return the module type for the module.
196
+ """Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
197
+ Return the module type for the module, which is used to determine the number of inputs, outputs, and parameters of the module.
195
198
196
199
The module type is used to determine the number of inputs, outputs, and parameters of the module.
197
200
For example, for linear module, the module type is: ModuleType(1, ["weight"], 1, False).
198
201
"""
199
202
raise NotImplementedError ("`get_module_type` is not implemented" )
200
203
201
204
def extra_repr (self ):
202
- try :
203
- return f"quantization_mode={ self .quantization_mode } , " + \
204
- f"module_info={ self .get_module_info ()} , " + \
205
- f"module_type={ self .get_module_type ()} "
206
- except NotImplementedError :
207
- return ""
205
+ """This extra_repr is only for the newly registered patched module that doesn't in _mod_default_dict."""
206
+ return f"quantization_mode={ self .quantization_mode } , " + \
207
+ f"module_info={ self .get_module_info ()} , " + \
208
+ f"module_type={ self .get_module_type ()} "
208
209
209
210
210
211
def _create_and_register_helper_module_class (name ):
0 commit comments