File size: 3,345 Bytes
aac338f
0824d60
 
a51b33a
0824d60
7e9b981
0824d60
aac338f
 
c61580a
0824d60
e95bdda
4ff15b7
7e9b981
 
 
2d762c7
0824d60
e95bdda
7e9b981
e95bdda
7e9b981
e95bdda
7e9b981
 
2d762c7
e95bdda
7e9b981
c61580a
7e9b981
 
0824d60
f36ee15
7e9b981
0824d60
c61580a
0824d60
7e9b981
0824d60
 
e95bdda
7e9b981
 
e95bdda
 
7e9b981
 
2d762c7
 
 
 
e95bdda
 
2d762c7
bbc9212
2d762c7
 
 
 
 
 
 
 
 
e95bdda
 
2d762c7
7e9b981
bbc9212
7e9b981
a51b33a
0824d60
2d762c7
0824d60
c61580a
e0e8379
957f949
2d762c7
9bfd4be
7e9b981
f1176ed
e95bdda
 
 
2d762c7
e95bdda
 
2d762c7
e95bdda
 
2d762c7
7e9b981
e95bdda
7e9b981
 
aac338f
2d762c7
7e9b981
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from random import randint
from all_models import models

from externalmod import gr_Interface_load, randomize_seed

import asyncio
import os
from threading import RLock

lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None

def load_fn(models):
    global models_load
    models_load = {}

    for model in models:
        if model not in models_load:
            try:
                print(f"Loading model: {model}")
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
                print(f"Loaded model: {model}")
            except Exception as error:
                print(f"Error loading model {model}: {error}")
                m = None  # Avoid using gr.Interface here
            models_load[model] = m

print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

num_models = 6
default_models = models[:num_models]
inference_timeout = 600
MAX_SEED = 3999999999
starting_seed = randint(1941, 2024)
print(f"Starting seed: {starting_seed}")

def extend_choices(choices):
    return choices[:num_models] + (num_models - len(choices)) * ['NA']

def update_imgbox(choices):
    choices_plus = extend_choices(choices)
    return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]

async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
    if model_str not in models_load or models_load[model_str] is None:
        print(f"Model {model_str} is not available.")
        return None
    
    kwargs = {"seed": seed}
    print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")

    try:
        result = await asyncio.wait_for(
            asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs),
            timeout=timeout
        )
        if result:
            save_path = "image.png"
            with lock:
                result.save(save_path)
            return save_path
    except Exception as e:
        print(f"Error during inference: {e}")
    
    return None

def gen_fnseed(model_str, prompt, seed=1):
    if model_str == 'NA':
        return None
    return asyncio.run(infer(model_str, prompt, seed))

print("Creating Gradio interface...")
with gr.Blocks(theme="gradio/soft") as demo:
    gr.HTML("<center><h1>TEXT-IMAGE-USING-MULTIMODELS</h1></center>")

    with gr.Tab():
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        gen_button = gr.Button('Generate')
        seed = gr.Slider("Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
        seed_rand = gr.Button("Randomize Seed 🎲")
        seed_rand.click(randomize_seed, None, [seed])

        output = [gr.Image(label=m) for m in default_models]
        current_models = [gr.Textbox(m, visible=False) for m in default_models]

        for m, o in zip(current_models, output):
            gen_button.click(gen_fnseed, [m, txt_input, seed], o)

        with gr.Accordion('Model selection'):
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
            model_choice.change(update_imgbox, model_choice, output)
            model_choice.change(extend_choices, model_choice, current_models)

demo.queue(default_concurrency_limit=500, max_size=500)
demo.launch(show_api=False, max_threads=400)