import gradio as gr import torch import os import tempfile import time import polars as pl import numpy as np from pathlib import Path from omegaconf import OmegaConf, DictConfig # --- InstaNovo Imports --- try: from instanovo.transformer.model import InstaNovo from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics from instanovo.transformer.dataset import SpectrumDataset, collate_batch from instanovo.inference import ( GreedyDecoder, KnapsackBeamSearchDecoder, Knapsack, ScoredSequence, Decoder, ) from instanovo.constants import MASS_SCALE, MAX_MASS from torch.utils.data import DataLoader except ImportError as e: raise ImportError("Failed to import InstaNovo components: {e}") # --- Configuration --- MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID KNAPSACK_DIR = Path("./knapsack_cache") DEFAULT_CONFIG_PATH = Path("./configs/inference/default.yaml") # Assuming instanovo installs configs locally relative to execution # Determine device DEVICE = "cuda" if torch.cuda.is_available() else "cpu" FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA # --- Global Variables (Load Model and Knapsack Once) --- MODEL: InstaNovo | None = None KNAPSACK: Knapsack | None = None MODEL_CONFIG: DictConfig | None = None RESIDUE_SET: ResidueSet | None = None # Assets gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"]) def load_model_and_knapsack(): """Loads the InstaNovo model and generates/loads the knapsack.""" global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET if MODEL is not None: print("Model already loaded.") return print(f"Loading InstaNovo model: {MODEL_ID} to {DEVICE}...") try: MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(MODEL_ID) MODEL.to(DEVICE) MODEL.eval() RESIDUE_SET = MODEL.residue_set print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") raise gr.Error(f"Failed to load InstaNovo model: {MODEL_ID}. Error: {e}") # --- Knapsack Handling --- knapsack_exists = ( (KNAPSACK_DIR / "parameters.pkl").exists() and (KNAPSACK_DIR / "masses.npy").exists() and (KNAPSACK_DIR / "chart.npy").exists() ) if knapsack_exists: print(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...") try: KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR)) print("Knapsack loaded successfully.") except Exception as e: print(f"Error loading knapsack: {e}. Will attempt to regenerate.") KNAPSACK = None # Force regeneration knapsack_exists = False # Ensure generation happens if not knapsack_exists: print("Knapsack not found or failed to load. Generating knapsack...") if RESIDUE_SET is None: raise gr.Error("Cannot generate knapsack because ResidueSet failed to load.") try: # Prepare residue masses for knapsack generation (handle negative/zero masses) residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy()) negative_residues = [k for k, v in residue_masses_knapsack.items() if v <= 0] if negative_residues: print(f"Warning: Non-positive masses found in residues: {negative_residues}. " "Excluding from knapsack generation.") for res in negative_residues: del residue_masses_knapsack[res] # Remove special tokens explicitly if they somehow got mass for special_token in RESIDUE_SET.special_tokens: if special_token in residue_masses_knapsack: del residue_masses_knapsack[special_token] # Ensure residue indices used match those without special/negative masses valid_residue_indices = { res: idx for res, idx in RESIDUE_SET.residue_to_index.items() if res in residue_masses_knapsack } KNAPSACK = Knapsack.construct_knapsack( residue_masses=residue_masses_knapsack, residue_indices=valid_residue_indices, # Use only valid indices max_mass=MAX_MASS, mass_scale=MASS_SCALE, ) print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...") KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs print("Knapsack saved.") except Exception as e: print(f"Error generating or saving knapsack: {e}") gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available. {e}") KNAPSACK = None # Ensure it's None if generation failed # Load the model and knapsack when the script starts load_model_and_knapsack() def create_inference_config( input_path: str, output_path: str, decoding_method: str, ) -> DictConfig: """Creates the OmegaConf DictConfig needed for prediction.""" # Load default config if available, otherwise create from scratch if DEFAULT_CONFIG_PATH.exists(): base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH) else: print(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.") # Create a minimal config if default is missing base_cfg = OmegaConf.create({ "data_path": None, "instanovo_model": MODEL_ID, "output_path": None, "knapsack_path": str(KNAPSACK_DIR), "denovo": True, "refine": False, # Not doing refinement here "num_beams": 1, "max_length": 40, "max_charge": 10, "isotope_error_range": [0, 1], "subset": 1.0, "use_knapsack": False, "save_beams": False, "batch_size": 64, # Adjust as needed "device": DEVICE, "fp16": FP16, "log_interval": 500, # Less relevant for Gradio app "use_basic_logging": True, "filter_precursor_ppm": 20, "filter_confidence": 1e-4, "filter_fdr_threshold": 0.05, "residue_remapping": { # Add default mappings "M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]", "S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]", "S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]", "Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]", "Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]", "C(+57.02)": "C[UNIMOD:4]", "(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]", }, "column_map": { # Add default mappings "Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz", "Mass": "precursor_mass", "Charge": "precursor_charge", "Mass values": "mz_array", "Mass spectrum": "mz_array", "Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array", "Scan number": "scan_number" }, "index_columns": [ "scan_number", "precursor_mz", "precursor_charge", ], # Add other defaults if needed based on errors }) # Override specific parameters cfg_overrides = { "data_path": input_path, "output_path": output_path, "device": DEVICE, "fp16": FP16, "denovo": True, "refine": False, } if "Greedy" in decoding_method: cfg_overrides["num_beams"] = 1 cfg_overrides["use_knapsack"] = False elif "Knapsack" in decoding_method: if KNAPSACK is None: raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.") cfg_overrides["num_beams"] = 5 cfg_overrides["use_knapsack"] = True cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR) else: raise ValueError(f"Unknown decoding method: {decoding_method}") # Merge base config with overrides final_cfg = OmegaConf.merge(base_cfg, cfg_overrides) return final_cfg def predict_peptides(input_file, decoding_method): """ Main function to load data, run prediction, and return results. """ if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None: load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart) if MODEL is None: raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.") if input_file is None: raise gr.Error("Please upload a mass spectrometry file.") input_path = input_file.name # Gradio provides the path in .name print(f"Processing file: {input_path}") print(f"Using decoding method: {decoding_method}") # Create a temporary file for the output CSV with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_out: output_csv_path = temp_out.name try: # 1. Create Config config = create_inference_config(input_path, output_csv_path, decoding_method) print("Inference Config:\n", OmegaConf.to_yaml(config)) # 2. Load Data using SpectrumDataFrame print("Loading spectrum data...") try: sdf = SpectrumDataFrame.load( config.data_path, lazy=False, # Load eagerly for Gradio simplicity is_annotated=False, # De novo mode column_mapping=config.get("column_map", None), shuffle=False, verbose=True # Print loading logs ) # Apply charge filter like in CLI original_size = len(sdf) max_charge = config.get("max_charge", 10) sdf.filter_rows( lambda row: (row["precursor_charge"] <= max_charge) and (row["precursor_charge"] > 0) ) if len(sdf) < original_size: print(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.") if len(sdf) == 0: raise gr.Error("No valid spectra found in the uploaded file after filtering.") print(f"Data loaded: {len(sdf)} spectra.") except Exception as e: print(f"Error loading data: {e}") raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}") # 3. Prepare Dataset and DataLoader ds = SpectrumDataset( sdf, RESIDUE_SET, MODEL_CONFIG.get("n_peaks", 200), return_str=True, # Needed for greedy/beam search targets later (though not used here) annotated=False, pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False), bin_spectra=config.get("conv_peak_encoder", False), ) dl = DataLoader( ds, batch_size=config.batch_size, num_workers=0, # Required by SpectrumDataFrame shuffle=False, # Required by SpectrumDataFrame collate_fn=collate_batch, ) # 4. Select Decoder print("Initializing decoder...") decoder: Decoder if config.use_knapsack: if KNAPSACK is None: # This check should ideally be earlier, but double-check raise gr.Error("Knapsack is required for Knapsack Beam Search but is not available.") # KnapsackBeamSearchDecoder doesn't directly load from path in this version? # We load Knapsack globally, so just pass it. # If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path) decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK) elif config.num_beams > 1: # BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1 print(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.") decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE) else: decoder = GreedyDecoder( model=MODEL, mass_scale=MASS_SCALE, # Add suppression options if needed from config suppressed_residues=config.get("suppressed_residues", None), disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True), ) print(f"Using decoder: {type(decoder).__name__}") # 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py) print("Starting prediction...") start_time = time.time() results_list: list[ScoredSequence | list] = [] # Store ScoredSequence or empty list for i, batch in enumerate(dl): spectra, precursors, spectra_mask, _, _ = batch # Ignore peptides/masks for de novo spectra = spectra.to(DEVICE) precursors = precursors.to(DEVICE) spectra_mask = spectra_mask.to(DEVICE) with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16): # Beam search decoder might return list[list[ScoredSequence]] if return_beam=True # Greedy decoder returns list[ScoredSequence] # KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]] batch_predictions = decoder.decode( spectra=spectra, precursors=precursors, beam_size=config.num_beams, max_length=config.max_length, # Knapsack/Beam Search specific params if needed mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, # Convert ppm to relative max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1, return_beam=False # Only get the top prediction for simplicity ) results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list] print(f"Processed batch {i+1}/{len(dl)}") end_time = time.time() print(f"Prediction finished in {end_time - start_time:.2f} seconds.") # 6. Format Results print("Formatting results...") output_data = [] # Use sdf index columns + prediction results index_cols = [col for col in config.index_columns if col in sdf.df.columns] base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range) for i, res in enumerate(results_list): row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data if isinstance(res, ScoredSequence) and res.sequence: sequence_str = "".join(res.sequence) row_data["prediction"] = sequence_str row_data["log_probability"] = f"{res.sequence_log_probability:.4f}" # Use metrics to calculate delta mass ppm for the top prediction try: _, delta_mass_list = metrics_calc.matches_precursor( res.sequence, row_data["precursor_mz"], row_data["precursor_charge"] ) # Find the smallest absolute ppm error across isotopes min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float('nan') row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}" except Exception as e: print(f"Warning: Could not calculate delta mass for prediction {i}: {e}") row_data["delta_mass_ppm"] = "N/A" else: row_data["prediction"] = "" row_data["log_probability"] = "N/A" row_data["delta_mass_ppm"] = "N/A" output_data.append(row_data) output_df = pl.DataFrame(output_data) # Ensure specific columns are present and ordered display_cols = ["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"] final_display_cols = [] for col in display_cols: if col in output_df.columns: final_display_cols.append(col) else: print(f"Warning: Expected display column '{col}' not found in results.") # Add any remaining index columns that weren't in display_cols for col in index_cols: if col not in final_display_cols and col in output_df.columns: final_display_cols.append(col) output_df_display = output_df.select(final_display_cols) # 7. Save full results to CSV print(f"Saving results to {output_csv_path}...") output_df.write_csv(output_csv_path) # Return DataFrame for display and path for download return output_df_display.to_pandas(), output_csv_path except Exception as e: print(f"An error occurred during prediction: {e}") # Clean up the temporary output file if it exists if os.path.exists(output_csv_path): os.remove(output_csv_path) # Re-raise as Gradio error raise gr.Error(f"Prediction failed: {e}") # --- Gradio Interface --- css = """ .gradio-container { font-family: sans-serif; } .gr-button { color: white; border-color: black; background: black; } footer { display: none !important; } /* Optional: Add some margin below the logo */ .logo-container img { margin-bottom: 1rem; } """ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo: # --- Logo Display --- gr.Markdown( """