InstaNovo / app.py
BioGeek's picture
Merge branch 'citation'
e38f067
raw
history blame
23.9 kB
import gradio as gr
import torch
import os
import tempfile
import time
import polars as pl
import numpy as np
import logging
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
from gradio_log import Log
# --- InstaNovo Imports ---
try:
from instanovo.transformer.model import InstaNovo
from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics
from instanovo.transformer.dataset import SpectrumDataset, collate_batch
from instanovo.inference import (
GreedyDecoder,
KnapsackBeamSearchDecoder,
Knapsack,
ScoredSequence,
Decoder,
)
from instanovo.constants import MASS_SCALE, MAX_MASS
from torch.utils.data import DataLoader
except ImportError as e:
raise ImportError(f"Failed to import InstaNovo components: {e}")
# --- Configuration ---
MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
KNAPSACK_DIR = Path("./knapsack_cache")
DEFAULT_CONFIG_PATH = Path(
"./configs/inference/default.yaml"
) # Assuming instanovo installs configs locally relative to execution
# Determine device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
# --- Global Variables (Load Model and Knapsack Once) ---
MODEL: InstaNovo | None = None
KNAPSACK: Knapsack | None = None
MODEL_CONFIG: DictConfig | None = None
RESIDUE_SET: ResidueSet | None = None
# --- Assets ---
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
# Logging configuration
log_file = "/tmp/instanovo_gradio_log.txt"
Path(log_file).touch()
logger = logging.getLogger("instanovo")
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
def load_model_and_knapsack():
"""Loads the InstaNovo model and generates/loads the knapsack."""
global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET
if MODEL is not None:
logger.info("Model already loaded.")
return
logger.info(f"Loading InstaNovo model: {MODEL_ID} to {DEVICE}...")
try:
MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(MODEL_ID)
MODEL.to(DEVICE)
MODEL.eval()
RESIDUE_SET = MODEL.residue_set
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise gr.Error(f"Failed to load InstaNovo model: {MODEL_ID}. Error: {e}")
# --- Knapsack Handling ---
knapsack_exists = (
(KNAPSACK_DIR / "parameters.pkl").exists()
and (KNAPSACK_DIR / "masses.npy").exists()
and (KNAPSACK_DIR / "chart.npy").exists()
)
if knapsack_exists:
logger.info(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...")
try:
KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR))
logger.info("Knapsack loaded successfully.")
except Exception as e:
logger.info(f"Error loading knapsack: {e}. Will attempt to regenerate.")
KNAPSACK = None # Force regeneration
knapsack_exists = False # Ensure generation happens
if not knapsack_exists:
logger.info("Knapsack not found or failed to load. Generating knapsack...")
if RESIDUE_SET is None:
raise gr.Error(
"Cannot generate knapsack because ResidueSet failed to load."
)
try:
# Prepare residue masses for knapsack generation (handle negative/zero masses)
residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
negative_residues = [
k for k, v in residue_masses_knapsack.items() if v <= 0
]
if negative_residues:
logger.info(f"Warning: Non-positive masses found in residues: {negative_residues}. "
"Excluding from knapsack generation.")
for res in negative_residues:
del residue_masses_knapsack[res]
# Remove special tokens explicitly if they somehow got mass
for special_token in RESIDUE_SET.special_tokens:
if special_token in residue_masses_knapsack:
del residue_masses_knapsack[special_token]
# Ensure residue indices used match those without special/negative masses
valid_residue_indices = {
res: idx
for res, idx in RESIDUE_SET.residue_to_index.items()
if res in residue_masses_knapsack
}
KNAPSACK = Knapsack.construct_knapsack(
residue_masses=residue_masses_knapsack,
residue_indices=valid_residue_indices, # Use only valid indices
max_mass=MAX_MASS,
mass_scale=MASS_SCALE,
)
logger.info(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs
logger.info("Knapsack saved.")
except Exception as e:
logger.info(f"Error generating or saving knapsack: {e}")
gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available. {e}")
KNAPSACK = None # Ensure it's None if generation failed
# Load the model and knapsack when the script starts
load_model_and_knapsack()
def create_inference_config(
input_path: str,
output_path: str,
decoding_method: str,
) -> DictConfig:
"""Creates the OmegaConf DictConfig needed for prediction."""
# Load default config if available, otherwise create from scratch
if DEFAULT_CONFIG_PATH.exists():
base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
else:
logger.info(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
# Create a minimal config if default is missing
base_cfg = OmegaConf.create({
"data_path": None,
"instanovo_model": MODEL_ID,
"output_path": None,
"knapsack_path": str(KNAPSACK_DIR),
"denovo": True,
"refine": False, # Not doing refinement here
"num_beams": 1,
"max_length": 40,
"max_charge": 10,
"isotope_error_range": [0, 1],
"subset": 1.0,
"use_knapsack": False,
"save_beams": False,
"batch_size": 64, # Adjust as needed
"device": DEVICE,
"fp16": FP16,
"log_interval": 500, # Less relevant for Gradio app
"use_basic_logging": True,
"filter_precursor_ppm": 20,
"filter_confidence": 1e-4,
"filter_fdr_threshold": 0.05,
"residue_remapping": { # Add default mappings
"M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]",
"S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]",
"S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]",
"Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]",
"Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]",
"C(+57.02)": "C[UNIMOD:4]",
"(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]",
},
"column_map": { # Add default mappings
"Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz",
"Mass": "precursor_mass", "Charge": "precursor_charge",
"Mass values": "mz_array", "Mass spectrum": "mz_array",
"Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array",
"Scan number": "scan_number"
},
"index_columns": [
"scan_number", "precursor_mz", "precursor_charge",
],
# Add other defaults if needed based on errors
})
# Override specific parameters
cfg_overrides = {
"data_path": input_path,
"output_path": output_path,
"device": DEVICE,
"fp16": FP16,
"denovo": True,
"refine": False,
}
if "Greedy" in decoding_method:
cfg_overrides["num_beams"] = 1
cfg_overrides["use_knapsack"] = False
elif "Knapsack" in decoding_method:
if KNAPSACK is None:
raise gr.Error(
"Knapsack is not available. Cannot use Knapsack Beam Search."
)
cfg_overrides["num_beams"] = 5
cfg_overrides["use_knapsack"] = True
cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
else:
raise ValueError(f"Unknown decoding method: {decoding_method}")
# Merge base config with overrides
final_cfg = OmegaConf.merge(base_cfg, cfg_overrides)
return final_cfg
def predict_peptides(input_file, decoding_method):
"""
Main function to load data, run prediction, and return results.
"""
if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
if MODEL is None:
raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
if input_file is None:
raise gr.Error("Please upload a mass spectrometry file.")
input_path = input_file.name # Gradio provides the path in .name
logger.info(f"Processing file: {input_path}")
logger.info(f"Using decoding method: {decoding_method}")
# Create a temporary file for the output CSV
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_out:
output_csv_path = temp_out.name
try:
# 1. Create Config
config = create_inference_config(input_path, output_csv_path, decoding_method)
logger.info(f"Inference Config:\n{OmegaConf.to_yaml(config)}")
# 2. Load Data using SpectrumDataFrame
logger.info("Loading spectrum data...")
try:
sdf = SpectrumDataFrame.load(
config.data_path,
lazy=False, # Load eagerly for Gradio simplicity
is_annotated=False, # De novo mode
column_mapping=config.get("column_map", None),
shuffle=False,
verbose=True, # Print loading logs
)
# Apply charge filter like in CLI
original_size = len(sdf)
max_charge = config.get("max_charge", 10)
sdf.filter_rows(
lambda row: (row["precursor_charge"] <= max_charge)
and (row["precursor_charge"] > 0)
)
if len(sdf) < original_size:
logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
if len(sdf) == 0:
raise gr.Error("No valid spectra found in the uploaded file after filtering.")
logger.info(f"Data loaded: {len(sdf)} spectra.")
except Exception as e:
logger.info(f"Error loading data: {e}")
raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}")
# 3. Prepare Dataset and DataLoader
ds = SpectrumDataset(
sdf,
RESIDUE_SET,
MODEL_CONFIG.get("n_peaks", 200),
return_str=True, # Needed for greedy/beam search targets later (though not used here)
annotated=False,
pad_spectrum_max_length=config.get("compile_model", False)
or config.get("use_flash_attention", False),
bin_spectra=config.get("conv_peak_encoder", False),
)
dl = DataLoader(
ds,
batch_size=config.batch_size,
num_workers=0, # Required by SpectrumDataFrame
shuffle=False, # Required by SpectrumDataFrame
collate_fn=collate_batch,
)
# 4. Select Decoder
logger.info("Initializing decoder...")
decoder: Decoder
if config.use_knapsack:
if KNAPSACK is None:
# This check should ideally be earlier, but double-check
raise gr.Error(
"Knapsack is required for Knapsack Beam Search but is not available."
)
# KnapsackBeamSearchDecoder doesn't directly load from path in this version?
# We load Knapsack globally, so just pass it.
# If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
elif config.num_beams > 1:
# BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1
logger.info(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.")
decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE)
else:
decoder = GreedyDecoder(
model=MODEL,
mass_scale=MASS_SCALE,
# Add suppression options if needed from config
suppressed_residues=config.get("suppressed_residues", None),
disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True),
)
logger.info(f"Using decoder: {type(decoder).__name__}")
# 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
logger.info("Starting prediction...")
start_time = time.time()
results_list: list[
ScoredSequence | list
] = [] # Store ScoredSequence or empty list
for i, batch in enumerate(dl):
spectra, precursors, spectra_mask, _, _ = (
batch # Ignore peptides/masks for de novo
)
spectra = spectra.to(DEVICE)
precursors = precursors.to(DEVICE)
spectra_mask = spectra_mask.to(DEVICE)
with (
torch.no_grad(),
torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16),
):
# Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
# Greedy decoder returns list[ScoredSequence]
# KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
batch_predictions = decoder.decode(
spectra=spectra,
precursors=precursors,
beam_size=config.num_beams,
max_length=config.max_length,
# Knapsack/Beam Search specific params if needed
mass_tolerance=config.get("filter_precursor_ppm", 20)
* 1e-6, # Convert ppm to relative
max_isotope=config.isotope_error_range[1]
if config.isotope_error_range
else 1,
return_beam=False, # Only get the top prediction for simplicity
)
results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
logger.info(f"Processed batch {i+1}/{len(dl)}")
end_time = time.time()
logger.info(f"Prediction finished in {end_time - start_time:.2f} seconds.")
# 6. Format Results
logger.info("Formatting results...")
output_data = []
# Use sdf index columns + prediction results
index_cols = [col for col in config.index_columns if col in sdf.df.columns]
base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
for i, res in enumerate(results_list):
row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
if isinstance(res, ScoredSequence) and res.sequence:
sequence_str = "".join(res.sequence)
row_data["prediction"] = sequence_str
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
# Use metrics to calculate delta mass ppm for the top prediction
try:
_, delta_mass_list = metrics_calc.matches_precursor(
res.sequence,
row_data["precursor_mz"],
row_data["precursor_charge"],
)
# Find the smallest absolute ppm error across isotopes
min_abs_ppm = (
min(abs(p) for p in delta_mass_list)
if delta_mass_list
else float("nan")
)
row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
except Exception as e:
logger.info(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
row_data["delta_mass_ppm"] = "N/A"
else:
row_data["prediction"] = ""
row_data["log_probability"] = "N/A"
row_data["delta_mass_ppm"] = "N/A"
output_data.append(row_data)
output_df = pl.DataFrame(output_data)
# Ensure specific columns are present and ordered
display_cols = [
"scan_number",
"precursor_mz",
"precursor_charge",
"prediction",
"log_probability",
"delta_mass_ppm",
]
final_display_cols = []
for col in display_cols:
if col in output_df.columns:
final_display_cols.append(col)
else:
logger.info(f"Warning: Expected display column '{col}' not found in results.")
# Add any remaining index columns that weren't in display_cols
for col in index_cols:
if col not in final_display_cols and col in output_df.columns:
final_display_cols.append(col)
output_df_display = output_df.select(final_display_cols)
# 7. Save full results to CSV
logger.info(f"Saving results to {output_csv_path}...")
output_df.write_csv(output_csv_path)
# Return DataFrame for display and path for download
return output_df_display.to_pandas(), output_csv_path
except Exception as e:
logger.info(f"An error occurred during prediction: {e}")
# Clean up the temporary output file if it exists
if os.path.exists(output_csv_path):
os.remove(output_csv_path)
# Re-raise as Gradio error
raise gr.Error(f"Prediction failed: {e}")
# --- Gradio Interface ---
css = """
.gradio-container { font-family: sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
footer { display: none !important; }
/* Optional: Add some margin below the logo */
.logo-container img { margin-bottom: 1rem; }
"""
with gr.Blocks(
css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
) as demo:
# --- Logo Display ---
gr.Markdown(
"""
<div style="text-align: center;" class="logo-container">
<img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
</div>
""",
elem_classes="logo-container", # Optional class for CSS targeting
)
# --- App Content ---
gr.Markdown(
"""
# πŸš€ _De Novo_ Peptide Sequencing with InstaNovo
Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions using InstaNovo.
Choose between fast Greedy Search or more accurate but slower Knapsack Beam Search.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_file = gr.File(
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
file_types=[".mgf", ".mzml", ".mzxml"],
)
decoding_method = gr.Radio(
[
"Greedy Search (Fast, resonably accurate)",
"Knapsack Beam Search (More accurate, but slower)",
],
label="Decoding Method",
value="Greedy Search (Fast, resonably accurate)", # Default to fast method
)
submit_btn = gr.Button("Predict Sequences", variant="primary")
with gr.Column(scale=2):
output_df = gr.DataFrame(
label="Prediction Results",
headers=[
"scan_number",
"precursor_mz",
"precursor_charge",
"prediction",
"log_probability",
"delta_mass_ppm",
],
wrap=True,
)
output_file = gr.File(label="Download Full Results (CSV)")
submit_btn.click(
predict_peptides,
inputs=[input_file, decoding_method],
outputs=[output_df, output_file],
)
gr.Examples(
[
["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)"],
[
"assets/sample_spectra.mgf",
"Knapsack Beam Search (More accurate, but slower)",
],
],
inputs=[input_file, decoding_method],
outputs=[output_df, output_file],
fn=predict_peptides,
cache_examples=False, # Re-run examples if needed
label="Example Usage",
)
gr.Markdown(
"""
**Notes:**
* Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model `{MODEL_ID}`.
* Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
* `delta_mass_ppm` shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron).
* Ensure your input file format is correctly specified. Large files may take time to process.
""".format(MODEL_ID=MODEL_ID)
)
# Add logging component
with gr.Accordion("Application Logs", open=True):
log_display = Log(log_file, dark=True, height=300)
gr.Textbox(
value="""
@article{eloff_kalogeropoulos_2025_instanovo,
title = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments},
author = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and
Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and
Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M.
Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and
Timothy P. Jenkins},
year = 2025,
month = {Mar},
day = 31,
journal = {Nature Machine Intelligence},
doi = {10.1038/s42256-025-01019-5},
url = {https://www.nature.com/articles/s42256-025-01019-5}
}
""",
show_copy_button=True,
label="If you use InstaNovo in your research, please cite:",
interactive=False,
)
# --- Launch the App ---
if __name__ == "__main__":
# Set share=True for temporary public link if running locally
# Set server_name="0.0.0.0" to allow access from network if needed
# demo.launch(server_name="0.0.0.0", server_port=7860)
# For Hugging Face Spaces, just demo.launch() is usually sufficient
demo.launch(share=True) # For local testing with public URL