|
38 | 38 | from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
|
39 | 39 | from nncf.torch.graph.transformations.commands import PTTargetPoint
|
40 | 40 | from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph
|
| 41 | +from nncf.torch.model_graph_manager import get_const_data |
41 | 42 | from nncf.torch.model_graph_manager import get_const_node
|
42 | 43 | from nncf.torch.model_graph_manager import get_module_by_name
|
43 | 44 | from nncf.torch.model_graph_manager import split_const_name
|
@@ -173,10 +174,8 @@ def get_weight(
|
173 | 174 | ) -> Tensor:
|
174 | 175 | weight_node = get_const_node(node_with_weight, weight_port_id, graph)
|
175 | 176 | weight_name = weight_node.layer_attributes.name
|
176 |
| - module_name, weight_attr_name = split_const_name(weight_name) |
177 |
| - module = get_module_by_name(module_name, model) |
178 |
| - weight = getattr(module, weight_attr_name) |
179 |
| - if weight is None or not isinstance(weight, torch.nn.Parameter): |
| 177 | + weight = get_const_data(weight_node, model) |
| 178 | + if weight is None: |
180 | 179 | raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.")
|
181 | 180 |
|
182 | 181 | return Tensor(weight)
|
@@ -222,10 +221,8 @@ def transform_model(
|
222 | 221 |
|
223 | 222 | weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
|
224 | 223 | weight_name = weight_node.layer_attributes.name
|
225 |
| - module_name, weight_attr_name = split_const_name(weight_name) |
226 |
| - module = get_module_by_name(module_name, model) |
227 |
| - weight = getattr(module, weight_attr_name) |
228 |
| - if weight is None or not isinstance(weight, torch.nn.Parameter): |
| 224 | + weight = get_const_data(weight_node, model) |
| 225 | + if weight is None: |
229 | 226 | raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.")
|
230 | 227 |
|
231 | 228 | # calculates compressed weights and decompression parameters
|
@@ -264,7 +261,14 @@ def transform_model(
|
264 | 261 | packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data)
|
265 | 262 |
|
266 | 263 | # sets compressed tensor
|
| 264 | + # TODO:(AlexanderDokuchaev): update set_const_data |
267 | 265 | compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
|
| 266 | + module_name, weight_attr_name = split_const_name(weight_name) |
| 267 | + module = get_module_by_name(module_name, model) |
| 268 | + weight = getattr(module, weight_attr_name) |
| 269 | + if not isinstance(weight, torch.nn.Parameter): |
| 270 | + raise nncf.InternalError(f"Weight is not a torch.nn.Parameter in the model by name {weight_name}.") |
| 271 | + |
268 | 272 | setattr(module, weight_attr_name, compressed_parameter)
|
269 | 273 |
|
270 | 274 | consumer_nodes = graph.get_next_nodes(weight_node)
|
|
0 commit comments