Skip to content

Commit

Permalink
feat: add damage rescaling
Browse files Browse the repository at this point in the history
  • Loading branch information
maxibor committed Jul 12, 2024
1 parent e40005d commit d3cd9cf
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 19 deletions.
11 changes: 7 additions & 4 deletions pydamage/damage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_damage(self, show_al):
self.G = []
self.GA = []
self.no_mut = []
self.read_dict = {}
self.read_dict = {self.reference: dict()}
for al in self.alignments:
if al.is_unmapped is False:
all_damage = damage_al(
Expand All @@ -68,7 +68,7 @@ def get_damage(self, show_al):
self.C_G_bases += all_damage["G"]
self.no_mut += all_damage["no_mut"]
if len(CT_GA) > 0:
self.read_dict[al.query_name] = np.array(CT_GA)
self.read_dict[self.reference][al.query_name] = np.array(CT_GA)

def compute_damage(self):
"""Computes the amount of damage for statistical modelling"""
Expand Down Expand Up @@ -121,6 +121,8 @@ def compute_damage(self):
for i in range(self.wlen):
if i not in CT_dict:
CT_dict[i] = 0
if i not in GA_dict:
GA_dict[i] = 0

if i not in damage_bases_dict:
damage_bases_dict[i] = 0
Expand Down Expand Up @@ -206,7 +208,7 @@ def test_damage(ref, bam, mode, wlen, show_al, process, verbose):
"""Prepare data and run LRtest to test for damage
Args:
ref (str): name of referene in alignment file
ref (str): name of reference in alignment file
bam (str): bam file
mode (str): opening mode of alignment file
wlen (int): window length
Expand Down Expand Up @@ -241,6 +243,7 @@ def test_damage(ref, bam, mode, wlen, show_al, process, verbose):

al = al_to_damage(reference=ref, al_handle=al_handle, wlen=wlen)
al.get_damage(show_al=show_al)
read_dict = al.read_dict
(
mut_count,
conserved_count,
Expand Down Expand Up @@ -284,7 +287,7 @@ def test_damage(ref, bam, mode, wlen, show_al, process, verbose):

# print(test_res)

return check_model_fit(test_res, wlen, verbose)
return (check_model_fit(test_res, wlen, verbose), read_dict)

except (ValueError, RuntimeError) as e:
if verbose:
Expand Down
42 changes: 29 additions & 13 deletions pydamage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from pydamage.plot import damageplot
from pydamage.exceptions import PyDamageWarning, AlignmentFileError
from pydamage.accuracy_model import prepare_data, glm_predict
from pydamage.rescale import rescale_bam
from pydamage.models import glm_model_params
import os
import sys
from tqdm import tqdm
import warnings
from pydamage import __version__

from collections import ChainMap

# def pydamage_analyze(
# bam,
Expand Down Expand Up @@ -104,22 +106,24 @@ def pydamage_analyze(
)
print("Estimating and testing Damage")
if group:
filt_res = [
damage.test_damage(
ref=None,
bam=bam,
mode=mode,
wlen=wlen,
show_al=show_al,
process=process,
verbose=verbose,
)
]
filt_res, read_dict = damage.test_damage(
ref=None,
bam=bam,
mode=mode,
wlen=wlen,
show_al=show_al,
process=process,
verbose=verbose,
)
filt_res = [filt_res]
else:
if len(refs) > 0:
with multiprocessing.Pool(proc) as p:
res = list(tqdm(p.imap(test_damage_partial, refs), total=len(refs)))
filt_res = [i for i in res if i]
filt_res, read_dicts = zip(*res)
filt_res = [i for i in filt_res if i]
read_dicts = [i for i in read_dicts if i]
read_dict = dict(ChainMap(*read_dicts))
else:
raise AlignmentFileError(
"No reference sequences were found in alignment file"
Expand All @@ -134,6 +138,18 @@ def pydamage_analyze(
"No alignments were found, check your alignment file", PyDamageWarning
)

rescale = True
if rescale:
print("\nRescaling quality scores")
rescale_bam(
bam=bam,
threshold=0.5,
alpha=0.05,
damage_dict=filt_res,
read_dict=read_dict,
outname=os.path.join(outdir, "rescaled.bam"),
)

if plot and len(filt_res) > 0:
print("\nGenerating Pydamage plots")
plotdir = f"{outdir}/plots"
Expand Down
3 changes: 2 additions & 1 deletion pydamage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _geom_pmf(self, x, p):
"""
return ((1 - p) ** x) * p

def fit(self, x, p, pmin, pmax, wlen=35):
def fit(self, x, p, pmin, pmax, wlen=35, **kwargs):
"""Damage model function
Args:
Expand All @@ -53,6 +53,7 @@ def fit(self, x, p, pmin, pmax, wlen=35):
xmax = self._geom_pmf(0, p)
xmin = self._geom_pmf(wlen - 1, p)
scaled_geom = ((base_geom - xmin) / (xmax - xmin)) * (pmax - pmin) + pmin
scaled_geom[wlen:] = pmin
return scaled_geom


Expand Down
88 changes: 88 additions & 0 deletions pydamage/rescale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pysam
import numpy as np
from array import array
from pydamage.models import damage_model


def phred_to_prob(qual):
"""Convert Phred quality score to probability
Args:
qual (array): Array of unsigned integer Phred quality scores
Returns:
np.array(int): Array of read error probabilities
"""
return 10 ** (-np.array(qual) / 10)


def rescale_qual(read_qual, dmg_pmf, damage_bases):
"""Rescale quality scores using damage model
Args:
read_qual (array): Array of Phred quality scores
dmg_pmf (array): Array of damage model probabilities
damage_bases (array): Array of positions with damage
Returns:
np.array(int): Array of rescaled Phred quality scores
"""
e = phred_to_prob(read_qual)
d = np.zeros(len(read_qual))
d[damage_bases] = dmg_pmf[damage_bases]
return array(
"B", np.round(-10 * np.log10(1 - np.multiply(1 - e, 1 - d)), 0).astype(int)
)


def rescale_bam(bam, threshold, alpha, damage_dict, read_dict, outname):
"""Rescale quality scores in BAM file using damage model
Args:
bam (str): Path to BAM file
threshold (float): Predicted accuracy threshold
alpha (float): Q-value threshold
damage_dict (dict): Damage model parameters
read_dict (dict): Dictionary of read names
outname (str): Path to output BAM file
"""
damage_dict = {v["reference"]: v for v in damage_dict}
print(damage_dict)
with pysam.AlignmentFile(bam, "rb") as al:
refs = al.references
with pysam.AlignmentFile(outname, "wb", template=al) as out:
for ref in refs:
dmg = damage_model()
if ref in read_dict:
pass_filter = False
if (
threshold
and threshold <= damage_dict[ref]["predicted_accuracy"]
) and (alpha and alpha >= damage_dict[ref]["qvalue"]):
pass_filter = True
if pass_filter:
dmg_pmf = dmg.fit(x=np.arange(400), **damage_dict[ref])
for read in al.fetch(ref):
if read.query_name in read_dict[ref]:
qual = read.query_qualities
read.query_qualities = rescale_qual(
qual, dmg_pmf, read_dict[ref][read.query_name]
)
print(
np.mean(np.array(qual)),
np.mean(
np.array(
rescale_qual(
qual,
dmg_pmf,
read_dict[ref][read.query_name],
)
)
),
)

out.write(read)
else:
for read in al.fetch(ref):
out.write(read)
else:
for read in al.fetch(ref):
out.write(read)
3 changes: 2 additions & 1 deletion tests/test_damage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_al_to_damage(bamfile):

assert al.C[:10] == [7, 11, 18, 2, 18, 9, 7, 11, 15, 4]
assert al.CT == [15, 0, 0, 2, 15, 11, 0]
assert al.GA == []
assert al.damage_bases == [15, 0, 0, 2, 15, 11, 0]
assert al.C_G_bases[:10] == [7, 11, 18, 2, 18, 9, 7, 11, 15, 4]
assert al.no_mut[:10] == [7, 11, 18, 2, 18, 9, 7, 11, 4, 16]
Expand Down Expand Up @@ -103,4 +104,4 @@ def test_test_damage():
assert dam["reference"] == "NZ_JHCB02000002.1"
assert dam["nb_reads_aligned"] == 153
assert dam["coverage"] == pytest.approx(0.18994194094919317)
assert dam["reflen"] == 48399
assert dam["reflen"] == 48399

0 comments on commit d3cd9cf

Please sign in to comment.