@@ -283,6 +283,7 @@ def get_fq_insertion_command(
283
283
orig_weight_shape : Tuple [int , ...],
284
284
compression_format : CompressionFormat ,
285
285
lora_adapter_rank : int ,
286
+ is_all_8bit : bool ,
286
287
) -> PTTransformationCommand :
287
288
"""
288
289
Creates a fake quantization insertion command for the given compressed weight.
@@ -291,9 +292,11 @@ def get_fq_insertion_command(
291
292
:param wc_params: Parameters for weight compression.
292
293
:param orig_weight_shape: The original shape of the weight tensor.
293
294
:param compression_format: The format of compression.
295
+ :param is_all_8bit: Flag indicating if all weights should be compressed to 8-bit.
294
296
:return: A PTTransformationCommand for inserting fake quantization to the model.
295
297
"""
296
298
compression_config = wc_params .compression_config
299
+ # default mapping for 4bit weight compression and FQ_LORA format, no need to add lora adapters for 8bit weight
297
300
mode_vs_schema_map = {
298
301
CompressWeightsMode .INT4_ASYM : QuantizationScheme .ASYMMETRIC_LORA ,
299
302
CompressWeightsMode .INT4_SYM : QuantizationScheme .SYMMETRIC_LORA ,
@@ -303,6 +306,9 @@ def get_fq_insertion_command(
303
306
if compression_format == CompressionFormat .FQ :
304
307
mode_vs_schema_map [CompressWeightsMode .INT4_ASYM ] = QuantizationScheme .ASYMMETRIC
305
308
mode_vs_schema_map [CompressWeightsMode .INT4_SYM ] = QuantizationScheme .SYMMETRIC
309
+ if is_all_8bit and compression_format == CompressionFormat .FQ_LORA :
310
+ mode_vs_schema_map [CompressWeightsMode .INT8_ASYM ] = QuantizationScheme .ASYMMETRIC_LORA
311
+ mode_vs_schema_map [CompressWeightsMode .INT8_SYM ] = QuantizationScheme .SYMMETRIC_LORA
306
312
307
313
schema = mode_vs_schema_map [compression_config .mode ]
308
314
@@ -469,6 +475,7 @@ def transform_model(
469
475
model_transformer = PTModelTransformer (model )
470
476
471
477
transformation_layout = TransformationLayout ()
478
+ is_all_8bit = all (wc_params .compression_config .num_bits == 8 for wc_params in weight_compression_parameters )
472
479
for wc_params in weight_compression_parameters :
473
480
compression_config = wc_params .compression_config
474
481
if compression_config .mode in [
@@ -499,7 +506,7 @@ def transform_model(
499
506
else :
500
507
rank = advanced_parameters .lora_adapter_rank
501
508
command = self .get_fq_insertion_command (
502
- compressed_weight , wc_params , weight .shape , compression_format , rank
509
+ compressed_weight , wc_params , weight .shape , compression_format , rank , is_all_8bit
503
510
)
504
511
transformation_layout .register (command )
505
512
0 commit comments