Skip to content

Commit 35cb79d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 483c219 commit 35cb79d

File tree

3 files changed

+46
-45
lines changed

3 files changed

+46
-45
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def tmp(_, inp, out):
581581
state_dict = torch.load(LWQ_WORKSPACE + f"/{self.get_full_layer_name(layer_name, block_idx)}.pt")
582582
Q = state_dict["weight"].data
583583
bias = state_dict["bias"] if "bias" in state_dict.keys() else None
584-
584+
585585
else:
586586
Q = sub_layers[layer_name].weight.data
587587
if weight_config_this_layer["act_order"]:
@@ -614,7 +614,7 @@ def tmp(_, inp, out):
614614

615615
if not self.use_layer_wise:
616616
bias = sub_layers[layer_name].bias
617-
617+
618618
new_module = WeightOnlyLinear(
619619
in_features,
620620
out_features,

neural_compressor/torch/algorithms/weight_only/modules.py

+39-38
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
# since the model classes inherit torch.nn.Module.
2020
import math
2121

22+
import numba
2223
import numpy as np
2324
import torch
2425
from torch.autograd import Function
2526
from torch.nn import functional as F
26-
import numba
2727

2828
from neural_compressor.torch.utils import accelerator, logger
2929

@@ -301,11 +301,11 @@ def unpack_tensor_with_torch(self, packed_tensor):
301301
unpacked_tensor[:, index].copy_(tmp.type(target_dtype))
302302
accelerator.synchronize()
303303
return unpacked_tensor
304-
304+
305305
@staticmethod
306306
@numba.jit(nopython=True, parallel=True)
307307
def pack_array_with_numba_b4_c32(
308-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
308+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
309309
) -> np.ndarray:
310310
for i in range(new_in_features):
311311
packed_array[:, i] = (
@@ -319,11 +319,11 @@ def pack_array_with_numba_b4_c32(
319319
| (raw_array[:, i * n_pack] & 0b1111)
320320
)
321321
return packed_array
322-
322+
323323
@staticmethod
324324
@numba.jit(nopython=True, parallel=True)
325325
def pack_array_with_numba_b4_c16(
326-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
326+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
327327
) -> np.ndarray:
328328
for i in range(new_in_features):
329329
packed_array[:, i] = (
@@ -333,23 +333,20 @@ def pack_array_with_numba_b4_c16(
333333
| (raw_array[:, i * n_pack] & 0b1111)
334334
)
335335
return packed_array
336-
336+
337337
@staticmethod
338338
@numba.jit(nopython=True, parallel=True)
339339
def pack_array_with_numba_b4_c8(
340-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
340+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
341341
) -> np.ndarray:
342342
for i in range(new_in_features):
343-
packed_array[:, i] = (
344-
((raw_array[:, i * n_pack + 1] & 0b1111) << 4)
345-
| (raw_array[:, i * n_pack] & 0b1111)
346-
)
343+
packed_array[:, i] = ((raw_array[:, i * n_pack + 1] & 0b1111) << 4) | (raw_array[:, i * n_pack] & 0b1111)
347344
return packed_array
348-
345+
349346
@staticmethod
350347
@numba.jit(nopython=True, parallel=True)
351348
def pack_array_with_numba_b4_c64(
352-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
349+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
353350
) -> np.ndarray:
354351
for i in range(new_in_features):
355352
packed_array[:, i] = (
@@ -372,11 +369,10 @@ def pack_array_with_numba_b4_c64(
372369
)
373370
return packed_array
374371

375-
376372
@staticmethod
377373
@numba.jit(nopython=True, parallel=True)
378374
def pack_array_with_numba_b8_c32(
379-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
375+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
380376
) -> np.ndarray:
381377
for i in range(new_in_features):
382378
packed_array[:, i] = (
@@ -386,11 +382,11 @@ def pack_array_with_numba_b8_c32(
386382
| (raw_array[:, i * n_pack] & 0b11111111)
387383
)
388384
return packed_array
389-
385+
390386
@staticmethod
391387
@numba.jit(nopython=True, parallel=True)
392388
def pack_array_with_numba_b8_c16(
393-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
389+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
394390
) -> np.ndarray:
395391
for i in range(new_in_features):
396392
packed_array[:, i] = (
@@ -400,20 +396,20 @@ def pack_array_with_numba_b8_c16(
400396
| (raw_array[:, i * n_pack] & 0b11111111)
401397
)
402398
return packed_array
403-
399+
404400
@staticmethod
405401
@numba.jit(nopython=True, parallel=True)
406402
def pack_array_with_numba_b8_c8(
407-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
403+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
408404
) -> np.ndarray:
409405
for i in range(new_in_features):
410-
packed_array[:, i] = (raw_array[:, i * n_pack] & 0b11111111)
406+
packed_array[:, i] = raw_array[:, i * n_pack] & 0b11111111
411407
return packed_array
412-
408+
413409
@staticmethod
414410
@numba.jit(nopython=True, parallel=True)
415411
def pack_array_with_numba_b8_c64(
416-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
412+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
417413
) -> np.ndarray:
418414
for i in range(new_in_features):
419415
packed_array[:, i] = (
@@ -427,11 +423,11 @@ def pack_array_with_numba_b8_c64(
427423
| (raw_array[:, i * n_pack] & 0b11111111)
428424
)
429425
return packed_array
430-
426+
431427
@staticmethod
432428
@numba.jit(nopython=True, parallel=True)
433429
def pack_array_with_numba_b2_c32(
434-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
430+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
435431
) -> np.ndarray:
436432
for i in range(new_in_features):
437433
packed_array[:, i] = (
@@ -457,7 +453,7 @@ def pack_array_with_numba_b2_c32(
457453
@staticmethod
458454
@numba.jit(nopython=True, parallel=True)
459455
def pack_array_with_numba_b2_c16(
460-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
456+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
461457
) -> np.ndarray:
462458
for i in range(new_in_features):
463459
packed_array[:, i] = (
@@ -471,11 +467,11 @@ def pack_array_with_numba_b2_c16(
471467
| (raw_array[:, i * n_pack] & 0b11)
472468
)
473469
return packed_array
474-
470+
475471
@staticmethod
476472
@numba.jit(nopython=True, parallel=True)
477473
def pack_array_with_numba_b2_c8(
478-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
474+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
479475
) -> np.ndarray:
480476
for i in range(new_in_features):
481477
packed_array[:, i] = (
@@ -485,11 +481,11 @@ def pack_array_with_numba_b2_c8(
485481
| (raw_array[:, i * n_pack] & 0b11)
486482
)
487483
return packed_array
488-
484+
489485
@staticmethod
490486
@numba.jit(nopython=True, parallel=True)
491487
def pack_array_with_numba_b2_c64(
492-
raw_array: np.ndarray, packed_array:np.ndarray, n_pack: int, new_in_features:int
488+
raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int
493489
) -> np.ndarray:
494490
for i in range(new_in_features):
495491
packed_array[:, i] = (
@@ -527,7 +523,7 @@ def pack_array_with_numba_b2_c64(
527523
| (raw_array[:, i * n_pack] & 0b11)
528524
)
529525
return packed_array
530-
526+
531527
def pack_array_with_numba(
532528
self, raw_array: np.ndarray, n_pack: int, bits: int, compress_bits: int, compression_dtype=np.int32
533529
) -> np.ndarray:
@@ -547,17 +543,18 @@ def pack_array_with_numba(
547543
new_in_features = (in_features + n_pack - 1) // n_pack
548544
packed_array = np.zeros((out_features, new_in_features), dtype=compression_dtype)
549545
raw_array = raw_array.astype(compression_dtype)
550-
546+
551547
pack_method_name = f"pack_array_with_numba_b{bits}_c{compress_bits}"
552548
pack_method = getattr(self, pack_method_name)
553549
return pack_method(raw_array, packed_array, n_pack, new_in_features)
554-
550+
555551
@staticmethod
556552
@numba.jit(nopython=True)
557553
def pack_array_with_numba_yi(
558554
raw_tensor: np.ndarray, n_pack: int, bits: int, compression_dtype=np.int32
559555
) -> np.ndarray:
560556
"""Packs the input tensor by combining elements into a specified bit-width format using NumPy.
557+
561558
Args:
562559
raw_tensor (np.ndarray): The tensor to be packed. Shape: [out_features, in_features] or [1, in_features].
563560
n_pack (int): The number of elements to be packed together.
@@ -575,7 +572,7 @@ def pack_array_with_numba_yi(
575572
for i in range(new_in_features):
576573
packed_tensor[:, i] = (
577574
(raw_tensor[:, i * n_pack + 7] << 28)
578-
| (raw_tensor[:, i * n_pack + 6] << 24)
575+
| (raw_tensor[:, i * n_pack + 6] << 24)
579576
| (raw_tensor[:, i * n_pack + 5] << 20)
580577
| (raw_tensor[:, i * n_pack + 4] << 16)
581578
| (raw_tensor[:, i * n_pack + 3] << 12)
@@ -585,25 +582,29 @@ def pack_array_with_numba_yi(
585582
)
586583

587584
return packed_tensor
588-
585+
589586
def pack_tensor_with_reshape(self, raw_tensor):
590587
raw_array = raw_tensor.cpu().numpy()
591588
target_len = np.ceil(raw_array.shape[1] / self.n_pack).astype(int)
592589
target_dtype = torch.tensor(0, dtype=self.compression_dtype).numpy().dtype
593590
reshaped = raw_array.reshape(-1, self.n_pack)
594591
packed_array = np.zeros(reshaped.shape[0], dtype=target_dtype)
595592
for i in range(self.n_pack):
596-
packed_array |= (reshaped[:, i].astype(target_dtype) << (self.bits * i))
597-
598-
packed_tensor = torch.from_numpy(packed_array.reshape((raw_array.shape[0], target_len))).to(device=raw_tensor.device)
593+
packed_array |= reshaped[:, i].astype(target_dtype) << (self.bits * i)
594+
595+
packed_tensor = torch.from_numpy(packed_array.reshape((raw_array.shape[0], target_len))).to(
596+
device=raw_tensor.device
597+
)
599598
return packed_tensor
600599

601600
def pack_tensor_with_numpy(self, raw_tensor):
602601
if self.bits not in [2, 4, 8]:
603602
return self.pack_tensor_with_reshape(raw_tensor)
604603
compression_dtype = torch.tensor(0, dtype=self.compression_dtype).numpy().dtype
605604
# packed_array = self.pack_array_with_numba_yi(raw_tensor.cpu().numpy(), self.n_pack, self.bits, compression_dtype)
606-
packed_array = self.pack_array_with_numba(raw_tensor.cpu().numpy(), self.n_pack, self.bits, self.compress_bits, compression_dtype)
605+
packed_array = self.pack_array_with_numba(
606+
raw_tensor.cpu().numpy(), self.n_pack, self.bits, self.compress_bits, compression_dtype
607+
)
607608
return torch.from_numpy(packed_array).to(device=raw_tensor.device)
608609

609610
def unpack_tensor_with_numpy(self, packed_tensor):

neural_compressor/torch/algorithms/weight_only/rtn.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def convert(
124124
"double_quant_group_size": kwargs.get("double_quant_group_size", 256),
125125
}
126126
use_optimum_format = kwargs.get("use_optimum_format", True)
127-
127+
128128
if use_layer_wise:
129129
from neural_compressor.common.utils import DEFAULT_WORKSPACE
130130
from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module, register_weight_hooks
@@ -135,10 +135,10 @@ def convert(
135135
model_path = get_path(model_path)
136136

137137
register_weight_hooks(model, model_path, device=device, clean_weight=True)
138-
138+
139139
for name, m in model.named_modules():
140-
141-
if not isinstance(m, supported_layers):
140+
141+
if not isinstance(m, supported_layers):
142142
continue
143143
if name in weight_config: # pragma: no cover
144144
# initialize op configuration
@@ -185,7 +185,7 @@ def convert(
185185
continue
186186
logger.debug(f"RTN quantized module:{name, m}")
187187
logger.debug(log_msg)
188-
188+
189189
if use_layer_wise:
190190
load_module(model, name, model_path, device=device)
191191

0 commit comments

Comments
 (0)