File size: 3,629 Bytes
aac338f
0824d60
 
 
 
aac338f
 
76e7d38
c61580a
76e7d38
0824d60
4ff15b7
76e7d38
 
 
 
7e9b981
 
 
0824d60
76e7d38
7e9b981
76e7d38
7e9b981
76e7d38
 
 
 
7e9b981
c61580a
7e9b981
 
0824d60
7e1de43
0824d60
76e7d38
1a2b6b0
76e7d38
0824d60
 
76e7d38
7e9b981
 
76e7d38
 
7e9b981
 
76e7d38
 
 
 
bbc9212
76e7d38
 
 
 
 
 
 
 
 
7e9b981
bbc9212
7e9b981
a51b33a
0824d60
3b0a2ee
 
76e7d38
3b0a2ee
76e7d38
 
3b0a2ee
 
 
 
 
c61580a
3b0a2ee
 
 
7e9b981
76e7d38
1a2b6b0
76e7d38
 
3b0a2ee
76e7d38
 
 
 
 
 
 
7e9b981
76e7d38
7e9b981
 
3b0a2ee
76e7d38
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load, randomize_seed
import asyncio
import os
from threading import RLock
from pathlib import Path

# Create a lock for thread safety
lock = RLock()

# Load Hugging Face token from environment variable (if available)
HF_TOKEN = os.getenv("HF_TOKEN")

# Function to load models
def load_fn(models):
    global models_load
    models_load = {}
    for model in models:
        if model not in models_load:
            try:
                print(f"Loading model: {model}")
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
                models_load[model] = m
            except Exception as e:
                print(f"Error loading model {model}: {e}")
                models_load[model] = gr.Interface(lambda: None, ['text'], ['image'])

print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

num_models = 1
starting_seed = randint(1941, 2024)
MAX_SEED = 3999999999
MAX_SEED = int(MAX_SEED)
inference_timeout = 600

def extend_choices(choices):
    return choices[:num_models] + ['NA'] * (num_models - len(choices))

def update_imgbox(choices):
    choices_extended = extend_choices(choices)
    return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_extended]

async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
    if model_str not in models_load:
        return None
    
    kwargs = {"seed": seed}
    try:
        print(f"Running inference for model: {model_str} with prompt: '{prompt}'")
        result = await asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN)
        if result:
            with lock:
                png_path = "image.png"
                result.save(png_path)
                return str(Path(png_path).resolve())
    except Exception as e:
        print(f"Error during inference for {model_str}: {e}")
    return None

def gen_fnseed(model_str, prompt, seed=1):
    if model_str == 'NA':
        return None
    try:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(infer(model_str, prompt, seed, inference_timeout))
    except Exception as e:
        print(f"Error generating image for {model_str}: {e}")
        result = None
    finally:
        loop.close()
    return result

print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Compare-6</h1></center>")
    with gr.Tab('Compare-6'):
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        gen_button = gr.Button('Generate up to 6 images')
        seed = gr.Slider(label="Seed (0 to MAX)", minimum=0, maximum=MAX_SEED, value=starting_seed)
        seed_rand = gr.Button("Randomize Seed 🎲")
        
        seed_rand.click(randomize_seed, None, [seed], queue=False)
        
        output = [gr.Image(label=m) for m in models[:num_models]]
        current_models = [gr.Textbox(m, visible=False) for m in models[:num_models]]
        
        for m, o in zip(current_models, output):
            gen_button.click(gen_fnseed, inputs=[m, txt_input, seed], outputs=[o], queue=False)
        
        with gr.Accordion('Model selection'):
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models')
            model_choice.change(update_imgbox, model_choice, output)
            model_choice.change(extend_choices, model_choice, current_models)

demo.queue(default_concurrency_limit=50, max_size=100)
demo.launch(show_api=False)