File size: 6,848 Bytes
aac338f
0824d60
 
 
 
aac338f
 
c61580a
 
0824d60
4ff15b7
c61580a
4ff15b7
aac338f
c61580a
0824d60
 
 
4ff15b7
c61580a
0824d60
 
 
c61580a
0824d60
c61580a
0824d60
c61580a
0824d60
 
aac338f
c61580a
0824d60
c61580a
0824d60
4ff15b7
0824d60
c61580a
0824d60
 
c61580a
0824d60
c61580a
0824d60
c61580a
0824d60
c61580a
 
 
 
0824d60
c61580a
0824d60
c61580a
0824d60
c61580a
 
 
0824d60
c61580a
4ff15b7
0824d60
 
 
 
4ff15b7
 
 
 
 
 
 
 
 
 
 
 
bbc9212
0824d60
c61580a
0824d60
c61580a
 
 
 
0824d60
 
 
4ff15b7
0824d60
 
c61580a
bbc9212
c61580a
0824d60
bbc9212
c61580a
4ff15b7
0824d60
c61580a
0824d60
 
4ff15b7
0824d60
4ff15b7
 
 
0824d60
c61580a
0824d60
 
 
c61580a
0824d60
 
c61580a
 
0824d60
ba9775e
c61580a
aac338f
0824d60
 
4ff15b7
 
0824d60
4ff15b7
 
 
 
 
e7c4130
0824d60
c61580a
 
 
0824d60
c61580a
4ff15b7
 
 
 
 
 
 
 
0824d60
4ff15b7
0824d60
 
 
4ff15b7
aac338f
c61580a
0824d60
c61580a
 
4ff15b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load, randomize_seed
import asyncio
import os
from threading import RLock

# Create a lock to ensure thread safety when accessing shared resources
lock = RLock()

# Load Hugging Face token from environment variable, if available
HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None  # If private or gated models aren't used, ENV setting is unnecessary.

# Function to load all models specified in the 'models' list
def load_fn(models):
    global models_load
    models_load = {}

    # Iterate through all models to load them
    for model in models:
        if model not in models_load.keys():
            try:
                print(f"Attempting to load model: {model}")
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
                print(f"Successfully loaded model: {model}")
            except Exception as error:
                print(f"Error loading model {model}: {error}")
                m = gr.Interface(lambda: None, ['text'], ['image'])
            models_load.update({model: m})

print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

num_models = 6

# Set the default models to use for inference
default_models = models[:num_models]
inference_timeout = 600
MAX_SEED = 3999999999
starting_seed = randint(1941, 2024)
print(f"Starting seed: {starting_seed}")

# Extend the choices list to ensure it contains 'num_models' elements
def extend_choices(choices):
    print(f"Extending choices: {choices}")
    extended = choices[:num_models] + (num_models - len(choices[:num_models])) * ['NA']
    print(f"Extended choices: {extended}")
    return extended

# Update the image boxes based on selected models
def update_imgbox(choices):
    print(f"Updating image boxes with choices: {choices}")
    choices_plus = extend_choices(choices[:num_models])
    imgboxes = [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
    print(f"Updated image boxes: {imgboxes}")
    return imgboxes

# Asynchronous function to perform inference on a given model
async def infer(model_str, prompt, seed=1, batch_size=1, output_format="PNG", priority="medium", timeout=inference_timeout):
    from pathlib import Path
    kwargs = {}
    noise = ""
    kwargs["seed"] = seed
    kwargs["batch_size"] = batch_size
    kwargs["priority"] = priority
    print(f"Starting inference for model: {model_str} with prompt: '{prompt}' and seed: {seed}, batch_size: {batch_size}, priority: {priority}")
    task = asyncio.create_task(
        asyncio.to_thread(
            models_load[model_str].fn,
            prompt=f'{prompt} {noise}',
            **kwargs,
            token=HF_TOKEN
        )
    )
    await asyncio.sleep(0)
    try:
        result = await asyncio.wait_for(task, timeout=timeout)
        print(f"Inference completed for model: {model_str}")
    except (Exception, asyncio.TimeoutError) as e:
        print(f"Error during inference for model {model_str}: {e}")
        if not task.done():
            task.cancel()
            print(f"Task cancelled for model: {model_str}")
        result = None
    if task.done() and result is not None:
        with lock:
            png_path = f"image.{output_format.lower()}"
            result.save(png_path)
            image = str(Path(png_path).resolve())
            print(f"Result saved as image: {image}")
        return image
    print(f"No result for model: {model_str}")
    return None

# Function to generate an image based on the given model, prompt, and seed
def gen_fnseed(model_str, prompt, seed=1, batch_size=1, output_format="PNG", priority="medium"):
    if model_str == 'NA':
        print(f"Model is 'NA', skipping generation.")
        return None
    try:
        print(f"Generating image for model: {model_str} with prompt: '{prompt}', seed: {seed}, batch_size: {batch_size}, priority: {priority}")
        loop = asyncio.new_event_loop()
        result = loop.run_until_complete(
            infer(model_str, prompt, seed, batch_size=batch_size, output_format=output_format, priority=priority)
        )
    except (Exception, asyncio.CancelledError) as e:
        print(f"Error during generation for model {model_str}: {e}")
        result = None
    finally:
        loop.close()
        print(f"Event loop closed for model: {model_str}")
    return result

# Create the Gradio Blocks interface with a custom theme
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Multi-models-prompt-to-image-generation</h1></center>")
    with gr.Tab('Compare-6'):
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        gen_button = gr.Button('Generate up to 6 images in up to 3 minutes total')
        with gr.Row():
            seed = gr.Slider(label="Seed (max 3999999999)", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed, scale=3)
            seed_rand = gr.Button("Randomize Seed 🎲", size="sm", variant="secondary", scale=1)
        seed_rand.click(randomize_seed, None, [seed], queue=False)

        # Add batch size slider
        batch_size_slider = gr.Slider(label="Batch Size", minimum=1, maximum=10, step=1, value=1)
        output_format_dropdown = gr.Dropdown(["PNG", "JPEG"], label="Output Format", value="PNG")
        priority_dropdown = gr.Dropdown(["low", "medium", "high"], label="Model Priority", value="medium")

        with gr.Row():
            output = [gr.Image(label=m, min_width=480) for m in default_models]
            current_models = [gr.Textbox(m, visible=False) for m in default_models]
            
            for m, o in zip(current_models, output):
                print(f"Setting up generation event for model: {m.value}")
                gen_event = gr.on(
                    triggers=[gen_button.click, txt_input.submit],
                    fn=gen_fnseed,
                    inputs=[m, txt_input, seed, batch_size_slider, output_format_dropdown, priority_dropdown],
                    outputs=[o],
                    concurrency_limit=None,
                    queue=False
                )
        with gr.Accordion('Model selection'):
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {int(num_models)} different models!', value=default_models, interactive=True)
            model_choice.change(update_imgbox, model_choice, output)
            model_choice.change(extend_choices, model_choice, current_models)
        with gr.Row():
            gr.HTML("<p>Additional UI elements can go here</p>")

print("Setting up queue...")
demo.queue(default_concurrency_limit=200, max_size=200)
print("Launching Gradio interface...")
demo.launch(show_api=False, max_threads=400)
print("Gradio interface launched successfully.")