@@ -295,58 +295,3 @@ def compare_num_quantized_nodes_per_model(
295
295
expected_num_weight_nodes .update ({k : 0 for k in set (num_weight_nodes ) - set (expected_num_weight_nodes )})
296
296
actual_num_weights_per_model .append (num_weight_nodes )
297
297
test_case .assertEqual (expected_num_weight_nodes_per_model , actual_num_weights_per_model )
298
-
299
-
300
- @contextmanager
301
- def mock_torch_cuda_is_available (to_patch ):
302
- original_is_available = torch .cuda .is_available
303
- if to_patch :
304
- torch .cuda .is_available = lambda : True
305
- try :
306
- yield
307
- finally :
308
- if to_patch :
309
- torch .cuda .is_available = original_is_available
310
-
311
-
312
- @contextmanager
313
- def patch_awq_for_inference (to_patch ):
314
- orig_gemm_forward = None
315
- if to_patch :
316
- # patch GEMM module to allow inference without CUDA GPU
317
- from awq .modules .linear .gemm import WQLinearMMFunction
318
- from awq .utils .packing_utils import dequantize_gemm
319
-
320
- def new_forward (
321
- ctx ,
322
- x ,
323
- qweight ,
324
- qzeros ,
325
- scales ,
326
- w_bit = 4 ,
327
- group_size = 128 ,
328
- bias = None ,
329
- out_features = 0 ,
330
- ):
331
- ctx .out_features = out_features
332
-
333
- out_shape = x .shape [:- 1 ] + (out_features ,)
334
- x = x .to (torch .float16 )
335
-
336
- out = dequantize_gemm (qweight , qzeros , scales , w_bit , group_size )
337
- out = torch .matmul (x , out )
338
-
339
- out = out + bias if bias is not None else out
340
- out = out .reshape (out_shape )
341
-
342
- if len (out .shape ) == 2 :
343
- out = out .unsqueeze (0 )
344
- return out
345
-
346
- orig_gemm_forward = WQLinearMMFunction .forward
347
- WQLinearMMFunction .forward = new_forward
348
- try :
349
- yield
350
- finally :
351
- if orig_gemm_forward is not None :
352
- WQLinearMMFunction .forward = orig_gemm_forward
0 commit comments