1
1
import typing as T
2
- from pathlib import Path
3
2
import warnings
3
+ from pathlib import Path
4
4
5
5
import numpy as np
6
6
import torch
9
9
OFProtein ,
10
10
atom37_to_frames ,
11
11
get_backbone_frames ,
12
+ make_atom14_masks ,
13
+ make_atom14_positions ,
12
14
make_pdb_features ,
13
15
protein_from_pdb_string ,
14
- make_atom14_masks ,
15
- make_atom14_positions
16
16
)
17
17
from lobster .transforms import trim_or_pad
18
18
@@ -47,7 +47,9 @@ def _openfold_features_from_pdb(
47
47
48
48
return protein_features
49
49
50
- def _process_structure_features (self , features : T .Dict [str , np .ndarray ], seq_len : T .Optional [int ] = None ):
50
+ def _process_structure_features (
51
+ self , features : T .Dict [str , np .ndarray ], seq_len : T .Optional [int ] = None
52
+ ):
51
53
"""Process feature dtypes and pad to max length for a single sequence."""
52
54
features_requiring_padding = [
53
55
"aatype" ,
@@ -69,7 +71,7 @@ def _process_structure_features(self, features: T.Dict[str, np.ndarray], seq_len
69
71
features [k ] = torch .from_numpy (v )
70
72
71
73
# Trim or pad to a fixed length for all per-specific features
72
- if (k in features_requiring_padding ) and (not seq_len is None ):
74
+ if (k in features_requiring_padding ) and (seq_len is not None ):
73
75
features [k ] = trim_or_pad (features [k ], seq_len )
74
76
75
77
# 'seq_length' is a tensor with shape equal to the aatype array length,
@@ -83,8 +85,8 @@ def _process_structure_features(self, features: T.Dict[str, np.ndarray], seq_len
83
85
features ["mask" ] = mask .long ()
84
86
85
87
# Make sure input sequence string is also trimmed
86
- if not seq_len is None :
87
- features [' sequence' ] = features [' sequence' ][:seq_len ]
88
+ if seq_len is not None :
89
+ features [" sequence" ] = features [" sequence" ][:seq_len ]
88
90
89
91
features ["aatype" ] = features ["aatype" ].argmax (dim = - 1 )
90
92
return features
@@ -93,11 +95,11 @@ def __call__(self, pdb_str: str, seq_len: int, pdb_id: T.Optional[str] = None):
93
95
with warnings .catch_warnings ():
94
96
warnings .simplefilter ("ignore" )
95
97
features = self ._openfold_features_from_pdb (pdb_str , pdb_id )
96
-
98
+
97
99
features = self ._process_structure_features (features , seq_len )
98
100
features = atom37_to_frames (features )
99
101
features = get_backbone_frames (features )
100
102
features = make_atom14_masks (features )
101
103
features = make_atom14_positions (features )
102
104
103
- return features
105
+ return features
0 commit comments