Spaces:
Runtime error
Runtime error
import gradio as gr | |
from random import randint | |
from all_models import models | |
from externalmod import gr_Interface_load, randomize_seed | |
import asyncio | |
import os | |
from threading import RLock | |
# Create a lock to ensure thread safety when accessing shared resources | |
lock = RLock() | |
# Load Hugging Face token from environment variable | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# Load models | |
def load_models(models): | |
global models_loaded | |
models_loaded = {} | |
for model in models: | |
try: | |
print(f"Loading model: {model}") | |
m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN) | |
print(f"Successfully loaded model: {model}") | |
models_loaded[model] = m | |
except Exception as e: | |
print(f"Error loading model {model}: {e}") | |
models_loaded[model] = None | |
print("Loading models...") | |
load_models(models) | |
print("Models loaded.") | |
# Global variables | |
num_models = 6 | |
inference_timeout = 600 | |
MAX_SEED = 3999999999 | |
starting_seed = randint(1941, 2024) | |
# Extend model choices to match the required number | |
def extend_choices(choices): | |
return choices[:num_models] + (['NA'] * (num_models - len(choices[:num_models]))) | |
# Function to perform inference asynchronously | |
async def infer(model_name, prompt, seed, batch_size, priority, timeout): | |
try: | |
kwargs = {"seed": seed, "batch_size": batch_size, "priority": priority} | |
print(f"Running inference for model: {model_name} with prompt: {prompt} and seed: {seed}") | |
task = asyncio.create_task( | |
asyncio.to_thread( | |
models_loaded[model_name].fn, | |
prompt=prompt, | |
**kwargs, | |
token=HF_TOKEN | |
) | |
) | |
result = await asyncio.wait_for(task, timeout=timeout) | |
return result | |
except Exception as e: | |
print(f"Inference failed for model {model_name}: {e}") | |
return None | |
# Generate images for each model | |
def generate_image(model_name, prompt, seed, batch_size, output_format, priority): | |
if model_name == 'NA' or models_loaded.get(model_name) is None: | |
print(f"Skipping model: {model_name} (Not available or NA)") | |
return None | |
try: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
result = loop.run_until_complete( | |
infer(model_name, prompt, seed, batch_size, priority, inference_timeout) | |
) | |
if result: | |
output_path = f"output.{output_format.lower()}" | |
result.save(output_path) | |
print(f"Image saved: {output_path}") | |
return output_path | |
except Exception as e: | |
print(f"Error generating image for model {model_name}: {e}") | |
finally: | |
loop.close() | |
return None | |
# Gradio Interface | |
print("Creating Gradio interface...") | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
gr.HTML("<center><h1>Multi-Model Prompt-to-Image Generator</h1></center>") | |
# Input area | |
with gr.Tab('Generate'): | |
txt_input = gr.Textbox(label='Your Prompt', lines=4) | |
gen_button = gr.Button('Generate Images') | |
with gr.Row(): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed) | |
seed_rand = gr.Button("Randomize Seed 🎲") | |
batch_size = gr.Slider(label="Batch Size", minimum=1, maximum=10, step=1, value=1) | |
output_format = gr.Dropdown(["PNG", "JPEG"], label="Output Format", value="PNG") | |
priority = gr.Dropdown(["low", "medium", "high"], label="Priority", value="medium") | |
seed_rand.click(randomize_seed, None, [seed], queue=False) | |
# Outputs | |
with gr.Row(): | |
output_images = [gr.Image(label=f"Model {i+1}") for i in range(num_models)] | |
# Model selection | |
with gr.Accordion("Model Selection", open=False): | |
model_choice = gr.CheckboxGroup(models, label=f"Choose up to {num_models} models", value=models[:num_models]) | |
# Generation logic | |
def generate_images(prompt, seed, batch_size, output_format, priority, selected_models): | |
results = [] | |
selected_models = extend_choices(selected_models) | |
for model in selected_models: | |
result = generate_image(model, prompt, seed, batch_size, output_format, priority) | |
results.append(result) | |
return results | |
gen_button.click( | |
fn=generate_images, | |
inputs=[txt_input, seed, batch_size, output_format, priority, model_choice], | |
outputs=output_images | |
) | |
print("Launching Gradio interface...") | |
demo.launch(show_api=False) | |