Skip to content

Commit

Permalink
added fetch_pdb
Browse files Browse the repository at this point in the history
  • Loading branch information
drewschaub committed Dec 25, 2024
1 parent a323ce7 commit 38fd206
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 39 deletions.
8 changes: 0 additions & 8 deletions protein_design_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +0,0 @@
# protein_design_tools/__init__.py

__version__ = '0.1.29'

from .core import Atom, Chain, Residue, ProteinStructure
from .io import read_pdb, write_pdb
from .metrics import compute_rmsd, compute_tmscore
from .utils import get_coordinates, get_masses
8 changes: 0 additions & 8 deletions protein_design_tools/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +0,0 @@
# protein_design_tools/core/__init__.py

from .atom import Atom
from .residue import Residue
from .chain import Chain
from .protein_structure import ProteinStructure

__all__ = ["Atom", "Residue", "Chain", "ProteinStructure"]
3 changes: 0 additions & 3 deletions protein_design_tools/io/__init__.py

This file was deleted.

17 changes: 14 additions & 3 deletions protein_design_tools/io/pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
from io import StringIO


def fetch_pdb(pdb_id: str, file_path: Optional[str] = None) -> ProteinStructure:
"""
Fetch a PDB file from RCSB PDB by its ID and optionally save it to a file.
Expand Down Expand Up @@ -37,7 +38,10 @@ def fetch_pdb(pdb_id: str, file_path: Optional[str] = None) -> ProteinStructure:
temp_path = StringIO(response.text)
return read_pdb(temp_path)
else:
raise ValueError(f"Failed to fetch PDB ID {pdb_id}: HTTP status {response.status_code}")
raise ValueError(
f"Failed to fetch PDB ID {pdb_id}: HTTP status {response.status_code}"
)


def read_pdb(
file_path: str, chains: Optional[List[str]] = None, name: Optional[str] = None
Expand Down Expand Up @@ -124,6 +128,7 @@ def read_pdb(

return structure


def write_pdb(structure: ProteinStructure, file_path: str) -> None:
"""
Write a ProteinStructure object to a PDB file.
Expand All @@ -140,7 +145,13 @@ def write_pdb(structure: ProteinStructure, file_path: str) -> None:
for chain in structure.chains:
for residue in chain.residues:
for atom in residue.atoms:
content += f"ATOM {atom.atom_id:5} {atom.name:<4} {residue.name:<3} {chain.name}{residue.res_seq:4}{residue.i_code:<1} {atom.x:8.3f}{atom.y:8.3f}{atom.z:8.3f}{atom.occupancy:6.2f}{atom.temp_factor:6.2f} {atom.element:2}{atom.charge:2}\n"
content += (
f"ATOM {atom.atom_id:5} {atom.name:<4} {residue.name:<3} "
f"{chain.name}{residue.res_seq:4}{residue.i_code:<1} "
f"{atom.x:8.3f}{atom.y:8.3f}{atom.z:8.3f}"
f"{atom.occupancy:6.2f}{atom.temp_factor:6.2f} "
f"{atom.element:2}{atom.charge:2}\n"
)
content += "TER\n"
content += "END\n"
f.write(content)
f.write(content)
2 changes: 1 addition & 1 deletion protein_design_tools/metrics/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ def compute_rmsd_tensorflow(P: tnp.ndarray, Q: tnp.ndarray) -> tnp.ndarray:
RMSD between P and Q
"""
assert P.shape == Q.shape
return tnp.sqrt(tnp.mean(tnp.sum((P - Q) ** 2, axis=1)))
return tnp.sqrt(tnp.mean(tnp.sum((P - Q) ** 2, axis=1)))
38 changes: 22 additions & 16 deletions tests/io/test_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from protein_design_tools.core.residue import Residue
from protein_design_tools.core.atom import Atom


@pytest.fixture
def sample_pdb_content():
"""from a previous AF2 test"""
Expand Down Expand Up @@ -216,21 +217,24 @@ def test_read_pdb_malformed_lines():

def test_read_pdb_with_hetatm(sample_pdb_content):
mock_file = mock_open(read_data=sample_pdb_content)

with patch("builtins.open", mock_file):
structure = read_pdb("dummy_path.pdb", chains=['A'], name="TestProteinWithHETATM")

structure = read_pdb(
"dummy_path.pdb", chains=["A"], name="TestProteinWithHETATM"
)

# Assertions for HETATM record
chain = structure.chains[0]
residue2 = chain.residues[1]
assert len(residue2.atoms) == 6 # Including OXT from HETATM
hetatm_atom = next((atom for atom in residue2.atoms if atom.name == 'OXT'), None)
hetatm_atom = next((atom for atom in residue2.atoms if atom.name == "OXT"), None)
assert hetatm_atom is not None
assert hetatm_atom.element == 'O'
assert hetatm_atom.element == "O"
assert hetatm_atom.x == 15.604
assert hetatm_atom.y == 15.707
assert hetatm_atom.z == 6.000


def test_read_pdb_multiple_chains():
multi_chain_pdb_content = """
ATOM 1 N ALA A 1 11.104 13.207 2.100 1.00 20.00 N
Expand All @@ -247,26 +251,28 @@ def test_read_pdb_multiple_chains():
END
"""
mock_file = mock_open(read_data=multi_chain_pdb_content)

with patch("builtins.open", mock_file):
structure = read_pdb("dummy_path.pdb", chains=['A', 'B'], name="MultiChainProtein")

structure = read_pdb(
"dummy_path.pdb", chains=["A", "B"], name="MultiChainProtein"
)

# Assertions
assert len(structure.chains) == 2
chain_a = next((c for c in structure.chains if c.name == 'A'), None)
chain_b = next((c for c in structure.chains if c.name == 'B'), None)
chain_a = next((c for c in structure.chains if c.name == "A"), None)
chain_b = next((c for c in structure.chains if c.name == "B"), None)

assert chain_a is not None
assert chain_b is not None

assert len(chain_a.residues) == 1
assert len(chain_b.residues) == 1

# Check residues and atoms
residue_a = chain_a.residues[0]
assert residue_a.name == 'ALA'
assert residue_a.name == "ALA"
assert len(residue_a.atoms) == 5

residue_b = chain_b.residues[0]
assert residue_b.name == 'ARG'
assert residue_b.name == "ARG"
assert len(residue_b.atoms) == 6 # Including OXT

0 comments on commit 38fd206

Please sign in to comment.