Spaces:
Runtime error
Runtime error
import gradio as gr | |
from models import models | |
from PIL import Image | |
import requests | |
import uuid | |
import io | |
import base64 | |
from transforms import RGBTransform | |
import concurrent.futures | |
import time | |
# Dictionary to track model availability status | |
model_status = {} | |
def load_models(): | |
""" | |
Attempts to load all models and tracks their availability status | |
Returns a list of successfully loaded models | |
""" | |
loaded_models = [] | |
for model in models: | |
try: | |
# Attempt to load the model | |
loaded_model = gr.load(f'models/{model}') | |
loaded_models.append(loaded_model) | |
model_status[model] = {'status': 'available', 'error': None} | |
except Exception as e: | |
# Track failed model loads | |
model_status[model] = {'status': 'unavailable', 'error': str(e)} | |
print(f"Failed to load {model}: {e}") | |
return loaded_models | |
def generate_single_image(model_name, model, prompt, color=None, tint_strength=0.3): | |
""" | |
Generates a single image from a specific model with optional color tinting | |
Returns tuple of (image, error_message, model_name) | |
""" | |
try: | |
# Generate image | |
out_img = model(prompt) | |
# Process the image | |
if isinstance(out_img, str): # If URL is returned | |
r = requests.get(f'https://omnibus-top-20.hf.space/file={out_img}', stream=True) | |
if r.status_code != 200: | |
return None, f"HTTP Error: {r.status_code}", model_name | |
img = Image.open(io.BytesIO(r.content)).convert('RGB') | |
else: | |
img = Image.open(out_img).convert('RGB') | |
# Apply color tinting if specified | |
if color is not None: | |
h = color.lstrip('#') | |
rgb_color = tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) | |
img = RGBTransform().mix_with(rgb_color, factor=float(tint_strength)).applied_to(img) | |
return img, None, model_name | |
except Exception as e: | |
return None, str(e), model_name | |
def run_all_models(prompt, color=None, tint_strength=0.3): | |
""" | |
Generates images from all available models in parallel | |
""" | |
results = [] | |
errors = [] | |
# Load models if not already loaded | |
loaded_models = load_models() | |
# Use ThreadPoolExecutor for parallel execution | |
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: | |
future_to_model = { | |
executor.submit( | |
generate_single_image, | |
model_name, | |
model, | |
prompt, | |
color, | |
tint_strength | |
): model_name | |
for model_name, model in zip(models, loaded_models) | |
} | |
for future in concurrent.futures.as_completed(future_to_model): | |
img, error, model_name = future.result() | |
if error: | |
errors.append(f"{model_name}: {error}") | |
model_status[model_name]['status'] = 'failed' | |
model_status[model_name]['error'] = error | |
if img: | |
results.append((img, model_name)) | |
# Generate HTML report | |
html_report = "<div class='results-grid'>" | |
for model in models: | |
status = model_status[model] | |
status_color = { | |
'available': 'green', | |
'unavailable': 'red', | |
'failed': 'orange' | |
}.get(status['status'], 'gray') | |
html_report += f""" | |
<div class='model-status'> | |
<h3>{model}</h3> | |
<p style='color: {status_color}'>Status: {status['status']}</p> | |
{f"<p class='error'>Error: {status['error']}</p>" if status['error'] else ""} | |
</div> | |
""" | |
html_report += "</div>" | |
return results, html_report | |
# Gradio interface | |
css = """ | |
.results-grid { | |
display: grid; | |
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr)); | |
gap: 1rem; | |
padding: 1rem; | |
} | |
.model-status { | |
border: 1px solid #ddd; | |
padding: 1rem; | |
border-radius: 4px; | |
} | |
.error { | |
color: red; | |
font-size: 0.9em; | |
word-break: break-word; | |
} | |
""" | |
with gr.Blocks(css=css, theme="Nymbo/Nymbo_Theme") as app: | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.Textbox(label="Prompt") | |
btn = gr.Button("Generate from All Models") | |
with gr.Column(): | |
col = gr.ColorPicker(label="Color Tint (Optional)") | |
tint = gr.Slider(label="Tint Strength", minimum=0, maximum=1, step=0.01, value=0.30) | |
status_html = gr.HTML(label="Model Status") | |
gallery = gr.Gallery() | |
def process_and_display(prompt, color, tint_strength): | |
results, html_report = run_all_models(prompt, color, tint_strength) | |
return ( | |
[img for img, _ in results], | |
html_report | |
) | |
btn.click( | |
process_and_display, | |
inputs=[inp, col, tint], | |
outputs=[gallery, status_html] | |
) | |
app.launch() |