File size: 40,096 Bytes
85d497c
a9ee52d
 
 
 
 
 
 
dd42569
a9ee52d
 
dd42569
a9ee52d
 
 
 
04bf12b
a9ee52d
 
 
 
 
 
 
 
 
04bf12b
 
 
 
 
 
a9ee52d
04bf12b
a9ee52d
6851f02
a9ee52d
 
04bf12b
 
a9ee52d
 
 
 
04bf12b
a9ee52d
04bf12b
68f76b5
 
 
 
04bf12b
a9ee52d
 
04bf12b
20c665f
 
69bd30f
 
 
 
 
dd42569
68f76b5
 
dd42569
 
 
68f76b5
757d63f
e7ec3c2
 
 
 
 
04bf12b
 
 
 
68f76b5
 
04bf12b
 
 
 
 
 
 
dd42569
04bf12b
68f76b5
04bf12b
 
68f76b5
 
 
 
04bf12b
 
 
 
 
 
20c665f
a9ee52d
04bf12b
68f76b5
04bf12b
 
68f76b5
 
 
 
 
04bf12b
68f76b5
04bf12b
 
 
 
 
68f76b5
04bf12b
 
 
a9ee52d
 
04bf12b
68f76b5
04bf12b
 
 
 
 
a9ee52d
04bf12b
 
 
 
 
 
 
 
 
 
 
 
 
e7ec3c2
04bf12b
e7ec3c2
04bf12b
 
e7ec3c2
 
 
 
 
 
 
 
04bf12b
 
e7ec3c2
04bf12b
e7ec3c2
 
04bf12b
 
 
 
 
 
 
 
 
 
 
 
68f76b5
04bf12b
a9ee52d
04bf12b
 
 
a9ee52d
44b2355
a9ee52d
 
 
 
04bf12b
2ce8475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9ee52d
 
04bf12b
 
a9ee52d
04bf12b
 
 
a9ee52d
04bf12b
 
68f76b5
 
04bf12b
 
 
 
 
 
 
 
68f76b5
04bf12b
 
 
 
 
a9ee52d
04bf12b
68f76b5
04bf12b
 
a9ee52d
04bf12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f76b5
 
04bf12b
 
68f76b5
04bf12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f76b5
 
 
04bf12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f76b5
04bf12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f76b5
04bf12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9ee52d
 
85d497c
04bf12b
a9ee52d
04bf12b
a9ee52d
04bf12b
68f76b5
04bf12b
68f76b5
04bf12b
68f76b5
04bf12b
68f76b5
04bf12b
 
 
 
 
 
a9ee52d
 
 
 
04bf12b
68f76b5
04bf12b
 
2ce8475
04bf12b
a9ee52d
04bf12b
 
 
 
 
 
 
 
 
a9ee52d
 
04bf12b
a9ee52d
dd42569
a9ee52d
04bf12b
a9ee52d
04bf12b
 
a9ee52d
 
 
04bf12b
 
 
 
 
 
 
 
a9ee52d
 
 
dd42569
04bf12b
 
 
a9ee52d
04bf12b
a9ee52d
 
04bf12b
 
 
 
 
a9ee52d
04bf12b
68f76b5
04bf12b
 
a9ee52d
04bf12b
 
 
a9ee52d
04bf12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ce8475
04bf12b
 
 
 
 
 
 
 
 
 
 
 
a9ee52d
 
04bf12b
a9ee52d
 
04bf12b
 
 
 
 
 
 
a9ee52d
04bf12b
 
 
 
 
 
 
 
 
 
 
a9ee52d
04bf12b
 
 
a9ee52d
04bf12b
 
a9ee52d
 
04bf12b
 
 
 
 
 
 
a9ee52d
 
44b2355
a9ee52d
 
 
 
 
20c665f
04bf12b
a9ee52d
 
44b2355
 
 
20c665f
 
 
 
 
 
04bf12b
20c665f
 
a9ee52d
04bf12b
f4fb49b
04bf12b
 
 
2ce8475
 
 
 
 
 
 
 
 
 
 
e7ec3c2
f4fb49b
2ce8475
 
a9ee52d
 
 
 
 
44b2355
68f76b5
a9ee52d
04bf12b
44b2355
68f76b5
04bf12b
68f76b5
44b2355
04bf12b
68f76b5
 
a9ee52d
04bf12b
 
 
 
68f76b5
04bf12b
 
 
 
 
68f76b5
04bf12b
 
a9ee52d
04bf12b
 
 
 
 
68f76b5
04bf12b
 
 
 
 
 
 
 
 
 
a9ee52d
44b2355
04bf12b
 
44b2355
a9ee52d
 
 
 
04bf12b
44b2355
a9ee52d
 
 
44b2355
68f76b5
 
04bf12b
 
68f76b5
44b2355
04bf12b
 
 
68f76b5
a9ee52d
 
68f76b5
dd42569
04bf12b
f4fb49b
 
 
 
 
 
 
04bf12b
44b2355
f4fb49b
44b2355
 
04bf12b
 
 
 
44b2355
 
 
 
 
 
 
 
f4fb49b
44b2355
f4fb49b
44b2355
 
a9ee52d
 
68f76b5
 
a9ee52d
 
 
 
68f76b5
04bf12b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
import spaces
import gradio as gr
import torch
import os
import tempfile
import time
import polars as pl
import numpy as np
import logging
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
from gradio_log import Log

# --- InstaNovo Imports ---
try:
    from instanovo.transformer.model import InstaNovo
    from instanovo.diffusion.multinomial_diffusion import InstaNovoPlus
    from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics
    from instanovo.transformer.dataset import SpectrumDataset, collate_batch
    from instanovo.inference import (
        GreedyDecoder,
        KnapsackBeamSearchDecoder,
        Knapsack,
        ScoredSequence,
        Decoder,
    )
    from instanovo.inference.diffusion import DiffusionDecoder
    from instanovo.constants import (
        MASS_SCALE,
        MAX_MASS,
        DIFFUSION_START_STEP,
    )
    from torch.utils.data import DataLoader
    import torch.nn.functional as F # For padding
except ImportError as e:
    raise ImportError(f"Failed to import InstaNovo components: {e}")

# --- Configuration ---
TRANSFORMER_MODEL_ID = "instanovo-v1.1.0"
DIFFUSION_MODEL_ID = "instanovoplus-v1.1.0-alpha"
KNAPSACK_DIR = Path("./knapsack_cache")

# Determine device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FP16 = DEVICE == "cuda"

# --- Global Variables (Load Models and Knapsack Once) ---
INSTANOVO: InstaNovo | None = None
INSTANOVO_CONFIG: DictConfig | None = None
INSTANOVOPLUS: InstaNovoPlus | None = None
INSTANOVOPLUS_CONFIG: DictConfig | None = None
KNAPSACK: Knapsack | None = None
RESIDUE_SET: ResidueSet | None = None

# --- Assets ---
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])

# Create gradio temporary directory
temp_dir = Path('/tmp/gradio')
if not temp_dir.exists():
    temp_dir.mkdir()

# Logging configuration
# TODO: create logfile per user/session
# see https://www.gradio.app/guides/resource-cleanup
log_file = "/tmp/instanovo_gradio_log.txt"
Path(log_file).touch()

logger = logging.getLogger("instanovo")
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)


def load_models_and_knapsack():
    """Loads the InstaNovo models and generates/loads the knapsack."""
    global INSTANOVO, KNAPSACK, INSTANOVO_CONFIG, RESIDUE_SET, INSTANOVOPLUS, INSTANOVOPLUS_CONFIG
    models_loaded = INSTANOVO is not None and INSTANOVOPLUS is not None
    if models_loaded:
        logger.info("Models already loaded.")
        # Still check knapsack if not loaded
        if KNAPSACK is None:
             logger.info("Models loaded, but knapsack needs loading/generation.")
        else:
             return # All loaded

    # --- Load Transformer Model ---
    if INSTANOVO is None:
        logger.info(f"Loading InstaNovo (Transformer) model: {TRANSFORMER_MODEL_ID} to {DEVICE}...")
        try:
            INSTANOVO, INSTANOVO_CONFIG = InstaNovo.from_pretrained(TRANSFORMER_MODEL_ID)
            INSTANOVO.to(DEVICE)
            INSTANOVO.eval()
            RESIDUE_SET = INSTANOVO.residue_set
            logger.info("Transformer model loaded successfully.")
        except Exception as e:
            logger.error(f"Error loading Transformer model: {e}")
            raise gr.Error(f"Failed to load InstaNovo Transformer model: {TRANSFORMER_MODEL_ID}. Error: {e}")
    else:
         logger.info("Transformer model already loaded.")


    # --- Load Diffusion Model ---
    if INSTANOVOPLUS is None:
        logger.info(f"Loading InstaNovo+ (Diffusion) model: {DIFFUSION_MODEL_ID} to {DEVICE}...")
        try:
            INSTANOVOPLUS, INSTANOVOPLUS_CONFIG = InstaNovoPlus.from_pretrained(DIFFUSION_MODEL_ID)
            INSTANOVOPLUS.to(DEVICE)
            INSTANOVOPLUS.eval()
            if RESIDUE_SET is not None and INSTANOVOPLUS.residues != RESIDUE_SET:
                 logger.warning("Residue sets between Transformer and Diffusion models differ. Using Transformer's set.")
            elif RESIDUE_SET is None:
                 RESIDUE_SET = INSTANOVOPLUS.residues

            logger.info("Diffusion model loaded successfully.")
        except Exception as e:
            logger.error(f"Error loading Diffusion model: {e}")
            gr.Warning(f"Failed to load InstaNovo+ Diffusion model ({DIFFUSION_MODEL_ID}): {e}. Diffusion modes will be unavailable.")
            INSTANOVOPLUS = None
    else:
        logger.info("Diffusion model already loaded.")


    # --- Knapsack Handling ---
    # Only attempt knapsack loading/generation if the Transformer model is loaded
    if INSTANOVO is not None and RESIDUE_SET is not None and KNAPSACK is None:
        knapsack_exists = (
            (KNAPSACK_DIR / "parameters.pkl").exists()
            and (KNAPSACK_DIR / "masses.npy").exists()
            and (KNAPSACK_DIR / "chart.npy").exists()
        )

        if knapsack_exists:
            logger.info(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...")
            try:
                KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR))
                logger.info("Knapsack loaded successfully.")
            except Exception as e:
                logger.info(f"Error loading knapsack: {e}. Will attempt to regenerate.")
                KNAPSACK = None
                knapsack_exists = False

        if not knapsack_exists:
            logger.info("Knapsack not found or failed to load. Generating knapsack...")
            try:
                residue_masses_for_calc = dict(RESIDUE_SET.residue_masses.copy())
                special_and_nonpositive = list(RESIDUE_SET.special_tokens) + [
                    k for k, v in residue_masses_for_calc.items() if v <= 0
                ]
                if special_and_nonpositive:
                        logger.info(f"Excluding special/non-positive mass residues from knapsack: {special_and_nonpositive}")
                        for res in set(special_and_nonpositive):
                            if res in residue_masses_for_calc:
                                del residue_masses_for_calc[res]

                full_residue_indices = RESIDUE_SET.residue_to_index

                if not residue_masses_for_calc: # Check if any residues are left for calculation
                    raise ValueError("No valid residues with positive mass found for knapsack generation.")

                logger.info("Generating knapsack. This will take a few minutes, please be patient.")
                KNAPSACK = Knapsack.construct_knapsack(
                    residue_masses=residue_masses_for_calc,
                    residue_indices=full_residue_indices,
                    max_mass=MAX_MASS,
                    mass_scale=MASS_SCALE,
                )
                logger.info(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
                KNAPSACK.save(str(KNAPSACK_DIR))
                logger.info("Knapsack saved.")
            except Exception as e:
                logger.error(f"Error generating or saving knapsack: {e}", exc_info=True)
                gr.Warning(f"Failed to generate Knapsack. Knapsack Beam Search will not be available. Error: {e}")
                KNAPSACK = None
    elif KNAPSACK is not None:
        logger.info("Knapsack already loaded.")
    elif INSTANOVO is None:
         logger.warning("Transformer model not loaded, skipping Knapsack loading/generation.")


# Load models and knapsack when the script starts
load_models_and_knapsack()


def create_inference_config(
    input_path: str,
    output_path: str,
) -> DictConfig:
    """Creates a base OmegaConf DictConfig for prediction environment."""
    base_cfg = OmegaConf.create({
        "data_path": None, "instanovo_model": TRANSFORMER_MODEL_ID,
        "instanovoplus_model": DIFFUSION_MODEL_ID, "output_path": None,
        "knapsack_path": str(KNAPSACK_DIR), "denovo": True, "refine": True,
        "num_beams": 1, "max_length": 40, "max_charge": 10,
        "isotope_error_range": [0, 1], "subset": 1.0, "use_knapsack": False,
        "save_beams": False, "batch_size": 64, "device": DEVICE, "fp16": FP16,
        "log_interval": 500, "use_basic_logging": True,
        "filter_precursor_ppm": 20, "filter_confidence": 1e-4,
        "filter_fdr_threshold": 0.05, "suppressed_residues": None,
        "disable_terminal_residues_anywhere": True,
        "residue_remapping": {
            "M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]",
            "S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]",
            "S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]",
            "Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]",
            "Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]",
            "C(+57.02)": "C[UNIMOD:4]", "(+42.01)": "[UNIMOD:1]",
            "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]",
        },
        "column_map": {
        "Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz",
        "Mass": "precursor_mass", "Charge": "precursor_charge",
        "Mass values": "mz_array", "Mass spectrum": "mz_array",
        "Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array",
        "Scan number": "scan_number"
        },
        "index_columns": [
            "scan_number", "precursor_mz", "precursor_charge",
            "retention_time", "spectrum_id", "experiment_name",
        ],
    })

    cfg_overrides = {
        "data_path": input_path, "output_path": output_path,
        "device": DEVICE, "fp16": FP16, "denovo": True,
    }
    final_cfg = OmegaConf.merge(base_cfg, cfg_overrides)
    logger.info(f"Created inference config:\n{OmegaConf.to_yaml(final_cfg)}")
    return final_cfg

def _get_transformer_decoder(selection: str, config: DictConfig) -> tuple[Decoder, int, bool]:
    """Helper to instantiate the correct transformer decoder based on selection."""
    global INSTANOVO, KNAPSACK
    if INSTANOVO is None:
        raise gr.Error("InstaNovo Transformer model not loaded.")

    num_beams = 1
    use_knapsack = False
    decoder: Decoder

    if "Greedy" in selection:
        decoder = GreedyDecoder(
            model=INSTANOVO,
            mass_scale=MASS_SCALE,
            suppressed_residues=config.get("suppressed_residues", None),
            disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True),
        )
    elif "Knapsack" in selection:
        if KNAPSACK is None:
            raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.")
        decoder = KnapsackBeamSearchDecoder(model=INSTANOVO, knapsack=KNAPSACK)
        num_beams = 5 # Default beam size for knapsack
        use_knapsack = True
    else:
        raise ValueError(f"Unknown transformer decoder selection: {selection}")

    logger.info(f"Using Transformer decoder: {type(decoder).__name__} (Num beams: {num_beams}, Use Knapsack: {use_knapsack})")
    return decoder, num_beams, use_knapsack


def run_transformer_prediction(dl, config, transformer_decoder_selection):
    """Runs prediction using only the transformer model."""
    global RESIDUE_SET
    if RESIDUE_SET is None:
        raise gr.Error("ResidueSet not loaded.")

    decoder, num_beams, use_knapsack = _get_transformer_decoder(transformer_decoder_selection, config)

    results_list: list[ScoredSequence | list] = []
    start_time = time.time()
    for i, batch in enumerate(dl):
        spectra, precursors, spectra_mask, _, _ = batch
        spectra = spectra.to(DEVICE)
        precursors = precursors.to(DEVICE)
        spectra_mask = spectra_mask.to(DEVICE)

        with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
            batch_predictions = decoder.decode(
                spectra=spectra,
                precursors=precursors,
                beam_size=num_beams,
                max_length=config.max_length,
                mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6,
                max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1,
                return_beam=False, # Only top result
            )
        results_list.extend(batch_predictions)
        if (i + 1) % 10 == 0 or (i + 1) == len(dl):
             logger.info(f"Transformer prediction: Processed batch {i+1}/{len(dl)}")

    end_time = time.time()
    logger.info(f"Transformer prediction finished in {end_time - start_time:.2f} seconds.")
    return results_list

def run_diffusion_prediction(dl, config):
    """Runs prediction using only the diffusion model."""
    global INSTANOVOPLUS, RESIDUE_SET
    if INSTANOVOPLUS is None or RESIDUE_SET is None:
        raise gr.Error("InstaNovo+ Diffusion model not loaded.")

    diffusion_decoder = DiffusionDecoder(model=INSTANOVOPLUS)
    logger.info(f"Using decoder: {type(diffusion_decoder).__name__}")

    results_sequences = []
    results_log_probs = []
    start_time = time.time()

    # Re-create dataloader iterator to get precursor info easily later
    all_batches = list(dl)

    for i, batch in enumerate(all_batches):
        spectra, precursors, spectra_mask, _, _ = batch
        spectra = spectra.to(DEVICE)
        precursors = precursors.to(DEVICE)
        spectra_mask = spectra_mask.to(DEVICE)

        with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
            batch_sequences, batch_log_probs = diffusion_decoder.decode(
                spectra=spectra,
                spectra_padding_mask=spectra_mask,
                precursors=precursors,
                initial_sequence=None,
            )
        results_sequences.extend(batch_sequences)
        results_log_probs.extend(batch_log_probs)
        if (i + 1) % 10 == 0 or (i + 1) == len(all_batches):
             logger.info(f"Diffusion prediction: Processed batch {i+1}/{len(all_batches)}")

    end_time = time.time()
    logger.info(f"Diffusion prediction finished in {end_time - start_time:.2f} seconds.")

    scored_results = []
    metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
    all_precursors = torch.cat([b[1] for b in all_batches], dim=0) # b[1] is precursors

    for idx, (seq, logp) in enumerate(zip(results_sequences, results_log_probs)):
         prec_mz = all_precursors[idx, 1].item()
         prec_ch = int(all_precursors[idx, 2].item())
         try:
             _, delta_mass_list = metrics_calc.matches_precursor(seq, prec_mz, prec_ch)
             min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan")
         except Exception as e:
             logger.info(f"Warning: Could not calculate delta mass for diffusion prediction {idx}: {e}")
             min_abs_ppm = float("nan")

         scored_results.append(
             ScoredSequence(sequence=seq, mass_error=min_abs_ppm, sequence_log_probability=logp, token_log_probabilities=[])
         )

    return scored_results


def run_refinement_prediction(dl, config, transformer_decoder_selection):
    """Runs transformer prediction followed by diffusion refinement."""
    global INSTANOVO, INSTANOVOPLUS, RESIDUE_SET, INSTANOVOPLUS_CONFIG
    if INSTANOVO is None or INSTANOVOPLUS is None or RESIDUE_SET is None or INSTANOVOPLUS_CONFIG is None:
         missing = [m for m, v in [("Transformer", INSTANOVO), ("Diffusion", INSTANOVOPLUS)] if v is None]
         raise gr.Error(f"Cannot run refinement: {', '.join(missing)} model not loaded.")

    # 1. Run Transformer Prediction (using selected decoder)
    logger.info(f"Running Transformer prediction ({transformer_decoder_selection}) for refinement...")
    transformer_decoder, num_beams, _ = _get_transformer_decoder(transformer_decoder_selection, config) # Get selected decoder
    transformer_results_list: list[ScoredSequence | list] = []

    all_batches = list(dl) # Store batches

    start_time_transformer = time.time()
    for i, batch in enumerate(all_batches):
        spectra, precursors, spectra_mask, _, _ = batch
        spectra = spectra.to(DEVICE)
        precursors = precursors.to(DEVICE)
        spectra_mask = spectra_mask.to(DEVICE)

        with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
            batch_predictions = transformer_decoder.decode(
                spectra=spectra,
                precursors=precursors,
                beam_size=num_beams, # Use selected beam size
                max_length=config.max_length,
                mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6,
                max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1,
                return_beam=False, # Only top result needed for refinement
            )
        transformer_results_list.extend(batch_predictions)
        if (i + 1) % 10 == 0 or (i + 1) == len(all_batches):
            logger.info(f"Refinement (Transformer): Processed batch {i+1}/{len(all_batches)}")

    logger.info(f"Transformer prediction for refinement finished in {time.time() - start_time_transformer:.2f} seconds.")

    # 2. Prepare Transformer Predictions as Initial Sequences for Diffusion
    logger.info("Encoding transformer predictions for diffusion input...")
    encoded_transformer_preds = []
    max_len_diffusion = INSTANOVOPLUS_CONFIG.get("max_length", 40)

    for res in transformer_results_list:
        if isinstance(res, ScoredSequence) and res.sequence:
             # Encode sequence *without* EOS for diffusion input.
             encoded = RESIDUE_SET.encode(res.sequence, add_eos=False, return_tensor='pt')
        else:
            # If transformer failed, provide a dummy PAD sequence
             encoded = torch.full((max_len_diffusion,), RESIDUE_SET.PAD_INDEX, dtype=torch.long)


        # Pad or truncate to the diffusion model's max length
        current_len = encoded.shape[0]
        if current_len > max_len_diffusion:
             logger.warning(f"Transformer prediction exceeded diffusion max length ({max_len_diffusion}). Truncating.")
             encoded = encoded[:max_len_diffusion]
        elif current_len < max_len_diffusion:
            padding = torch.full((max_len_diffusion - current_len,), RESIDUE_SET.PAD_INDEX, dtype=torch.long)
            encoded = torch.cat((encoded, padding))

        encoded_transformer_preds.append(encoded)

    if not encoded_transformer_preds:
        raise gr.Error("Transformer prediction yielded no results to refine.")
    encoded_transformer_preds_tensor = torch.stack(encoded_transformer_preds).to(DEVICE)
    logger.info(f"Encoded {encoded_transformer_preds_tensor.shape[0]} sequences for diffusion.")


    # 3. Run Diffusion Refinement
    logger.info("Running Diffusion refinement...")
    diffusion_decoder = DiffusionDecoder(model=INSTANOVOPLUS)
    refined_sequences = []
    refined_log_probs = []
    start_time_diffusion = time.time()

    current_idx = 0
    for i, batch in enumerate(all_batches):
        spectra, precursors, spectra_mask, _, _ = batch
        spectra = spectra.to(DEVICE)
        precursors = precursors.to(DEVICE)
        spectra_mask = spectra_mask.to(DEVICE)

        batch_size = spectra.shape[0]
        initial_sequences_batch = encoded_transformer_preds_tensor[current_idx : current_idx + batch_size]
        current_idx += batch_size

        if initial_sequences_batch.shape[0] != batch_size:
             logger.error(f"Batch size mismatch during refinement: expected {batch_size}, got {initial_sequences_batch.shape[0]}")
             continue # Skip batch?

        with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
            batch_refined_seqs, batch_refined_logp = diffusion_decoder.decode(
                spectra=spectra,
                spectra_padding_mask=spectra_mask,
                precursors=precursors,
                initial_sequence=initial_sequences_batch,
                start_step=DIFFUSION_START_STEP,
            )
        refined_sequences.extend(batch_refined_seqs)
        refined_log_probs.extend(batch_refined_logp)
        if (i + 1) % 10 == 0 or (i + 1) == len(all_batches):
             logger.info(f"Refinement (Diffusion): Processed batch {i+1}/{len(all_batches)}")

    logger.info(f"Diffusion refinement finished in {time.time() - start_time_diffusion:.2f} seconds.")

    # 4. Combine and Format Results
    all_precursors = torch.cat([b[1] for b in all_batches], dim=0) # b[1] is precursors
    metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
    combined_results = []
    for idx, (transformer_res, refined_seq, refined_logp) in enumerate(zip(transformer_results_list, refined_sequences, refined_log_probs)):
         prec_mz = all_precursors[idx, 1].item()
         prec_ch = int(all_precursors[idx, 2].item())
         try:
             _, delta_mass_list = metrics_calc.matches_precursor(refined_seq, prec_mz, prec_ch)
             min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan")
         except Exception as e:
             logger.info(f"Warning: Could not calculate delta mass for refined prediction {idx}: {e}")
             min_abs_ppm = float("nan")

         combined_data = {
             "transformer_prediction": "".join(transformer_res.sequence) if isinstance(transformer_res, ScoredSequence) else "",
             "transformer_log_probability": transformer_res.sequence_log_probability if isinstance(transformer_res, ScoredSequence) else float('-inf'),
             "refined_prediction": "".join(refined_seq),
             "refined_log_probability": refined_logp,
             "refined_delta_mass_ppm": min_abs_ppm,
         }
         combined_results.append(combined_data)

    return combined_results


@spaces.GPU
def predict_peptides(input_file, mode_selection, transformer_decoder_selection):
    """
    Main function to load data, select mode, run prediction, and return results.
    """
    # Ensure models are loaded
    if INSTANOVO is None or RESIDUE_SET is None:
        load_models_and_knapsack() # Try reload
        if INSTANOVO is None:
            raise gr.Error("InstaNovo Transformer model failed to load. Cannot perform prediction.")
    if ("refinement" in mode_selection or "InstaNovo+" in mode_selection) and INSTANOVOPLUS is None:
         load_models_and_knapsack() # Try reload diffusion
         if INSTANOVOPLUS is None:
             raise gr.Error("InstaNovo+ Diffusion model failed to load. Cannot perform Refinement or InstaNovo+ Only prediction.")
    if "Knapsack" in transformer_decoder_selection and KNAPSACK is None:
         load_models_and_knapsack() # Try reload knapsack
         if KNAPSACK is None:
             raise gr.Error("Knapsack failed to load. Cannot use Knapsack Beam Search.")


    if input_file is None:
        raise gr.Error("Please upload a mass spectrometry file.")

    input_path = input_file.name
    logger.info("--- New Prediction Request ---")
    logger.info(f"Input File: {input_path}")
    logger.info(f"Selected Mode: {mode_selection}")
    if "refinement" in mode_selection or "InstaNovo Only" in mode_selection:
        logger.info(f"Selected Transformer Decoder: {transformer_decoder_selection}")

    # Create temp output file
    gradio_tmp_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp")
    try:
        with tempfile.NamedTemporaryFile(dir=gradio_tmp_dir, delete=False, suffix=".csv") as temp_out:
            output_csv_path = temp_out.name
        logger.info(f"Temporary output path: {output_csv_path}")
    except Exception as e:
         logger.error(f"Failed to create temporary file in {gradio_tmp_dir}: {e}")
         raise gr.Error(f"Failed to create temporary output file: {e}")

    try:
        config = create_inference_config(input_path, output_csv_path)

        logger.info("Loading spectrum data...")
        try:
            # Load data eagerly
            sdf = SpectrumDataFrame.load(
                config.data_path, lazy=False, is_annotated=False,
                column_mapping=config.get("column_map", None), shuffle=False, verbose=True,
            )
            original_size = len(sdf)
            max_charge = config.get("max_charge", 10)
            if "precursor_charge" in sdf.df.columns:
                sdf.filter_rows(
                    lambda row: ("precursor_charge" in row and row["precursor_charge"] is not None and 0 < row["precursor_charge"] <= max_charge)
                )
                if len(sdf) < original_size:
                    logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with invalid or out-of-range charge (<=0 or >{max_charge}).")
            else:
                 logger.warning("Column 'precursor_charge' not found. Cannot filter by charge.")

            if len(sdf) == 0:
                 raise gr.Error("No valid spectra found in the uploaded file after filtering.")
            logger.info(f"Data loaded: {len(sdf)} spectra.")
            index_cols_present = [col for col in config.index_columns if col in sdf.df.columns]
            base_df_pd = sdf.df.select(index_cols_present).to_pandas()

        except Exception as e:
            logger.error(f"Error loading data: {e}", exc_info=True)
            raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}")

        if RESIDUE_SET is None: raise gr.Error("Residue set not loaded.") # Should not happen if model loaded

        # --- Prepare DataLoader ---
        # Use reverse_peptide=True for Transformer steps, False for Diffusion-only
        reverse_for_transformer = "InstaNovo+ Only" not in mode_selection
        ds = SpectrumDataset(
            sdf, RESIDUE_SET,
            INSTANOVO_CONFIG.get("n_peaks", 200) if INSTANOVO_CONFIG else 200,
            return_str=True, annotated=False,
            pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False),
            bin_spectra=config.get("conv_peak_encoder", False),
            peptide_pad_length=config.get("max_length", 40) if config.get("compile_model", False) else 0,
            reverse_peptide=reverse_for_transformer, # Key change based on mode
            diffusion="InstaNovo+ Only" in mode_selection # Signal if input is for diffusion
        )
        dl = DataLoader(ds, batch_size=config.batch_size, num_workers=0, shuffle=False, collate_fn=collate_batch)

        # --- Run Prediction ---
        results_data = None
        output_headers = index_cols_present[:]

        if "InstaNovo Only" in mode_selection:
             output_headers.extend(["prediction", "log_probability", "delta_mass_ppm", "token_log_probabilities"])
             transformer_results = run_transformer_prediction(dl, config, transformer_decoder_selection)
             results_data = []
             metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
             for i, res in enumerate(transformer_results):
                 row_data = {}
                 if isinstance(res, ScoredSequence) and res.sequence:
                     row_data["prediction"] = "".join(res.sequence)
                     row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
                     row_data["token_log_probabilities"] = ", ".join(f"{p:.4f}" for p in res.token_log_probabilities)
                     try:
                         prec_mz = base_df_pd.loc[i, "precursor_mz"]
                         prec_ch = base_df_pd.loc[i, "precursor_charge"]
                         _, delta_mass_list = metrics_calc.matches_precursor(res.sequence, prec_mz, prec_ch)
                         min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float("nan")
                         row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
                     except Exception as e:
                          logger.warning(f"Could not calculate delta mass for Tx prediction {i}: {e}")
                          row_data["delta_mass_ppm"] = "N/A"
                 else:
                     row_data.update({k: "N/A" for k in ["prediction", "log_probability", "delta_mass_ppm", "token_log_probabilities"]})
                     row_data["prediction"] = "" # Ensure empty string for failed preds
                     row_data["token_log_probabilities"] = ""
                 results_data.append(row_data)

        elif "InstaNovo+ Only" in mode_selection:
             output_headers.extend(["prediction", "log_probability", "delta_mass_ppm"])
             diffusion_results = run_diffusion_prediction(dl, config)
             results_data = []
             for res in diffusion_results:
                 row_data = {}
                 if isinstance(res, ScoredSequence) and res.sequence:
                     row_data["prediction"] = "".join(res.sequence)
                     row_data["log_probability"] = f"{res.sequence_log_probability:.4f}" # Avg loss
                     row_data["delta_mass_ppm"] = f"{res.mass_error:.2f}" if not np.isnan(res.mass_error) else "N/A" # ppm
                 else:
                     row_data.update({k: "N/A" for k in ["prediction", "log_probability", "delta_mass_ppm"]})
                     row_data["prediction"] = ""
                 results_data.append(row_data)

        elif "refinement" in mode_selection:
             output_headers.extend([
                 "transformer_prediction", "transformer_log_probability",
                 "refined_prediction", "refined_log_probability", "refined_delta_mass_ppm"
             ])
             # Pass the selected transformer decoder to the refinement function
             results_data = run_refinement_prediction(dl, config, transformer_decoder_selection)
             for row in results_data:
                 # Format numbers after getting the list of dicts
                 row["transformer_log_probability"] = f"{row['transformer_log_probability']:.4f}" if isinstance(row['transformer_log_probability'], (float, int)) else "N/A"
                 row["refined_log_probability"] = f"{row['refined_log_probability']:.4f}" if isinstance(row['refined_log_probability'], (float, int)) else "N/A"
                 row["refined_delta_mass_ppm"] = f"{row['refined_delta_mass_ppm']:.2f}" if isinstance(row['refined_delta_mass_ppm'], (float, int)) and not np.isnan(row['refined_delta_mass_ppm']) else "N/A"


        else:
            raise ValueError(f"Unknown mode selection: {mode_selection}")


        # --- Combine, Save, Return ---
        logger.info("Combining results...")
        if results_data is None: raise gr.Error("Prediction did not produce results.")

        results_df = pl.DataFrame(results_data)
        # Ensure base_df_pd has unique index if using join, or just concat horizontally if order is guaranteed
        base_df_pl = pl.from_pandas(base_df_pd.reset_index(drop=True))

        # Simple horizontal concat assuming order is preserved by dataloader (shuffle=False)
        if len(base_df_pl) == len(results_df):
            final_df = pl.concat([base_df_pl, results_df], how="horizontal")
        else:
            logger.error(f"Length mismatch between base data ({len(base_df_pl)}) and results ({len(results_df)}). Cannot reliably combine.")
            # Fallback or error? Let's just use results for now, but log error.
            final_df = results_df # Display only results in case of mismatch

        logger.info(f"Saving full results to {output_csv_path}...")
        final_df.write_csv(output_csv_path)
        logger.info("Save complete.")

        # Select display columns - make sure they exist in final_df
        display_cols_final = [col for col in output_headers if col in final_df.columns]
        display_df = final_df.select(display_cols_final)

        logger.info("--- Prediction Request Complete ---")
        return display_df.to_pandas(), output_csv_path

    except Exception as e:
        logger.error(f"An error occurred during prediction: {e}", exc_info=True)
        if 'output_csv_path' in locals() and os.path.exists(output_csv_path):
            try:
                os.remove(output_csv_path)
                logger.info(f"Removed temporary file {output_csv_path}")
            except OSError:
                 logger.error(f"Failed to remove temporary file {output_csv_path}")
        raise gr.Error(f"Prediction failed: {e}")


# --- Gradio Interface ---
css = """
.gradio-container { font-family: sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
footer { display: none !important; }
.logo-container img { margin-bottom: 1rem; }
.feedback { font-size: 0.9rem; color: gray; }
"""

with gr.Blocks(
    css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
) as demo:
    gr.Markdown(
        """
        <div style="text-align: center;" class="logo-container">
          <img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
        </div>
        """,
        elem_classes="logo-container",
    )

    gr.Markdown(
        f"""
        # ๐Ÿš€ _De Novo_ Peptide Sequencing with InstaNovo and InstaNovo+
        Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions.
        Choose your prediction method and decoding options.

         **Notes:**
         *   Predictions use version `{TRANSFORMER_MODEL_ID}` for the transformer-based InstaNovo model and version `{DIFFUSION_MODEL_ID}` for the diffusion-based InstaNovo+ model.
         *   The InstaNovo+ model `{DIFFUSION_MODEL_ID}` is an alpha release.
         * **Predction Modes:**
             *   **InstaNovo with InstaNovo+ refinement** Runs initial prediction with the selected Transformer method (Greedy/Knapsack), then refines using InstaNovo+.
             *   **InstaNovo Only:** Uses only the Transformer with the selected decoding method.
             *   **InstaNovo+ Only:** Predicts directly using the Diffusion model (alpha release).
        * **Transformer Decoding Methods:**
             *   **Greedy Search:** use this for optimal performance, has similar performance as Knapsack Beam Search at 5% FDR.
             *   **Knapsack Beam Search:** use this for the best results and highest peptide recall, but is about 10x slower than Greedy Search.
         *   Check logs for progress, especially for large files or slower methods.

         This Hugging Face Space is powered by a [ZeroGPU ](https://huggingface.co/docs/hub/en/spaces-zerogpu), which is free but **limited to 5 minutes per day per user**โ€”so if you test with your own files, please use only small files.
         """,
         elem_classes="feedback"
    )
    with gr.Row():
        with gr.Column(scale=1):
            input_file = gr.File(
                label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
                file_types=[".mgf", ".mzml", ".mzxml"],
                scale=1
            )
            mode_selection = gr.Radio(
                [
                    "InstaNovo with InstaNovo+ refinement (Default, Recommended)",
                    "InstaNovo Only (Transformer)",
                    "InstaNovo+ Only (Diffusion, Alpha release)",
                ],
                label="Prediction Mode",
                value="InstaNovo with InstaNovo+ refinement (Default, Recommended)",
                scale=1
            )
            # Transformer decoder selection - visible for relevant modes
            transformer_decoder_selection = gr.Radio(
                [
                    "Greedy Search (Fast)",
                    "Knapsack Beam Search (Accurate, Slower)"
                ],
                label="Transformer Decoding Method",
                value="Greedy Search (Fast)",
                visible=True, # Start visible as default mode uses it
                interactive=True,
                scale=1
            )

            submit_btn = gr.Button("Predict Sequences", variant="primary")

            # --- Control Visibility & Choices ---
            def update_transformer_options(mode):
                # Show decoder selection if mode uses the transformer
                show_decoder = "InstaNovo+ Only" not in mode
                choices = ["Greedy Search (Fast)", "Knapsack Beam Search (Accurate, Slower)"]
                current_value = "Greedy Search (Fast)" # Default reset value
                return gr.update(visible=show_decoder, choices=choices, value=current_value)

            mode_selection.change(
                fn=update_transformer_options,
                inputs=mode_selection,
                outputs=transformer_decoder_selection,
            )


        with gr.Column(scale=2):
            output_df = gr.DataFrame(
                label="Prediction Results Preview",
                headers=["scan_number", "prediction", "log_probability", "delta_mass_ppm"]
            )
            output_file = gr.File(label="Download Full Results (CSV)")

    submit_btn.click(
        predict_peptides,
        inputs=[input_file, mode_selection, transformer_decoder_selection],
        outputs=[output_df, output_file],
    )

    gr.Examples(
        [
            ["assets/sample_spectra.mgf", "InstaNovo with InstaNovo+ refinement (Default, Recommended)", "Greedy Search (Fast)"],
            ["assets/sample_spectra.mgf", "InstaNovo with InstaNovo+ refinement (Default, Recommended)", "Knapsack Beam Search (Accurate, Slower)"],
            ["assets/sample_spectra.mgf", "InstaNovo Only (Transformer)", "Greedy Search (Fast)"],
            ["assets/sample_spectra.mgf", "InstaNovo Only (Transformer)", "Knapsack Beam Search (Accurate, Slower)"],
            ["assets/sample_spectra.mgf", "InstaNovo+ Only (Diffusion, Alpha release)", ""],
        ],
        inputs=[input_file, mode_selection, transformer_decoder_selection],
        # outputs=[output_df, output_file],
        cache_examples=False,
        label="Example Usage:",
    )

    with gr.Accordion("Application Logs", open=True):
        log_display = Log(log_file, dark=True, height=300)


    gr.Markdown(""" **Links:**
        * [GitHub Repository for InstaNovo](https://github.com/instadeepai/instanovo)
        * [InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments](https://www.nature.com/articles/s42256-025-01019-5), Eloff, Kalogeropoulos et al. 2025, Nature Machine Intelligence.
        
        If you use InstaNovo in your research, please cite:""")
    
    gr.Markdown(
        value="""
```
@article{eloff_kalogeropoulos_2025_instanovo,
	title        = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments},
	author       = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and
                    Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and
                    Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M.
                    Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and
                    Timothy P. Jenkins},
	year         = 2025,
	month        = {Mar},
	day          = 31,
	journal      = {Nature Machine Intelligence},
	doi          = {10.1038/s42256-025-01019-5},
	url          = {https://www.nature.com/articles/s42256-025-01019-5}
}
```
""",
        show_copy_button=True
    )

# --- Launch the App ---
if __name__ == "__main__":
    # https://www.gradio.app/guides/setting-up-a-demo-for-maximum-performance
    demo.queue(default_concurrency_limit=5)
    # Set share=True for temporary public link if running locally
    # Set server_name="0.0.0.0" to allow access from network if needed
    # demo.launch(server_name="0.0.0.0", server_port=7860)
    # For Hugging Face Spaces, just demo.launch() is usually sufficient
    demo.launch(debug=True, show_error=True)
    # demo.launch(share=True)  # For local testing with public URL