Skip to content

Commit c500822

Browse files
Add weighted progress tracking for weight compression (#2892)
### Changes During nncf weight compression, `rich` progress bar is used to display the progress. In this PR, progress bar is changed to be weighted according to model weights. With these changes, each weight contributes proportional amount of percent to the progress bar. Iteration number was removed from weight compression progress bar to avoid confusion between different speeds in percent and iteration coordinates. For example now a single weight might contribute 5-10% to the whole progress. ### Reason for changes The time it takes to compress a weight is roughly proportional to its size, so incrementing the progress by 1 for each weight is not ideal. Especially after #2803 when weight sorting was added. Now, the largest weights come first and the smallest ones are at the end of the compression. This leads to misleading time estimation when progress contribution from every weight is equal. Weights sizes for tinyllama-1.1b for reference: ![weight_size_hist](https://github.com/user-attachments/assets/30ba1e1b-0fc5-4d6b-84db-948362672bf2) ![weight_size_cumsum_hist](https://github.com/user-attachments/assets/b00e79e8-5000-44a4-97a5-4102c9aed0ae)
1 parent 1104f1b commit c500822

File tree

2 files changed

+90
-10
lines changed

2 files changed

+90
-10
lines changed

nncf/common/logging/track_progress.py

+88-9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from rich.progress import ProgressColumn
1919
from rich.progress import ProgressType
2020
from rich.progress import Task
21+
from rich.progress import TaskID
2122
from rich.progress import TaskProgressColumn
2223
from rich.progress import TextColumn
2324
from rich.progress import TimeElapsedColumn
@@ -59,6 +60,65 @@ def render(self, task: "Task") -> Text:
5960
return Text(text._text[0], style=INTEL_BLUE_COLOR)
6061

6162

63+
class WeightedProgress(Progress):
64+
"""
65+
A class to perform a weighted progress tracking.
66+
"""
67+
68+
def update(self, task_id: TaskID, **kwargs) -> None:
69+
task = self._tasks[task_id]
70+
71+
advance = kwargs.get("advance", None)
72+
if advance is not None:
73+
kwargs["advance"] = self.weighted_advance(task, advance)
74+
75+
completed = kwargs.get("completed", None)
76+
if completed is not None:
77+
kwargs["completed"] = self.get_weighted_completed(task, completed)
78+
79+
super().update(task_id, **kwargs)
80+
81+
def advance(self, task_id: TaskID, advance: float = 1) -> None:
82+
if advance is not None:
83+
task = self._tasks[task_id]
84+
advance = self.weighted_advance(task, advance)
85+
super().advance(task_id, advance)
86+
87+
def reset(self, task_id: TaskID, **kwargs) -> None:
88+
task = self._tasks[task_id]
89+
90+
completed = kwargs.get("completed", None)
91+
if completed is not None:
92+
kwargs["completed"] = self.get_weighted_completed(task, completed)
93+
94+
super().reset(task_id, **kwargs)
95+
96+
if completed == 0:
97+
task.fields["completed_steps"] = 0
98+
99+
@staticmethod
100+
def weighted_advance(task: Task, advance: float) -> float:
101+
"""
102+
Perform weighted advancement based on an integer step value.
103+
"""
104+
if advance % 1 != 0:
105+
raise Exception(f"Unexpected `advance` value: {advance}.")
106+
advance = int(advance)
107+
current_step = task.fields["completed_steps"]
108+
weighted_advance = sum(task.fields["weights"][current_step : current_step + advance])
109+
task.fields["completed_steps"] = current_step + advance
110+
return weighted_advance
111+
112+
@staticmethod
113+
def get_weighted_completed(task: Task, completed: float) -> float:
114+
"""
115+
Get weighted `completed` corresponding to an integer `completed` field.
116+
"""
117+
if completed % 1 != 0:
118+
raise Exception(f"Unexpected `completed` value: {completed}.")
119+
return sum(task.fields["weights"][: int(completed)])
120+
121+
62122
class track:
63123
def __init__(
64124
self,
@@ -77,6 +137,7 @@ def __init__(
77137
update_period: float = 0.1,
78138
disable: bool = False,
79139
show_speed: bool = True,
140+
weights: Optional[List[float]] = None,
80141
):
81142
"""
82143
Track progress by iterating over a sequence.
@@ -98,11 +159,14 @@ def __init__(
98159
:param update_period: Minimum time (in seconds) between calls to update(). Defaults to 0.1.
99160
:param disable: Disable display of progress.
100161
:param show_speed: Show speed if the total isn't known. Defaults to True.
162+
:param weights: List of progress weights for each sequence element. Weights should be proportional to the time
163+
it takes to process sequence elements. Useful when processing time is strongly non-uniform.
101164
:return: An iterable of the values in the sequence.
102165
"""
103166

104167
self.sequence = sequence
105-
self.total = total
168+
self.weights = weights
169+
self.total = sum(self.weights) if self.weights is not None else total
106170
self.description = description
107171
self.update_period = update_period
108172
self.task = None
@@ -120,7 +184,13 @@ def __init__(
120184
bar_width=None,
121185
),
122186
TaskProgressColumn(show_speed=show_speed),
123-
IterationsColumn(),
187+
)
188+
)
189+
# Do not add iterations column for weighted tracking because steps will be in weighted coordinates
190+
if self.weights is None:
191+
self.columns.append(IterationsColumn())
192+
self.columns.extend(
193+
(
124194
SeparatorColumn(),
125195
TimeElapsedColumnWithStyle(),
126196
SeparatorColumn(disable_if_no_total=True), # disable because time remaining will be empty
@@ -130,7 +200,8 @@ def __init__(
130200

131201
disable = disable or (hasattr(sequence, "__len__") and len(sequence) == 0)
132202

133-
self.progress = Progress(
203+
progress_cls = Progress if weights is None else WeightedProgress
204+
self.progress = progress_cls(
134205
*self.columns,
135206
auto_refresh=auto_refresh,
136207
console=console,
@@ -141,16 +212,24 @@ def __init__(
141212
)
142213

143214
def __iter__(self) -> Iterable[ProgressType]:
144-
with self.progress:
215+
with self:
145216
yield from self.progress.track(
146-
self.sequence, total=self.total, description=self.description, update_period=self.update_period
217+
self.sequence,
218+
total=self.total,
219+
task_id=self.task,
220+
description=self.description,
221+
update_period=self.update_period,
147222
)
148223

149224
def __enter__(self):
150-
self.progress.start()
151-
self.task = self.progress.add_task(self.description, total=self.total)
152-
return self
225+
kwargs = {}
226+
if self.weights is not None:
227+
kwargs["weights"] = self.weights
228+
kwargs["completed_steps"] = 0
229+
self.task = self.progress.add_task(self.description, total=self.total, **kwargs)
230+
return self.progress.__enter__()
153231

154232
def __exit__(self, *args):
233+
self.progress.__exit__(*args)
234+
self.progress.remove_task(self.task)
155235
self.task = None
156-
self.progress.stop()

nncf/quantization/algorithms/weight_compression/algorithm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -405,12 +405,13 @@ def apply(
405405

406406
# Sort weight params to start compression with the bigger constants. This lowers peak memory footprint.
407407
all_weight_params = sorted(all_weight_params, key=lambda wp: wp.num_weights, reverse=True)
408+
all_weight_sizes = [wp.num_weights for wp in all_weight_params]
408409

409410
# Compress model using weight compression parameters
410411
transformed_model = self._backend_entity.transform_model(
411412
model,
412413
graph,
413-
track(all_weight_params, description="Applying Weight Compression"),
414+
track(all_weight_params, description="Applying Weight Compression", weights=all_weight_sizes),
414415
scales,
415416
zero_points,
416417
)

0 commit comments

Comments
 (0)