@@ -53,6 +53,17 @@ def forward(self, x):
53
53
54
54
55
55
##################### FP8 modules #######################
56
+ def _map_guadi2_scale (scale ):
57
+ USE_GUADI2_SCALE = os .environ .get ("USE_GUADI2_SCALE" )
58
+ if USE_GUADI2_SCALE :
59
+ scale_list = torch .tensor ([16 , 1 , 1 / 16 , 1 / 256 ])
60
+ for i in scale_list :
61
+ if scale > i or i == torch .tensor (1 / 256 ):
62
+ return i
63
+ else :
64
+ return scale
65
+
66
+
56
67
class FP8DynamicLinear (torch .nn .Module ):
57
68
def __init__ (self , org_module , dtype = torch .float8_e4m3fn ) -> None :
58
69
super ().__init__ ()
@@ -86,6 +97,7 @@ def __init__(self, org_module, dtype=torch.float8_e4m3fn) -> None:
86
97
# scale = HF_max /amax
87
98
if self .use_amax :
88
99
self .weight_scale = self .dtype_amax / org_module .weight .data .abs ().max ()
100
+ self .weight_scale = _map_guadi2_scale (self .weight_scale )
89
101
self .weight_scale_inv = torch .reciprocal (self .weight_scale )
90
102
else :
91
103
self .weight_scale = None
@@ -233,9 +245,9 @@ def __init__(self, org_module, dtype) -> None:
233
245
dtype = torch .float32 ,
234
246
),
235
247
)
236
- self .scale_inv = torch .reciprocal (self .scale )
237
248
238
249
self .weight_scale = self .dtype_amax / org_module .weight .data .abs ().max ()
250
+ self .weight_scale = _map_guadi2_scale (self .weight_scale )
239
251
self .weight_scale_inv = torch .reciprocal (self .weight_scale )
240
252
self .weight .data .copy_ (
241
253
torch .ops .hpu .cast_to_fp8_v2 (org_module .weight .data , self .weight_scale , False , False , self .dtype )[0 ]
@@ -251,6 +263,7 @@ def forward(self, inp):
251
263
org_middle_shape = inp .shape [1 :- 1 ]
252
264
inp = inp .view ((- 1 , self .in_features ))
253
265
inp = torch .ops .hpu .cast_to_fp8_v2 (inp , self .scale , False , False , self .dtype )[0 ]
266
+ self .scale_inv = torch .reciprocal (self .scale )
254
267
out = torch .ops .hpu .fp8_gemm_v2 (
255
268
inp ,
256
269
False ,
@@ -283,26 +296,24 @@ def __init__(self, org_module, dtype) -> None:
283
296
self .dtype = dtype
284
297
self .dtype_amax = E4M3_AMAX if self .dtype == torch .float8_e4m3fn else E5M2_AMAX
285
298
self .out_dtype = torch .float32
286
- scale = org_module .scale if hasattr (org_module , "scale" ) else 1.0
287
299
scale1 = org_module .scale1 if hasattr (org_module , "scale1" ) else 1.0
300
+ scale2 = org_module .scale2 if hasattr (org_module , "scale2" ) else 1.0
288
301
self .register_buffer (
289
- "scale " ,
302
+ "scale1 " ,
290
303
torch .tensor (
291
- scale ,
304
+ scale1 ,
292
305
device = "hpu" ,
293
306
dtype = self .out_dtype ,
294
307
),
295
308
)
296
309
self .register_buffer (
297
- "scale1 " ,
310
+ "scale2 " ,
298
311
torch .tensor (
299
- scale1 ,
312
+ scale2 ,
300
313
device = "hpu" ,
301
314
dtype = self .out_dtype ,
302
315
),
303
316
)
304
- self .input1_scale_inv = torch .reciprocal (self .scale )
305
- self .input2_scale_inv = torch .reciprocal (self .scale1 )
306
317
307
318
def forward (self , input1 , input2 ):
308
319
dim1 = input1 .shape [- 1 ]
@@ -311,12 +322,14 @@ def forward(self, input1, input2):
311
322
312
323
if input1 .dtype not in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
313
324
self .out_dtype = input1 .dtype
314
- input1 = torch .ops .hpu .cast_to_fp8_v2 (input1 , self .scale , False , False , self .dtype )[0 ]
325
+ input1 = torch .ops .hpu .cast_to_fp8_v2 (input1 , self .scale1 , False , False , self .dtype )[0 ]
326
+ self .input1_scale_inv = torch .reciprocal (self .scale1 )
315
327
else :
316
328
self .input1_scale_inv = None
317
329
if input2 .dtype not in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
318
330
self .out_dtype = input2 .dtype
319
- input2 = torch .ops .hpu .cast_to_fp8_v2 (input2 , self .scale1 , False , False , self .dtype )[0 ]
331
+ input2 = torch .ops .hpu .cast_to_fp8_v2 (input2 , self .scale2 , False , False , self .dtype )[0 ]
332
+ self .input2_scale_inv = torch .reciprocal (self .scale2 )
320
333
else :
321
334
self .input2_scale_inv = None
322
335
out = torch .ops .hpu .fp8_gemm_v2 (
@@ -407,10 +420,10 @@ def __init__(self, org_module, dtype) -> None:
407
420
dtype = torch .float32 ,
408
421
),
409
422
)
410
- self .scale_inv = 1.0 / self .scale
411
423
# user configuration
412
424
# scale = HF_max /amax
413
425
self .weight_scale = self .dtype_amax / org_module .weight .data .abs ().max ()
426
+ self .weight_scale = _map_guadi2_scale (self .weight_scale )
414
427
self .weight_scale_inv = 1.0 / self .weight_scale
415
428
self .weight = torch .ops .hpu .cast_to_fp8_v2 (org_module .weight .data , self .weight_scale , False , False , self .dtype )[
416
429
0
@@ -432,6 +445,7 @@ def forward(self, inp):
432
445
assert inp .shape [- 1 ] == self .in_features , "GEMM not possible"
433
446
inputmat = inp .view ((- 1 , self .in_features ))
434
447
inputmat = torch .ops .hpu .cast_to_fp8_v2 (inputmat , self .scale , False , False , self .dtype )[0 ]
448
+ self .scale_inv = torch .reciprocal (self .scale )
435
449
out = torch .ops .hpu .fp8_gemm_v2 (
436
450
inputmat ,
437
451
False ,
@@ -487,10 +501,10 @@ def __init__(self, org_module, dtype) -> None:
487
501
dtype = torch .float32 ,
488
502
),
489
503
)
490
- self .scale_inv = 1.0 / self .scale
491
504
# user configuration
492
505
# scale = HF_max /amax
493
506
self .weight_scale = self .dtype_amax / org_module .weight .data .abs ().max ()
507
+ self .weight_scale = _map_guadi2_scale (self .weight_scale )
494
508
self .weight_scale_inv = 1.0 / self .weight_scale
495
509
self .weight = torch .ops .hpu .cast_to_fp8_v2 (org_module .weight .data , self .weight_scale , False , False , self .dtype )[
496
510
0
@@ -513,6 +527,7 @@ def forward(self, inp):
513
527
assert inp .shape [- 1 ] == self .in_features , "GEMM not possible"
514
528
inputmat = inp .view ((- 1 , self .in_features ))
515
529
inputmat = torch .ops .hpu .cast_to_fp8_v2 (inputmat , self .scale , False , False , self .dtype )[0 ]
530
+ self .scale_inv = torch .reciprocal (self .scale )
516
531
out = torch .ops .hpu .fp8_gemm_v2 (
517
532
inputmat ,
518
533
False ,
@@ -572,10 +587,10 @@ def __init__(self, org_module, dtype) -> None:
572
587
dtype = torch .float32 ,
573
588
),
574
589
)
575
- self .scale_inv = 1.0 / self .scale
576
590
# user configuration
577
591
# scale = HF_max /amax
578
592
self .weight_scale = self .dtype_amax / org_module .weight .data .abs ().max ()
593
+ self .weight_scale = _map_guadi2_scale (self .weight_scale )
579
594
self .weight_scale_inv = 1.0 / self .weight_scale
580
595
self .weight = torch .ops .hpu .cast_to_fp8_v2 (org_module .weight .data , self .weight_scale , False , False , self .dtype )[
581
596
0
@@ -608,6 +623,7 @@ def forward(self, inp):
608
623
input_shard = inp .shape [- 1 ] // self .world_size
609
624
inputmat = inp [:, :, self .rank * input_shard : (self .rank + 1 ) * input_shard ]
610
625
inputmat = torch .ops .hpu .cast_to_fp8_v2 (inputmat , self .scale , False , False , self .dtype )[0 ]
626
+ self .scale_inv = torch .reciprocal (self .scale )
611
627
out = torch .ops .hpu .fp8_gemm_v2 (
612
628
inputmat ,
613
629
False ,
0 commit comments