BioGeek commited on
Commit
04bf12b
·
1 Parent(s): 69bd30f

feat: adding support for IN+

Browse files
Files changed (1) hide show
  1. app.py +647 -371
app.py CHANGED
@@ -14,6 +14,7 @@ from gradio_log import Log
14
  # --- InstaNovo Imports ---
15
  try:
16
  from instanovo.transformer.model import InstaNovo
 
17
  from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics
18
  from instanovo.transformer.dataset import SpectrumDataset, collate_batch
19
  from instanovo.inference import (
@@ -23,29 +24,38 @@ try:
23
  ScoredSequence,
24
  Decoder,
25
  )
26
- from instanovo.constants import MASS_SCALE, MAX_MASS
 
 
 
 
 
27
  from torch.utils.data import DataLoader
 
28
  except ImportError as e:
29
  raise ImportError(f"Failed to import InstaNovo components: {e}")
30
 
31
  # --- Configuration ---
32
- MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
 
33
  KNAPSACK_DIR = Path("./knapsack_cache")
34
  DEFAULT_CONFIG_PATH = Path(
35
  "./configs/inference/default.yaml"
36
- ) # Assuming instanovo installs configs locally relative to execution
37
 
38
  # Determine device
39
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
- FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
41
 
42
- # --- Global Variables (Load Model and Knapsack Once) ---
43
  MODEL: InstaNovo | None = None
44
- KNAPSACK: Knapsack | None = None
45
  MODEL_CONFIG: DictConfig | None = None
 
 
 
46
  RESIDUE_SET: ResidueSet | None = None
47
 
48
- # --- Assets ---
49
  gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
50
 
51
  # Create gradio temporary directory
@@ -57,141 +67,165 @@ if not temp_dir.exists():
57
  log_file = "/tmp/instanovo_gradio_log.txt"
58
  Path(log_file).touch()
59
 
60
- logger = logging.getLogger("instanovo")
61
  logger.setLevel(logging.INFO)
62
- file_handler = logging.FileHandler(log_file)
63
- file_handler.setLevel(logging.INFO)
64
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
65
- file_handler.setFormatter(formatter)
66
- logger.addHandler(file_handler)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- def load_model_and_knapsack():
70
- """Loads the InstaNovo model and generates/loads the knapsack."""
71
- global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET
72
- if MODEL is not None:
73
- logger.info("Model already loaded.")
74
- return
75
 
76
- logger.info(f"Loading InstaNovo model: {MODEL_ID} to {DEVICE}...")
77
- try:
78
- MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(MODEL_ID)
79
- MODEL.to(DEVICE)
80
- MODEL.eval()
81
- RESIDUE_SET = MODEL.residue_set
82
- logger.info("Model loaded successfully.")
83
- except Exception as e:
84
- logger.error(f"Error loading model: {e}")
85
- raise gr.Error(f"Failed to load InstaNovo model: {MODEL_ID}. Error: {e}")
 
 
 
 
 
 
 
 
 
 
86
 
87
  # --- Knapsack Handling ---
88
- knapsack_exists = (
89
- (KNAPSACK_DIR / "parameters.pkl").exists()
90
- and (KNAPSACK_DIR / "masses.npy").exists()
91
- and (KNAPSACK_DIR / "chart.npy").exists()
92
- )
 
 
93
 
94
- if knapsack_exists:
95
- logger.info(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...")
96
- try:
97
- KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR))
98
- logger.info("Knapsack loaded successfully.")
99
- except Exception as e:
100
- logger.info(f"Error loading knapsack: {e}. Will attempt to regenerate.")
101
- KNAPSACK = None # Force regeneration
102
- knapsack_exists = False # Ensure generation happens
103
-
104
- if not knapsack_exists:
105
- logger.info("Knapsack not found or failed to load. Generating knapsack...")
106
- if RESIDUE_SET is None:
107
- raise gr.Error(
108
- "Cannot generate knapsack because ResidueSet failed to load."
109
- )
110
- try:
111
- # Prepare residue masses for knapsack generation (handle negative/zero masses)
112
- residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
113
- negative_residues = [
114
- k for k, v in residue_masses_knapsack.items() if v <= 0
115
- ]
116
- if negative_residues:
117
- logger.info(f"Warning: Non-positive masses found in residues: {negative_residues}. "
118
- "Excluding from knapsack generation.")
119
- for res in negative_residues:
120
- del residue_masses_knapsack[res]
121
- # Remove special tokens explicitly if they somehow got mass
122
- for special_token in RESIDUE_SET.special_tokens:
123
- if special_token in residue_masses_knapsack:
124
- del residue_masses_knapsack[special_token]
125
-
126
- # Ensure residue indices used match those without special/negative masses
127
- valid_residue_indices = {
128
- res: idx
129
- for res, idx in RESIDUE_SET.residue_to_index.items()
130
- if res in residue_masses_knapsack
131
- }
132
-
133
- KNAPSACK = Knapsack.construct_knapsack(
134
- residue_masses=residue_masses_knapsack,
135
- residue_indices=valid_residue_indices, # Use only valid indices
136
- max_mass=MAX_MASS,
137
- mass_scale=MASS_SCALE,
138
- )
139
- logger.info(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
140
- KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs
141
- logger.info("Knapsack saved.")
142
- except Exception as e:
143
- logger.info(f"Error generating or saving knapsack: {e}")
144
- gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available. {e}")
145
- KNAPSACK = None # Ensure it's None if generation failed
146
 
147
- # Load the model and knapsack when the script starts
148
- load_model_and_knapsack()
 
149
 
150
 
151
  def create_inference_config(
152
  input_path: str,
153
  output_path: str,
154
- decoding_method: str,
155
  ) -> DictConfig:
156
- """Creates the OmegaConf DictConfig needed for prediction."""
157
- # Load default config if available, otherwise create from scratch
158
  if DEFAULT_CONFIG_PATH.exists():
159
  base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
 
160
  else:
161
  logger.info(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
162
- # Create a minimal config if default is missing
163
  base_cfg = OmegaConf.create({
164
- "data_path": None,
165
- "instanovo_model": MODEL_ID,
166
- "output_path": None,
167
- "knapsack_path": str(KNAPSACK_DIR),
168
- "denovo": True,
169
- "refine": False, # Not doing refinement here
170
- "num_beams": 1,
171
- "max_length": 40,
172
- "max_charge": 10,
173
- "isotope_error_range": [0, 1],
174
- "subset": 1.0,
175
- "use_knapsack": False,
176
- "save_beams": False,
177
- "batch_size": 64, # Adjust as needed
178
- "device": DEVICE,
179
- "fp16": FP16,
180
- "log_interval": 500, # Less relevant for Gradio app
181
- "use_basic_logging": True,
182
- "filter_precursor_ppm": 20,
183
- "filter_confidence": 1e-4,
184
- "filter_fdr_threshold": 0.05,
185
- "residue_remapping": { # Add default mappings
186
  "M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]",
187
  "S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]",
188
  "S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]",
189
  "Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]",
190
  "Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]",
191
- "C(+57.02)": "C[UNIMOD:4]",
192
- "(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]",
193
  },
194
- "column_map": { # Add default mappings
195
  "Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz",
196
  "Mass": "precursor_mass", "Charge": "precursor_charge",
197
  "Mass values": "mz_array", "Mass spectrum": "mz_array",
@@ -200,256 +234,457 @@ def create_inference_config(
200
  },
201
  "index_columns": [
202
  "scan_number", "precursor_mz", "precursor_charge",
 
203
  ],
204
- # Add other defaults if needed based on errors
205
  })
206
 
207
- # Override specific parameters
208
  cfg_overrides = {
209
- "data_path": input_path,
210
- "output_path": output_path,
211
- "device": DEVICE,
212
- "fp16": FP16,
213
- "denovo": True,
214
- "refine": False,
215
  }
 
 
 
216
 
217
- if "Greedy" in decoding_method:
218
- cfg_overrides["num_beams"] = 1
219
- cfg_overrides["use_knapsack"] = False
220
- elif "Knapsack" in decoding_method:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  if KNAPSACK is None:
222
- raise gr.Error(
223
- "Knapsack is not available. Cannot use Knapsack Beam Search."
224
- )
225
- cfg_overrides["num_beams"] = 5
226
- cfg_overrides["use_knapsack"] = True
227
- cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
228
  else:
229
- raise ValueError(f"Unknown decoding method: {decoding_method}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- # Merge base config with overrides
232
- final_cfg = OmegaConf.merge(base_cfg, cfg_overrides)
233
- return final_cfg
234
 
235
  @spaces.GPU
236
- def predict_peptides(input_file, decoding_method):
237
  """
238
- Main function to load data, run prediction, and return results.
239
  """
240
- if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
241
- load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
 
242
  if MODEL is None:
243
- raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
 
 
 
 
 
 
 
 
 
244
 
245
  if input_file is None:
246
  raise gr.Error("Please upload a mass spectrometry file.")
247
 
248
- input_path = input_file.name # Gradio provides the path in .name
249
- logger.info(f"Processing file: {input_path}")
250
- logger.info(f"Using decoding method: {decoding_method}")
 
 
 
251
 
252
- # Create a temporary file for the output CSV
253
- with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_out:
254
- output_csv_path = temp_out.name
 
 
 
 
 
 
255
 
256
  try:
257
- # 1. Create Config
258
- config = create_inference_config(input_path, output_csv_path, decoding_method)
259
- logger.info(f"Inference Config:\n{OmegaConf.to_yaml(config)}")
260
 
261
- # 2. Load Data using SpectrumDataFrame
262
  logger.info("Loading spectrum data...")
263
  try:
 
264
  sdf = SpectrumDataFrame.load(
265
- config.data_path,
266
- lazy=False, # Load eagerly for Gradio simplicity
267
- is_annotated=False, # De novo mode
268
- column_mapping=config.get("column_map", None),
269
- shuffle=False,
270
- verbose=True, # Print loading logs
271
  )
272
- # Apply charge filter like in CLI
273
  original_size = len(sdf)
274
  max_charge = config.get("max_charge", 10)
275
- sdf.filter_rows(
276
- lambda row: (row["precursor_charge"] <= max_charge)
277
- and (row["precursor_charge"] > 0)
278
- )
279
- if len(sdf) < original_size:
280
- logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
 
 
281
 
282
  if len(sdf) == 0:
283
  raise gr.Error("No valid spectra found in the uploaded file after filtering.")
284
  logger.info(f"Data loaded: {len(sdf)} spectra.")
 
 
 
285
  except Exception as e:
286
- logger.info(f"Error loading data: {e}")
287
  raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}")
288
 
289
- # 3. Prepare Dataset and DataLoader
 
 
 
 
290
  ds = SpectrumDataset(
291
- sdf,
292
- RESIDUE_SET,
293
- MODEL_CONFIG.get("n_peaks", 200),
294
- return_str=True, # Needed for greedy/beam search targets later (though not used here)
295
- annotated=False,
296
- pad_spectrum_max_length=config.get("compile_model", False)
297
- or config.get("use_flash_attention", False),
298
  bin_spectra=config.get("conv_peak_encoder", False),
 
 
 
299
  )
300
- dl = DataLoader(
301
- ds,
302
- batch_size=config.batch_size,
303
- num_workers=0, # Required by SpectrumDataFrame
304
- shuffle=False, # Required by SpectrumDataFrame
305
- collate_fn=collate_batch,
306
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- # 4. Select Decoder
309
- logger.info("Initializing decoder...")
310
- decoder: Decoder
311
- if config.use_knapsack:
312
- if KNAPSACK is None:
313
- # This check should ideally be earlier, but double-check
314
- raise gr.Error(
315
- "Knapsack is required for Knapsack Beam Search but is not available."
316
- )
317
- # KnapsackBeamSearchDecoder doesn't directly load from path in this version?
318
- # We load Knapsack globally, so just pass it.
319
- # If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
320
- decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
321
- elif config.num_beams > 1:
322
- # BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1
323
- logger.info(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.")
324
- decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE)
325
  else:
326
- decoder = GreedyDecoder(
327
- model=MODEL,
328
- mass_scale=MASS_SCALE,
329
- # Add suppression options if needed from config
330
- suppressed_residues=config.get("suppressed_residues", None),
331
- disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True),
332
- )
333
- logger.info(f"Using decoder: {type(decoder).__name__}")
334
-
335
- # 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
336
- logger.info("Starting prediction...")
337
- start_time = time.time()
338
- results_list: list[
339
- ScoredSequence | list
340
- ] = [] # Store ScoredSequence or empty list
341
-
342
- for i, batch in enumerate(dl):
343
- spectra, precursors, spectra_mask, _, _ = (
344
- batch # Ignore peptides/masks for de novo
345
- )
346
- spectra = spectra.to(DEVICE)
347
- precursors = precursors.to(DEVICE)
348
- spectra_mask = spectra_mask.to(DEVICE)
349
-
350
- with (
351
- torch.no_grad(),
352
- torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16),
353
- ):
354
- # Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
355
- # Greedy decoder returns list[ScoredSequence]
356
- # KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
357
- batch_predictions = decoder.decode(
358
- spectra=spectra,
359
- precursors=precursors,
360
- beam_size=config.num_beams,
361
- max_length=config.max_length,
362
- # Knapsack/Beam Search specific params if needed
363
- mass_tolerance=config.get("filter_precursor_ppm", 20)
364
- * 1e-6, # Convert ppm to relative
365
- max_isotope=config.isotope_error_range[1]
366
- if config.isotope_error_range
367
- else 1,
368
- return_beam=False, # Only get the top prediction for simplicity
369
- )
370
- results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
371
- logger.info(f"Processed batch {i+1}/{len(dl)}")
372
-
373
- end_time = time.time()
374
- logger.info(f"Prediction finished in {end_time - start_time:.2f} seconds.")
375
-
376
- # 6. Format Results
377
- logger.info("Formatting results...")
378
- output_data = []
379
- # Use sdf index columns + prediction results
380
- index_cols = [col for col in config.index_columns if col in sdf.df.columns]
381
- base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
382
-
383
- metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
384
-
385
- for i, res in enumerate(results_list):
386
- row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
387
- if isinstance(res, ScoredSequence) and res.sequence:
388
- sequence_str = "".join(res.sequence)
389
- row_data["prediction"] = sequence_str
390
- row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
391
- # Use metrics to calculate delta mass ppm for the top prediction
392
- try:
393
- _, delta_mass_list = metrics_calc.matches_precursor(
394
- res.sequence,
395
- row_data["precursor_mz"],
396
- row_data["precursor_charge"],
397
- )
398
- # Find the smallest absolute ppm error across isotopes
399
- min_abs_ppm = (
400
- min(abs(p) for p in delta_mass_list)
401
- if delta_mass_list
402
- else float("nan")
403
- )
404
- row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
405
- except Exception as e:
406
- logger.info(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
407
- row_data["delta_mass_ppm"] = "N/A"
408
 
409
- else:
410
- row_data["prediction"] = ""
411
- row_data["log_probability"] = "N/A"
412
- row_data["delta_mass_ppm"] = "N/A"
413
- output_data.append(row_data)
414
-
415
- output_df = pl.DataFrame(output_data)
416
-
417
- # Ensure specific columns are present and ordered
418
- display_cols = [
419
- "scan_number",
420
- "precursor_mz",
421
- "precursor_charge",
422
- "prediction",
423
- "log_probability",
424
- "delta_mass_ppm",
425
- ]
426
- final_display_cols = []
427
- for col in display_cols:
428
- if col in output_df.columns:
429
- final_display_cols.append(col)
430
- else:
431
- logger.info(f"Warning: Expected display column '{col}' not found in results.")
432
 
433
- # Add any remaining index columns that weren't in display_cols
434
- for col in index_cols:
435
- if col not in final_display_cols and col in output_df.columns:
436
- final_display_cols.append(col)
 
 
 
437
 
438
- output_df_display = output_df.select(final_display_cols)
 
 
 
 
 
 
 
 
 
 
439
 
440
- # 7. Save full results to CSV
441
- logger.info(f"Saving results to {output_csv_path}...")
442
- output_df.write_csv(output_csv_path)
443
 
444
- # Return DataFrame for display and path for download
445
- return output_df_display.to_pandas(), output_csv_path
446
 
447
  except Exception as e:
448
- logger.info(f"An error occurred during prediction: {e}")
449
- # Clean up the temporary output file if it exists
450
- if os.path.exists(output_csv_path):
451
- os.remove(output_csv_path)
452
- # Re-raise as Gradio error
 
 
453
  raise gr.Error(f"Prediction failed: {e}")
454
 
455
 
@@ -458,29 +693,29 @@ css = """
458
  .gradio-container { font-family: sans-serif; }
459
  .gr-button { color: white; border-color: black; background: black; }
460
  footer { display: none !important; }
461
- /* Optional: Add some margin below the logo */
462
  .logo-container img { margin-bottom: 1rem; }
 
463
  """
464
 
465
  with gr.Blocks(
466
  css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
467
  ) as demo:
468
- # --- Logo Display ---
469
  gr.Markdown(
470
  """
471
  <div style="text-align: center;" class="logo-container">
472
  <img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
473
  </div>
474
  """,
475
- elem_classes="logo-container", # Optional class for CSS targeting
476
  )
477
 
478
- # --- App Content ---
479
  gr.Markdown(
480
- """
481
- # 🚀 _De Novo_ Peptide Sequencing with InstaNovo
482
- Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions using InstaNovo.
483
- Choose between fast Greedy Search or more accurate but slower Knapsack Beam Search.
 
 
484
  """
485
  )
486
  with gr.Row():
@@ -489,73 +724,114 @@ with gr.Blocks(
489
  label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
490
  file_types=[".mgf", ".mzml", ".mzxml"],
491
  )
492
- decoding_method = gr.Radio(
493
  [
494
- "Greedy Search (Fast, resonably accurate)",
495
- "Knapsack Beam Search (More accurate, but slower)",
 
496
  ],
497
- label="Decoding Method",
498
- value="Greedy Search (Fast, resonably accurate)", # Default to fast method
499
  )
 
 
 
 
 
 
 
 
 
 
 
 
500
  submit_btn = gr.Button("Predict Sequences", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  with gr.Column(scale=2):
502
  output_df = gr.DataFrame(
503
- label="Prediction Results",
504
- headers=[
505
- "scan_number",
506
- "precursor_mz",
507
- "precursor_charge",
508
- "prediction",
509
- "log_probability",
510
- "delta_mass_ppm",
511
- ],
512
- wrap=True,
513
  )
514
  output_file = gr.File(label="Download Full Results (CSV)")
515
 
516
  submit_btn.click(
517
  predict_peptides,
518
- inputs=[input_file, decoding_method],
519
  outputs=[output_df, output_file],
520
  )
521
 
522
  gr.Examples(
523
  [
524
- ["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)"],
525
- [
526
- "assets/sample_spectra.mgf",
527
- "Knapsack Beam Search (More accurate, but slower)",
528
- ],
529
  ],
530
- inputs=[input_file, decoding_method],
531
- outputs=[output_df, output_file],
532
- fn=predict_peptides,
533
- cache_examples=False, # Re-run examples if needed
534
- label="Example Usage",
535
  )
536
 
537
  gr.Markdown(
538
- """
539
  **Notes:**
540
- * Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model `{MODEL_ID}`.
541
- * Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
542
- * `delta_mass_ppm` shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron).
543
- * Ensure your input file format is correctly specified. Large files may take time to process.
544
- """.format(MODEL_ID=MODEL_ID)
 
 
 
 
545
  )
546
 
547
- # Add logging component
548
- with gr.Accordion("Application Logs", open=True):
549
  log_display = Log(log_file, dark=True, height=300)
550
-
551
- gr.Textbox(
552
  value="""
 
 
 
553
  @article{eloff_kalogeropoulos_2025_instanovo,
554
  title = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments},
555
- author = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and
556
- Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and
557
- Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M.
558
- Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and
559
  Timothy P. Jenkins},
560
  year = 2025,
561
  month = {Mar},
@@ -566,8 +842,7 @@ with gr.Blocks(
566
  }
567
  """,
568
  show_copy_button=True,
569
- label="If you use InstaNovo in your research, please cite:",
570
- interactive=False,
571
  )
572
 
573
  # --- Launch the App ---
@@ -576,4 +851,5 @@ if __name__ == "__main__":
576
  # Set server_name="0.0.0.0" to allow access from network if needed
577
  # demo.launch(server_name="0.0.0.0", server_port=7860)
578
  # For Hugging Face Spaces, just demo.launch() is usually sufficient
579
- demo.launch(share=True) # For local testing with public URL
 
 
14
  # --- InstaNovo Imports ---
15
  try:
16
  from instanovo.transformer.model import InstaNovo
17
+ from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus
18
  from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics
19
  from instanovo.transformer.dataset import SpectrumDataset, collate_batch
20
  from instanovo.inference import (
 
24
  ScoredSequence,
25
  Decoder,
26
  )
27
+ from instanovo.inference.diffusion import DiffusionDecoder
28
+ from instanovo.constants import (
29
+ MASS_SCALE,
30
+ MAX_MASS,
31
+ DIFFUSION_START_STEP,
32
+ )
33
  from torch.utils.data import DataLoader
34
+ import torch.nn.functional as F # For padding
35
  except ImportError as e:
36
  raise ImportError(f"Failed to import InstaNovo components: {e}")
37
 
38
  # --- Configuration ---
39
+ TRANSFORMER_MODEL_ID = "instanovo-v1.1.0"
40
+ DIFFUSION_MODEL_ID = "instanovoplus-v1.1.0-alpha"
41
  KNAPSACK_DIR = Path("./knapsack_cache")
42
  DEFAULT_CONFIG_PATH = Path(
43
  "./configs/inference/default.yaml"
44
+ )
45
 
46
  # Determine device
47
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
48
+ FP16 = DEVICE == "cuda"
49
 
50
+ # --- Global Variables (Load Models and Knapsack Once) ---
51
  MODEL: InstaNovo | None = None
 
52
  MODEL_CONFIG: DictConfig | None = None
53
+ MODEL_PLUS: InstaNovoPlus | None = None
54
+ MODEL_PLUS_CONFIG: DictConfig | None = None
55
+ KNAPSACK: Knapsack | None = None
56
  RESIDUE_SET: ResidueSet | None = None
57
 
58
+ # --- Assets ---
59
  gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
60
 
61
  # Create gradio temporary directory
 
67
  log_file = "/tmp/instanovo_gradio_log.txt"
68
  Path(log_file).touch()
69
 
70
+ logger = logging.getLogger("instanovo_gradio")
71
  logger.setLevel(logging.INFO)
72
+ if not logger.handlers:
73
+ file_handler = logging.FileHandler(log_file)
74
+ file_handler.setLevel(logging.INFO)
75
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
76
+ file_handler.setFormatter(formatter)
77
+ logger.addHandler(file_handler)
78
+ stream_handler = logging.StreamHandler()
79
+ stream_handler.setLevel(logging.INFO)
80
+ stream_handler.setFormatter(formatter)
81
+ logger.addHandler(stream_handler)
82
+
83
+
84
+ def load_models_and_knapsack():
85
+ """Loads the InstaNovo models and generates/loads the knapsack."""
86
+ global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET, MODEL_PLUS, MODEL_PLUS_CONFIG
87
+ models_loaded = MODEL is not None and MODEL_PLUS is not None
88
+ if models_loaded:
89
+ logger.info("Models already loaded.")
90
+ # Still check knapsack if not loaded
91
+ if KNAPSACK is None:
92
+ logger.info("Models loaded, but knapsack needs loading/generation.")
93
+ else:
94
+ return # All loaded
95
 
96
+ # --- Load Transformer Model ---
97
+ if MODEL is None:
98
+ logger.info(f"Loading InstaNovo (Transformer) model: {TRANSFORMER_MODEL_ID} to {DEVICE}...")
99
+ try:
100
+ MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(TRANSFORMER_MODEL_ID)
101
+ MODEL.to(DEVICE)
102
+ MODEL.eval()
103
+ RESIDUE_SET = MODEL.residue_set
104
+ logger.info("Transformer model loaded successfully.")
105
+ except Exception as e:
106
+ logger.error(f"Error loading Transformer model: {e}")
107
+ raise gr.Error(f"Failed to load InstaNovo Transformer model: {TRANSFORMER_MODEL_ID}. Error: {e}")
108
+ else:
109
+ logger.info("Transformer model already loaded.")
110
 
 
 
 
 
 
 
111
 
112
+ # --- Load Diffusion Model ---
113
+ if MODEL_PLUS is None:
114
+ logger.info(f"Loading InstaNovo+ (Diffusion) model: {DIFFUSION_MODEL_ID} to {DEVICE}...")
115
+ try:
116
+ MODEL_PLUS, MODEL_PLUS_CONFIG = InstaNovoPlus.from_pretrained(DIFFUSION_MODEL_ID)
117
+ MODEL_PLUS.to(DEVICE)
118
+ MODEL_PLUS.eval()
119
+ if RESIDUE_SET is not None and MODEL_PLUS.residues != RESIDUE_SET:
120
+ logger.warning("Residue sets between Transformer and Diffusion models may differ. Using Transformer's set.")
121
+ elif RESIDUE_SET is None:
122
+ RESIDUE_SET = MODEL_PLUS.residues
123
+
124
+ logger.info("Diffusion model loaded successfully.")
125
+ except Exception as e:
126
+ logger.error(f"Error loading Diffusion model: {e}")
127
+ gr.Warning(f"Failed to load InstaNovo+ Diffusion model ({DIFFUSION_MODEL_ID}): {e}. Diffusion modes will be unavailable.")
128
+ MODEL_PLUS = None
129
+ else:
130
+ logger.info("Diffusion model already loaded.")
131
+
132
 
133
  # --- Knapsack Handling ---
134
+ # Only attempt knapsack loading/generation if the Transformer model is loaded
135
+ if MODEL is not None and RESIDUE_SET is not None and KNAPSACK is None:
136
+ knapsack_exists = (
137
+ (KNAPSACK_DIR / "parameters.pkl").exists()
138
+ and (KNAPSACK_DIR / "masses.npy").exists()
139
+ and (KNAPSACK_DIR / "chart.npy").exists()
140
+ )
141
 
142
+ if knapsack_exists:
143
+ logger.info(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...")
144
+ try:
145
+ KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR))
146
+ logger.info("Knapsack loaded successfully.")
147
+ except Exception as e:
148
+ logger.info(f"Error loading knapsack: {e}. Will attempt to regenerate.")
149
+ KNAPSACK = None
150
+ knapsack_exists = False
151
+
152
+ if not knapsack_exists:
153
+ logger.info("Knapsack not found or failed to load. Generating knapsack...")
154
+ try:
155
+ residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
156
+ special_and_nonpositive = list(RESIDUE_SET.special_tokens) + [
157
+ k for k, v in residue_masses_knapsack.items() if v <= 0
158
+ ]
159
+ if special_and_nonpositive:
160
+ logger.info(f"Excluding special/non-positive mass residues from knapsack: {special_and_nonpositive}")
161
+ for res in set(special_and_nonpositive):
162
+ if res in residue_masses_knapsack:
163
+ del residue_masses_knapsack[res]
164
+
165
+ valid_residue_indices = {
166
+ res: idx
167
+ for res, idx in RESIDUE_SET.residue_to_index.items()
168
+ if res in residue_masses_knapsack
169
+ }
170
+
171
+ if not residue_masses_knapsack:
172
+ raise ValueError("No valid residues with positive mass found for knapsack generation.")
173
+
174
+ KNAPSACK = Knapsack.construct_knapsack(
175
+ residue_masses=residue_masses_knapsack,
176
+ residue_indices=valid_residue_indices,
177
+ max_mass=MAX_MASS,
178
+ mass_scale=MASS_SCALE,
179
+ )
180
+ logger.info(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
181
+ KNAPSACK_DIR.mkdir(parents=True, exist_ok=True)
182
+ KNAPSACK.save(str(KNAPSACK_DIR))
183
+ logger.info("Knapsack saved.")
184
+ except Exception as e:
185
+ logger.error(f"Error generating or saving knapsack: {e}", exc_info=True)
186
+ gr.Warning(f"Failed to generate Knapsack. Knapsack Beam Search will not be available. Error: {e}")
187
+ KNAPSACK = None
188
+ elif KNAPSACK is not None:
189
+ logger.info("Knapsack already loaded.")
190
+ elif MODEL is None:
191
+ logger.warning("Transformer model not loaded, skipping Knapsack loading/generation.")
 
 
192
 
193
+
194
+ # Load models and knapsack when the script starts
195
+ load_models_and_knapsack()
196
 
197
 
198
  def create_inference_config(
199
  input_path: str,
200
  output_path: str,
 
201
  ) -> DictConfig:
202
+ """Creates a base OmegaConf DictConfig for prediction environment."""
 
203
  if DEFAULT_CONFIG_PATH.exists():
204
  base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
205
+ logger.info(f"Loaded base config from {DEFAULT_CONFIG_PATH}")
206
  else:
207
  logger.info(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
 
208
  base_cfg = OmegaConf.create({
209
+ "data_path": None, "instanovo_model": TRANSFORMER_MODEL_ID,
210
+ "instanovoplus_model": DIFFUSION_MODEL_ID, "output_path": None,
211
+ "knapsack_path": str(KNAPSACK_DIR), "denovo": True, "refine": True,
212
+ "num_beams": 1, "max_length": 40, "max_charge": 10,
213
+ "isotope_error_range": [0, 1], "subset": 1.0, "use_knapsack": False,
214
+ "save_beams": False, "batch_size": 64, "device": DEVICE, "fp16": FP16,
215
+ "log_interval": 500, "use_basic_logging": True,
216
+ "filter_precursor_ppm": 20, "filter_confidence": 1e-4,
217
+ "filter_fdr_threshold": 0.05, "suppressed_residues": None,
218
+ "disable_terminal_residues_anywhere": True,
219
+ "residue_remapping": {
 
 
 
 
 
 
 
 
 
 
 
220
  "M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]",
221
  "S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]",
222
  "S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]",
223
  "Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]",
224
  "Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]",
225
+ "C(+57.02)": "C[UNIMOD:4]", "(+42.01)": "[UNIMOD:1]",
226
+ "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]",
227
  },
228
+ "column_map": {
229
  "Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz",
230
  "Mass": "precursor_mass", "Charge": "precursor_charge",
231
  "Mass values": "mz_array", "Mass spectrum": "mz_array",
 
234
  },
235
  "index_columns": [
236
  "scan_number", "precursor_mz", "precursor_charge",
237
+ "retention_time", "spectrum_id", "experiment_name",
238
  ],
 
239
  })
240
 
 
241
  cfg_overrides = {
242
+ "data_path": input_path, "output_path": output_path,
243
+ "device": DEVICE, "fp16": FP16, "denovo": True,
 
 
 
 
244
  }
245
+ final_cfg = OmegaConf.merge(base_cfg, cfg_overrides)
246
+ logger.info(f"Created inference config:\n{OmegaConf.to_yaml(final_cfg)}")
247
+ return final_cfg
248
 
249
+ def _get_transformer_decoder(selection: str, config: DictConfig) -> tuple[Decoder, int, bool]:
250
+ """Helper to instantiate the correct transformer decoder based on selection."""
251
+ global MODEL, KNAPSACK
252
+ if MODEL is None:
253
+ raise gr.Error("InstaNovo Transformer model not loaded.")
254
+
255
+ num_beams = 1
256
+ use_knapsack = False
257
+ decoder: Decoder
258
+
259
+ if "Greedy" in selection:
260
+ decoder = GreedyDecoder(
261
+ model=MODEL,
262
+ mass_scale=MASS_SCALE,
263
+ suppressed_residues=config.get("suppressed_residues", None),
264
+ disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True),
265
+ )
266
+ elif "Knapsack" in selection:
267
  if KNAPSACK is None:
268
+ raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.")
269
+ decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
270
+ num_beams = 5 # Default beam size for knapsack
271
+ use_knapsack = True
 
 
272
  else:
273
+ raise ValueError(f"Unknown transformer decoder selection: {selection}")
274
+
275
+ logger.info(f"Using Transformer decoder: {type(decoder).__name__} (Num beams: {num_beams}, Use Knapsack: {use_knapsack})")
276
+ return decoder, num_beams, use_knapsack
277
+
278
+
279
+ def run_transformer_prediction(dl, config, transformer_decoder_selection):
280
+ """Runs prediction using only the transformer model."""
281
+ global RESIDUE_SET
282
+ if RESIDUE_SET is None:
283
+ raise gr.Error("ResidueSet not loaded.")
284
+
285
+ decoder, num_beams, use_knapsack = _get_transformer_decoder(transformer_decoder_selection, config)
286
+
287
+ results_list: list[ScoredSequence | list] = []
288
+ start_time = time.time()
289
+ for i, batch in enumerate(dl):
290
+ spectra, precursors, spectra_mask, _, _ = batch
291
+ spectra = spectra.to(DEVICE)
292
+ precursors = precursors.to(DEVICE)
293
+ spectra_mask = spectra_mask.to(DEVICE)
294
+
295
+ with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
296
+ batch_predictions = decoder.decode(
297
+ spectra=spectra,
298
+ precursors=precursors,
299
+ beam_size=num_beams,
300
+ max_length=config.max_length,
301
+ mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6,
302
+ max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1,
303
+ return_beam=False, # Only top result
304
+ )
305
+ results_list.extend(batch_predictions)
306
+ if (i + 1) % 10 == 0 or (i + 1) == len(dl):
307
+ logger.info(f"Transformer prediction: Processed batch {i+1}/{len(dl)}")
308
+
309
+ end_time = time.time()
310
+ logger.info(f"Transformer prediction finished in {end_time - start_time:.2f} seconds.")
311
+ return results_list
312
+
313
+ def run_diffusion_prediction(dl, config):
314
+ """Runs prediction using only the diffusion model."""
315
+ global MODEL_PLUS, RESIDUE_SET
316
+ if MODEL_PLUS is None or RESIDUE_SET is None:
317
+ raise gr.Error("InstaNovo+ Diffusion model not loaded.")
318
+
319
+ diffusion_decoder = DiffusionDecoder(model=MODEL_PLUS)
320
+ logger.info(f"Using decoder: {type(diffusion_decoder).__name__}")
321
+
322
+ results_sequences = []
323
+ results_log_probs = []
324
+ start_time = time.time()
325
+
326
+ # Re-create dataloader iterator to get precursor info easily later
327
+ all_batches = list(dl)
328
+
329
+ for i, batch in enumerate(all_batches):
330
+ spectra, precursors, spectra_mask, _, _ = batch
331
+ spectra = spectra.to(DEVICE)
332
+ precursors = precursors.to(DEVICE)
333
+ spectra_mask = spectra_mask.to(DEVICE)
334
+
335
+ with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
336
+ batch_sequences, batch_log_probs = diffusion_decoder.decode(
337
+ spectra=spectra,
338
+ spectra_padding_mask=spectra_mask,
339
+ precursors=precursors,
340
+ initial_sequence=None,
341
+ )
342
+ results_sequences.extend(batch_sequences)
343
+ results_log_probs.extend(batch_log_probs)
344
+ if (i + 1) % 10 == 0 or (i + 1) == len(all_batches):
345
+ logger.info(f"Diffusion prediction: Processed batch {i+1}/{len(all_batches)}")
346
+
347
+ end_time = time.time()
348
+ logger.info(f"Diffusion prediction finished in {end_time - start_time:.2f} seconds.")
349
+
350
+ scored_results = []
351
+ metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
352
+ all_precursors = torch.cat([b[1] for b in all_batches], dim=0) # b[1] is precursors
353
+
354
+ for idx, (seq, logp) in enumerate(zip(results_sequences, results_log_probs)):
355
+ prec_mz = all_precursors[idx, 1].item()
356
+ prec_ch = int(all_precursors[idx, 2].item())
357
+ try:
358
+ _, delta_mass_list = metrics_calc.matches_precursor(seq, prec_mz, prec_ch)
359
+ min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan")
360
+ except Exception as e:
361
+ logger.info(f"Warning: Could not calculate delta mass for diffusion prediction {idx}: {e}")
362
+ min_abs_ppm = float("nan")
363
+
364
+ scored_results.append(
365
+ ScoredSequence(sequence=seq, mass_error=min_abs_ppm, sequence_log_probability=logp, token_log_probabilities=[])
366
+ )
367
+
368
+ return scored_results
369
+
370
+
371
+ def run_refinement_prediction(dl, config, transformer_decoder_selection):
372
+ """Runs transformer prediction followed by diffusion refinement."""
373
+ global MODEL, MODEL_PLUS, RESIDUE_SET, MODEL_PLUS_CONFIG
374
+ if MODEL is None or MODEL_PLUS is None or RESIDUE_SET is None or MODEL_PLUS_CONFIG is None:
375
+ missing = [m for m, v in [("Transformer", MODEL), ("Diffusion", MODEL_PLUS)] if v is None]
376
+ raise gr.Error(f"Cannot run refinement: {', '.join(missing)} model not loaded.")
377
+
378
+ # 1. Run Transformer Prediction (using selected decoder)
379
+ logger.info(f"Running Transformer prediction ({transformer_decoder_selection}) for refinement...")
380
+ transformer_decoder, num_beams, _ = _get_transformer_decoder(transformer_decoder_selection, config) # Get selected decoder
381
+ transformer_results_list: list[ScoredSequence | list] = []
382
+
383
+ all_batches = list(dl) # Store batches
384
+
385
+ start_time_transformer = time.time()
386
+ for i, batch in enumerate(all_batches):
387
+ spectra, precursors, spectra_mask, _, _ = batch
388
+ spectra = spectra.to(DEVICE)
389
+ precursors = precursors.to(DEVICE)
390
+ spectra_mask = spectra_mask.to(DEVICE)
391
+
392
+ with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
393
+ batch_predictions = transformer_decoder.decode(
394
+ spectra=spectra,
395
+ precursors=precursors,
396
+ beam_size=num_beams, # Use selected beam size
397
+ max_length=config.max_length,
398
+ mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6,
399
+ max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1,
400
+ return_beam=False, # Only top result needed for refinement
401
+ )
402
+ transformer_results_list.extend(batch_predictions)
403
+ if (i + 1) % 10 == 0 or (i + 1) == len(all_batches):
404
+ logger.info(f"Refinement (Transformer): Processed batch {i+1}/{len(all_batches)}")
405
+
406
+ logger.info(f"Transformer prediction for refinement finished in {time.time() - start_time_transformer:.2f} seconds.")
407
+
408
+ # 2. Prepare Transformer Predictions as Initial Sequences for Diffusion
409
+ logger.info("Encoding transformer predictions for diffusion input...")
410
+ encoded_transformer_preds = []
411
+ max_len_diffusion = MODEL_PLUS_CONFIG.get("max_length", 40)
412
+
413
+ for res in transformer_results_list:
414
+ if isinstance(res, ScoredSequence) and res.sequence:
415
+ # Encode sequence *without* EOS for diffusion input.
416
+ encoded = RESIDUE_SET.encode(res.sequence, add_eos=False, return_tensor='pt')
417
+ else:
418
+ # If transformer failed, provide a dummy PAD sequence
419
+ encoded = torch.full((max_len_diffusion,), RESIDUE_SET.PAD_INDEX, dtype=torch.long)
420
+
421
+
422
+ # Pad or truncate to the diffusion model's max length
423
+ current_len = encoded.shape[0]
424
+ if current_len > max_len_diffusion:
425
+ logger.warning(f"Transformer prediction exceeded diffusion max length ({max_len_diffusion}). Truncating.")
426
+ encoded = encoded[:max_len_diffusion]
427
+ elif current_len < max_len_diffusion:
428
+ padding = torch.full((max_len_diffusion - current_len,), RESIDUE_SET.PAD_INDEX, dtype=torch.long)
429
+ encoded = torch.cat((encoded, padding))
430
+
431
+ encoded_transformer_preds.append(encoded)
432
+
433
+ if not encoded_transformer_preds:
434
+ raise gr.Error("Transformer prediction yielded no results to refine.")
435
+ encoded_transformer_preds_tensor = torch.stack(encoded_transformer_preds).to(DEVICE)
436
+ logger.info(f"Encoded {encoded_transformer_preds_tensor.shape[0]} sequences for diffusion.")
437
+
438
+
439
+ # 3. Run Diffusion Refinement
440
+ logger.info("Running Diffusion refinement...")
441
+ diffusion_decoder = DiffusionDecoder(model=MODEL_PLUS)
442
+ refined_sequences = []
443
+ refined_log_probs = []
444
+ start_time_diffusion = time.time()
445
+
446
+ current_idx = 0
447
+ for i, batch in enumerate(all_batches):
448
+ spectra, precursors, spectra_mask, _, _ = batch
449
+ spectra = spectra.to(DEVICE)
450
+ precursors = precursors.to(DEVICE)
451
+ spectra_mask = spectra_mask.to(DEVICE)
452
+
453
+ batch_size = spectra.shape[0]
454
+ initial_sequences_batch = encoded_transformer_preds_tensor[current_idx : current_idx + batch_size]
455
+ current_idx += batch_size
456
+
457
+ if initial_sequences_batch.shape[0] != batch_size:
458
+ logger.error(f"Batch size mismatch during refinement: expected {batch_size}, got {initial_sequences_batch.shape[0]}")
459
+ continue # Skip batch?
460
+
461
+ with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
462
+ batch_refined_seqs, batch_refined_logp = diffusion_decoder.decode(
463
+ spectra=spectra,
464
+ spectra_padding_mask=spectra_mask,
465
+ precursors=precursors,
466
+ initial_sequence=initial_sequences_batch,
467
+ start_step=DIFFUSION_START_STEP,
468
+ )
469
+ refined_sequences.extend(batch_refined_seqs)
470
+ refined_log_probs.extend(batch_refined_logp)
471
+ if (i + 1) % 10 == 0 or (i + 1) == len(all_batches):
472
+ logger.info(f"Refinement (Diffusion): Processed batch {i+1}/{len(all_batches)}")
473
+
474
+ logger.info(f"Diffusion refinement finished in {time.time() - start_time_diffusion:.2f} seconds.")
475
+
476
+ # 4. Combine and Format Results
477
+ all_precursors = torch.cat([b[1] for b in all_batches], dim=0) # b[1] is precursors
478
+ metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
479
+ combined_results = []
480
+ for idx, (transformer_res, refined_seq, refined_logp) in enumerate(zip(transformer_results_list, refined_sequences, refined_log_probs)):
481
+ prec_mz = all_precursors[idx, 1].item()
482
+ prec_ch = int(all_precursors[idx, 2].item())
483
+ try:
484
+ _, delta_mass_list = metrics_calc.matches_precursor(refined_seq, prec_mz, prec_ch)
485
+ min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan")
486
+ except Exception as e:
487
+ logger.info(f"Warning: Could not calculate delta mass for refined prediction {idx}: {e}")
488
+ min_abs_ppm = float("nan")
489
+
490
+ combined_data = {
491
+ "transformer_prediction": "".join(transformer_res.sequence) if isinstance(transformer_res, ScoredSequence) else "",
492
+ "transformer_log_probability": transformer_res.sequence_log_probability if isinstance(transformer_res, ScoredSequence) else float('-inf'),
493
+ "refined_prediction": "".join(refined_seq),
494
+ "refined_log_probability": refined_logp,
495
+ "refined_delta_mass_ppm": min_abs_ppm,
496
+ }
497
+ combined_results.append(combined_data)
498
+
499
+ return combined_results
500
 
 
 
 
501
 
502
  @spaces.GPU
503
+ def predict_peptides(input_file, mode_selection, transformer_decoder_selection):
504
  """
505
+ Main function to load data, select mode, run prediction, and return results.
506
  """
507
+ # Ensure models are loaded
508
+ if MODEL is None or RESIDUE_SET is None:
509
+ load_models_and_knapsack() # Try reload
510
  if MODEL is None:
511
+ raise gr.Error("InstaNovo Transformer model failed to load. Cannot perform prediction.")
512
+ if ("Refinement" in mode_selection or "InstaNovo+" in mode_selection) and MODEL_PLUS is None:
513
+ load_models_and_knapsack() # Try reload diffusion
514
+ if MODEL_PLUS is None:
515
+ raise gr.Error("InstaNovo+ Diffusion model failed to load. Cannot perform Refinement or InstaNovo+ Only prediction.")
516
+ if "Knapsack" in transformer_decoder_selection and KNAPSACK is None:
517
+ load_models_and_knapsack() # Try reload knapsack
518
+ if KNAPSACK is None:
519
+ raise gr.Error("Knapsack failed to load. Cannot use Knapsack Beam Search.")
520
+
521
 
522
  if input_file is None:
523
  raise gr.Error("Please upload a mass spectrometry file.")
524
 
525
+ input_path = input_file.name
526
+ logger.info(f"--- New Prediction Request ---")
527
+ logger.info(f"Input File: {input_path}")
528
+ logger.info(f"Selected Mode: {mode_selection}")
529
+ if "Refinement" in mode_selection or "InstaNovo Only" in mode_selection:
530
+ logger.info(f"Selected Transformer Decoder: {transformer_decoder_selection}")
531
 
532
+ # Create temp output file
533
+ gradio_tmp_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp")
534
+ try:
535
+ with tempfile.NamedTemporaryFile(dir=gradio_tmp_dir, delete=False, suffix=".csv") as temp_out:
536
+ output_csv_path = temp_out.name
537
+ logger.info(f"Temporary output path: {output_csv_path}")
538
+ except Exception as e:
539
+ logger.error(f"Failed to create temporary file in {gradio_tmp_dir}: {e}")
540
+ raise gr.Error(f"Failed to create temporary output file: {e}")
541
 
542
  try:
543
+ config = create_inference_config(input_path, output_csv_path)
 
 
544
 
 
545
  logger.info("Loading spectrum data...")
546
  try:
547
+ # Load data eagerly
548
  sdf = SpectrumDataFrame.load(
549
+ config.data_path, lazy=False, is_annotated=False,
550
+ column_mapping=config.get("column_map", None), shuffle=False, verbose=True,
 
 
 
 
551
  )
 
552
  original_size = len(sdf)
553
  max_charge = config.get("max_charge", 10)
554
+ if "precursor_charge" in sdf.df.columns:
555
+ sdf.filter_rows(
556
+ lambda row: ("precursor_charge" in row and row["precursor_charge"] is not None and 0 < row["precursor_charge"] <= max_charge)
557
+ )
558
+ if len(sdf) < original_size:
559
+ logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with invalid or out-of-range charge (<=0 or >{max_charge}).")
560
+ else:
561
+ logger.warning("Column 'precursor_charge' not found. Cannot filter by charge.")
562
 
563
  if len(sdf) == 0:
564
  raise gr.Error("No valid spectra found in the uploaded file after filtering.")
565
  logger.info(f"Data loaded: {len(sdf)} spectra.")
566
+ index_cols_present = [col for col in config.index_columns if col in sdf.df.columns]
567
+ base_df_pd = sdf.df.select(index_cols_present).to_pandas()
568
+
569
  except Exception as e:
570
+ logger.error(f"Error loading data: {e}", exc_info=True)
571
  raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}")
572
 
573
+ if RESIDUE_SET is None: raise gr.Error("Residue set not loaded.") # Should not happen if model loaded
574
+
575
+ # --- Prepare DataLoader ---
576
+ # Use reverse_peptide=True for Transformer steps, False for Diffusion-only
577
+ reverse_for_transformer = "InstaNovo+ Only" not in mode_selection
578
  ds = SpectrumDataset(
579
+ sdf, RESIDUE_SET,
580
+ MODEL_CONFIG.get("n_peaks", 200) if MODEL_CONFIG else 200,
581
+ return_str=True, annotated=False,
582
+ pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False),
 
 
 
583
  bin_spectra=config.get("conv_peak_encoder", False),
584
+ peptide_pad_length=config.get("max_length", 40) if config.get("compile_model", False) else 0,
585
+ reverse_peptide=reverse_for_transformer, # Key change based on mode
586
+ diffusion="InstaNovo+ Only" in mode_selection # Signal if input is for diffusion
587
  )
588
+ dl = DataLoader(ds, batch_size=config.batch_size, num_workers=0, shuffle=False, collate_fn=collate_batch)
589
+
590
+ # --- Run Prediction ---
591
+ results_data = None
592
+ output_headers = index_cols_present[:]
593
+
594
+ if "InstaNovo Only" in mode_selection:
595
+ output_headers.extend(["prediction", "log_probability", "delta_mass_ppm", "token_log_probabilities"])
596
+ transformer_results = run_transformer_prediction(dl, config, transformer_decoder_selection)
597
+ results_data = []
598
+ metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
599
+ for i, res in enumerate(transformer_results):
600
+ row_data = {}
601
+ if isinstance(res, ScoredSequence) and res.sequence:
602
+ row_data["prediction"] = "".join(res.sequence)
603
+ row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
604
+ row_data["token_log_probabilities"] = ", ".join(f"{p:.4f}" for p in res.token_log_probabilities)
605
+ try:
606
+ prec_mz = base_df_pd.loc[i, "precursor_mz"]
607
+ prec_ch = base_df_pd.loc[i, "precursor_charge"]
608
+ _, delta_mass_list = metrics_calc.matches_precursor(res.sequence, prec_mz, prec_ch)
609
+ min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan")
610
+ row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
611
+ except Exception as e:
612
+ logger.warning(f"Could not calculate delta mass for Tx prediction {i}: {e}")
613
+ row_data["delta_mass_ppm"] = "N/A"
614
+ else:
615
+ row_data.update({k: "N/A" for k in ["prediction", "log_probability", "delta_mass_ppm", "token_log_probabilities"]})
616
+ row_data["prediction"] = "" # Ensure empty string for failed preds
617
+ row_data["token_log_probabilities"] = ""
618
+ results_data.append(row_data)
619
+
620
+ elif "InstaNovo+ Only" in mode_selection:
621
+ output_headers.extend(["prediction", "log_probability", "delta_mass_ppm"])
622
+ diffusion_results = run_diffusion_prediction(dl, config)
623
+ results_data = []
624
+ for res in diffusion_results:
625
+ row_data = {}
626
+ if isinstance(res, ScoredSequence) and res.sequence:
627
+ row_data["prediction"] = "".join(res.sequence)
628
+ row_data["log_probability"] = f"{res.sequence_log_probability:.4f}" # Avg loss
629
+ row_data["delta_mass_ppm"] = f"{res.mass_error:.2f}" if not np.isnan(res.mass_error) else "N/A" # ppm
630
+ else:
631
+ row_data.update({k: "N/A" for k in ["prediction", "log_probability", "delta_mass_ppm"]})
632
+ row_data["prediction"] = ""
633
+ results_data.append(row_data)
634
+
635
+ elif "Refinement" in mode_selection:
636
+ output_headers.extend([
637
+ "transformer_prediction", "transformer_log_probability",
638
+ "refined_prediction", "refined_log_probability", "refined_delta_mass_ppm"
639
+ ])
640
+ # Pass the selected transformer decoder to the refinement function
641
+ results_data = run_refinement_prediction(dl, config, transformer_decoder_selection)
642
+ for row in results_data:
643
+ # Format numbers after getting the list of dicts
644
+ row["transformer_log_probability"] = f"{row['transformer_log_probability']:.4f}" if isinstance(row['transformer_log_probability'], (float, int)) else "N/A"
645
+ row["refined_log_probability"] = f"{row['refined_log_probability']:.4f}" if isinstance(row['refined_log_probability'], (float, int)) else "N/A"
646
+ row["refined_delta_mass_ppm"] = f"{row['refined_delta_mass_ppm']:.2f}" if isinstance(row['refined_delta_mass_ppm'], (float, int)) and not np.isnan(row['refined_delta_mass_ppm']) else "N/A"
647
+
648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  else:
650
+ raise ValueError(f"Unknown mode selection: {mode_selection}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
+ # --- Combine, Save, Return ---
654
+ logger.info("Combining results...")
655
+ if results_data is None: raise gr.Error("Prediction did not produce results.")
656
+
657
+ results_df = pl.DataFrame(results_data)
658
+ # Ensure base_df_pd has unique index if using join, or just concat horizontally if order is guaranteed
659
+ base_df_pl = pl.from_pandas(base_df_pd.reset_index(drop=True))
660
 
661
+ # Simple horizontal concat assuming order is preserved by dataloader (shuffle=False)
662
+ if len(base_df_pl) == len(results_df):
663
+ final_df = pl.concat([base_df_pl, results_df], how="horizontal")
664
+ else:
665
+ logger.error(f"Length mismatch between base data ({len(base_df_pl)}) and results ({len(results_df)}). Cannot reliably combine.")
666
+ # Fallback or error? Let's just use results for now, but log error.
667
+ final_df = results_df # Display only results in case of mismatch
668
+
669
+ logger.info(f"Saving full results to {output_csv_path}...")
670
+ final_df.write_csv(output_csv_path)
671
+ logger.info("Save complete.")
672
 
673
+ # Select display columns - make sure they exist in final_df
674
+ display_cols_final = [col for col in output_headers if col in final_df.columns]
675
+ display_df = final_df.select(display_cols_final)
676
 
677
+ logger.info("--- Prediction Request Complete ---")
678
+ return display_df.to_pandas(), output_csv_path
679
 
680
  except Exception as e:
681
+ logger.error(f"An error occurred during prediction: {e}", exc_info=True)
682
+ if 'output_csv_path' in locals() and os.path.exists(output_csv_path):
683
+ try:
684
+ os.remove(output_csv_path)
685
+ logger.info(f"Removed temporary file {output_csv_path}")
686
+ except OSError:
687
+ logger.error(f"Failed to remove temporary file {output_csv_path}")
688
  raise gr.Error(f"Prediction failed: {e}")
689
 
690
 
 
693
  .gradio-container { font-family: sans-serif; }
694
  .gr-button { color: white; border-color: black; background: black; }
695
  footer { display: none !important; }
 
696
  .logo-container img { margin-bottom: 1rem; }
697
+ .feedback { font-size: 0.9rem; color: gray; }
698
  """
699
 
700
  with gr.Blocks(
701
  css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
702
  ) as demo:
 
703
  gr.Markdown(
704
  """
705
  <div style="text-align: center;" class="logo-container">
706
  <img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
707
  </div>
708
  """,
709
+ elem_classes="logo-container",
710
  )
711
 
 
712
  gr.Markdown(
713
+ f"""
714
+ # 🚀 _De Novo_ Peptide Sequencing with InstaNovo
715
+ Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions.
716
+ Choose your prediction method and decoding options.
717
+
718
+ **Note:** The InstaNovo+ model `{DIFFUSION_MODEL_ID}` is an alpha release.
719
  """
720
  )
721
  with gr.Row():
 
724
  label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
725
  file_types=[".mgf", ".mzml", ".mzxml"],
726
  )
727
+ mode_selection = gr.Radio(
728
  [
729
+ "InstaNovo + Refinement (Default, Recommended)",
730
+ "InstaNovo Only (Transformer)",
731
+ "InstaNovo+ Only (Diffusion, Alpha)",
732
  ],
733
+ label="Prediction Mode",
734
+ value="InstaNovo + Refinement (Default, Recommended)",
735
  )
736
+ # Transformer decoder selection - visible for relevant modes
737
+ transformer_decoder_selection = gr.Radio(
738
+ [
739
+ "Greedy Search (Fast)",
740
+ # Knapsack option added dynamically based on KNAPSACK availability
741
+ ],
742
+ label="Transformer Decoding Method",
743
+ value="Greedy Search (Fast)",
744
+ visible=True, # Start visible as default mode uses it
745
+ interactive=True,
746
+ )
747
+
748
  submit_btn = gr.Button("Predict Sequences", variant="primary")
749
+
750
+ # --- Control Visibility & Choices ---
751
+ def update_transformer_options(mode):
752
+ # Show decoder selection if mode uses the transformer
753
+ show_decoder = "InstaNovo+ Only" not in mode
754
+ # Update choices based on knapsack availability
755
+ knapsack_available = KNAPSACK is not None
756
+ choices = ["Greedy Search (Fast)"]
757
+ if knapsack_available:
758
+ choices.append("Knapsack Beam Search (Accurate, Slower)")
759
+ else:
760
+ logger.info("Knapsack check: Not available, disabling Knapsack Beam Search option.")
761
+ # Reset to Greedy if Knapsack was selected but becomes unavailable
762
+ current_value = "Greedy Search (Fast)" # Default reset value
763
+ return gr.update(visible=show_decoder, choices=choices, value=current_value)
764
+
765
+ mode_selection.change(
766
+ fn=update_transformer_options,
767
+ inputs=mode_selection,
768
+ outputs=transformer_decoder_selection,
769
+ )
770
+ # Initial check in case knapsack fails on startup
771
+ # This requires JS or a different approach in Gradio.
772
+ # For simplicity, we rely on the check during prediction.
773
+ # We can set initial choices based on load status here though.
774
+ initial_choices = ["Greedy Search (Fast)"]
775
+ if KNAPSACK is not None:
776
+ initial_choices.append("Knapsack Beam Search (Accurate, Slower)")
777
+ transformer_decoder_selection.choices = initial_choices
778
+
779
+
780
  with gr.Column(scale=2):
781
  output_df = gr.DataFrame(
782
+ label="Prediction Results Preview",
783
+ headers=["scan_number", "prediction", "log_probability", "delta_mass_ppm"]
 
 
 
 
 
 
 
 
784
  )
785
  output_file = gr.File(label="Download Full Results (CSV)")
786
 
787
  submit_btn.click(
788
  predict_peptides,
789
+ inputs=[input_file, mode_selection, transformer_decoder_selection],
790
  outputs=[output_df, output_file],
791
  )
792
 
793
  gr.Examples(
794
  [
795
+ ["assets/sample_spectra.mgf", "InstaNovo + Refinement (Default, Recommended)", "Greedy Search (Fast)"],
796
+ ["assets/sample_spectra.mgf", "InstaNovo + Refinement (Default, Recommended)", "Knapsack Beam Search (Accurate, Slower)"],
797
+ ["assets/sample_spectra.mgf", "InstaNovo Only (Transformer)", "Greedy Search (Fast)"],
798
+ ["assets/sample_spectra.mgf", "InstaNovo Only (Transformer)", "Knapsack Beam Search (Accurate, Slower)"],
799
+ ["assets/sample_spectra.mgf", "InstaNovo+ Only (Diffusion, Alpha)", "Greedy Search (Fast)"],
800
  ],
801
+ inputs=[input_file, mode_selection, transformer_decoder_selection],
802
+ # outputs=[output_df, output_file],
803
+ cache_examples=False,
804
+ label="Example Usage (Note: Knapsack examples require Knapsack to be available)",
 
805
  )
806
 
807
  gr.Markdown(
808
+ f"""
809
  **Notes:**
810
+ * Predictions use `{TRANSFORMER_MODEL_ID}` (Transformer) and `{DIFFUSION_MODEL_ID}` (Diffusion, Alpha).
811
+ * **Refinement Mode:** Runs initial prediction with the selected Transformer method (Greedy/Knapsack), then refines using InstaNovo+.
812
+ * **InstaNovo Only Mode:** Uses only the Transformer with the selected decoding method.
813
+ * **InstaNovo+ Only Mode:** Predicts directly using the Diffusion model (alpha version).
814
+ * `delta_mass_ppm` shows the lowest absolute precursor mass error (ppm) across isotopes 0-1 for the final sequence.
815
+ * Knapsack Beam Search requires a pre-computed knapsack file. If unavailable, the option will be disabled.
816
+ * Check logs for progress, especially for large files or slower methods.
817
+ """,
818
+ elem_classes="feedback"
819
  )
820
 
821
+ with gr.Accordion("Application Logs", open=False):
 
822
  log_display = Log(log_file, dark=True, height=300)
823
+
824
+ gr.Markdown(
825
  value="""
826
+ If you use InstaNovo in your research, please cite:
827
+
828
+ ```bibtex
829
  @article{eloff_kalogeropoulos_2025_instanovo,
830
  title = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments},
831
+ author = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and
832
+ Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and
833
+ Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M.
834
+ Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and
835
  Timothy P. Jenkins},
836
  year = 2025,
837
  month = {Mar},
 
842
  }
843
  """,
844
  show_copy_button=True,
845
+ label="If you use InstaNovo in your research, please cite:"
 
846
  )
847
 
848
  # --- Launch the App ---
 
851
  # Set server_name="0.0.0.0" to allow access from network if needed
852
  # demo.launch(server_name="0.0.0.0", server_port=7860)
853
  # For Hugging Face Spaces, just demo.launch() is usually sufficient
854
+ demo.launch()
855
+ # demo.launch(share=True) # For local testing with public URL