File size: 4,702 Bytes
aac338f
0824d60
 
 
 
aac338f
 
c61580a
 
0824d60
4ff15b7
7cb0a88
 
aac338f
7cb0a88
 
 
 
0824d60
7cb0a88
 
 
 
 
 
 
 
aac338f
c61580a
7cb0a88
 
0824d60
7cb0a88
4ff15b7
0824d60
c61580a
0824d60
 
7cb0a88
0824d60
7cb0a88
 
 
 
bbc9212
7cb0a88
 
 
 
 
 
 
 
 
 
 
0824d60
7cb0a88
 
 
 
bbc9212
7cb0a88
 
 
 
0824d60
7cb0a88
0824d60
 
7cb0a88
4ff15b7
7cb0a88
4ff15b7
7cb0a88
 
 
 
 
 
 
0824d60
 
 
7cb0a88
 
 
c61580a
0824d60
7cb0a88
 
 
 
 
 
 
0824d60
7cb0a88
 
 
 
 
4ff15b7
7cb0a88
e7c4130
7cb0a88
0824d60
7cb0a88
 
 
 
 
 
 
 
 
 
c61580a
7cb0a88
 
 
 
 
 
 
 
 
 
 
aac338f
c61580a
7cb0a88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load, randomize_seed
import asyncio
import os
from threading import RLock

# Create a lock to ensure thread safety when accessing shared resources
lock = RLock()

# Load Hugging Face token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)

# Load models
def load_models(models):
    global models_loaded
    models_loaded = {}
    for model in models:
        try:
            print(f"Loading model: {model}")
            m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
            print(f"Successfully loaded model: {model}")
            models_loaded[model] = m
        except Exception as e:
            print(f"Error loading model {model}: {e}")
            models_loaded[model] = None

print("Loading models...")
load_models(models)
print("Models loaded.")

# Global variables
num_models = 6
inference_timeout = 600
MAX_SEED = 3999999999
starting_seed = randint(1941, 2024)

# Extend model choices to match the required number
def extend_choices(choices):
    return choices[:num_models] + (['NA'] * (num_models - len(choices[:num_models])))

# Function to perform inference asynchronously
async def infer(model_name, prompt, seed, batch_size, priority, timeout):
    try:
        kwargs = {"seed": seed, "batch_size": batch_size, "priority": priority}
        print(f"Running inference for model: {model_name} with prompt: {prompt} and seed: {seed}")
        
        task = asyncio.create_task(
            asyncio.to_thread(
                models_loaded[model_name].fn,
                prompt=prompt,
                **kwargs,
                token=HF_TOKEN
            )
        )
        result = await asyncio.wait_for(task, timeout=timeout)
        return result
    except Exception as e:
        print(f"Inference failed for model {model_name}: {e}")
        return None

# Generate images for each model
def generate_image(model_name, prompt, seed, batch_size, output_format, priority):
    if model_name == 'NA' or models_loaded.get(model_name) is None:
        print(f"Skipping model: {model_name} (Not available or NA)")
        return None

    try:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(
            infer(model_name, prompt, seed, batch_size, priority, inference_timeout)
        )
        if result:
            output_path = f"output.{output_format.lower()}"
            result.save(output_path)
            print(f"Image saved: {output_path}")
            return output_path
    except Exception as e:
        print(f"Error generating image for model {model_name}: {e}")
    finally:
        loop.close()

    return None

# Gradio Interface
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Multi-Model Prompt-to-Image Generator</h1></center>")
    
    # Input area
    with gr.Tab('Generate'):
        txt_input = gr.Textbox(label='Your Prompt', lines=4)
        gen_button = gr.Button('Generate Images')

        with gr.Row():
            seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
            seed_rand = gr.Button("Randomize Seed 🎲")
            batch_size = gr.Slider(label="Batch Size", minimum=1, maximum=10, step=1, value=1)
            output_format = gr.Dropdown(["PNG", "JPEG"], label="Output Format", value="PNG")
            priority = gr.Dropdown(["low", "medium", "high"], label="Priority", value="medium")

        seed_rand.click(randomize_seed, None, [seed], queue=False)

        # Outputs
        with gr.Row():
            output_images = [gr.Image(label=f"Model {i+1}") for i in range(num_models)]
        
        # Model selection
        with gr.Accordion("Model Selection", open=False):
            model_choice = gr.CheckboxGroup(models, label=f"Choose up to {num_models} models", value=models[:num_models])
        
        # Generation logic
        def generate_images(prompt, seed, batch_size, output_format, priority, selected_models):
            results = []
            selected_models = extend_choices(selected_models)
            
            for model in selected_models:
                result = generate_image(model, prompt, seed, batch_size, output_format, priority)
                results.append(result)

            return results

        gen_button.click(
            fn=generate_images,
            inputs=[txt_input, seed, batch_size, output_format, priority, model_choice],
            outputs=output_images
        )

print("Launching Gradio interface...")
demo.launch(show_api=False)