import gradio as gr import torch import os import tempfile import time import polars as pl import numpy as np from pathlib import Path from omegaconf import OmegaConf, DictConfig # --- InstaNovo Imports --- # It's good practice to handle potential import issues try: from instanovo.transformer.model import InstaNovo from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics from instanovo.transformer.dataset import SpectrumDataset, collate_batch from instanovo.inference import ( GreedyDecoder, KnapsackBeamSearchDecoder, Knapsack, ScoredSequence, Decoder, ) from instanovo.constants import MASS_SCALE, MAX_MASS from torch.utils.data import DataLoader except ImportError as e: print(f"Error importing InstaNovo components: {e}") print("Please ensure InstaNovo is installed correctly.") # Optionally, raise the error or exit if InstaNovo is critical # raise e # --- Configuration --- MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID KNAPSACK_DIR = Path("./knapsack_cache") DEFAULT_CONFIG_PATH = Path("./configs/inference/default.yaml") # Assuming instanovo installs configs locally relative to execution # Determine device DEVICE = "cuda" if torch.cuda.is_available() else "cpu" FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA # --- Global Variables (Load Model and Knapsack Once) --- MODEL: InstaNovo | None = None KNAPSACK: Knapsack | None = None MODEL_CONFIG: DictConfig | None = None RESIDUE_SET: ResidueSet | None = None def load_model_and_knapsack(): """Loads the InstaNovo model and generates/loads the knapsack.""" global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET if MODEL is not None: print("Model already loaded.") return print(f"Loading InstaNovo model: {MODEL_ID} to {DEVICE}...") try: MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(MODEL_ID) MODEL.to(DEVICE) MODEL.eval() RESIDUE_SET = MODEL.residue_set print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") raise gr.Error(f"Failed to load InstaNovo model: {MODEL_ID}. Error: {e}") # --- Knapsack Handling --- KNAPSACK_DIR.mkdir(parents=True, exist_ok=True) knapsack_exists = ( (KNAPSACK_DIR / "parameters.pkl").exists() and (KNAPSACK_DIR / "masses.npy").exists() and (KNAPSACK_DIR / "chart.npy").exists() ) if knapsack_exists: print(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...") try: KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR)) print("Knapsack loaded successfully.") except Exception as e: print(f"Error loading knapsack: {e}. Will attempt to regenerate.") KNAPSACK = None # Force regeneration knapsack_exists = False # Ensure generation happens if not knapsack_exists: print("Knapsack not found or failed to load. Generating knapsack...") if RESIDUE_SET is None: raise gr.Error("Cannot generate knapsack because ResidueSet failed to load.") try: # Prepare residue masses for knapsack generation (handle negative/zero masses) residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy()) negative_residues = [k for k, v in residue_masses_knapsack.items() if v <= 0] if negative_residues: print(f"Warning: Non-positive masses found in residues: {negative_residues}. " "Excluding from knapsack generation.") for res in negative_residues: del residue_masses_knapsack[res] # Remove special tokens explicitly if they somehow got mass for special_token in RESIDUE_SET.special_tokens: if special_token in residue_masses_knapsack: del residue_masses_knapsack[special_token] # Ensure residue indices used match those without special/negative masses valid_residue_indices = { res: idx for res, idx in RESIDUE_SET.residue_to_index.items() if res in residue_masses_knapsack } KNAPSACK = Knapsack.construct_knapsack( residue_masses=residue_masses_knapsack, residue_indices=valid_residue_indices, # Use only valid indices max_mass=MAX_MASS, mass_scale=MASS_SCALE, ) print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...") KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs print("Knapsack saved.") except Exception as e: print(f"Error generating or saving knapsack: {e}") gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available.") KNAPSACK = None # Ensure it's None if generation failed # Load the model and knapsack when the script starts load_model_and_knapsack() def create_inference_config( input_path: str, output_path: str, decoding_method: str, ) -> DictConfig: """Creates the OmegaConf DictConfig needed for prediction.""" # Load default config if available, otherwise create from scratch if DEFAULT_CONFIG_PATH.exists(): base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH) else: print(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.") # Create a minimal config if default is missing base_cfg = OmegaConf.create({ "data_path": None, "instanovo_model": MODEL_ID, "output_path": None, "knapsack_path": str(KNAPSACK_DIR), "denovo": True, "refine": False, # Not doing refinement here "num_beams": 1, "max_length": 40, "max_charge": 10, "isotope_error_range": [0, 1], "subset": 1.0, "use_knapsack": False, "save_beams": False, "batch_size": 64, # Adjust as needed "device": DEVICE, "fp16": FP16, "log_interval": 500, # Less relevant for Gradio app "use_basic_logging": True, "filter_precursor_ppm": 20, "filter_confidence": 1e-4, "filter_fdr_threshold": 0.05, "residue_remapping": { # Add default mappings "M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]", "S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]", "S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]", "Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]", "Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]", "C(+57.02)": "C[UNIMOD:4]", "(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]", }, "column_map": { # Add default mappings "Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz", "Mass": "precursor_mass", "Charge": "precursor_charge", "Mass values": "mz_array", "Mass spectrum": "mz_array", "Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array", "Scan number": "scan_number" }, "index_columns": [ "scan_number", "precursor_mz", "precursor_charge", ], # Add other defaults if needed based on errors }) # Override specific parameters cfg_overrides = { "data_path": input_path, "output_path": output_path, "device": DEVICE, "fp16": FP16, "denovo": True, "refine": False, } if "Greedy" in decoding_method: cfg_overrides["num_beams"] = 1 cfg_overrides["use_knapsack"] = False elif "Knapsack" in decoding_method: if KNAPSACK is None: raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.") cfg_overrides["num_beams"] = 5 cfg_overrides["use_knapsack"] = True cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR) else: raise ValueError(f"Unknown decoding method: {decoding_method}") # Merge base config with overrides final_cfg = OmegaConf.merge(base_cfg, cfg_overrides) return final_cfg def predict_peptides(input_file, decoding_method): """ Main function to load data, run prediction, and return results. """ if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None: load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart) if MODEL is None: raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.") if input_file is None: raise gr.Error("Please upload a mass spectrometry file.") input_path = input_file.name # Gradio provides the path in .name print(f"Processing file: {input_path}") print(f"Using decoding method: {decoding_method}") # Create a temporary file for the output CSV with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_out: output_csv_path = temp_out.name try: # 1. Create Config config = create_inference_config(input_path, output_csv_path, decoding_method) print("Inference Config:\n", OmegaConf.to_yaml(config)) # 2. Load Data using SpectrumDataFrame print("Loading spectrum data...") try: sdf = SpectrumDataFrame.load( config.data_path, lazy=False, # Load eagerly for Gradio simplicity is_annotated=False, # De novo mode column_mapping=config.get("column_map", None), shuffle=False, verbose=True # Print loading logs ) # Apply charge filter like in CLI original_size = len(sdf) max_charge = config.get("max_charge", 10) sdf.filter_rows( lambda row: (row["precursor_charge"] <= max_charge) and (row["precursor_charge"] > 0) ) if len(sdf) < original_size: print(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.") if len(sdf) == 0: raise gr.Error("No valid spectra found in the uploaded file after filtering.") print(f"Data loaded: {len(sdf)} spectra.") except Exception as e: print(f"Error loading data: {e}") raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}") # 3. Prepare Dataset and DataLoader ds = SpectrumDataset( sdf, RESIDUE_SET, MODEL_CONFIG.get("n_peaks", 200), return_str=True, # Needed for greedy/beam search targets later (though not used here) annotated=False, pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False), bin_spectra=config.get("conv_peak_encoder", False), ) dl = DataLoader( ds, batch_size=config.batch_size, num_workers=0, # Required by SpectrumDataFrame shuffle=False, # Required by SpectrumDataFrame collate_fn=collate_batch, ) # 4. Select Decoder print("Initializing decoder...") decoder: Decoder if config.use_knapsack: if KNAPSACK is None: # This check should ideally be earlier, but double-check raise gr.Error("Knapsack is required for Knapsack Beam Search but is not available.") # KnapsackBeamSearchDecoder doesn't directly load from path in this version? # We load Knapsack globally, so just pass it. # If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path) decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK) elif config.num_beams > 1: # BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1 print(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.") decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE) else: decoder = GreedyDecoder( model=MODEL, mass_scale=MASS_SCALE, # Add suppression options if needed from config suppressed_residues=config.get("suppressed_residues", None), disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True), ) print(f"Using decoder: {type(decoder).__name__}") # 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py) print("Starting prediction...") start_time = time.time() results_list: list[ScoredSequence | list] = [] # Store ScoredSequence or empty list for i, batch in enumerate(dl): spectra, precursors, spectra_mask, _, _ = batch # Ignore peptides/masks for de novo spectra = spectra.to(DEVICE) precursors = precursors.to(DEVICE) spectra_mask = spectra_mask.to(DEVICE) with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16): # Beam search decoder might return list[list[ScoredSequence]] if return_beam=True # Greedy decoder returns list[ScoredSequence] # KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]] batch_predictions = decoder.decode( spectra=spectra, precursors=precursors, beam_size=config.num_beams, max_length=config.max_length, # Knapsack/Beam Search specific params if needed mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, # Convert ppm to relative max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1, return_beam=False # Only get the top prediction for simplicity ) results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list] print(f"Processed batch {i+1}/{len(dl)}") end_time = time.time() print(f"Prediction finished in {end_time - start_time:.2f} seconds.") # 6. Format Results print("Formatting results...") output_data = [] # Use sdf index columns + prediction results index_cols = [col for col in config.index_columns if col in sdf.df.columns] base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range) for i, res in enumerate(results_list): row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data if isinstance(res, ScoredSequence) and res.sequence: sequence_str = "".join(res.sequence) row_data["prediction"] = sequence_str row_data["log_probability"] = f"{res.sequence_log_probability:.4f}" # Use metrics to calculate delta mass ppm for the top prediction try: _, delta_mass_list = metrics_calc.matches_precursor( res.sequence, row_data["precursor_mz"], row_data["precursor_charge"] ) # Find the smallest absolute ppm error across isotopes min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float('nan') row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}" except Exception as e: print(f"Warning: Could not calculate delta mass for prediction {i}: {e}") row_data["delta_mass_ppm"] = "N/A" else: row_data["prediction"] = "" row_data["log_probability"] = "N/A" row_data["delta_mass_ppm"] = "N/A" output_data.append(row_data) output_df = pl.DataFrame(output_data) # Ensure specific columns are present and ordered display_cols = ["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"] final_display_cols = [] for col in display_cols: if col in output_df.columns: final_display_cols.append(col) else: print(f"Warning: Expected display column '{col}' not found in results.") # Add any remaining index columns that weren't in display_cols for col in index_cols: if col not in final_display_cols and col in output_df.columns: final_display_cols.append(col) output_df_display = output_df.select(final_display_cols) # 7. Save full results to CSV print(f"Saving results to {output_csv_path}...") output_df.write_csv(output_csv_path) # Return DataFrame for display and path for download return output_df_display.to_pandas(), output_csv_path except Exception as e: print(f"An error occurred during prediction: {e}") # Clean up the temporary output file if it exists if os.path.exists(output_csv_path): os.remove(output_csv_path) # Re-raise as Gradio error raise gr.Error(f"Prediction failed: {e}") # --- Gradio Interface --- css = """ .gradio-container { font-family: sans-serif; } .gr-button { color: white; border-color: black; background: black; } footer { display: none !important; } """ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo: gr.Markdown( """ # 🚀 InstaNovo _De Novo_ Peptide Sequencing Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions using InstaNovo. Choose between fast Greedy Search or more accurate but slower Knapsack Beam Search. """ ) with gr.Row(): with gr.Column(scale=1): input_file = gr.File( label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)", file_types=[".mgf", ".mzml", ".mzxml"] ) decoding_method = gr.Radio( ["Greedy Search (Fast)", "Knapsack Beam Search (More accurate, but slower)"], label="Decoding Method", value="Greedy Search (Fast)" # Default to fast method ) submit_btn = gr.Button("Predict Sequences", variant="primary") with gr.Column(scale=2): output_df = gr.DataFrame(label="Prediction Results", wrap=True) output_file = gr.File(label="Download Full Results (CSV)") submit_btn.click( predict_peptides, inputs=[input_file, decoding_method], outputs=[output_df, output_file] ) gr.Examples( [["./sample_spectra.mgf", "Knapsack Beam Search (Accurate, 5 Beams)"]], # Requires test data fetched inputs=[input_file, decoding_method], outputs=[output_df, output_file], fn=predict_peptides, cache_examples=False, # Re-run examples if needed label="Example Usage" ) gr.Markdown( """ **Notes:** * Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}). * Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer. * 'delta_mass_ppm' shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron). * Ensure your input file format is correctly specified. Large files may take time to process. """.format(MODEL_ID=MODEL_ID) ) # --- Launch the App --- if __name__ == "__main__": # Set share=True for temporary public link if running locally # Set server_name="0.0.0.0" to allow access from network if needed # demo.launch(server_name="0.0.0.0", server_port=7860) # For Hugging Face Spaces, just demo.launch() is usually sufficient demo.launch(share=True) # For local testing with public URL