InstaNovo / app.py
BioGeek's picture
feat: first iteration
a9ee52d
raw
history blame
20.7 kB
import gradio as gr
import torch
import os
import tempfile
import time
import polars as pl
import numpy as np
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
# --- InstaNovo Imports ---
# It's good practice to handle potential import issues
try:
from instanovo.transformer.model import InstaNovo
from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics
from instanovo.transformer.dataset import SpectrumDataset, collate_batch
from instanovo.inference import (
GreedyDecoder,
KnapsackBeamSearchDecoder,
Knapsack,
ScoredSequence,
Decoder,
)
from instanovo.constants import MASS_SCALE, MAX_MASS
from torch.utils.data import DataLoader
except ImportError as e:
print(f"Error importing InstaNovo components: {e}")
print("Please ensure InstaNovo is installed correctly.")
# Optionally, raise the error or exit if InstaNovo is critical
# raise e
# --- Configuration ---
MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
KNAPSACK_DIR = Path("./knapsack_cache")
DEFAULT_CONFIG_PATH = Path("./configs/inference/default.yaml") # Assuming instanovo installs configs locally relative to execution
# Determine device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
# --- Global Variables (Load Model and Knapsack Once) ---
MODEL: InstaNovo | None = None
KNAPSACK: Knapsack | None = None
MODEL_CONFIG: DictConfig | None = None
RESIDUE_SET: ResidueSet | None = None
def load_model_and_knapsack():
"""Loads the InstaNovo model and generates/loads the knapsack."""
global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET
if MODEL is not None:
print("Model already loaded.")
return
print(f"Loading InstaNovo model: {MODEL_ID} to {DEVICE}...")
try:
MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(MODEL_ID)
MODEL.to(DEVICE)
MODEL.eval()
RESIDUE_SET = MODEL.residue_set
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
raise gr.Error(f"Failed to load InstaNovo model: {MODEL_ID}. Error: {e}")
# --- Knapsack Handling ---
KNAPSACK_DIR.mkdir(parents=True, exist_ok=True)
knapsack_exists = (
(KNAPSACK_DIR / "parameters.pkl").exists() and
(KNAPSACK_DIR / "masses.npy").exists() and
(KNAPSACK_DIR / "chart.npy").exists()
)
if knapsack_exists:
print(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...")
try:
KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR))
print("Knapsack loaded successfully.")
except Exception as e:
print(f"Error loading knapsack: {e}. Will attempt to regenerate.")
KNAPSACK = None # Force regeneration
knapsack_exists = False # Ensure generation happens
if not knapsack_exists:
print("Knapsack not found or failed to load. Generating knapsack...")
if RESIDUE_SET is None:
raise gr.Error("Cannot generate knapsack because ResidueSet failed to load.")
try:
# Prepare residue masses for knapsack generation (handle negative/zero masses)
residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
negative_residues = [k for k, v in residue_masses_knapsack.items() if v <= 0]
if negative_residues:
print(f"Warning: Non-positive masses found in residues: {negative_residues}. "
"Excluding from knapsack generation.")
for res in negative_residues:
del residue_masses_knapsack[res]
# Remove special tokens explicitly if they somehow got mass
for special_token in RESIDUE_SET.special_tokens:
if special_token in residue_masses_knapsack:
del residue_masses_knapsack[special_token]
# Ensure residue indices used match those without special/negative masses
valid_residue_indices = {
res: idx for res, idx in RESIDUE_SET.residue_to_index.items()
if res in residue_masses_knapsack
}
KNAPSACK = Knapsack.construct_knapsack(
residue_masses=residue_masses_knapsack,
residue_indices=valid_residue_indices, # Use only valid indices
max_mass=MAX_MASS,
mass_scale=MASS_SCALE,
)
print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs
print("Knapsack saved.")
except Exception as e:
print(f"Error generating or saving knapsack: {e}")
gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available.")
KNAPSACK = None # Ensure it's None if generation failed
# Load the model and knapsack when the script starts
load_model_and_knapsack()
def create_inference_config(
input_path: str,
output_path: str,
decoding_method: str,
) -> DictConfig:
"""Creates the OmegaConf DictConfig needed for prediction."""
# Load default config if available, otherwise create from scratch
if DEFAULT_CONFIG_PATH.exists():
base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
else:
print(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
# Create a minimal config if default is missing
base_cfg = OmegaConf.create({
"data_path": None,
"instanovo_model": MODEL_ID,
"output_path": None,
"knapsack_path": str(KNAPSACK_DIR),
"denovo": True,
"refine": False, # Not doing refinement here
"num_beams": 1,
"max_length": 40,
"max_charge": 10,
"isotope_error_range": [0, 1],
"subset": 1.0,
"use_knapsack": False,
"save_beams": False,
"batch_size": 64, # Adjust as needed
"device": DEVICE,
"fp16": FP16,
"log_interval": 500, # Less relevant for Gradio app
"use_basic_logging": True,
"filter_precursor_ppm": 20,
"filter_confidence": 1e-4,
"filter_fdr_threshold": 0.05,
"residue_remapping": { # Add default mappings
"M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]",
"S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]",
"S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]",
"Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]",
"Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]",
"C(+57.02)": "C[UNIMOD:4]",
"(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]",
},
"column_map": { # Add default mappings
"Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz",
"Mass": "precursor_mass", "Charge": "precursor_charge",
"Mass values": "mz_array", "Mass spectrum": "mz_array",
"Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array",
"Scan number": "scan_number"
},
"index_columns": [
"scan_number", "precursor_mz", "precursor_charge",
],
# Add other defaults if needed based on errors
})
# Override specific parameters
cfg_overrides = {
"data_path": input_path,
"output_path": output_path,
"device": DEVICE,
"fp16": FP16,
"denovo": True,
"refine": False,
}
if "Greedy" in decoding_method:
cfg_overrides["num_beams"] = 1
cfg_overrides["use_knapsack"] = False
elif "Knapsack" in decoding_method:
if KNAPSACK is None:
raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.")
cfg_overrides["num_beams"] = 5
cfg_overrides["use_knapsack"] = True
cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
else:
raise ValueError(f"Unknown decoding method: {decoding_method}")
# Merge base config with overrides
final_cfg = OmegaConf.merge(base_cfg, cfg_overrides)
return final_cfg
def predict_peptides(input_file, decoding_method):
"""
Main function to load data, run prediction, and return results.
"""
if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
if MODEL is None:
raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
if input_file is None:
raise gr.Error("Please upload a mass spectrometry file.")
input_path = input_file.name # Gradio provides the path in .name
print(f"Processing file: {input_path}")
print(f"Using decoding method: {decoding_method}")
# Create a temporary file for the output CSV
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_out:
output_csv_path = temp_out.name
try:
# 1. Create Config
config = create_inference_config(input_path, output_csv_path, decoding_method)
print("Inference Config:\n", OmegaConf.to_yaml(config))
# 2. Load Data using SpectrumDataFrame
print("Loading spectrum data...")
try:
sdf = SpectrumDataFrame.load(
config.data_path,
lazy=False, # Load eagerly for Gradio simplicity
is_annotated=False, # De novo mode
column_mapping=config.get("column_map", None),
shuffle=False,
verbose=True # Print loading logs
)
# Apply charge filter like in CLI
original_size = len(sdf)
max_charge = config.get("max_charge", 10)
sdf.filter_rows(
lambda row: (row["precursor_charge"] <= max_charge) and (row["precursor_charge"] > 0)
)
if len(sdf) < original_size:
print(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
if len(sdf) == 0:
raise gr.Error("No valid spectra found in the uploaded file after filtering.")
print(f"Data loaded: {len(sdf)} spectra.")
except Exception as e:
print(f"Error loading data: {e}")
raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}")
# 3. Prepare Dataset and DataLoader
ds = SpectrumDataset(
sdf,
RESIDUE_SET,
MODEL_CONFIG.get("n_peaks", 200),
return_str=True, # Needed for greedy/beam search targets later (though not used here)
annotated=False,
pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False),
bin_spectra=config.get("conv_peak_encoder", False),
)
dl = DataLoader(
ds,
batch_size=config.batch_size,
num_workers=0, # Required by SpectrumDataFrame
shuffle=False, # Required by SpectrumDataFrame
collate_fn=collate_batch,
)
# 4. Select Decoder
print("Initializing decoder...")
decoder: Decoder
if config.use_knapsack:
if KNAPSACK is None:
# This check should ideally be earlier, but double-check
raise gr.Error("Knapsack is required for Knapsack Beam Search but is not available.")
# KnapsackBeamSearchDecoder doesn't directly load from path in this version?
# We load Knapsack globally, so just pass it.
# If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
elif config.num_beams > 1:
# BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1
print(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.")
decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE)
else:
decoder = GreedyDecoder(
model=MODEL,
mass_scale=MASS_SCALE,
# Add suppression options if needed from config
suppressed_residues=config.get("suppressed_residues", None),
disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True),
)
print(f"Using decoder: {type(decoder).__name__}")
# 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
print("Starting prediction...")
start_time = time.time()
results_list: list[ScoredSequence | list] = [] # Store ScoredSequence or empty list
for i, batch in enumerate(dl):
spectra, precursors, spectra_mask, _, _ = batch # Ignore peptides/masks for de novo
spectra = spectra.to(DEVICE)
precursors = precursors.to(DEVICE)
spectra_mask = spectra_mask.to(DEVICE)
with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
# Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
# Greedy decoder returns list[ScoredSequence]
# KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
batch_predictions = decoder.decode(
spectra=spectra,
precursors=precursors,
beam_size=config.num_beams,
max_length=config.max_length,
# Knapsack/Beam Search specific params if needed
mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, # Convert ppm to relative
max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1,
return_beam=False # Only get the top prediction for simplicity
)
results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
print(f"Processed batch {i+1}/{len(dl)}")
end_time = time.time()
print(f"Prediction finished in {end_time - start_time:.2f} seconds.")
# 6. Format Results
print("Formatting results...")
output_data = []
# Use sdf index columns + prediction results
index_cols = [col for col in config.index_columns if col in sdf.df.columns]
base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
for i, res in enumerate(results_list):
row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
if isinstance(res, ScoredSequence) and res.sequence:
sequence_str = "".join(res.sequence)
row_data["prediction"] = sequence_str
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
# Use metrics to calculate delta mass ppm for the top prediction
try:
_, delta_mass_list = metrics_calc.matches_precursor(
res.sequence,
row_data["precursor_mz"],
row_data["precursor_charge"]
)
# Find the smallest absolute ppm error across isotopes
min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float('nan')
row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
except Exception as e:
print(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
row_data["delta_mass_ppm"] = "N/A"
else:
row_data["prediction"] = ""
row_data["log_probability"] = "N/A"
row_data["delta_mass_ppm"] = "N/A"
output_data.append(row_data)
output_df = pl.DataFrame(output_data)
# Ensure specific columns are present and ordered
display_cols = ["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"]
final_display_cols = []
for col in display_cols:
if col in output_df.columns:
final_display_cols.append(col)
else:
print(f"Warning: Expected display column '{col}' not found in results.")
# Add any remaining index columns that weren't in display_cols
for col in index_cols:
if col not in final_display_cols and col in output_df.columns:
final_display_cols.append(col)
output_df_display = output_df.select(final_display_cols)
# 7. Save full results to CSV
print(f"Saving results to {output_csv_path}...")
output_df.write_csv(output_csv_path)
# Return DataFrame for display and path for download
return output_df_display.to_pandas(), output_csv_path
except Exception as e:
print(f"An error occurred during prediction: {e}")
# Clean up the temporary output file if it exists
if os.path.exists(output_csv_path):
os.remove(output_csv_path)
# Re-raise as Gradio error
raise gr.Error(f"Prediction failed: {e}")
# --- Gradio Interface ---
css = """
.gradio-container { font-family: sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
footer { display: none !important; }
"""
with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
gr.Markdown(
"""
# πŸš€ InstaNovo _De Novo_ Peptide Sequencing
Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions using InstaNovo.
Choose between fast Greedy Search or more accurate but slower Knapsack Beam Search.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_file = gr.File(
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
file_types=[".mgf", ".mzml", ".mzxml"]
)
decoding_method = gr.Radio(
["Greedy Search (Fast)", "Knapsack Beam Search (More accurate, but slower)"],
label="Decoding Method",
value="Greedy Search (Fast)" # Default to fast method
)
submit_btn = gr.Button("Predict Sequences", variant="primary")
with gr.Column(scale=2):
output_df = gr.DataFrame(label="Prediction Results", wrap=True)
output_file = gr.File(label="Download Full Results (CSV)")
submit_btn.click(
predict_peptides,
inputs=[input_file, decoding_method],
outputs=[output_df, output_file]
)
gr.Examples(
[["./sample_spectra.mgf", "Knapsack Beam Search (Accurate, 5 Beams)"]], # Requires test data fetched
inputs=[input_file, decoding_method],
outputs=[output_df, output_file],
fn=predict_peptides,
cache_examples=False, # Re-run examples if needed
label="Example Usage"
)
gr.Markdown(
"""
**Notes:**
* Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}).
* Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
* 'delta_mass_ppm' shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron).
* Ensure your input file format is correctly specified. Large files may take time to process.
""".format(MODEL_ID=MODEL_ID)
)
# --- Launch the App ---
if __name__ == "__main__":
# Set share=True for temporary public link if running locally
# Set server_name="0.0.0.0" to allow access from network if needed
# demo.launch(server_name="0.0.0.0", server_port=7860)
# For Hugging Face Spaces, just demo.launch() is usually sufficient
demo.launch(share=True) # For local testing with public URL