File size: 2,627 Bytes
a963077
 
 
e878446
45f797c
 
e878446
45f797c
 
e878446
a963077
45f797c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a963077
45f797c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a963077
45f797c
 
 
 
 
 
 
 
 
 
a963077
45f797c
 
 
 
 
 
a963077
45f797c
a963077
45f797c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from model_pipelines import load_pipelines, generate_all
from grace_eval import compute_sample_scores, plot_radar
import torch
import time
from functools import partial

# Force CPU and suppress warnings
torch.set_grad_enabled(False)
torch.backends.cuda.is_available = lambda: False

class ModelLoader:
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance.models = None
        return cls._instance
    
    def load(self):
        if self.models is None:
            print("🔄 Initializing models...")
            start = time.time()
            self.models = load_pipelines()
            print(f"✅ Models loaded in {time.time()-start:.1f}s")
        return self.models

def create_interface():
    with gr.Blocks(title="🖼️ AI Image Generator Comparison", theme=gr.themes.Soft()) as demo:
        gr.Markdown("## 🏆 图像生成模型对比实验 (CPU模式)")
        
        with gr.Tab("🆚 Arena"):
            with gr.Row():
                prompt = gr.Textbox(label="✨ 输入提示词", placeholder="描述您想生成的图像...")
            with gr.Row():
                generate_btn = gr.Button("🚀 生成图像", variant="primary")
            with gr.Row():
                outputs = [
                    gr.Image(label="Stable Diffusion v1.5", type="pil"),
                    gr.Image(label="Openjourney v4", type="pil"),
                    gr.Image(label="LDM 256", type="pil")
                ]
            generate_btn.click(
                partial(generate_all, ModelLoader().load()),
                inputs=prompt,
                outputs=outputs
            )

        with gr.Tab("📊 Leaderboard"):
            with gr.Column():
                eval_prompt = gr.Textbox(label="评估用提示词")
                eval_btn = gr.Button("生成雷达图")
                radar_img = gr.Image(label="GRACE评估结果")
            eval_btn.click(
                lambda p: (plot_radar(compute_sample_scores(None, p)) or "radar.png"),
                inputs=eval_prompt,
                outputs=radar_img
            )

        with gr.Tab("📝 Report"):
            try:
                with open("report.md", "r", encoding="utf-8") as f:
                    gr.Markdown(f.read())
            except:
                gr.Markdown("## 实验报告\n报告加载失败")

    return demo

if __name__ == "__main__":
    create_interface().launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        enable_queue=True
    )