Skip to content

Commit

Permalink
feat: add GtoA damage
Browse files Browse the repository at this point in the history
  • Loading branch information
maxibor committed Jul 11, 2024
1 parent f10da15 commit e40005d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 56 deletions.
25 changes: 19 additions & 6 deletions pydamage/damage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def get_damage(self, show_al):
self.G = []
self.GA = []
self.no_mut = []
self.read_dict = {}
for al in self.alignments:
if al.is_unmapped is False:
all_damage = damage_al(
Expand All @@ -61,11 +62,13 @@ def get_damage(self, show_al):
self.CT += all_damage["CT"]
self.G += all_damage["G"]
self.GA += all_damage["GA"]
self.damage_bases += all_damage["CT"]
self.damage_bases += all_damage["GA"]
CT_GA = all_damage["CT"] + all_damage["GA"]
self.damage_bases += CT_GA
self.C_G_bases += all_damage["C"]
self.C_G_bases += all_damage["G"]
self.no_mut += all_damage["no_mut"]
if len(CT_GA) > 0:
self.read_dict[al.query_name] = np.array(CT_GA)

def compute_damage(self):
"""Computes the amount of damage for statistical modelling"""
Expand All @@ -77,7 +80,7 @@ def compute_damage(self):
),
return_counts=True,
)
C_dict = dict(zip(C_pos, C_counts))
C_dict = dict(zip(C_pos, C_counts)) # {position: count at each position}

# All G in reference
G_pos, G_counts = np.unique(
Expand All @@ -90,6 +93,10 @@ def compute_damage(self):
CT_pos, CT_counts = np.unique(np.sort(self.CT), return_counts=True)
CT_dict = dict(zip(CT_pos, CT_counts))

# GtoA transitions
GA_pos, GA_counts = np.unique(np.sort(self.GA), return_counts=True)
GA_dict = dict(zip(GA_pos, GA_counts))

# All transitions
damage_bases_pos, damage_bases_counts = np.unique(
np.sort(self.damage_bases), return_counts=True
Expand All @@ -114,16 +121,23 @@ def compute_damage(self):
for i in range(self.wlen):
if i not in CT_dict:
CT_dict[i] = 0

if i not in damage_bases_dict:
damage_bases_dict[i] = 0

if i not in no_mut_dict:
no_mut_dict[i] = 0

if i not in C_dict:
CT_damage_amount.append(0)
else:
CT_damage_amount.append(CT_dict[i] / C_dict[i])

if i not in G_dict:
GA_damage_amount.append(0)
else:
CT_damage_amount.append(CT_dict[i] / C_dict[i])
GA_damage_amount.append(GA_dict[i] / G_dict[i])

if i not in C_G_bases_dict:
damage_amount.append(0)
else:
Expand Down Expand Up @@ -235,8 +249,6 @@ def test_damage(ref, bam, mode, wlen, show_al, process, verbose):
all_damage,
) = al.compute_damage()

print(CT_damage)
# if all_damage:
model_A = models.damage_model()
model_B = models.null_model()
test_res = fit_models(
Expand All @@ -258,6 +270,7 @@ def test_damage(ref, bam, mode, wlen, show_al, process, verbose):

for i in range(wlen):
CT_log[f"CtoT-{i}"] = CT_damage[i]
GA_log[f"GtoA-{i}"] = GA_damage[i]
test_res.update(CT_log)
test_res.update(GA_log)

Expand Down
98 changes: 52 additions & 46 deletions pydamage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,52 +75,58 @@ def pydamage_analyze(
##########################
# Simple loop for debugging
##########################
filt_res = []
for ref in refs:
res = damage.test_damage(
bam=bam,
ref=ref,
wlen=wlen,
show_al=show_al,
mode=mode,
process=process,
verbose=verbose,
)
if res:
filt_res.append(res)
break
##########################
##########################

# test_damage_partial = partial(
# damage.test_damage,
# bam=bam,
# mode=mode,
# wlen=wlen,
# show_al=show_al,
# process=process,
# verbose=verbose,
# )
# print("Estimating and testing Damage")
# if group:
# filt_res = [
# damage.test_damage(
# ref=None,
# bam=bam,
# mode=mode,
# wlen=wlen,
# show_al=show_al,
# process=process,
# verbose=verbose,
# )
# ]
# else:
# if len(refs) > 0:
# with multiprocessing.Pool(proc) as p:
# res = list(tqdm(p.imap(test_damage_partial, refs), total=len(refs)))
# filt_res = [i for i in res if i]
# else:
# raise AlignmentFileError("No reference sequences were found in alignment file")
# filt_res = []
# for ref in refs:
# res = damage.test_damage(
# bam=bam,
# ref=ref,
# wlen=wlen,
# show_al=show_al,
# mode=mode,
# process=process,
# verbose=verbose,
# )
# if res:
# filt_res.append(res)
# break
######################
# Multiprocessing code
######################

test_damage_partial = partial(
damage.test_damage,
bam=bam,
mode=mode,
wlen=wlen,
show_al=show_al,
process=process,
verbose=verbose,
)
print("Estimating and testing Damage")
if group:
filt_res = [
damage.test_damage(
ref=None,
bam=bam,
mode=mode,
wlen=wlen,
show_al=show_al,
process=process,
verbose=verbose,
)
]
else:
if len(refs) > 0:
with multiprocessing.Pool(proc) as p:
res = list(tqdm(p.imap(test_damage_partial, refs), total=len(refs)))
filt_res = [i for i in res if i]
else:
raise AlignmentFileError(
"No reference sequences were found in alignment file"
)

######################
######################

print(f"{len(filt_res)} contig(s) analyzed by Pydamage")
if len(filt_res) == 0:
Expand Down
2 changes: 1 addition & 1 deletion pydamage/parse_damage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def damage_al(
query (string): query sequence
cigartuple (tuple): cigar tuple (pysam)
wlen (int): window length
print_al (bool): print alignment
show_al (bool): print alignment
Returns:
dict : {'C': [ C pos from 5'],
'CT': [ CtoT pos from 5'],
Expand Down
6 changes: 3 additions & 3 deletions pydamage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def RMSE(residuals: np.ndarray) -> float:
Returns:
float: RMSE
"""
return np.sqrt(np.mean(residuals ** 2))
return np.sqrt(np.mean(residuals**2))


def create_damage_dict(
Expand Down Expand Up @@ -218,6 +218,7 @@ def create_damage_dict(

return (damage_dict, non_damage_dict)


def prepare_bam(bam: str, minlen: int) -> Tuple[Tuple, str]:
"""Checks for file extension, and returns tuple of mapped refs of minlen
Expand All @@ -240,11 +241,10 @@ def prepare_bam(bam: str, minlen: int) -> Tuple[Tuple, str]:
sys.exit(1)

present_refs = set()
for ref_stat,ref_len in zip(alf.get_index_statistics(), alf.lengths):
for ref_stat, ref_len in zip(alf.get_index_statistics(), alf.lengths):
refname = ref_stat[0]
nb_mapped_reads = ref_stat[1]
if nb_mapped_reads > 0 and ref_len >= minlen:
present_refs.add(refname)
alf.close()
return tuple(present_refs), mode

0 comments on commit e40005d

Please sign in to comment.