Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #6

Merged
merged 14 commits into from
Feb 13, 2025
7 changes: 5 additions & 2 deletions fantasia/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ max_workers: 1
constants: "./fantasia/constants.yaml"

# Monitoring interval in seconds (for processes that require periodic checks).
monitor_interval: 5
monitor_interval: 10


# ==========================
Expand Down Expand Up @@ -58,7 +58,7 @@ fantasia_input_fasta: data_sample/worm_test.fasta
# Reference tag used for lookup operations. (None for Complete Reference table from information system)
lookup_reference_tag: 0

# Maximum number of entries to process.
# K-closest protein to consider for lookup
limit_per_entry: 10

# Prefix for output file names.
Expand All @@ -73,11 +73,14 @@ redundancy_filter: 0
# Number of sequences to package in each queue batch.
sequence_queue_package: 64

delete_queues: True

# ==========================
# 🧬 Embedding Configuration
# ==========================

embedding:
distance_metric: "<=>" # options: "<=>" (cosine) or "<->" (euclidean)
models:
esm:
enabled: True # 🔹 Enable or disable the ESM model
Expand Down
3 changes: 1 addition & 2 deletions fantasia/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ def initialize(config_path, embeddings_url=None):
conf = yaml.safe_load(config_file)
if embeddings_url:
conf["embeddings_url"] = embeddings_url
embeddings_dir = os.path.expanduser(conf["directories"]["embeddings"])
embeddings_dir = os.path.join(os.path.expanduser(conf["base_directory"]), "embeddings")
os.makedirs(embeddings_dir, exist_ok=True)
tar_path = os.path.join(embeddings_dir, "embeddings.tar")
print("Downloading embeddings...")
download_embeddings(conf["embeddings_url"], tar_path)
print("Loading dump into the database...")
load_dump_to_db(tar_path, conf)
Expand Down
64 changes: 40 additions & 24 deletions fantasia/src/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os

import pandas as pd
from goatools.base import get_godag
from pycdhit import cd_hit, read_clstr
from sqlalchemy import text
import h5py
Expand Down Expand Up @@ -81,6 +82,15 @@ def __init__(self, conf, current_date):
if redundancy_filter > 0:
self.generate_clusters()

self.go = get_godag('go-basic.obo', optional_attrs='relationship')

self.distance_metric = self.conf.get("embedding", {}).get("distance_metric", "<->")

valid_metrics = ["<->", "<=>"]
if self.distance_metric not in valid_metrics:
self.logger.warning(f"Invalid distance metric '{self.distance_metric}', defaulting to '<->' (Euclidean).")
self.distance_metric = "<->"

def fetch_models_info(self):
"""
Retrieves and initializes embedding models based on configuration.
Expand Down Expand Up @@ -129,6 +139,7 @@ def generate_clusters(self):
Exception
Si ocurre un error durante el proceso.
"""
input_h5 = os.path.join(self.conf['experiment_path'], "embeddings.h5")
try:
self.reference_fasta = os.path.join(self.experiment_path, "redundancy.fasta")
filtered_fasta = os.path.join(self.experiment_path, "filtered.fasta")
Expand All @@ -141,11 +152,10 @@ def generate_clusters(self):
for row in result:
ref_file.write(f">{row.id}\n{row.sequence}\n")

with h5py.File(os.path.expanduser(self.h5_path), "r") as h5file:
with h5py.File(input_h5, "r") as h5file:
for accession, accession_group in h5file.items():
if "sequence" in accession_group:
sequence = accession_group["sequence"][()].decode("utf-8")
# Remover el prefijo "accession_"
clean_accession = accession.removeprefix("accession_")
ref_file.write(f">{clean_accession}\n{sequence}\n")

Expand Down Expand Up @@ -285,7 +295,7 @@ def process(self, task_data):
annotated_results AS (
SELECT
s.sequence,
(se.embedding <-> te.embedding) AS distance,
(se.embedding {self.distance_metric} te.embedding) AS distance,
p.id AS protein_id,
p.gene_name AS gene_name,
p.organism AS organism,
Expand All @@ -303,7 +313,6 @@ def process(self, task_data):
JOIN accession ac ON p.id = ac.protein_id
WHERE
se.embedding_type_id = :embedding_type_id
AND (se.embedding <-> te.embedding) < :max_distance
{not_in_clause}
{tag_filter}
),
Expand All @@ -313,6 +322,7 @@ def process(self, task_data):
MIN(distance) AS min_distance
FROM
annotated_results
where distance <= :max_distance
GROUP BY
protein_id
ORDER BY
Expand Down Expand Up @@ -377,36 +387,42 @@ def process(self, task_data):
f"Error processing task for accession {accession} and embedding type {embedding_type_id}: {e}")
raise

def store_entry(self, go_terms):
"""
Stores the retrieved GO terms in a CSV file.

Parameters
----------
go_terms : list of dict
List of dictionaries containing GO term results.

Raises
------
Exception
If an error occurs while writing to the CSV file.
"""
if not go_terms:
def store_entry(self, annotations):
if not annotations:
self.logger.info("No valid GO terms to store.")
return

try:
df = pd.DataFrame(annotations)

if self.distance_metric == '<=>':
df["reliability_index"] = 1 - df["distance"]
if self.distance_metric == '<->':
df["reliability_index"] = 0.5 / (0.5 + df["distance"])

df = pd.DataFrame(go_terms)
# Mantener solo el GO term con mayor confiabilidad por cada (accession, go_id)
df = df.loc[df.groupby(["accession", "go_id"])["reliability_index"].idxmax()]

# Write to file
# Obtener términos padres usando goatools
parent_terms = set()
for go_id in df["go_id"].unique():
if go_id in self.go: # Asegurar que el GO term está en el DAG
parent_terms.update(p.id for p in self.go[go_id].parents)

# Filtrar términos padres si un hijo más específico está presente
df = df[~df["go_id"].isin(parent_terms)]

# Ordenar por reliability_index de mayor a menor
df = df.sort_values(by="reliability_index", ascending=False)

# Guardar en CSV
results_path = self.results_path
if os.path.exists(results_path) and os.path.getsize(results_path) > 0:
df.to_csv(self.results_path, mode='a', index=False, header=False)
self.logger.info(f"Appended {len(go_terms)} entries to {results_path}.")
df.to_csv(results_path, mode='a', index=False, header=False)
else:
df.to_csv(results_path, mode='w', index=False, header=True)
self.logger.info(f"Created new file and stored {len(go_terms)} entries in {results_path}.")

self.logger.info(f"Stored {len(df)} collapsed entries.")

except Exception as e:
self.logger.error(f"Error storing results in CSV: {e}")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "FANTASIA"
version = "0.6.0"
version = "0.7.0"
description = "Functional ANnoTAtion based on embedding space SImilArity"
authors = ["Francisco Miguel Pérez Canales <frapercan1@alum.us.es>","Gemma Martínez Redondo <gemma.martinez@ibe.upf-csic.es>"]
readme = "README.md"
Expand All @@ -18,7 +18,7 @@ flake8-bugbear = "^23.2.13"
taskipy = "^1.10.3"
sphinx = "^7.2.6"
sphinx-rtd-theme = "^1.2.0"
protein-metamorphisms-is = "^1.7.0"
protein-metamorphisms-is = "^2.0.0"


[tool.coverage.run]
Expand Down
Loading