Skip to content

Commit

Permalink
changes tf implemtnation to use experimental numpy arrays due to inco…
Browse files Browse the repository at this point in the history
…mpatiability with numpy>2
  • Loading branch information
drewschaub committed Dec 25, 2024
1 parent c81f73a commit a323ce7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 19 deletions.
3 changes: 3 additions & 0 deletions protein_design_tools/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pdb import fetch_pdb, read_pdb, write_pdb

__all__ = ["fetch_pdb", "read_pdb", "write_pdb"]
21 changes: 21 additions & 0 deletions protein_design_tools/io/pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,24 @@ def read_pdb(
residue.atoms.append(atom)

return structure

def write_pdb(structure: ProteinStructure, file_path: str) -> None:
"""
Write a ProteinStructure object to a PDB file.
Parameters
----------
structure : ProteinStructure
The protein structure to write.
file_path : str
The path to write the PDB file.
"""
content = ""
with open(file_path, "w") as f:
for chain in structure.chains:
for residue in chain.residues:
for atom in residue.atoms:
content += f"ATOM {atom.atom_id:5} {atom.name:<4} {residue.name:<3} {chain.name}{residue.res_seq:4}{residue.i_code:<1} {atom.x:8.3f}{atom.y:8.3f}{atom.z:8.3f}{atom.occupancy:6.2f}{atom.temp_factor:6.2f} {atom.element:2}{atom.charge:2}\n"
content += "TER\n"
content += "END\n"
f.write(content)
14 changes: 7 additions & 7 deletions protein_design_tools/metrics/rmsd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import torch
import jax.numpy as jnp
from jax import jit
Expand Down Expand Up @@ -66,21 +66,21 @@ def compute_rmsd_pytorch(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor:
return torch.sqrt(torch.mean(torch.sum((P - Q) ** 2, dim=1)))


def compute_rmsd_tensorflow(P: tf.Tensor, Q: tf.Tensor) -> tf.Tensor:
def compute_rmsd_tensorflow(P: tnp.ndarray, Q: tnp.ndarray) -> tnp.ndarray:
"""
Compute RMSD between two NxD TensorFlow tensors.
Compute RMSD between two NxD TensorFlow tensors using tf.experimental.numpy.
Parameters
----------
P : tf.Tensor
P : tnp.ndarray
Mobile points, shape (N, D)
Q : tf.Tensor
Q : tnp.ndarray
Target points, shape (N, D)
Returns
-------
float
tnp.ndarray
RMSD between P and Q
"""
assert P.shape == Q.shape
return tf.sqrt(tf.reduce_mean(tf.reduce_sum((P - Q) ** 2, axis=1)))
return tnp.sqrt(tnp.mean(tnp.sum((P - Q) ** 2, axis=1)))
24 changes: 12 additions & 12 deletions protein_design_tools/metrics/tmscore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import torch
import jax.numpy as jnp
from jax import jit
Expand Down Expand Up @@ -81,26 +81,26 @@ def compute_tm_score_pytorch(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor:
return tm_score


def compute_tm_score_tensorflow(P: tf.Tensor, Q: tf.Tensor) -> tf.Tensor:
def compute_tm_score_tensorflow(P: tnp.ndarray, Q: tnp.ndarray) -> tnp.ndarray:
"""
Compute TM-score between two NxD TensorFlow tensors.
Compute TM-score between two NxD TensorFlow tensors using tf.experimental.numpy.
Parameters
----------
P : tf.Tensor
P : tnp.ndarray
Mobile points, shape (N, D)
Q : tf.Tensor
Q : tnp.ndarray
Target points, shape (N, D)
Returns
-------
tf.Tensor
tnp.ndarray
TM-score between P and Q
"""
L_ref = tf.cast(tf.shape(P)[0], P.dtype)
d0 = 1.24 * tf.pow(L_ref - 15.0, 1 / 3) - 1.8
d0 = tf.maximum(d0, 1.0) # Ensure d0 is positive
distances = tf.norm(P - Q, axis=1)
tm_scores = 1.0 / (1.0 + tf.square(distances / d0))
tm_score = tf.reduce_sum(tm_scores) / L_ref
L_ref = tnp.array(P.shape[0], dtype=P.dtype)
d0 = 1.24 * tnp.power(L_ref - 15.0, 1 / 3) - 1.8
d0 = tnp.maximum(d0, 1.0) # Ensure d0 is positive
distances = tnp.linalg.norm(P - Q, axis=1)
tm_scores = 1.0 / (1.0 + (distances / d0) ** 2)
tm_score = tnp.sum(tm_scores) / L_ref
return tm_score

0 comments on commit a323ce7

Please sign in to comment.