File size: 3,405 Bytes
aac338f
0824d60
 
a51b33a
0824d60
7e9b981
0824d60
aac338f
 
c61580a
0824d60
e95bdda
4ff15b7
7e9b981
 
 
a51b33a
0824d60
e95bdda
7e9b981
e95bdda
7e9b981
e95bdda
7e9b981
 
a51b33a
e95bdda
7e9b981
c61580a
7e9b981
 
0824d60
4ff15b7
7e9b981
0824d60
c61580a
0824d60
7e9b981
0824d60
 
e95bdda
7e9b981
 
e95bdda
 
7e9b981
 
e95bdda
 
a51b33a
bbc9212
0824d60
e95bdda
 
7e9b981
 
e95bdda
 
7e9b981
e95bdda
 
7e9b981
bbc9212
7e9b981
a51b33a
0824d60
e95bdda
 
 
7e9b981
0824d60
c61580a
0824d60
7e9b981
 
 
 
e95bdda
 
 
 
 
a51b33a
e95bdda
 
a51b33a
7e9b981
e95bdda
7e9b981
 
aac338f
7e9b981
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from random import randint
from all_models import models

from externalmod import gr_Interface_load, randomize_seed

import asyncio
import os
from threading import RLock

lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None

def load_fn(models):
    global models_load
    models_load = {}
    
    for model in models:
        if model not in models_load:
            try:
                print(f"Loading model: {model}")
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
                print(f"Loaded model: {model}")
            except Exception as error:
                print(f"Error loading model {model}: {error}")
                m = gr.Interface(lambda: None, ['text'], ['image'])
            models_load[model] = m

print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

num_models = 6
default_models = models[:num_models]
inference_timeout = 600
MAX_SEED = 3999999999
starting_seed = randint(1941, 2024)
print(f"Starting seed: {starting_seed}")

def extend_choices(choices):
    return choices[:num_models] + (num_models - len(choices)) * ['NA']

def update_imgbox(choices):
    choices_plus = extend_choices(choices)
    return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]

async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
    kwargs = {"seed": seed}
    print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
    task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
    try:
        result = await asyncio.wait_for(task, timeout=timeout)
    except Exception as e:
        print(f"Error during inference: {e}")
        if not task.done():
            task.cancel()
        return None
    if task.done() and result:
        with lock:
            result.save("image.png")
            return "image.png"
    return None

def gen_fnseed(model_str, prompt, seed=1):
    if model_str == 'NA':
        return None
    loop = asyncio.new_event_loop()
    result = loop.run_until_complete(infer(model_str, prompt, seed))
    loop.close()
    return result

print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Compare-6</h1></center>")
    with gr.Tab('Compare-6'):
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        gen_button = gr.Button('Generate up to 6 images in up to 3 minutes total')
        seed = gr.Slider("Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
        seed_rand = gr.Button("Randomize Seed 🎲")
        seed_rand.click(randomize_seed, None, [seed])
        output = [gr.Image(label=m) for m in default_models]
        current_models = [gr.Textbox(m, visible=False) for m in default_models]
        
        for m, o in zip(current_models, output):
            gen_button.click(gen_fnseed, [m, txt_input, seed], o)
        
        with gr.Accordion('Model selection'):
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
            model_choice.change(update_imgbox, model_choice, output)
            model_choice.change(extend_choices, model_choice, current_models)

demo.queue(default_concurrency_limit=200, max_size=200)
demo.launch(show_api=False, max_threads=400)