Abinivesh's picture
Update app.py
7cb0a88 verified
raw
history blame
4.7 kB
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load, randomize_seed
import asyncio
import os
from threading import RLock
# Create a lock to ensure thread safety when accessing shared resources
lock = RLock()
# Load Hugging Face token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Load models
def load_models(models):
global models_loaded
models_loaded = {}
for model in models:
try:
print(f"Loading model: {model}")
m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
print(f"Successfully loaded model: {model}")
models_loaded[model] = m
except Exception as e:
print(f"Error loading model {model}: {e}")
models_loaded[model] = None
print("Loading models...")
load_models(models)
print("Models loaded.")
# Global variables
num_models = 6
inference_timeout = 600
MAX_SEED = 3999999999
starting_seed = randint(1941, 2024)
# Extend model choices to match the required number
def extend_choices(choices):
return choices[:num_models] + (['NA'] * (num_models - len(choices[:num_models])))
# Function to perform inference asynchronously
async def infer(model_name, prompt, seed, batch_size, priority, timeout):
try:
kwargs = {"seed": seed, "batch_size": batch_size, "priority": priority}
print(f"Running inference for model: {model_name} with prompt: {prompt} and seed: {seed}")
task = asyncio.create_task(
asyncio.to_thread(
models_loaded[model_name].fn,
prompt=prompt,
**kwargs,
token=HF_TOKEN
)
)
result = await asyncio.wait_for(task, timeout=timeout)
return result
except Exception as e:
print(f"Inference failed for model {model_name}: {e}")
return None
# Generate images for each model
def generate_image(model_name, prompt, seed, batch_size, output_format, priority):
if model_name == 'NA' or models_loaded.get(model_name) is None:
print(f"Skipping model: {model_name} (Not available or NA)")
return None
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(
infer(model_name, prompt, seed, batch_size, priority, inference_timeout)
)
if result:
output_path = f"output.{output_format.lower()}"
result.save(output_path)
print(f"Image saved: {output_path}")
return output_path
except Exception as e:
print(f"Error generating image for model {model_name}: {e}")
finally:
loop.close()
return None
# Gradio Interface
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
gr.HTML("<center><h1>Multi-Model Prompt-to-Image Generator</h1></center>")
# Input area
with gr.Tab('Generate'):
txt_input = gr.Textbox(label='Your Prompt', lines=4)
gen_button = gr.Button('Generate Images')
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
seed_rand = gr.Button("Randomize Seed 🎲")
batch_size = gr.Slider(label="Batch Size", minimum=1, maximum=10, step=1, value=1)
output_format = gr.Dropdown(["PNG", "JPEG"], label="Output Format", value="PNG")
priority = gr.Dropdown(["low", "medium", "high"], label="Priority", value="medium")
seed_rand.click(randomize_seed, None, [seed], queue=False)
# Outputs
with gr.Row():
output_images = [gr.Image(label=f"Model {i+1}") for i in range(num_models)]
# Model selection
with gr.Accordion("Model Selection", open=False):
model_choice = gr.CheckboxGroup(models, label=f"Choose up to {num_models} models", value=models[:num_models])
# Generation logic
def generate_images(prompt, seed, batch_size, output_format, priority, selected_models):
results = []
selected_models = extend_choices(selected_models)
for model in selected_models:
result = generate_image(model, prompt, seed, batch_size, output_format, priority)
results.append(result)
return results
gen_button.click(
fn=generate_images,
inputs=[txt_input, seed, batch_size, output_format, priority, model_choice],
outputs=output_images
)
print("Launching Gradio interface...")
demo.launch(show_api=False)