-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhuffman_coding.py
126 lines (110 loc) · 4.82 KB
/
huffman_coding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
"""huffman_coding.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1wusw3AIvLH6nC030QMWi5GlJdq9pnSgq
"""
import numpy as np
from sklearn.cluster import KMeans
from pruned_layers import *
import torch.nn as nn
import heapq
import graphviz
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def _huffman_coding_per_layer(weight, centers):
"""
Huffman coding at each layer
:param weight: weight parameter of the current layer.
:param centers: KMeans centroids in the quantization codebook of the current weight layer.
:return:
'encodings': Encoding map mapping each weight parameter to its Huffman coding.
'frequency': Frequency map mapping each weight parameter to the total number of its appearance.
'encodings' should be in this format:
{"0.24315": '0', "-0.2145": "100", "1.1234e-5": "101", ...
}
'frequency' should be in this format:
{"0.25235": 100, "-0.2145": 42, "1.1234e-5": 36, ...
}
'encodings' and 'frequency' does not need to be ordered in any way.
"""
frequency = {}
encodings = {}
# Step 1: Create a frequency map based on the weight values
for w in weight.flatten():
key = str(w)
if key in frequency:
frequency[key] += 1
else:
frequency[key] = 1
# Step 2: Build Huffman encodings using a priority queue (heap)
heap = [[weight, [enc, ""]] for enc, weight in frequency.items()]
heapq.heapify(heap)
while len(heap) > 1:
lo = heapq.heappop(heap)
hi = heapq.heappop(heap)
for pair in lo[1:]:
pair[1] = '0' + pair[1]
for pair in hi[1:]:
pair[1] = '1' + pair[1]
heapq.heappush(heap, [lo[0] + hi[0]] + lo[1:] + hi[1:])
huffman_tree = heap[0]
# Step 3: Extract the Huffman encodings
encodings = dict(huffman_tree[1:])
return encodings, frequency
def compute_average_bits(encodings, frequency):
"""
Compute the average storage bits of the current layer after Huffman Coding.
:param 'encodings': Encoding map mapping each weight parameter to its Huffman coding.
:param 'frequency': Frequency map mapping each weight parameter to the total number of its appearance.
'encodings' should be in this format:
{"0.24315": '0', "-0.2145": "100", "1.1234e-5": "101", ...
}
'frequency' should be in this format:
{"0.25235": 100, "-0.2145": 42, "1.1234e-5": 36, ...
}
'encodings' and 'frequency' does not need to be ordered in any way.
:return (float) a floating value represents the average bits.
"""
total = 0
total_bits = 0
for key in frequency.keys():
total += frequency[key]
total_bits += frequency[key] * len(encodings[key])
return total_bits / total
def huffman_coding(net, centers):
"""
Apply huffman coding on a 'quantized' model to save further computation cost.
:param net: a 'nn.Module' network object.
:param centers: KMeans centroids in the quantization codebook for Huffman coding.
:return: frequency map and encoding map of the whole 'net' object.
"""
assert isinstance(net, nn.Module)
layer_ind = 0
freq_map = []
encodings_map = []
for n, m in net.named_modules():
if isinstance(m, PrunedConv):
weight = m.conv.weight.data.cpu().numpy()
center = centers[layer_ind]
orginal_avg_bits = round(np.log2(len(center)))
print("Original storage for each parameter: %.4f bits" %orginal_avg_bits)
encodings, frequency = _huffman_coding_per_layer(weight, center)
freq_map.append(frequency)
encodings_map.append(encodings)
huffman_avg_bits = compute_average_bits(encodings, frequency)
print("Average storage for each parameter after Huffman Coding: %.4f bits" %huffman_avg_bits)
layer_ind += 1
print("Complete %d layers for Huffman Coding..." %layer_ind)
elif isinstance(m, PruneLinear):
weight = m.linear.weight.data.cpu().numpy()
center = centers[layer_ind]
orginal_avg_bits = round(np.log2(len(center)))
print("Original storage for each parameter: %.4f bits" %orginal_avg_bits)
encodings, frequency = _huffman_coding_per_layer(weight, center)
freq_map.append(frequency)
encodings_map.append(encodings)
huffman_avg_bits = compute_average_bits(encodings, frequency)
print("Average storage for each parameter after Huffman Coding: %.4f bits" %huffman_avg_bits)
layer_ind += 1
print("Complete %d layers for Huffman Coding..." %layer_ind)
return freq_map, encodings_map