Lookup Module

Embedding LookUp Module

Overview

This module implements the EmbeddingLookUp component for transferring Gene Ontology (GO) annotations to query protein sequences based on vector similarity between sequence embeddings.

Given an HDF5 file of query embeddings, the component:
  1. Retrieves reference embeddings from a vector-aware relational database.

  2. Computes distances (GPU or CPU) to the references.

  3. Selects nearest neighbors under configurable thresholds and redundancy filters.

  4. Expands neighbors into GO-term annotations (preloaded from the DB).

  5. Persists results per query/model/layer and exports TopGO-ready TSV files.

  6. Optionally performs alignment-based post-processing to derive identity/similarity metrics.

Key features

  • Taxonomy filters (include/exclude lists; optional descendant expansion).

  • Redundancy-aware neighbor selection using MMseqs2 clusters.

  • Support for multiple embedding models with per-model thresholds and layer control.

  • Distance computation on GPU (PyTorch) or CPU (SciPy).

  • Post-processing pipeline (Polars/Pandas) with scoring and summary aggregation.

  • TopGO exports by model/layer and ensemble across models.

Inputs

  • Query embeddings in HDF5, organized by accession → type_{model_id} → layer_{k} → embedding.

  • Reference embeddings and GO annotations stored in the database.

Outputs

  • Hierarchical CSVs under <experiment_path>/raw_results/{model}/layer_{k}/<accession>.csv.

  • Global summary at <experiment_path>/summary.csv.

  • TopGO TSVs under <experiment_path>/topgo/{model}/layer_{k}/ and ensemble under topgo/ensemble/.

  • A combined FASTA of all seen sequences at <experiment_path>/<sequences_fasta|sequences.fasta>.

Configuration (selected keys)

  • experiment_path (str): Base directory for inputs/outputs.

  • embeddings_path (str): Path to HDF5 with query embeddings.

  • batch_size (int), limit_per_entry (int), precision (int).

  • lookup (dict):
    • use_gpu (bool), batch_size (int), limit_per_entry (int), topgo (bool), lookup_cache_max (int)

    • distance_metric ({“euclidean”,”cosine”})

    • redundancy: identity (float), coverage (float), threads (int)

    • taxonomy: exclude (list[int]), include_only (list[int]), get_descendants (bool)

  • embedding.models (dict per logical model name):
    • enabled (bool), distance_threshold (float|None), batch_size (int|None), layer_index (list[int]|None).

  • postprocess (dict): keep_sequences (bool), store_workers (int), and summary spec with metric aggregation and weights.

  • limit_execution (int|None): Optional SQL LIMIT for reference lookup.

Dependencies

Relies on SQLAlchemy models for sequences, proteins and embeddings; GO terms are loaded with goatools; distances can be computed with PyTorch or SciPy; aggregation uses Polars/Pandas. MMseqs2 is required for redundancy clustering.

Reference

Inspired by GoPredSim (Rostlab). See: https://github.com/Rostlab/goPredSim

class fantasia.src.lookup.EmbeddingLookUp(conf, current_date)

Bases: GPUTaskInitializer

GO annotation transfer via embedding similarity.

This component reads query embeddings (HDF5) and compares them against reference embeddings stored in a vector-aware relational database. For the closest reference sequences, it retrieves GO annotations and writes results to CSV (and optionally TopGO-ready TSV).

Features

  • Taxonomy-based filtering (include/exclude, optional descendant expansion).

  • Redundancy-aware neighbor selection (MMseqs2 clusters).

  • Multiple embedding models with per-model distance thresholds and layer control.

  • Distance computation on GPU (PyTorch) or CPU (SciPy).

  • Optional pairwise alignment post-processing (identity/similarity).

param conf:

Runtime configuration including paths, thresholds, model settings, and processing options.

type conf:

dict

param current_date:

Timestamp-like suffix to version output artifacts.

type current_date:

str

Notes

  • Supported distance metrics: "euclidean" and "cosine" (default: “cosine”).

  • Redundancy filtering uses MMseqs2 identity/coverage thresholds when enabled.

  • GO annotations are preloaded once from the database and may be constrained by taxonomy filters.

enqueue()

Scan the query HDF5 file and publish homogeneous batches of embeddings.

This method iterates over accessions in the query HDF5, groups items by (embedding_type_id, layer_index), and publishes batches up to batch_size. Each task contains lightweight HDF5 pointers (h5_path + h5_group), the resolved model metadata, and the per-model distance threshold.

Behavior:
  • Skips accessions missing the sequence dataset.

  • Supports both layered embeddings (.../type_<id>/layer_<k>) and legacy embeddings without layers (layer_index=None).

  • Emits one payload per group of tasks with the following schema:

    {"model_id": int, "layer_index": Optional[int], "tasks": [ ... ]}.

Logging:
  • Reports the number of queries encountered.

  • Reports the number of published batches (based on group size and batch_size).

  • Warns when encountering malformed groups or missing sequences.

Raises:
  • FileNotFoundError – If the configured HDF5 file does not exist.

  • Exception – Propagates any unexpected errors during processing.

export_topgo()

Export TopGO-compatible TSV files: topgo/{model}/layer_{k}/{category}.topgo

Columns: accession, go_term, reliability_index

export_topgo_ensemble()

Export ensemble TopGO TSV files by collapsing across models/layers.

Keeps the best reliability_index per (accession, go_id, category) and writes: topgo/ensemble/{category}.topgo

generate_clusters()

Generate non-redundant sequence clusters using MMseqs2.

Steps:
  1. Collect all protein sequences from the database and the embeddings HDF5.

  2. Write them into a temporary FASTA file.

  3. Run MMseqs2 createdb, cluster, and createtsv with the configured thresholds.

  4. Load the resulting cluster assignments into in-memory structures.

Outputs:
  • self.clusters: pandas DataFrame of raw cluster assignments.

  • self.clusters_by_id: pandas DataFrame indexed by sequence ID → cluster ID.

  • self.clusters_by_cluster: dict mapping cluster ID → set of sequence IDs.

Configuration:
  • redundancy_filter (float): sequence identity threshold.

  • alignment_coverage (float): alignment coverage threshold.

  • threads (int): number of threads to use.

Raises:

Exception – If MMseqs2 fails or the clustering pipeline encounters an error.

load_model(model_type)

Placeholder: load a model into memory if required.

load_model_definitions()

Initialize self.types by matching DB embedding types with configuration.

Behavior:
  • Queries available embedding types from the database.

  • Matches DB models with those defined in the configuration.

  • Keeps only models that appear in both sources and are marked as enabled.

  • For each model, inspects available layers in the HDF5 file (if present) and determines effective layers to be used.

Logging:
  • Skips if a DB model is missing in config or is disabled.

  • Warns if no enabled models remain after matching.

  • For each enabled model, logs:
    • model name and DB id

    • distance_threshold (from config)

    • enabled_layers (from config, or ALL if unrestricted)

    • available_layers_in_h5 (discovered by scanning the HDF5 file)

    • effective_layers (intersection if restricted, else ALL)

Raises:

Exception – If database query fails.

post_processing() str

Aggregate per-accession results, compute weighted scores, and write a global summary.

Workflow

  1. Locate all CSV shards under raw_results/** across models and layers.

  2. For each accession:
    • Concatenate all shards belonging to it.

    • Compute per-GO aggregation metrics defined in the configuration.

    • Apply weighting scheme (if provided) to compute per-metric contributions and a global final_score.

    • Collect associated proteins and counts.

    • Write results incrementally to a global summary.csv.

  3. Trigger downstream exports:
    • Per model/layer TopGO files: topgo/{model}/layer_{k}/{category}.topgo.

    • Ensemble TopGO files: topgo/ensemble/{category}.topgo.

Configuration (conf[‘postprocess’][‘summary’])

  • metricsdict

    {column_name: [aggregation_fn]}. Aggregation functions: “mean”, “max”, “min” (others ignored).

  • aliasesdict

    Optional renaming for metrics used in weighting.

  • include_countsbool

    If True, add a normalized neighbor/support count metric.

  • weightsdict

    Raw weights per metric, alias, or metric+agg combination. Normalized to form final_score and weighted columns prefixed by weighted_prefix.

  • weighted_prefixstr

    Prefix for weighted metric outputs (default: “w_”).

Output

  • <experiment_path>/summary.csv (incremental write).

  • TopGO export files under topgo/.

returns:

Absolute path to the written summary.csv. Returns an empty string if no raw input files were found.

rtype:

str

preload_annotations()

Preload GO annotations from the database into memory.

Behavior:
  • Queries sequence, protein, protein_go_term_annotation, and go_terms.

  • Groups annotations by sequence ID.

  • Skips entries whose taxonomy_id is in self.exclude_taxon_ids.

Each stored annotation includes:
  • sequence (str)

  • go_id (str)

  • category (str)

  • evidence_code (str)

  • go_description (str)

  • protein_id (int)

  • organism (str)

  • taxonomy_id (int)

  • gene_name (str)

Results are stored in:
self.go_annotationsdict[int, list[dict]]

Mapping from sequence_id → list of annotation dicts.

process(payload: dict) list[dict]

Compute nearest neighbors for a homogeneous (model_id, layer) batch.

This method:

  • Loads query embeddings from HDF5 (or directly from payload in legacy mode).

  • Computes distances against the cached reference matrix for the given (model_id, layer_index).

  • Optionally applies redundancy filtering (exclude cluster members).

  • Selects nearest neighbors based on distance threshold and per-entry limit.

  • Returns a compact list of neighbor hits.

Expected Payload

{
  "model_id": int,
  "layer_index": Optional[int],
  "tasks": [
    {
      "h5_path": str,
      "h5_group": str,
      "model_name": str,
      "distance_threshold": Optional[float],
      "layer_index": Optional[int]
    },
    # Legacy form (discouraged):
    {"embedding": np.ndarray, "accession": str, ...}
  ]
}

Distance computation

  • Uses GPU (PyTorch) if conf['use_gpu'] is True.

  • Otherwise falls back to CPU (SciPy cdist).

  • self.distance_metric must be either "cosine" or "euclidean".

Redundancy

  • If MMseqs2 clustering is configured, neighbors belonging to the same redundancy cluster as the query accession are excluded.

Neighbor selection

  • Results are sorted by ascending distance.

  • Filtered by per-model distance_threshold (if provided).

  • Truncated to limit_per_entry.

returns:

One record per selected reference neighbor, with fields:

  • accession (str)

  • ref_sequence_id (int)

  • distance (float)

  • model_name (str)

  • embedding_type_id (int)

  • layer_index (Optional[int])

rtype:

list[dict]

Notes

GO expansion and sequence retrieval are deferred to store_entry() to keep payloads minimal.

retrieve_cluster_members(accession: str) set

Retrieve all sequence IDs belonging to the same MMseqs2 cluster as the given accession.

Parameters:

accession (str) – Sequence ID used in clustering (must match the identifier in the FASTA header).

Returns:

Set of sequence IDs in the same cluster. Returns an empty set if the accession is not found or has no cluster members.

Return type:

set of str

store_entry(annotations_or_hits: list[dict]) None

Persist per-(model, layer) results and update the global FASTA index.

Input:
Either
  1. compact neighbor hits produced by process() (preferred), or

  2. legacy, already-expanded annotation rows.

Pipeline:
  1. If input are compact hits, expand them into per-GO rows using preloaded self.go_annotations; lazily read the query sequence from HDF5 (once per accession) and fetch the reference sequence from the cache.

  2. Cast types with Polars, compute a reliability index from distance (metric-dependent), and optionally compute pairwise alignment metrics (identity, similarity, etc.) when both sequences are available.

  3. Append execution metadata (distance metric and per-model threshold), then write CSV shards hierarchically under: raw_results/{model_name}/layer_{k or 'legacy'}/{accession}.csv.

  4. Update a global FASTA containing all queries (>Q{idx}) and references (>R{idx}) with stable indices for downstream tools.

Configuration:
  • postprocess.keep_sequences (bool): keep or drop raw sequences in CSV.

  • precision (int): float formatting precision for CSV output.

Raises:
  • Exception – Propagates unexpected errors after logging context.

  • Partial writes may remain for accessions processed before the failure.

unload_model(model_type)

Placeholder: unload a model from memory if required.