File size: 3,689 Bytes
aac338f
fca3535
 
 
b44117d
 
 
4ff15b7
2b9811d
b44117d
76e7d38
2b9811d
 
 
 
7e9b981
 
 
0824d60
2b9811d
7e9b981
2b9811d
b44117d
2b9811d
b44117d
2b9811d
b44117d
2b9811d
b44117d
2b9811d
c61580a
7e9b981
 
0824d60
2b9811d
b44117d
 
0824d60
2b9811d
0824d60
b44117d
 
7e9b981
2b9811d
7e9b981
2b9811d
 
b44117d
2b9811d
 
 
 
bbc9212
2b9811d
 
b44117d
2b9811d
 
 
 
 
 
 
7e9b981
bbc9212
2b9811d
7e9b981
a51b33a
0824d60
2b9811d
3b0a2ee
2b9811d
c61580a
3b0a2ee
2b9811d
 
7e9b981
2b9811d
 
 
 
 
 
 
 
 
 
 
7e9b981
2b9811d
7e9b981
 
b44117d
2b9811d
 
b44117d
2b9811d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load, randomize_seed
import asyncio
import os
from threading import RLock

# Create a lock for thread safety
lock = RLock()

# Load Hugging Face token from environment variables
HF_TOKEN = os.environ.get("HF_TOKEN")

# Function to load models
def load_fn(models):
    global models_load
    models_load = {}
    for model in models:
        if model not in models_load:
            try:
                print(f"Loading model: {model}")
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
                print(f"Successfully loaded: {model}")
            except Exception as error:
                print(f"Error loading {model}: {error}")
                m = gr.Interface(lambda: None, ['text'], ['image'])
            models_load[model] = m

# Load models
print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

num_models = min(3, len(models))  # Reduce to 3 models to prevent GPU overloading
starting_seed = randint(1941, 2024)
print(f"Starting seed: {starting_seed}")

# Extend choices to match num_models
def extend_choices(choices):
    extended = choices[:num_models] + (num_models - len(choices[:num_models])) * ['NA']
    return extended

# Update image boxes based on selected models
def update_imgbox(choices):
    choices_extended = extend_choices(choices)
    return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_extended]

# Asynchronous inference function
async def infer(model_str, prompt, seed=1, timeout=600):
    if model_str == 'NA':
        return None
    try:
        print(f"Running inference on {model_str} with prompt: '{prompt}'")
        task = asyncio.to_thread(models_load[model_str].fn, prompt=prompt, seed=seed, token=HF_TOKEN)
        result = await asyncio.wait_for(task, timeout=timeout)
        if result:
            with lock:
                image_path = "image.png"
                result.save(image_path)
                return image_path
    except Exception as e:
        print(f"Error in inference for {model_str}: {e}")
    return None

# Wrapper function for inference
def gen_fnseed(model_str, prompt, seed=1):
    if model_str == 'NA':
        return None
    return asyncio.run(infer(model_str, prompt, seed))

# Create Gradio interface
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Compare-3</h1></center>")
    with gr.Tab('Compare-3'):
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        gen_button = gr.Button('Generate images')
        seed = gr.Slider(label="Seed (max 3999999999)", minimum=0, maximum=3999999999, step=1, value=starting_seed)
        seed_rand = gr.Button("Randomize Seed 🎲")
        seed_rand.click(randomize_seed, None, [seed])
        
        output = [gr.Image(label=m) for m in models[:num_models]]
        current_models = [gr.Textbox(m, visible=False) for m in models[:num_models]]
        
        for m, o in zip(current_models, output):
            gen_button.click(gen_fnseed, inputs=[m, txt_input, seed], outputs=[o])
        
        with gr.Accordion('Model selection'):
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=models[:num_models])
            model_choice.change(update_imgbox, model_choice, output)
            model_choice.change(extend_choices, model_choice, current_models)

# Reduce concurrency to avoid T4 overload
demo.queue(default_concurrency_limit=50, max_size=100)
print("Launching Gradio interface...")
demo.launch(show_api=False, max_threads=50, debug=True)