Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
import os | |
import tempfile | |
import time | |
import polars as pl | |
import numpy as np | |
import logging | |
from pathlib import Path | |
from omegaconf import OmegaConf, DictConfig | |
from gradio_log import Log | |
# --- InstaNovo Imports --- | |
try: | |
from instanovo.transformer.model import InstaNovo | |
from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus | |
from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics | |
from instanovo.transformer.dataset import SpectrumDataset, collate_batch | |
from instanovo.inference import ( | |
GreedyDecoder, | |
KnapsackBeamSearchDecoder, | |
Knapsack, | |
ScoredSequence, | |
Decoder, | |
) | |
from instanovo.inference.diffusion import DiffusionDecoder | |
from instanovo.constants import ( | |
MASS_SCALE, | |
MAX_MASS, | |
DIFFUSION_START_STEP, | |
) | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F # For padding | |
except ImportError as e: | |
raise ImportError(f"Failed to import InstaNovo components: {e}") | |
# --- Configuration --- | |
TRANSFORMER_MODEL_ID = "instanovo-v1.1.0" | |
DIFFUSION_MODEL_ID = "instanovoplus-v1.1.0-alpha" | |
KNAPSACK_DIR = Path("./knapsack_cache") | |
# Determine device | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
FP16 = DEVICE == "cuda" | |
# --- Global Variables (Load Models and Knapsack Once) --- | |
INSTANOVO: InstaNovo | None = None | |
INSTANOVO_CONFIG: DictConfig | None = None | |
INSTANOVOPLUS: InstaNovoPlus | None = None | |
INSTANOVOPLUS_CONFIG: DictConfig | None = None | |
KNAPSACK: Knapsack | None = None | |
RESIDUE_SET: ResidueSet | None = None | |
# --- Assets --- | |
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"]) | |
# Create gradio temporary directory | |
temp_dir = Path('/tmp/gradio') | |
if not temp_dir.exists(): | |
temp_dir.mkdir() | |
# Logging configuration | |
# TODO: create logfile per user/session | |
# see https://www.gradio.app/guides/resource-cleanup | |
log_file = "/tmp/instanovo_gradio_log.txt" | |
Path(log_file).touch() | |
logger = logging.getLogger("instanovo") | |
logger.setLevel(logging.INFO) | |
file_handler = logging.FileHandler(log_file) | |
file_handler.setLevel(logging.INFO) | |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
def load_models_and_knapsack(): | |
"""Loads the InstaNovo models and generates/loads the knapsack.""" | |
global INSTANOVO, KNAPSACK, INSTANOVO_CONFIG, RESIDUE_SET, INSTANOVOPLUS, INSTANOVOPLUS_CONFIG | |
models_loaded = INSTANOVO is not None and INSTANOVOPLUS is not None | |
if models_loaded: | |
logger.info("Models already loaded.") | |
# Still check knapsack if not loaded | |
if KNAPSACK is None: | |
logger.info("Models loaded, but knapsack needs loading/generation.") | |
else: | |
return # All loaded | |
# --- Load Transformer Model --- | |
if INSTANOVO is None: | |
logger.info(f"Loading InstaNovo (Transformer) model: {TRANSFORMER_MODEL_ID} to {DEVICE}...") | |
try: | |
INSTANOVO, INSTANOVO_CONFIG = InstaNovo.from_pretrained(TRANSFORMER_MODEL_ID) | |
INSTANOVO.to(DEVICE) | |
INSTANOVO.eval() | |
RESIDUE_SET = INSTANOVO.residue_set | |
logger.info("Transformer model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Error loading Transformer model: {e}") | |
raise gr.Error(f"Failed to load InstaNovo Transformer model: {TRANSFORMER_MODEL_ID}. Error: {e}") | |
else: | |
logger.info("Transformer model already loaded.") | |
# --- Load Diffusion Model --- | |
if INSTANOVOPLUS is None: | |
logger.info(f"Loading InstaNovo+ (Diffusion) model: {DIFFUSION_MODEL_ID} to {DEVICE}...") | |
try: | |
INSTANOVOPLUS, INSTANOVOPLUS_CONFIG = InstaNovoPlus.from_pretrained(DIFFUSION_MODEL_ID) | |
INSTANOVOPLUS.to(DEVICE) | |
INSTANOVOPLUS.eval() | |
if RESIDUE_SET is not None and INSTANOVOPLUS.residues != RESIDUE_SET: | |
logger.warning("Residue sets between Transformer and Diffusion models differ. Using Transformer's set.") | |
elif RESIDUE_SET is None: | |
RESIDUE_SET = INSTANOVOPLUS.residues | |
logger.info("Diffusion model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Error loading Diffusion model: {e}") | |
gr.Warning(f"Failed to load InstaNovo+ Diffusion model ({DIFFUSION_MODEL_ID}): {e}. Diffusion modes will be unavailable.") | |
INSTANOVOPLUS = None | |
else: | |
logger.info("Diffusion model already loaded.") | |
# --- Knapsack Handling --- | |
# Only attempt knapsack loading/generation if the Transformer model is loaded | |
if INSTANOVO is not None and RESIDUE_SET is not None and KNAPSACK is None: | |
knapsack_exists = ( | |
(KNAPSACK_DIR / "parameters.pkl").exists() | |
and (KNAPSACK_DIR / "masses.npy").exists() | |
and (KNAPSACK_DIR / "chart.npy").exists() | |
) | |
if knapsack_exists: | |
logger.info(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...") | |
try: | |
KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR)) | |
logger.info("Knapsack loaded successfully.") | |
except Exception as e: | |
logger.info(f"Error loading knapsack: {e}. Will attempt to regenerate.") | |
KNAPSACK = None | |
knapsack_exists = False | |
if not knapsack_exists: | |
logger.info("Knapsack not found or failed to load. Generating knapsack...") | |
try: | |
residue_masses_for_calc = dict(RESIDUE_SET.residue_masses.copy()) | |
special_and_nonpositive = list(RESIDUE_SET.special_tokens) + [ | |
k for k, v in residue_masses_for_calc.items() if v <= 0 | |
] | |
if special_and_nonpositive: | |
logger.info(f"Excluding special/non-positive mass residues from knapsack: {special_and_nonpositive}") | |
for res in set(special_and_nonpositive): | |
if res in residue_masses_for_calc: | |
del residue_masses_for_calc[res] | |
full_residue_indices = RESIDUE_SET.residue_to_index | |
if not residue_masses_for_calc: # Check if any residues are left for calculation | |
raise ValueError("No valid residues with positive mass found for knapsack generation.") | |
logger.info("Generating knapsack. This will take a few minutes, please be patient.") | |
KNAPSACK = Knapsack.construct_knapsack( | |
residue_masses=residue_masses_for_calc, | |
residue_indices=full_residue_indices, | |
max_mass=MAX_MASS, | |
mass_scale=MASS_SCALE, | |
) | |
logger.info(f"Knapsack generated. Saving to {KNAPSACK_DIR}...") | |
KNAPSACK.save(str(KNAPSACK_DIR)) | |
logger.info("Knapsack saved.") | |
except Exception as e: | |
logger.error(f"Error generating or saving knapsack: {e}", exc_info=True) | |
gr.Warning(f"Failed to generate Knapsack. Knapsack Beam Search will not be available. Error: {e}") | |
KNAPSACK = None | |
elif KNAPSACK is not None: | |
logger.info("Knapsack already loaded.") | |
elif INSTANOVO is None: | |
logger.warning("Transformer model not loaded, skipping Knapsack loading/generation.") | |
# Load models and knapsack when the script starts | |
load_models_and_knapsack() | |
def create_inference_config( | |
input_path: str, | |
output_path: str, | |
) -> DictConfig: | |
"""Creates a base OmegaConf DictConfig for prediction environment.""" | |
base_cfg = OmegaConf.create({ | |
"data_path": None, "instanovo_model": TRANSFORMER_MODEL_ID, | |
"instanovoplus_model": DIFFUSION_MODEL_ID, "output_path": None, | |
"knapsack_path": str(KNAPSACK_DIR), "denovo": True, "refine": True, | |
"num_beams": 1, "max_length": 40, "max_charge": 10, | |
"isotope_error_range": [0, 1], "subset": 1.0, "use_knapsack": False, | |
"save_beams": False, "batch_size": 64, "device": DEVICE, "fp16": FP16, | |
"log_interval": 500, "use_basic_logging": True, | |
"filter_precursor_ppm": 20, "filter_confidence": 1e-4, | |
"filter_fdr_threshold": 0.05, "suppressed_residues": None, | |
"disable_terminal_residues_anywhere": True, | |
"residue_remapping": { | |
"M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]", | |
"S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]", | |
"S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]", | |
"Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]", | |
"Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]", | |
"C(+57.02)": "C[UNIMOD:4]", "(+42.01)": "[UNIMOD:1]", | |
"(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]", | |
}, | |
"column_map": { | |
"Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz", | |
"Mass": "precursor_mass", "Charge": "precursor_charge", | |
"Mass values": "mz_array", "Mass spectrum": "mz_array", | |
"Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array", | |
"Scan number": "scan_number" | |
}, | |
"index_columns": [ | |
"scan_number", "precursor_mz", "precursor_charge", | |
"retention_time", "spectrum_id", "experiment_name", | |
], | |
}) | |
cfg_overrides = { | |
"data_path": input_path, "output_path": output_path, | |
"device": DEVICE, "fp16": FP16, "denovo": True, | |
} | |
final_cfg = OmegaConf.merge(base_cfg, cfg_overrides) | |
logger.info(f"Created inference config:\n{OmegaConf.to_yaml(final_cfg)}") | |
return final_cfg | |
def _get_transformer_decoder(selection: str, config: DictConfig) -> tuple[Decoder, int, bool]: | |
"""Helper to instantiate the correct transformer decoder based on selection.""" | |
global INSTANOVO, KNAPSACK | |
if INSTANOVO is None: | |
raise gr.Error("InstaNovo Transformer model not loaded.") | |
num_beams = 1 | |
use_knapsack = False | |
decoder: Decoder | |
if "Greedy" in selection: | |
decoder = GreedyDecoder( | |
model=INSTANOVO, | |
mass_scale=MASS_SCALE, | |
suppressed_residues=config.get("suppressed_residues", None), | |
disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True), | |
) | |
elif "Knapsack" in selection: | |
if KNAPSACK is None: | |
raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.") | |
decoder = KnapsackBeamSearchDecoder(model=INSTANOVO, knapsack=KNAPSACK) | |
num_beams = 5 # Default beam size for knapsack | |
use_knapsack = True | |
else: | |
raise ValueError(f"Unknown transformer decoder selection: {selection}") | |
logger.info(f"Using Transformer decoder: {type(decoder).__name__} (Num beams: {num_beams}, Use Knapsack: {use_knapsack})") | |
return decoder, num_beams, use_knapsack | |
def run_transformer_prediction(dl, config, transformer_decoder_selection): | |
"""Runs prediction using only the transformer model.""" | |
global RESIDUE_SET | |
if RESIDUE_SET is None: | |
raise gr.Error("ResidueSet not loaded.") | |
decoder, num_beams, use_knapsack = _get_transformer_decoder(transformer_decoder_selection, config) | |
results_list: list[ScoredSequence | list] = [] | |
start_time = time.time() | |
for i, batch in enumerate(dl): | |
spectra, precursors, spectra_mask, _, _ = batch | |
spectra = spectra.to(DEVICE) | |
precursors = precursors.to(DEVICE) | |
spectra_mask = spectra_mask.to(DEVICE) | |
with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16): | |
batch_predictions = decoder.decode( | |
spectra=spectra, | |
precursors=precursors, | |
beam_size=num_beams, | |
max_length=config.max_length, | |
mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, | |
max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1, | |
return_beam=False, # Only top result | |
) | |
results_list.extend(batch_predictions) | |
if (i + 1) % 10 == 0 or (i + 1) == len(dl): | |
logger.info(f"Transformer prediction: Processed batch {i+1}/{len(dl)}") | |
end_time = time.time() | |
logger.info(f"Transformer prediction finished in {end_time - start_time:.2f} seconds.") | |
return results_list | |
def run_diffusion_prediction(dl, config): | |
"""Runs prediction using only the diffusion model.""" | |
global INSTANOVOPLUS, RESIDUE_SET | |
if INSTANOVOPLUS is None or RESIDUE_SET is None: | |
raise gr.Error("InstaNovo+ Diffusion model not loaded.") | |
diffusion_decoder = DiffusionDecoder(model=INSTANOVOPLUS) | |
logger.info(f"Using decoder: {type(diffusion_decoder).__name__}") | |
results_sequences = [] | |
results_log_probs = [] | |
start_time = time.time() | |
# Re-create dataloader iterator to get precursor info easily later | |
all_batches = list(dl) | |
for i, batch in enumerate(all_batches): | |
spectra, precursors, spectra_mask, _, _ = batch | |
spectra = spectra.to(DEVICE) | |
precursors = precursors.to(DEVICE) | |
spectra_mask = spectra_mask.to(DEVICE) | |
with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16): | |
batch_sequences, batch_log_probs = diffusion_decoder.decode( | |
spectra=spectra, | |
spectra_padding_mask=spectra_mask, | |
precursors=precursors, | |
initial_sequence=None, | |
) | |
results_sequences.extend(batch_sequences) | |
results_log_probs.extend(batch_log_probs) | |
if (i + 1) % 10 == 0 or (i + 1) == len(all_batches): | |
logger.info(f"Diffusion prediction: Processed batch {i+1}/{len(all_batches)}") | |
end_time = time.time() | |
logger.info(f"Diffusion prediction finished in {end_time - start_time:.2f} seconds.") | |
scored_results = [] | |
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range) | |
all_precursors = torch.cat([b[1] for b in all_batches], dim=0) # b[1] is precursors | |
for idx, (seq, logp) in enumerate(zip(results_sequences, results_log_probs)): | |
prec_mz = all_precursors[idx, 1].item() | |
prec_ch = int(all_precursors[idx, 2].item()) | |
try: | |
_, delta_mass_list = metrics_calc.matches_precursor(seq, prec_mz, prec_ch) | |
min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan") | |
except Exception as e: | |
logger.info(f"Warning: Could not calculate delta mass for diffusion prediction {idx}: {e}") | |
min_abs_ppm = float("nan") | |
scored_results.append( | |
ScoredSequence(sequence=seq, mass_error=min_abs_ppm, sequence_log_probability=logp, token_log_probabilities=[]) | |
) | |
return scored_results | |
def run_refinement_prediction(dl, config, transformer_decoder_selection): | |
"""Runs transformer prediction followed by diffusion refinement.""" | |
global INSTANOVO, INSTANOVOPLUS, RESIDUE_SET, INSTANOVOPLUS_CONFIG | |
if INSTANOVO is None or INSTANOVOPLUS is None or RESIDUE_SET is None or INSTANOVOPLUS_CONFIG is None: | |
missing = [m for m, v in [("Transformer", INSTANOVO), ("Diffusion", INSTANOVOPLUS)] if v is None] | |
raise gr.Error(f"Cannot run refinement: {', '.join(missing)} model not loaded.") | |
# 1. Run Transformer Prediction (using selected decoder) | |
logger.info(f"Running Transformer prediction ({transformer_decoder_selection}) for refinement...") | |
transformer_decoder, num_beams, _ = _get_transformer_decoder(transformer_decoder_selection, config) # Get selected decoder | |
transformer_results_list: list[ScoredSequence | list] = [] | |
all_batches = list(dl) # Store batches | |
start_time_transformer = time.time() | |
for i, batch in enumerate(all_batches): | |
spectra, precursors, spectra_mask, _, _ = batch | |
spectra = spectra.to(DEVICE) | |
precursors = precursors.to(DEVICE) | |
spectra_mask = spectra_mask.to(DEVICE) | |
with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16): | |
batch_predictions = transformer_decoder.decode( | |
spectra=spectra, | |
precursors=precursors, | |
beam_size=num_beams, # Use selected beam size | |
max_length=config.max_length, | |
mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, | |
max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1, | |
return_beam=False, # Only top result needed for refinement | |
) | |
transformer_results_list.extend(batch_predictions) | |
if (i + 1) % 10 == 0 or (i + 1) == len(all_batches): | |
logger.info(f"Refinement (Transformer): Processed batch {i+1}/{len(all_batches)}") | |
logger.info(f"Transformer prediction for refinement finished in {time.time() - start_time_transformer:.2f} seconds.") | |
# 2. Prepare Transformer Predictions as Initial Sequences for Diffusion | |
logger.info("Encoding transformer predictions for diffusion input...") | |
encoded_transformer_preds = [] | |
max_len_diffusion = INSTANOVOPLUS_CONFIG.get("max_length", 40) | |
for res in transformer_results_list: | |
if isinstance(res, ScoredSequence) and res.sequence: | |
# Encode sequence *without* EOS for diffusion input. | |
encoded = RESIDUE_SET.encode(res.sequence, add_eos=False, return_tensor='pt') | |
else: | |
# If transformer failed, provide a dummy PAD sequence | |
encoded = torch.full((max_len_diffusion,), RESIDUE_SET.PAD_INDEX, dtype=torch.long) | |
# Pad or truncate to the diffusion model's max length | |
current_len = encoded.shape[0] | |
if current_len > max_len_diffusion: | |
logger.warning(f"Transformer prediction exceeded diffusion max length ({max_len_diffusion}). Truncating.") | |
encoded = encoded[:max_len_diffusion] | |
elif current_len < max_len_diffusion: | |
padding = torch.full((max_len_diffusion - current_len,), RESIDUE_SET.PAD_INDEX, dtype=torch.long) | |
encoded = torch.cat((encoded, padding)) | |
encoded_transformer_preds.append(encoded) | |
if not encoded_transformer_preds: | |
raise gr.Error("Transformer prediction yielded no results to refine.") | |
encoded_transformer_preds_tensor = torch.stack(encoded_transformer_preds).to(DEVICE) | |
logger.info(f"Encoded {encoded_transformer_preds_tensor.shape[0]} sequences for diffusion.") | |
# 3. Run Diffusion Refinement | |
logger.info("Running Diffusion refinement...") | |
diffusion_decoder = DiffusionDecoder(model=INSTANOVOPLUS) | |
refined_sequences = [] | |
refined_log_probs = [] | |
start_time_diffusion = time.time() | |
current_idx = 0 | |
for i, batch in enumerate(all_batches): | |
spectra, precursors, spectra_mask, _, _ = batch | |
spectra = spectra.to(DEVICE) | |
precursors = precursors.to(DEVICE) | |
spectra_mask = spectra_mask.to(DEVICE) | |
batch_size = spectra.shape[0] | |
initial_sequences_batch = encoded_transformer_preds_tensor[current_idx : current_idx + batch_size] | |
current_idx += batch_size | |
if initial_sequences_batch.shape[0] != batch_size: | |
logger.error(f"Batch size mismatch during refinement: expected {batch_size}, got {initial_sequences_batch.shape[0]}") | |
continue # Skip batch? | |
with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16): | |
batch_refined_seqs, batch_refined_logp = diffusion_decoder.decode( | |
spectra=spectra, | |
spectra_padding_mask=spectra_mask, | |
precursors=precursors, | |
initial_sequence=initial_sequences_batch, | |
start_step=DIFFUSION_START_STEP, | |
) | |
refined_sequences.extend(batch_refined_seqs) | |
refined_log_probs.extend(batch_refined_logp) | |
if (i + 1) % 10 == 0 or (i + 1) == len(all_batches): | |
logger.info(f"Refinement (Diffusion): Processed batch {i+1}/{len(all_batches)}") | |
logger.info(f"Diffusion refinement finished in {time.time() - start_time_diffusion:.2f} seconds.") | |
# 4. Combine and Format Results | |
all_precursors = torch.cat([b[1] for b in all_batches], dim=0) # b[1] is precursors | |
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range) | |
combined_results = [] | |
for idx, (transformer_res, refined_seq, refined_logp) in enumerate(zip(transformer_results_list, refined_sequences, refined_log_probs)): | |
prec_mz = all_precursors[idx, 1].item() | |
prec_ch = int(all_precursors[idx, 2].item()) | |
try: | |
_, delta_mass_list = metrics_calc.matches_precursor(refined_seq, prec_mz, prec_ch) | |
min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan") | |
except Exception as e: | |
logger.info(f"Warning: Could not calculate delta mass for refined prediction {idx}: {e}") | |
min_abs_ppm = float("nan") | |
combined_data = { | |
"transformer_prediction": "".join(transformer_res.sequence) if isinstance(transformer_res, ScoredSequence) else "", | |
"transformer_log_probability": transformer_res.sequence_log_probability if isinstance(transformer_res, ScoredSequence) else float('-inf'), | |
"refined_prediction": "".join(refined_seq), | |
"refined_log_probability": refined_logp, | |
"refined_delta_mass_ppm": min_abs_ppm, | |
} | |
combined_results.append(combined_data) | |
return combined_results | |
def predict_peptides(input_file, mode_selection, transformer_decoder_selection): | |
""" | |
Main function to load data, select mode, run prediction, and return results. | |
""" | |
# Ensure models are loaded | |
if INSTANOVO is None or RESIDUE_SET is None: | |
load_models_and_knapsack() # Try reload | |
if INSTANOVO is None: | |
raise gr.Error("InstaNovo Transformer model failed to load. Cannot perform prediction.") | |
if ("refinement" in mode_selection or "InstaNovo+" in mode_selection) and INSTANOVOPLUS is None: | |
load_models_and_knapsack() # Try reload diffusion | |
if INSTANOVOPLUS is None: | |
raise gr.Error("InstaNovo+ Diffusion model failed to load. Cannot perform Refinement or InstaNovo+ Only prediction.") | |
if "Knapsack" in transformer_decoder_selection and KNAPSACK is None: | |
load_models_and_knapsack() # Try reload knapsack | |
if KNAPSACK is None: | |
raise gr.Error("Knapsack failed to load. Cannot use Knapsack Beam Search.") | |
if input_file is None: | |
raise gr.Error("Please upload a mass spectrometry file.") | |
input_path = input_file.name | |
logger.info("--- New Prediction Request ---") | |
logger.info(f"Input File: {input_path}") | |
logger.info(f"Selected Mode: {mode_selection}") | |
if "refinement" in mode_selection or "InstaNovo Only" in mode_selection: | |
logger.info(f"Selected Transformer Decoder: {transformer_decoder_selection}") | |
# Create temp output file | |
gradio_tmp_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp") | |
try: | |
with tempfile.NamedTemporaryFile(dir=gradio_tmp_dir, delete=False, suffix=".csv") as temp_out: | |
output_csv_path = temp_out.name | |
logger.info(f"Temporary output path: {output_csv_path}") | |
except Exception as e: | |
logger.error(f"Failed to create temporary file in {gradio_tmp_dir}: {e}") | |
raise gr.Error(f"Failed to create temporary output file: {e}") | |
try: | |
config = create_inference_config(input_path, output_csv_path) | |
logger.info("Loading spectrum data...") | |
try: | |
# Load data eagerly | |
sdf = SpectrumDataFrame.load( | |
config.data_path, lazy=False, is_annotated=False, | |
column_mapping=config.get("column_map", None), shuffle=False, verbose=True, | |
) | |
original_size = len(sdf) | |
max_charge = config.get("max_charge", 10) | |
if "precursor_charge" in sdf.df.columns: | |
sdf.filter_rows( | |
lambda row: ("precursor_charge" in row and row["precursor_charge"] is not None and 0 < row["precursor_charge"] <= max_charge) | |
) | |
if len(sdf) < original_size: | |
logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with invalid or out-of-range charge (<=0 or >{max_charge}).") | |
else: | |
logger.warning("Column 'precursor_charge' not found. Cannot filter by charge.") | |
if len(sdf) == 0: | |
raise gr.Error("No valid spectra found in the uploaded file after filtering.") | |
logger.info(f"Data loaded: {len(sdf)} spectra.") | |
index_cols_present = [col for col in config.index_columns if col in sdf.df.columns] | |
base_df_pd = sdf.df.select(index_cols_present).to_pandas() | |
except Exception as e: | |
logger.error(f"Error loading data: {e}", exc_info=True) | |
raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}") | |
if RESIDUE_SET is None: raise gr.Error("Residue set not loaded.") # Should not happen if model loaded | |
# --- Prepare DataLoader --- | |
# Use reverse_peptide=True for Transformer steps, False for Diffusion-only | |
reverse_for_transformer = "InstaNovo+ Only" not in mode_selection | |
ds = SpectrumDataset( | |
sdf, RESIDUE_SET, | |
INSTANOVO_CONFIG.get("n_peaks", 200) if INSTANOVO_CONFIG else 200, | |
return_str=True, annotated=False, | |
pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False), | |
bin_spectra=config.get("conv_peak_encoder", False), | |
peptide_pad_length=config.get("max_length", 40) if config.get("compile_model", False) else 0, | |
reverse_peptide=reverse_for_transformer, # Key change based on mode | |
diffusion="InstaNovo+ Only" in mode_selection # Signal if input is for diffusion | |
) | |
dl = DataLoader(ds, batch_size=config.batch_size, num_workers=0, shuffle=False, collate_fn=collate_batch) | |
# --- Run Prediction --- | |
results_data = None | |
output_headers = index_cols_present[:] | |
if "InstaNovo Only" in mode_selection: | |
output_headers.extend(["prediction", "log_probability", "delta_mass_ppm", "token_log_probabilities"]) | |
transformer_results = run_transformer_prediction(dl, config, transformer_decoder_selection) | |
results_data = [] | |
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range) | |
for i, res in enumerate(transformer_results): | |
row_data = {} | |
if isinstance(res, ScoredSequence) and res.sequence: | |
row_data["prediction"] = "".join(res.sequence) | |
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}" | |
row_data["token_log_probabilities"] = ", ".join(f"{p:.4f}" for p in res.token_log_probabilities) | |
try: | |
prec_mz = base_df_pd.loc[i, "precursor_mz"] | |
prec_ch = base_df_pd.loc[i, "precursor_charge"] | |
_, delta_mass_list = metrics_calc.matches_precursor(res.sequence, prec_mz, prec_ch) | |
min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan") | |
row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}" | |
except Exception as e: | |
logger.warning(f"Could not calculate delta mass for Tx prediction {i}: {e}") | |
row_data["delta_mass_ppm"] = "N/A" | |
else: | |
row_data.update({k: "N/A" for k in ["prediction", "log_probability", "delta_mass_ppm", "token_log_probabilities"]}) | |
row_data["prediction"] = "" # Ensure empty string for failed preds | |
row_data["token_log_probabilities"] = "" | |
results_data.append(row_data) | |
elif "InstaNovo+ Only" in mode_selection: | |
output_headers.extend(["prediction", "log_probability", "delta_mass_ppm"]) | |
diffusion_results = run_diffusion_prediction(dl, config) | |
results_data = [] | |
for res in diffusion_results: | |
row_data = {} | |
if isinstance(res, ScoredSequence) and res.sequence: | |
row_data["prediction"] = "".join(res.sequence) | |
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}" # Avg loss | |
row_data["delta_mass_ppm"] = f"{res.mass_error:.2f}" if not np.isnan(res.mass_error) else "N/A" # ppm | |
else: | |
row_data.update({k: "N/A" for k in ["prediction", "log_probability", "delta_mass_ppm"]}) | |
row_data["prediction"] = "" | |
results_data.append(row_data) | |
elif "refinement" in mode_selection: | |
output_headers.extend([ | |
"transformer_prediction", "transformer_log_probability", | |
"refined_prediction", "refined_log_probability", "refined_delta_mass_ppm" | |
]) | |
# Pass the selected transformer decoder to the refinement function | |
results_data = run_refinement_prediction(dl, config, transformer_decoder_selection) | |
for row in results_data: | |
# Format numbers after getting the list of dicts | |
row["transformer_log_probability"] = f"{row['transformer_log_probability']:.4f}" if isinstance(row['transformer_log_probability'], (float, int)) else "N/A" | |
row["refined_log_probability"] = f"{row['refined_log_probability']:.4f}" if isinstance(row['refined_log_probability'], (float, int)) else "N/A" | |
row["refined_delta_mass_ppm"] = f"{row['refined_delta_mass_ppm']:.2f}" if isinstance(row['refined_delta_mass_ppm'], (float, int)) and not np.isnan(row['refined_delta_mass_ppm']) else "N/A" | |
else: | |
raise ValueError(f"Unknown mode selection: {mode_selection}") | |
# --- Combine, Save, Return --- | |
logger.info("Combining results...") | |
if results_data is None: raise gr.Error("Prediction did not produce results.") | |
results_df = pl.DataFrame(results_data) | |
# Ensure base_df_pd has unique index if using join, or just concat horizontally if order is guaranteed | |
base_df_pl = pl.from_pandas(base_df_pd.reset_index(drop=True)) | |
# Simple horizontal concat assuming order is preserved by dataloader (shuffle=False) | |
if len(base_df_pl) == len(results_df): | |
final_df = pl.concat([base_df_pl, results_df], how="horizontal") | |
else: | |
logger.error(f"Length mismatch between base data ({len(base_df_pl)}) and results ({len(results_df)}). Cannot reliably combine.") | |
# Fallback or error? Let's just use results for now, but log error. | |
final_df = results_df # Display only results in case of mismatch | |
logger.info(f"Saving full results to {output_csv_path}...") | |
final_df.write_csv(output_csv_path) | |
logger.info("Save complete.") | |
# Select display columns - make sure they exist in final_df | |
display_cols_final = [col for col in output_headers if col in final_df.columns] | |
display_df = final_df.select(display_cols_final) | |
logger.info("--- Prediction Request Complete ---") | |
return display_df.to_pandas(), output_csv_path | |
except Exception as e: | |
logger.error(f"An error occurred during prediction: {e}", exc_info=True) | |
if 'output_csv_path' in locals() and os.path.exists(output_csv_path): | |
try: | |
os.remove(output_csv_path) | |
logger.info(f"Removed temporary file {output_csv_path}") | |
except OSError: | |
logger.error(f"Failed to remove temporary file {output_csv_path}") | |
raise gr.Error(f"Prediction failed: {e}") | |
# --- Gradio Interface --- | |
css = """ | |
.gradio-container { font-family: sans-serif; } | |
.gr-button { color: white; border-color: black; background: black; } | |
footer { display: none !important; } | |
.logo-container img { margin-bottom: 1rem; } | |
.feedback { font-size: 0.9rem; color: gray; } | |
""" | |
with gr.Blocks( | |
css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue") | |
) as demo: | |
gr.Markdown( | |
""" | |
<div style="text-align: center;" class="logo-container"> | |
<img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;"> | |
</div> | |
""", | |
elem_classes="logo-container", | |
) | |
gr.Markdown( | |
f""" | |
# ๐ _De Novo_ Peptide Sequencing with InstaNovo and InstaNovo+ | |
Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions. | |
Choose your prediction method and decoding options. | |
**Notes:** | |
* Predictions use version `{TRANSFORMER_MODEL_ID}` for the transformer-based InstaNovo model and version `{DIFFUSION_MODEL_ID}` for the diffusion-based InstaNovo+ model. | |
* The InstaNovo+ model `{DIFFUSION_MODEL_ID}` is an alpha release. | |
* **Predction Modes:** | |
* **InstaNovo with InstaNovo+ refinement** Runs initial prediction with the selected Transformer method (Greedy/Knapsack), then refines using InstaNovo+. | |
* **InstaNovo Only:** Uses only the Transformer with the selected decoding method. | |
* **InstaNovo+ Only:** Predicts directly using the Diffusion model (alpha release). | |
* **Transformer Decoding Methods:** | |
* **Greedy Search:** use this for optimal performance, has similar performance as Knapsack Beam Search at 5% FDR. | |
* **Knapsack Beam Search:** use this for the best results and highest peptide recall, but is about 10x slower than Greedy Search. | |
* Check logs for progress, especially for large files or slower methods. | |
This Hugging Face Space is powered by a [ZeroGPU ](https://huggingface.co/docs/hub/en/spaces-zerogpu), which is free but **limited to 5 minutes per day per user**โso if you test with your own files, please use only small files. | |
""", | |
elem_classes="feedback" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_file = gr.File( | |
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)", | |
file_types=[".mgf", ".mzml", ".mzxml"], | |
scale=1 | |
) | |
mode_selection = gr.Radio( | |
[ | |
"InstaNovo with InstaNovo+ refinement (Default, Recommended)", | |
"InstaNovo Only (Transformer)", | |
"InstaNovo+ Only (Diffusion, Alpha release)", | |
], | |
label="Prediction Mode", | |
value="InstaNovo with InstaNovo+ refinement (Default, Recommended)", | |
scale=1 | |
) | |
# Transformer decoder selection - visible for relevant modes | |
transformer_decoder_selection = gr.Radio( | |
[ | |
"Greedy Search (Fast)", | |
"Knapsack Beam Search (Accurate, Slower)" | |
], | |
label="Transformer Decoding Method", | |
value="Greedy Search (Fast)", | |
visible=True, # Start visible as default mode uses it | |
interactive=True, | |
scale=1 | |
) | |
submit_btn = gr.Button("Predict Sequences", variant="primary") | |
# --- Control Visibility & Choices --- | |
def update_transformer_options(mode): | |
# Show decoder selection if mode uses the transformer | |
show_decoder = "InstaNovo+ Only" not in mode | |
choices = ["Greedy Search (Fast)", "Knapsack Beam Search (Accurate, Slower)"] | |
current_value = "Greedy Search (Fast)" # Default reset value | |
return gr.update(visible=show_decoder, choices=choices, value=current_value) | |
mode_selection.change( | |
fn=update_transformer_options, | |
inputs=mode_selection, | |
outputs=transformer_decoder_selection, | |
) | |
with gr.Column(scale=2): | |
output_df = gr.DataFrame( | |
label="Prediction Results Preview", | |
headers=["scan_number", "prediction", "log_probability", "delta_mass_ppm"] | |
) | |
output_file = gr.File(label="Download Full Results (CSV)") | |
submit_btn.click( | |
predict_peptides, | |
inputs=[input_file, mode_selection, transformer_decoder_selection], | |
outputs=[output_df, output_file], | |
) | |
gr.Examples( | |
[ | |
["assets/sample_spectra.mgf", "InstaNovo with InstaNovo+ refinement (Default, Recommended)", "Greedy Search (Fast)"], | |
["assets/sample_spectra.mgf", "InstaNovo with InstaNovo+ refinement (Default, Recommended)", "Knapsack Beam Search (Accurate, Slower)"], | |
["assets/sample_spectra.mgf", "InstaNovo Only (Transformer)", "Greedy Search (Fast)"], | |
["assets/sample_spectra.mgf", "InstaNovo Only (Transformer)", "Knapsack Beam Search (Accurate, Slower)"], | |
["assets/sample_spectra.mgf", "InstaNovo+ Only (Diffusion, Alpha release)", ""], | |
], | |
inputs=[input_file, mode_selection, transformer_decoder_selection], | |
# outputs=[output_df, output_file], | |
cache_examples=False, | |
label="Example Usage:", | |
) | |
with gr.Accordion("Application Logs", open=True): | |
log_display = Log(log_file, dark=True, height=300) | |
gr.Markdown(""" **Links:** | |
* [GitHub Repository for InstaNovo](https://github.com/instadeepai/instanovo) | |
* [InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments](https://www.nature.com/articles/s42256-025-01019-5), Eloff, Kalogeropoulos et al. 2025, Nature Machine Intelligence. | |
If you use InstaNovo in your research, please cite:""") | |
gr.Markdown( | |
value=""" | |
``` | |
@article{eloff_kalogeropoulos_2025_instanovo, | |
title = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments}, | |
author = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and | |
Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and | |
Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M. | |
Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and | |
Timothy P. Jenkins}, | |
year = 2025, | |
month = {Mar}, | |
day = 31, | |
journal = {Nature Machine Intelligence}, | |
doi = {10.1038/s42256-025-01019-5}, | |
url = {https://www.nature.com/articles/s42256-025-01019-5} | |
} | |
``` | |
""", | |
show_copy_button=True | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
# https://www.gradio.app/guides/setting-up-a-demo-for-maximum-performance | |
demo.queue(default_concurrency_limit=5) | |
# Set share=True for temporary public link if running locally | |
# Set server_name="0.0.0.0" to allow access from network if needed | |
# demo.launch(server_name="0.0.0.0", server_port=7860) | |
# For Hugging Face Spaces, just demo.launch() is usually sufficient | |
demo.launch(debug=True, show_error=True) | |
# demo.launch(share=True) # For local testing with public URL | |