-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QAT Lora 5/N] Fixes for loading/saving compression checkpoint #3341
Open
ljaljushkin
wants to merge
6
commits into
openvinotoolkit:develop
Choose a base branch
from
ljaljushkin:nl/fq_lora_pr_ckpt
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+130
−16
Open
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
69df6d8
Fixes for loading/saving compression checkpoint
ljaljushkin a15ca14
typo
ljaljushkin 4042607
fixed lora rank test
ljaljushkin 4921c2c
Merge remote-tracking branch 'origin/develop' into nl/fq_lora_pr_ckpt
ljaljushkin 9360c0f
removed unused line
ljaljushkin 2c071ed
added todo to fix wa for new tracing
ljaljushkin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -768,6 +768,9 @@ def signed(self, signed: bool): | |
self.set_levels() | ||
|
||
def quantize(self, x, execute_traced_op_as_identity: bool = False): | ||
with DisableTorchFunction(): | ||
# in multi-device case after loading nncf checkpoint, quantizers have a different device. | ||
self.to(x.device) | ||
return symmetric_quantize( | ||
x, self.levels, self.level_low, self.level_high, self.scale, self.eps, skip=execute_traced_op_as_identity | ||
) | ||
|
@@ -955,6 +958,9 @@ def set_levels(self): | |
self.level_low, self.level_high = calculate_asymmetric_level_ranges(self.num_bits - scaled_num_bits) | ||
|
||
def quantize(self, x, execute_traced_op_as_identity: bool = False): | ||
with DisableTorchFunction(): | ||
# in multi-device case after loading nncf checkpoint, quantizers have a different device. | ||
self.to(x.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only as WA, need to be reworked before release, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added TODO |
||
return asymmetric_quantize( | ||
x, | ||
self.levels, | ||
|
@@ -1067,9 +1073,14 @@ class LoraMixin: | |
|
||
def init_lora(self, lspec: PTLoraSpec): | ||
self._lspec = lspec | ||
default_lora_dtype = torch.bfloat16 | ||
out_features, in_features = lspec.orig_weight_shape | ||
self.lora_A = torch.nn.Parameter(torch.ones((lspec.lora_rank, in_features), dtype=torch.bfloat16)) | ||
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, lspec.lora_rank), dtype=torch.bfloat16)) | ||
rank = lspec.lora_rank | ||
if rank > out_features or rank > in_features: | ||
msg = f"Specified LoRA rank={rank} cannot exceed any dimension of the weight tensor" | ||
raise nncf.ValidationError(msg) | ||
self.lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype)) | ||
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype)) | ||
|
||
def enable_gradients(self): | ||
self.lora_A.requires_grad = True | ||
|
@@ -1097,6 +1108,9 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec): | |
self.init_lora(lspec) | ||
|
||
def quantize(self, x: torch.Tensor, execute_traced_op_as_identity: bool = False): | ||
with DisableTorchFunction(): | ||
# in multi-device case after loading nncf checkpoint, quantizers have a different device. | ||
self.to(x.device) | ||
daniil-lyakhov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return asymmetric_quantize_lora( | ||
x, | ||
self._lspec.weight_shape, | ||
|
@@ -1142,6 +1156,9 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec): | |
self.init_lora(lspec) | ||
|
||
def quantize(self, x, execute_traced_op_as_identity: bool = False): | ||
with DisableTorchFunction(): | ||
# in multi-device case after loading nncf checkpoint, quantizers have a different device. | ||
self.to(x.device) | ||
return symmetric_quantize_lora( | ||
x, | ||
self._lspec.weight_shape, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fore new tracing it's not works, no need to add this hack
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently, it's needed to pass graph test for torch2