File size: 5,046 Bytes
a85be17
 
7b52fe5
bd1d32c
 
 
b4cc1c9
38fb96a
 
 
749bb52
38fb96a
 
bd1d32c
38fb96a
 
 
 
 
 
 
382d70a
38fb96a
 
 
 
e122d23
38fb96a
 
 
 
d3a8ff8
38fb96a
 
 
 
 
d3a8ff8
38fb96a
 
d3a8ff8
38fb96a
 
 
 
 
 
 
e122d23
38fb96a
 
 
 
 
 
 
 
 
d3a8ff8
38fb96a
 
d3a8ff8
38fb96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a8ff8
38fb96a
 
 
 
 
 
 
b87e516
38fb96a
 
 
 
 
 
 
 
 
e369512
b87e516
38fb96a
64a6fb2
e122d23
2868311
38fb96a
 
2868311
38fb96a
2868311
38fb96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1d32c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
from models import models
from PIL import Image
import requests
import uuid
import io 
import base64
from transforms import RGBTransform
import concurrent.futures
import time

# Dictionary to track model availability status
model_status = {}

def load_models():
    """
    Attempts to load all models and tracks their availability status
    Returns a list of successfully loaded models
    """
    loaded_models = []
    for model in models:
        try:
            # Attempt to load the model
            loaded_model = gr.load(f'models/{model}')
            loaded_models.append(loaded_model)
            model_status[model] = {'status': 'available', 'error': None}
        except Exception as e:
            # Track failed model loads
            model_status[model] = {'status': 'unavailable', 'error': str(e)}
            print(f"Failed to load {model}: {e}")
    return loaded_models

def generate_single_image(model_name, model, prompt, color=None, tint_strength=0.3):
    """
    Generates a single image from a specific model with optional color tinting
    Returns tuple of (image, error_message, model_name)
    """
    try:
        # Generate image
        out_img = model(prompt)
        
        # Process the image
        if isinstance(out_img, str):  # If URL is returned
            r = requests.get(f'https://omnibus-top-20.hf.space/file={out_img}', stream=True)
            if r.status_code != 200:
                return None, f"HTTP Error: {r.status_code}", model_name
                
            img = Image.open(io.BytesIO(r.content)).convert('RGB')
        else:
            img = Image.open(out_img).convert('RGB')
            
        # Apply color tinting if specified
        if color is not None:
            h = color.lstrip('#')
            rgb_color = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
            img = RGBTransform().mix_with(rgb_color, factor=float(tint_strength)).applied_to(img)
            
        return img, None, model_name
        
    except Exception as e:
        return None, str(e), model_name

def run_all_models(prompt, color=None, tint_strength=0.3):
    """
    Generates images from all available models in parallel
    """
    results = []
    errors = []
    
    # Load models if not already loaded
    loaded_models = load_models()
    
    # Use ThreadPoolExecutor for parallel execution
    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        future_to_model = {
            executor.submit(
                generate_single_image, 
                model_name, 
                model, 
                prompt, 
                color, 
                tint_strength
            ): model_name 
            for model_name, model in zip(models, loaded_models)
        }
        
        for future in concurrent.futures.as_completed(future_to_model):
            img, error, model_name = future.result()
            if error:
                errors.append(f"{model_name}: {error}")
                model_status[model_name]['status'] = 'failed'
                model_status[model_name]['error'] = error
            if img:
                results.append((img, model_name))
                
    # Generate HTML report
    html_report = "<div class='results-grid'>"
    for model in models:
        status = model_status[model]
        status_color = {
            'available': 'green',
            'unavailable': 'red',
            'failed': 'orange'
        }.get(status['status'], 'gray')
        
        html_report += f"""
        <div class='model-status'>
            <h3>{model}</h3>
            <p style='color: {status_color}'>Status: {status['status']}</p>
            {f"<p class='error'>Error: {status['error']}</p>" if status['error'] else ""}
        </div>
        """
    html_report += "</div>"
    
    return results, html_report

# Gradio interface
css = """
.results-grid {
    display: grid;
    grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
    gap: 1rem;
    padding: 1rem;
}
.model-status {
    border: 1px solid #ddd;
    padding: 1rem;
    border-radius: 4px;
}
.error {
    color: red;
    font-size: 0.9em;
    word-break: break-word;
}
"""

with gr.Blocks(css=css, theme="Nymbo/Nymbo_Theme") as app:
    with gr.Row():
        with gr.Column():
            inp = gr.Textbox(label="Prompt")
            btn = gr.Button("Generate from All Models")
        with gr.Column():
            col = gr.ColorPicker(label="Color Tint (Optional)")
            tint = gr.Slider(label="Tint Strength", minimum=0, maximum=1, step=0.01, value=0.30)
    
    status_html = gr.HTML(label="Model Status")
    gallery = gr.Gallery()
    
    def process_and_display(prompt, color, tint_strength):
        results, html_report = run_all_models(prompt, color, tint_strength)
        return (
            [img for img, _ in results],
            html_report
        )
    
    btn.click(
        process_and_display,
        inputs=[inp, col, tint],
        outputs=[gallery, status_html]
    )

app.launch()