Spaces:
Running
on
Zero
Running
on
Zero
File size: 21,841 Bytes
a9ee52d dd42569 a9ee52d dd42569 a9ee52d 20c665f a9ee52d dd42569 20c665f dd42569 20c665f a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 20c665f a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d dd42569 a9ee52d 20c665f a9ee52d 20c665f a9ee52d 20c665f a9ee52d 20c665f a9ee52d 20c665f a9ee52d 20c665f a9ee52d 20c665f a9ee52d 20c665f a9ee52d dd42569 a9ee52d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 |
import gradio as gr
import torch
import os
import tempfile
import time
import polars as pl
import numpy as np
import logging
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
from gradio_log import Log
# --- InstaNovo Imports ---
try:
from instanovo.transformer.model import InstaNovo
from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics
from instanovo.transformer.dataset import SpectrumDataset, collate_batch
from instanovo.inference import (
GreedyDecoder,
KnapsackBeamSearchDecoder,
Knapsack,
ScoredSequence,
Decoder,
)
from instanovo.constants import MASS_SCALE, MAX_MASS
from torch.utils.data import DataLoader
except ImportError as e:
raise ImportError("Failed to import InstaNovo components: {e}")
# --- Configuration ---
MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
KNAPSACK_DIR = Path("./knapsack_cache")
DEFAULT_CONFIG_PATH = Path("./configs/inference/default.yaml") # Assuming instanovo installs configs locally relative to execution
# Determine device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
# --- Global Variables (Load Model and Knapsack Once) ---
MODEL: InstaNovo | None = None
KNAPSACK: Knapsack | None = None
MODEL_CONFIG: DictConfig | None = None
RESIDUE_SET: ResidueSet | None = None
# --- Assets ---
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
# Logging configuration
log_file = "/tmp/instanovo_gradio_log.txt"
Path(log_file).touch()
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
logger = logging.getLogger("instanovo")
logger.setLevel(logging.DEBUG)
def load_model_and_knapsack():
"""Loads the InstaNovo model and generates/loads the knapsack."""
global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET
if MODEL is not None:
logger.info("Model already loaded.")
return
logger.info(f"Loading InstaNovo model: {MODEL_ID} to {DEVICE}...")
try:
MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(MODEL_ID)
MODEL.to(DEVICE)
MODEL.eval()
RESIDUE_SET = MODEL.residue_set
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise gr.Error(f"Failed to load InstaNovo model: {MODEL_ID}. Error: {e}")
# --- Knapsack Handling ---
knapsack_exists = (
(KNAPSACK_DIR / "parameters.pkl").exists() and
(KNAPSACK_DIR / "masses.npy").exists() and
(KNAPSACK_DIR / "chart.npy").exists()
)
if knapsack_exists:
logger.info(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...")
try:
KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR))
logger.info("Knapsack loaded successfully.")
except Exception as e:
logger.info(f"Error loading knapsack: {e}. Will attempt to regenerate.")
KNAPSACK = None # Force regeneration
knapsack_exists = False # Ensure generation happens
if not knapsack_exists:
logger.info("Knapsack not found or failed to load. Generating knapsack...")
if RESIDUE_SET is None:
raise gr.Error("Cannot generate knapsack because ResidueSet failed to load.")
try:
# Prepare residue masses for knapsack generation (handle negative/zero masses)
residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
negative_residues = [k for k, v in residue_masses_knapsack.items() if v <= 0]
if negative_residues:
logger.info(f"Warning: Non-positive masses found in residues: {negative_residues}. "
"Excluding from knapsack generation.")
for res in negative_residues:
del residue_masses_knapsack[res]
# Remove special tokens explicitly if they somehow got mass
for special_token in RESIDUE_SET.special_tokens:
if special_token in residue_masses_knapsack:
del residue_masses_knapsack[special_token]
# Ensure residue indices used match those without special/negative masses
valid_residue_indices = {
res: idx for res, idx in RESIDUE_SET.residue_to_index.items()
if res in residue_masses_knapsack
}
KNAPSACK = Knapsack.construct_knapsack(
residue_masses=residue_masses_knapsack,
residue_indices=valid_residue_indices, # Use only valid indices
max_mass=MAX_MASS,
mass_scale=MASS_SCALE,
)
logger.info(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs
logger.info("Knapsack saved.")
except Exception as e:
logger.info(f"Error generating or saving knapsack: {e}")
gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available. {e}")
KNAPSACK = None # Ensure it's None if generation failed
# Load the model and knapsack when the script starts
load_model_and_knapsack()
def create_inference_config(
input_path: str,
output_path: str,
decoding_method: str,
) -> DictConfig:
"""Creates the OmegaConf DictConfig needed for prediction."""
# Load default config if available, otherwise create from scratch
if DEFAULT_CONFIG_PATH.exists():
base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
else:
logger.info(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
# Create a minimal config if default is missing
base_cfg = OmegaConf.create({
"data_path": None,
"instanovo_model": MODEL_ID,
"output_path": None,
"knapsack_path": str(KNAPSACK_DIR),
"denovo": True,
"refine": False, # Not doing refinement here
"num_beams": 1,
"max_length": 40,
"max_charge": 10,
"isotope_error_range": [0, 1],
"subset": 1.0,
"use_knapsack": False,
"save_beams": False,
"batch_size": 64, # Adjust as needed
"device": DEVICE,
"fp16": FP16,
"log_interval": 500, # Less relevant for Gradio app
"use_basic_logging": True,
"filter_precursor_ppm": 20,
"filter_confidence": 1e-4,
"filter_fdr_threshold": 0.05,
"residue_remapping": { # Add default mappings
"M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]",
"S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]",
"S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]",
"Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]",
"Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]",
"C(+57.02)": "C[UNIMOD:4]",
"(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]",
},
"column_map": { # Add default mappings
"Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz",
"Mass": "precursor_mass", "Charge": "precursor_charge",
"Mass values": "mz_array", "Mass spectrum": "mz_array",
"Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array",
"Scan number": "scan_number"
},
"index_columns": [
"scan_number", "precursor_mz", "precursor_charge",
],
# Add other defaults if needed based on errors
})
# Override specific parameters
cfg_overrides = {
"data_path": input_path,
"output_path": output_path,
"device": DEVICE,
"fp16": FP16,
"denovo": True,
"refine": False,
}
if "Greedy" in decoding_method:
cfg_overrides["num_beams"] = 1
cfg_overrides["use_knapsack"] = False
elif "Knapsack" in decoding_method:
if KNAPSACK is None:
raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.")
cfg_overrides["num_beams"] = 5
cfg_overrides["use_knapsack"] = True
cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
else:
raise ValueError(f"Unknown decoding method: {decoding_method}")
# Merge base config with overrides
final_cfg = OmegaConf.merge(base_cfg, cfg_overrides)
return final_cfg
def predict_peptides(input_file, decoding_method):
"""
Main function to load data, run prediction, and return results.
"""
if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
if MODEL is None:
raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
if input_file is None:
raise gr.Error("Please upload a mass spectrometry file.")
input_path = input_file.name # Gradio provides the path in .name
logger.info(f"Processing file: {input_path}")
logger.info(f"Using decoding method: {decoding_method}")
# Create a temporary file for the output CSV
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_out:
output_csv_path = temp_out.name
try:
# 1. Create Config
config = create_inference_config(input_path, output_csv_path, decoding_method)
logger.info(f"Inference Config:\n{OmegaConf.to_yaml(config)}")
# 2. Load Data using SpectrumDataFrame
logger.info("Loading spectrum data...")
try:
sdf = SpectrumDataFrame.load(
config.data_path,
lazy=False, # Load eagerly for Gradio simplicity
is_annotated=False, # De novo mode
column_mapping=config.get("column_map", None),
shuffle=False,
verbose=True # Print loading logs
)
# Apply charge filter like in CLI
original_size = len(sdf)
max_charge = config.get("max_charge", 10)
sdf.filter_rows(
lambda row: (row["precursor_charge"] <= max_charge) and (row["precursor_charge"] > 0)
)
if len(sdf) < original_size:
logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
if len(sdf) == 0:
raise gr.Error("No valid spectra found in the uploaded file after filtering.")
logger.info(f"Data loaded: {len(sdf)} spectra.")
except Exception as e:
logger.info(f"Error loading data: {e}")
raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}")
# 3. Prepare Dataset and DataLoader
ds = SpectrumDataset(
sdf,
RESIDUE_SET,
MODEL_CONFIG.get("n_peaks", 200),
return_str=True, # Needed for greedy/beam search targets later (though not used here)
annotated=False,
pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False),
bin_spectra=config.get("conv_peak_encoder", False),
)
dl = DataLoader(
ds,
batch_size=config.batch_size,
num_workers=0, # Required by SpectrumDataFrame
shuffle=False, # Required by SpectrumDataFrame
collate_fn=collate_batch,
)
# 4. Select Decoder
logger.info("Initializing decoder...")
decoder: Decoder
if config.use_knapsack:
if KNAPSACK is None:
# This check should ideally be earlier, but double-check
raise gr.Error("Knapsack is required for Knapsack Beam Search but is not available.")
# KnapsackBeamSearchDecoder doesn't directly load from path in this version?
# We load Knapsack globally, so just pass it.
# If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
elif config.num_beams > 1:
# BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1
logger.info(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.")
decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE)
else:
decoder = GreedyDecoder(
model=MODEL,
mass_scale=MASS_SCALE,
# Add suppression options if needed from config
suppressed_residues=config.get("suppressed_residues", None),
disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True),
)
logger.info(f"Using decoder: {type(decoder).__name__}")
# 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
logger.info("Starting prediction...")
start_time = time.time()
results_list: list[ScoredSequence | list] = [] # Store ScoredSequence or empty list
for i, batch in enumerate(dl):
spectra, precursors, spectra_mask, _, _ = batch # Ignore peptides/masks for de novo
spectra = spectra.to(DEVICE)
precursors = precursors.to(DEVICE)
spectra_mask = spectra_mask.to(DEVICE)
with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
# Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
# Greedy decoder returns list[ScoredSequence]
# KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
batch_predictions = decoder.decode(
spectra=spectra,
precursors=precursors,
beam_size=config.num_beams,
max_length=config.max_length,
# Knapsack/Beam Search specific params if needed
mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, # Convert ppm to relative
max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1,
return_beam=False # Only get the top prediction for simplicity
)
results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
logger.info(f"Processed batch {i+1}/{len(dl)}")
end_time = time.time()
logger.info(f"Prediction finished in {end_time - start_time:.2f} seconds.")
# 6. Format Results
logger.info("Formatting results...")
output_data = []
# Use sdf index columns + prediction results
index_cols = [col for col in config.index_columns if col in sdf.df.columns]
base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
for i, res in enumerate(results_list):
row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
if isinstance(res, ScoredSequence) and res.sequence:
sequence_str = "".join(res.sequence)
row_data["prediction"] = sequence_str
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
# Use metrics to calculate delta mass ppm for the top prediction
try:
_, delta_mass_list = metrics_calc.matches_precursor(
res.sequence,
row_data["precursor_mz"],
row_data["precursor_charge"]
)
# Find the smallest absolute ppm error across isotopes
min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float('nan')
row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
except Exception as e:
logger.info(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
row_data["delta_mass_ppm"] = "N/A"
else:
row_data["prediction"] = ""
row_data["log_probability"] = "N/A"
row_data["delta_mass_ppm"] = "N/A"
output_data.append(row_data)
output_df = pl.DataFrame(output_data)
# Ensure specific columns are present and ordered
display_cols = ["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"]
final_display_cols = []
for col in display_cols:
if col in output_df.columns:
final_display_cols.append(col)
else:
logger.info(f"Warning: Expected display column '{col}' not found in results.")
# Add any remaining index columns that weren't in display_cols
for col in index_cols:
if col not in final_display_cols and col in output_df.columns:
final_display_cols.append(col)
output_df_display = output_df.select(final_display_cols)
# 7. Save full results to CSV
logger.info(f"Saving results to {output_csv_path}...")
output_df.write_csv(output_csv_path)
# Return DataFrame for display and path for download
return output_df_display.to_pandas(), output_csv_path
except Exception as e:
logger.info(f"An error occurred during prediction: {e}")
# Clean up the temporary output file if it exists
if os.path.exists(output_csv_path):
os.remove(output_csv_path)
# Re-raise as Gradio error
raise gr.Error(f"Prediction failed: {e}")
# --- Gradio Interface ---
css = """
.gradio-container { font-family: sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
footer { display: none !important; }
/* Optional: Add some margin below the logo */
.logo-container img { margin-bottom: 1rem; }
"""
with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
# --- Logo Display ---
gr.Markdown(
"""
<div style="text-align: center;" class="logo-container">
<img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
</div>
""",
elem_classes="logo-container" # Optional class for CSS targeting
)
# --- App Content ---
gr.Markdown(
"""
# π _De Novo_ Peptide Sequencing with InstaNovo
Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions using InstaNovo.
Choose between fast Greedy Search or more accurate but slower Knapsack Beam Search.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_file = gr.File(
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
file_types=[".mgf", ".mzml", ".mzxml"]
)
decoding_method = gr.Radio(
["Greedy Search (Fast, resonably accurate)", "Knapsack Beam Search (More accurate, but slower)"],
label="Decoding Method",
value="Greedy Search (Fast, resonably accurate)" # Default to fast method
)
submit_btn = gr.Button("Predict Sequences", variant="primary")
with gr.Column(scale=2):
output_df = gr.DataFrame(label="Prediction Results", headers=["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"], wrap=True)
output_file = gr.File(label="Download Full Results (CSV)")
submit_btn.click(
predict_peptides,
inputs=[input_file, decoding_method],
outputs=[output_df, output_file]
)
gr.Examples(
[["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)" ],
["assets/sample_spectra.mgf", "Knapsack Beam Search (More accurate, but slower)" ]],
inputs=[input_file, decoding_method],
outputs=[output_df, output_file],
fn=predict_peptides,
cache_examples=False, # Re-run examples if needed
label="Example Usage"
)
gr.Markdown(
"""
**Notes:**
* Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}).
* Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
* `delta_mass_ppm` shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron).
* Ensure your input file format is correctly specified. Large files may take time to process.
""".format(MODEL_ID=MODEL_ID)
)
# Add logging component
with gr.Accordion("Application Logs", open=False):
log_display = Log(log_file, dark=True, height=300)
# --- Launch the App ---
if __name__ == "__main__":
# Set share=True for temporary public link if running locally
# Set server_name="0.0.0.0" to allow access from network if needed
# demo.launch(server_name="0.0.0.0", server_port=7860)
# For Hugging Face Spaces, just demo.launch() is usually sufficient
demo.launch(share=True) # For local testing with public URL |