File size: 3,714 Bytes
aac338f
0824d60
 
 
7e9b981
0824d60
aac338f
 
c61580a
0824d60
e95bdda
4ff15b7
7e9b981
 
 
a9293ca
0824d60
e95bdda
7e9b981
e95bdda
7e9b981
a9293ca
 
e95bdda
7e9b981
 
a9293ca
e95bdda
7e9b981
c61580a
7e9b981
 
0824d60
4ff15b7
7e9b981
0824d60
c61580a
0824d60
7e9b981
0824d60
 
e95bdda
7e9b981
 
e95bdda
 
7e9b981
 
a9293ca
 
 
 
 
e95bdda
a9293ca
e95bdda
a9293ca
bbc9212
a9293ca
0824d60
e95bdda
 
7e9b981
 
e95bdda
a9293ca
e95bdda
7e9b981
e95bdda
 
7e9b981
bbc9212
7e9b981
a9293ca
0824d60
e95bdda
 
 
7e9b981
0824d60
c61580a
0824d60
7e9b981
 
 
 
e95bdda
 
 
 
 
a9293ca
e95bdda
 
a9293ca
7e9b981
e95bdda
7e9b981
 
aac338f
7e9b981
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load, randomize_seed

import asyncio
import os
from threading import RLock

lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None

def load_fn(models):
    global models_load
    models_load = {}

    for model in models:
        if model not in models_load:
            try:
                print(f"Loading model: {model}")
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
                if not hasattr(m, 'predict'):
                    raise ValueError(f"Model {model} does not have a 'predict' method.")
                print(f"Loaded model: {model}")
            except Exception as error:
                print(f"Error loading model {model}: {error}")
                m = None  # Ensure failed models are not stored
            models_load[model] = m

print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

num_models = 6
default_models = models[:num_models]
inference_timeout = 600
MAX_SEED = 3999999999
starting_seed = randint(1941, 2024)
print(f"Starting seed: {starting_seed}")

def extend_choices(choices):
    return choices[:num_models] + (num_models - len(choices)) * ['NA']

def update_imgbox(choices):
    choices_plus = extend_choices(choices)
    return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]

async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
    if model_str not in models_load or models_load[model_str] is None:
        print(f"Model {model_str} is not available.")
        return None
    
    model = models_load[model_str]
    kwargs = {"seed": seed}

    print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
    
    try:
        task = asyncio.create_task(asyncio.to_thread(model.predict, prompt, **kwargs))
        result = await asyncio.wait_for(task, timeout=timeout)
    except Exception as e:
        print(f"Error during inference: {e}")
        if not task.done():
            task.cancel()
        return None

    if task.done() and result:
        with lock:
            result.save("image.png")
            return "image.png"
    return None

def gen_fnseed(model_str, prompt, seed=1):
    if model_str == 'NA' or models_load.get(model_str) is None:
        return None
    loop = asyncio.new_event_loop()
    result = loop.run_until_complete(infer(model_str, prompt, seed))
    loop.close()
    return result

print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Compare-6</h1></center>")
    with gr.Tab('Compare-6'):
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        gen_button = gr.Button('Generate up to 6 images in up to 3 minutes total')
        seed = gr.Slider("Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
        seed_rand = gr.Button("Randomize Seed 🎲")
        seed_rand.click(randomize_seed, None, [seed])
        output = [gr.Image(label=m) for m in default_models]
        current_models = [gr.Textbox(m, visible=False) for m in default_models]

        for m, o in zip(current_models, output):
            gen_button.click(gen_fnseed, [m, txt_input, seed], o)

        with gr.Accordion('Model selection'):
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
            model_choice.change(update_imgbox, model_choice, output)
            model_choice.change(extend_choices, model_choice, current_models)

demo.queue(default_concurrency_limit=200, max_size=200)
demo.launch(show_api=False, max_threads=400)