|
9 | 9 | # See the License for the specific language governing permissions and
|
10 | 10 | # limitations under the License.
|
11 | 11 |
|
| 12 | +import contextlib |
12 | 13 | from collections import defaultdict
|
13 | 14 | from functools import partial
|
14 | 15 | from typing import Callable, Dict, List, Optional, Tuple
|
|
20 | 21 | from nncf.common.graph.model_transformer import ModelTransformer
|
21 | 22 | from nncf.common.graph.transformations.commands import TargetType
|
22 | 23 | from nncf.common.graph.transformations.commands import TransformationPriority
|
| 24 | +from nncf.errors import InternalError |
23 | 25 | from nncf.torch.extractor import extract_model
|
24 | 26 | from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
|
25 | 27 | from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
|
@@ -80,15 +82,23 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor
|
80 | 82 | for command in transformation_layout.transformations:
|
81 | 83 | compression_module = command.fn
|
82 | 84 | if isinstance(compression_module, nn.Module):
|
83 |
| - target_point = command.target_points[0] |
84 |
| - node_with_weight = graph.get_node_by_name(target_point.target_node_name) |
85 |
| - weight_node = get_const_node(node_with_weight, target_point.input_port_id, graph) |
86 |
| - if weight_node is None: |
87 |
| - weight_node = node_with_weight # Decompression in DQ compression format is applied to const. |
88 |
| - const_data = get_const_data(weight_node, model) |
89 |
| - # Compression module and the corresponding layer may have a different device in multi-device setup |
90 |
| - # (e.g. when HF model was loaded with device_map='auto'). Need to align devices. |
91 |
| - compression_module.to(const_data.device) |
| 85 | + points = [command.target_point] |
| 86 | + if hasattr(command, "target_points"): |
| 87 | + points = command.target_points |
| 88 | + for target_point in points: |
| 89 | + target_node = graph.get_node_by_name(target_point.target_node_name) |
| 90 | + weight_node = None |
| 91 | + with contextlib.suppress(InternalError): |
| 92 | + weight_node = get_const_node(target_node, target_point.input_port_id, graph) |
| 93 | + if weight_node is None: |
| 94 | + weight_node = target_node # Decompression in DQ compression format is applied to const |
| 95 | + const_data = None |
| 96 | + with contextlib.suppress(AttributeError): |
| 97 | + const_data = get_const_data(weight_node, model) |
| 98 | + if const_data is not None: |
| 99 | + # Compression module and the corresponding layer may have a different device in multi-device |
| 100 | + # setup (e.g. when HF model was loaded with device_map='auto'). Need to align devices. |
| 101 | + compression_module.to(const_data.device) |
92 | 102 | model.nncf.rebuild_graph()
|
93 | 103 |
|
94 | 104 | return model
|
|
0 commit comments