Abinivesh's picture
Update app.py
fca3535 verified
raw
history blame
4.05 kB
import gradio as gr
import torch
import asyncio
import os
from random import randint
from threading import RLock
from pathlib import Path
from all_models import models
from externalmod import gr_Interface_load, randomize_seed
# Create a lock for thread safety
lock = RLock()
# Load Hugging Face token from environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
# Function to load models with optimized settings
def load_fn(models):
global models_load
models_load = {}
for model in models:
if model not in models_load:
try:
print(f"Loading model: {model}")
m = gr_Interface_load(
f'models/{model}',
hf_token=HF_TOKEN,
torch_dtype=torch.float16 # Reduce memory usage
)
m.enable_model_cpu_offload() # Offload to CPU when not in use
models_load[model] = m
except Exception as e:
print(f"Error loading model {model}: {e}")
models_load[model] = None
print("Loading models...")
load_fn(models)
print("Models loaded successfully.")
# Constants
num_models = 1
starting_seed = randint(1941, 2024)
MAX_SEED = 3999999999
inference_timeout = 600
# Update UI components
def extend_choices(choices):
return choices[:num_models] + ['NA'] * (num_models - len(choices))
def update_imgbox(choices):
choices_extended = extend_choices(choices)
return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_extended]
# Async inference function
async def infer(model_str, prompt, seed=1):
if model_str not in models_load or models_load[model_str] is None:
print(f"Model {model_str} is unavailable.")
return None
kwargs = {"seed": seed}
try:
print(f"Running inference for model: {model_str} with prompt: '{prompt}'")
result = await asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN)
if result:
with lock:
png_path = "image.png"
result.save(png_path)
return str(Path(png_path).resolve())
except torch.cuda.OutOfMemoryError:
print(f"CUDA memory error for {model_str}. Try reducing image size.")
except Exception as e:
print(f"Error during inference for {model_str}: {e}")
return None
# Synchronous wrapper
def gen_fnseed(model_str, prompt, seed=1):
if model_str == 'NA':
return None
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(infer(model_str, prompt, seed))
except Exception as e:
print(f"Error generating image for {model_str}: {e}")
result = None
finally:
loop.close()
return result
# Gradio UI
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
gr.HTML("<center><h1>Compare-6</h1></center>")
with gr.Tab('Compare-6'):
txt_input = gr.Textbox(label='Your prompt:', lines=4)
gen_button = gr.Button('Generate up to 6 images')
seed = gr.Slider(label="Seed (0 to MAX)", minimum=0, maximum=MAX_SEED, value=starting_seed)
seed_rand = gr.Button("Randomize Seed 🎲")
seed_rand.click(randomize_seed, None, [seed], queue=False)
output = [gr.Image(label=m) for m in models[:num_models]]
current_models = [gr.Textbox(m, visible=False) for m in models[:num_models]]
for m, o in zip(current_models, output):
gen_button.click(gen_fnseed, inputs=[m, txt_input, seed], outputs=[o], queue=False)
with gr.Accordion('Model selection'):
model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models')
model_choice.change(update_imgbox, model_choice, output)
model_choice.change(extend_choices, model_choice, current_models)
demo.queue(default_concurrency_limit=20, max_size=50) # Adjusted for better stability
demo.launch(show_api=False)