Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import os | |
import tempfile | |
import time | |
import polars as pl | |
import numpy as np | |
from pathlib import Path | |
from omegaconf import OmegaConf, DictConfig | |
# --- InstaNovo Imports --- | |
# It's good practice to handle potential import issues | |
try: | |
from instanovo.transformer.model import InstaNovo | |
from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics | |
from instanovo.transformer.dataset import SpectrumDataset, collate_batch | |
from instanovo.inference import ( | |
GreedyDecoder, | |
KnapsackBeamSearchDecoder, | |
Knapsack, | |
ScoredSequence, | |
Decoder, | |
) | |
from instanovo.constants import MASS_SCALE, MAX_MASS | |
from torch.utils.data import DataLoader | |
except ImportError as e: | |
print(f"Error importing InstaNovo components: {e}") | |
print("Please ensure InstaNovo is installed correctly.") | |
# Optionally, raise the error or exit if InstaNovo is critical | |
# raise e | |
# --- Configuration --- | |
MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID | |
KNAPSACK_DIR = Path("./knapsack_cache") | |
DEFAULT_CONFIG_PATH = Path("./configs/inference/default.yaml") # Assuming instanovo installs configs locally relative to execution | |
# Determine device | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA | |
# --- Global Variables (Load Model and Knapsack Once) --- | |
MODEL: InstaNovo | None = None | |
KNAPSACK: Knapsack | None = None | |
MODEL_CONFIG: DictConfig | None = None | |
RESIDUE_SET: ResidueSet | None = None | |
def load_model_and_knapsack(): | |
"""Loads the InstaNovo model and generates/loads the knapsack.""" | |
global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET | |
if MODEL is not None: | |
print("Model already loaded.") | |
return | |
print(f"Loading InstaNovo model: {MODEL_ID} to {DEVICE}...") | |
try: | |
MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(MODEL_ID) | |
MODEL.to(DEVICE) | |
MODEL.eval() | |
RESIDUE_SET = MODEL.residue_set | |
print("Model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise gr.Error(f"Failed to load InstaNovo model: {MODEL_ID}. Error: {e}") | |
# --- Knapsack Handling --- | |
KNAPSACK_DIR.mkdir(parents=True, exist_ok=True) | |
knapsack_exists = ( | |
(KNAPSACK_DIR / "parameters.pkl").exists() and | |
(KNAPSACK_DIR / "masses.npy").exists() and | |
(KNAPSACK_DIR / "chart.npy").exists() | |
) | |
if knapsack_exists: | |
print(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...") | |
try: | |
KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR)) | |
print("Knapsack loaded successfully.") | |
except Exception as e: | |
print(f"Error loading knapsack: {e}. Will attempt to regenerate.") | |
KNAPSACK = None # Force regeneration | |
knapsack_exists = False # Ensure generation happens | |
if not knapsack_exists: | |
print("Knapsack not found or failed to load. Generating knapsack...") | |
if RESIDUE_SET is None: | |
raise gr.Error("Cannot generate knapsack because ResidueSet failed to load.") | |
try: | |
# Prepare residue masses for knapsack generation (handle negative/zero masses) | |
residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy()) | |
negative_residues = [k for k, v in residue_masses_knapsack.items() if v <= 0] | |
if negative_residues: | |
print(f"Warning: Non-positive masses found in residues: {negative_residues}. " | |
"Excluding from knapsack generation.") | |
for res in negative_residues: | |
del residue_masses_knapsack[res] | |
# Remove special tokens explicitly if they somehow got mass | |
for special_token in RESIDUE_SET.special_tokens: | |
if special_token in residue_masses_knapsack: | |
del residue_masses_knapsack[special_token] | |
# Ensure residue indices used match those without special/negative masses | |
valid_residue_indices = { | |
res: idx for res, idx in RESIDUE_SET.residue_to_index.items() | |
if res in residue_masses_knapsack | |
} | |
KNAPSACK = Knapsack.construct_knapsack( | |
residue_masses=residue_masses_knapsack, | |
residue_indices=valid_residue_indices, # Use only valid indices | |
max_mass=MAX_MASS, | |
mass_scale=MASS_SCALE, | |
) | |
print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...") | |
KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs | |
print("Knapsack saved.") | |
except Exception as e: | |
print(f"Error generating or saving knapsack: {e}") | |
gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available.") | |
KNAPSACK = None # Ensure it's None if generation failed | |
# Load the model and knapsack when the script starts | |
load_model_and_knapsack() | |
def create_inference_config( | |
input_path: str, | |
output_path: str, | |
decoding_method: str, | |
) -> DictConfig: | |
"""Creates the OmegaConf DictConfig needed for prediction.""" | |
# Load default config if available, otherwise create from scratch | |
if DEFAULT_CONFIG_PATH.exists(): | |
base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH) | |
else: | |
print(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.") | |
# Create a minimal config if default is missing | |
base_cfg = OmegaConf.create({ | |
"data_path": None, | |
"instanovo_model": MODEL_ID, | |
"output_path": None, | |
"knapsack_path": str(KNAPSACK_DIR), | |
"denovo": True, | |
"refine": False, # Not doing refinement here | |
"num_beams": 1, | |
"max_length": 40, | |
"max_charge": 10, | |
"isotope_error_range": [0, 1], | |
"subset": 1.0, | |
"use_knapsack": False, | |
"save_beams": False, | |
"batch_size": 64, # Adjust as needed | |
"device": DEVICE, | |
"fp16": FP16, | |
"log_interval": 500, # Less relevant for Gradio app | |
"use_basic_logging": True, | |
"filter_precursor_ppm": 20, | |
"filter_confidence": 1e-4, | |
"filter_fdr_threshold": 0.05, | |
"residue_remapping": { # Add default mappings | |
"M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]", | |
"S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]", | |
"S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]", | |
"Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]", | |
"Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]", | |
"C(+57.02)": "C[UNIMOD:4]", | |
"(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]", | |
}, | |
"column_map": { # Add default mappings | |
"Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz", | |
"Mass": "precursor_mass", "Charge": "precursor_charge", | |
"Mass values": "mz_array", "Mass spectrum": "mz_array", | |
"Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array", | |
"Scan number": "scan_number" | |
}, | |
"index_columns": [ | |
"scan_number", "precursor_mz", "precursor_charge", | |
], | |
# Add other defaults if needed based on errors | |
}) | |
# Override specific parameters | |
cfg_overrides = { | |
"data_path": input_path, | |
"output_path": output_path, | |
"device": DEVICE, | |
"fp16": FP16, | |
"denovo": True, | |
"refine": False, | |
} | |
if "Greedy" in decoding_method: | |
cfg_overrides["num_beams"] = 1 | |
cfg_overrides["use_knapsack"] = False | |
elif "Knapsack" in decoding_method: | |
if KNAPSACK is None: | |
raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.") | |
cfg_overrides["num_beams"] = 5 | |
cfg_overrides["use_knapsack"] = True | |
cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR) | |
else: | |
raise ValueError(f"Unknown decoding method: {decoding_method}") | |
# Merge base config with overrides | |
final_cfg = OmegaConf.merge(base_cfg, cfg_overrides) | |
return final_cfg | |
def predict_peptides(input_file, decoding_method): | |
""" | |
Main function to load data, run prediction, and return results. | |
""" | |
if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None: | |
load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart) | |
if MODEL is None: | |
raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.") | |
if input_file is None: | |
raise gr.Error("Please upload a mass spectrometry file.") | |
input_path = input_file.name # Gradio provides the path in .name | |
print(f"Processing file: {input_path}") | |
print(f"Using decoding method: {decoding_method}") | |
# Create a temporary file for the output CSV | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_out: | |
output_csv_path = temp_out.name | |
try: | |
# 1. Create Config | |
config = create_inference_config(input_path, output_csv_path, decoding_method) | |
print("Inference Config:\n", OmegaConf.to_yaml(config)) | |
# 2. Load Data using SpectrumDataFrame | |
print("Loading spectrum data...") | |
try: | |
sdf = SpectrumDataFrame.load( | |
config.data_path, | |
lazy=False, # Load eagerly for Gradio simplicity | |
is_annotated=False, # De novo mode | |
column_mapping=config.get("column_map", None), | |
shuffle=False, | |
verbose=True # Print loading logs | |
) | |
# Apply charge filter like in CLI | |
original_size = len(sdf) | |
max_charge = config.get("max_charge", 10) | |
sdf.filter_rows( | |
lambda row: (row["precursor_charge"] <= max_charge) and (row["precursor_charge"] > 0) | |
) | |
if len(sdf) < original_size: | |
print(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.") | |
if len(sdf) == 0: | |
raise gr.Error("No valid spectra found in the uploaded file after filtering.") | |
print(f"Data loaded: {len(sdf)} spectra.") | |
except Exception as e: | |
print(f"Error loading data: {e}") | |
raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}") | |
# 3. Prepare Dataset and DataLoader | |
ds = SpectrumDataset( | |
sdf, | |
RESIDUE_SET, | |
MODEL_CONFIG.get("n_peaks", 200), | |
return_str=True, # Needed for greedy/beam search targets later (though not used here) | |
annotated=False, | |
pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False), | |
bin_spectra=config.get("conv_peak_encoder", False), | |
) | |
dl = DataLoader( | |
ds, | |
batch_size=config.batch_size, | |
num_workers=0, # Required by SpectrumDataFrame | |
shuffle=False, # Required by SpectrumDataFrame | |
collate_fn=collate_batch, | |
) | |
# 4. Select Decoder | |
print("Initializing decoder...") | |
decoder: Decoder | |
if config.use_knapsack: | |
if KNAPSACK is None: | |
# This check should ideally be earlier, but double-check | |
raise gr.Error("Knapsack is required for Knapsack Beam Search but is not available.") | |
# KnapsackBeamSearchDecoder doesn't directly load from path in this version? | |
# We load Knapsack globally, so just pass it. | |
# If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path) | |
decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK) | |
elif config.num_beams > 1: | |
# BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1 | |
print(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.") | |
decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE) | |
else: | |
decoder = GreedyDecoder( | |
model=MODEL, | |
mass_scale=MASS_SCALE, | |
# Add suppression options if needed from config | |
suppressed_residues=config.get("suppressed_residues", None), | |
disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True), | |
) | |
print(f"Using decoder: {type(decoder).__name__}") | |
# 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py) | |
print("Starting prediction...") | |
start_time = time.time() | |
results_list: list[ScoredSequence | list] = [] # Store ScoredSequence or empty list | |
for i, batch in enumerate(dl): | |
spectra, precursors, spectra_mask, _, _ = batch # Ignore peptides/masks for de novo | |
spectra = spectra.to(DEVICE) | |
precursors = precursors.to(DEVICE) | |
spectra_mask = spectra_mask.to(DEVICE) | |
with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16): | |
# Beam search decoder might return list[list[ScoredSequence]] if return_beam=True | |
# Greedy decoder returns list[ScoredSequence] | |
# KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]] | |
batch_predictions = decoder.decode( | |
spectra=spectra, | |
precursors=precursors, | |
beam_size=config.num_beams, | |
max_length=config.max_length, | |
# Knapsack/Beam Search specific params if needed | |
mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, # Convert ppm to relative | |
max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1, | |
return_beam=False # Only get the top prediction for simplicity | |
) | |
results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list] | |
print(f"Processed batch {i+1}/{len(dl)}") | |
end_time = time.time() | |
print(f"Prediction finished in {end_time - start_time:.2f} seconds.") | |
# 6. Format Results | |
print("Formatting results...") | |
output_data = [] | |
# Use sdf index columns + prediction results | |
index_cols = [col for col in config.index_columns if col in sdf.df.columns] | |
base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info | |
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range) | |
for i, res in enumerate(results_list): | |
row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data | |
if isinstance(res, ScoredSequence) and res.sequence: | |
sequence_str = "".join(res.sequence) | |
row_data["prediction"] = sequence_str | |
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}" | |
# Use metrics to calculate delta mass ppm for the top prediction | |
try: | |
_, delta_mass_list = metrics_calc.matches_precursor( | |
res.sequence, | |
row_data["precursor_mz"], | |
row_data["precursor_charge"] | |
) | |
# Find the smallest absolute ppm error across isotopes | |
min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float('nan') | |
row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}" | |
except Exception as e: | |
print(f"Warning: Could not calculate delta mass for prediction {i}: {e}") | |
row_data["delta_mass_ppm"] = "N/A" | |
else: | |
row_data["prediction"] = "" | |
row_data["log_probability"] = "N/A" | |
row_data["delta_mass_ppm"] = "N/A" | |
output_data.append(row_data) | |
output_df = pl.DataFrame(output_data) | |
# Ensure specific columns are present and ordered | |
display_cols = ["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"] | |
final_display_cols = [] | |
for col in display_cols: | |
if col in output_df.columns: | |
final_display_cols.append(col) | |
else: | |
print(f"Warning: Expected display column '{col}' not found in results.") | |
# Add any remaining index columns that weren't in display_cols | |
for col in index_cols: | |
if col not in final_display_cols and col in output_df.columns: | |
final_display_cols.append(col) | |
output_df_display = output_df.select(final_display_cols) | |
# 7. Save full results to CSV | |
print(f"Saving results to {output_csv_path}...") | |
output_df.write_csv(output_csv_path) | |
# Return DataFrame for display and path for download | |
return output_df_display.to_pandas(), output_csv_path | |
except Exception as e: | |
print(f"An error occurred during prediction: {e}") | |
# Clean up the temporary output file if it exists | |
if os.path.exists(output_csv_path): | |
os.remove(output_csv_path) | |
# Re-raise as Gradio error | |
raise gr.Error(f"Prediction failed: {e}") | |
# --- Gradio Interface --- | |
css = """ | |
.gradio-container { font-family: sans-serif; } | |
.gr-button { color: white; border-color: black; background: black; } | |
footer { display: none !important; } | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo: | |
gr.Markdown( | |
""" | |
# π InstaNovo _De Novo_ Peptide Sequencing | |
Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions using InstaNovo. | |
Choose between fast Greedy Search or more accurate but slower Knapsack Beam Search. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_file = gr.File( | |
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)", | |
file_types=[".mgf", ".mzml", ".mzxml"] | |
) | |
decoding_method = gr.Radio( | |
["Greedy Search (Fast)", "Knapsack Beam Search (More accurate, but slower)"], | |
label="Decoding Method", | |
value="Greedy Search (Fast)" # Default to fast method | |
) | |
submit_btn = gr.Button("Predict Sequences", variant="primary") | |
with gr.Column(scale=2): | |
output_df = gr.DataFrame(label="Prediction Results", wrap=True) | |
output_file = gr.File(label="Download Full Results (CSV)") | |
submit_btn.click( | |
predict_peptides, | |
inputs=[input_file, decoding_method], | |
outputs=[output_df, output_file] | |
) | |
gr.Examples( | |
[["./sample_spectra.mgf", "Knapsack Beam Search (Accurate, 5 Beams)"]], # Requires test data fetched | |
inputs=[input_file, decoding_method], | |
outputs=[output_df, output_file], | |
fn=predict_peptides, | |
cache_examples=False, # Re-run examples if needed | |
label="Example Usage" | |
) | |
gr.Markdown( | |
""" | |
**Notes:** | |
* Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}). | |
* Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer. | |
* 'delta_mass_ppm' shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron). | |
* Ensure your input file format is correctly specified. Large files may take time to process. | |
""".format(MODEL_ID=MODEL_ID) | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
# Set share=True for temporary public link if running locally | |
# Set server_name="0.0.0.0" to allow access from network if needed | |
# demo.launch(server_name="0.0.0.0", server_port=7860) | |
# For Hugging Face Spaces, just demo.launch() is usually sufficient | |
demo.launch(share=True) # For local testing with public URL |