Spaces:
Runtime error
Runtime error
File size: 3,714 Bytes
aac338f 0824d60 7e9b981 0824d60 aac338f c61580a 0824d60 e95bdda 4ff15b7 7e9b981 a9293ca 0824d60 e95bdda 7e9b981 e95bdda 7e9b981 a9293ca e95bdda 7e9b981 a9293ca e95bdda 7e9b981 c61580a 7e9b981 0824d60 4ff15b7 7e9b981 0824d60 c61580a 0824d60 7e9b981 0824d60 e95bdda 7e9b981 e95bdda 7e9b981 a9293ca e95bdda a9293ca e95bdda a9293ca bbc9212 a9293ca 0824d60 e95bdda 7e9b981 e95bdda a9293ca e95bdda 7e9b981 e95bdda 7e9b981 bbc9212 7e9b981 a9293ca 0824d60 e95bdda 7e9b981 0824d60 c61580a 0824d60 7e9b981 e95bdda a9293ca e95bdda a9293ca 7e9b981 e95bdda 7e9b981 aac338f 7e9b981 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load, randomize_seed
import asyncio
import os
from threading import RLock
lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None
def load_fn(models):
global models_load
models_load = {}
for model in models:
if model not in models_load:
try:
print(f"Loading model: {model}")
m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
if not hasattr(m, 'predict'):
raise ValueError(f"Model {model} does not have a 'predict' method.")
print(f"Loaded model: {model}")
except Exception as error:
print(f"Error loading model {model}: {error}")
m = None # Ensure failed models are not stored
models_load[model] = m
print("Loading models...")
load_fn(models)
print("Models loaded successfully.")
num_models = 6
default_models = models[:num_models]
inference_timeout = 600
MAX_SEED = 3999999999
starting_seed = randint(1941, 2024)
print(f"Starting seed: {starting_seed}")
def extend_choices(choices):
return choices[:num_models] + (num_models - len(choices)) * ['NA']
def update_imgbox(choices):
choices_plus = extend_choices(choices)
return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
if model_str not in models_load or models_load[model_str] is None:
print(f"Model {model_str} is not available.")
return None
model = models_load[model_str]
kwargs = {"seed": seed}
print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
try:
task = asyncio.create_task(asyncio.to_thread(model.predict, prompt, **kwargs))
result = await asyncio.wait_for(task, timeout=timeout)
except Exception as e:
print(f"Error during inference: {e}")
if not task.done():
task.cancel()
return None
if task.done() and result:
with lock:
result.save("image.png")
return "image.png"
return None
def gen_fnseed(model_str, prompt, seed=1):
if model_str == 'NA' or models_load.get(model_str) is None:
return None
loop = asyncio.new_event_loop()
result = loop.run_until_complete(infer(model_str, prompt, seed))
loop.close()
return result
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
gr.HTML("<center><h1>Compare-6</h1></center>")
with gr.Tab('Compare-6'):
txt_input = gr.Textbox(label='Your prompt:', lines=4)
gen_button = gr.Button('Generate up to 6 images in up to 3 minutes total')
seed = gr.Slider("Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
seed_rand = gr.Button("Randomize Seed 🎲")
seed_rand.click(randomize_seed, None, [seed])
output = [gr.Image(label=m) for m in default_models]
current_models = [gr.Textbox(m, visible=False) for m in default_models]
for m, o in zip(current_models, output):
gen_button.click(gen_fnseed, [m, txt_input, seed], o)
with gr.Accordion('Model selection'):
model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
model_choice.change(update_imgbox, model_choice, output)
model_choice.change(extend_choices, model_choice, current_models)
demo.queue(default_concurrency_limit=200, max_size=200)
demo.launch(show_api=False, max_threads=400)
|