xiehuangbao1122 commited on
Commit
45f797c
·
verified ·
1 Parent(s): 81fb56b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -31
app.py CHANGED
@@ -2,42 +2,75 @@ import gradio as gr
2
  from model_pipelines import load_pipelines, generate_all
3
  from grace_eval import compute_sample_scores, plot_radar
4
  import torch
 
 
5
 
6
- # 强制使用CPU
 
7
  torch.backends.cuda.is_available = lambda: False
8
 
9
- pipes = load_pipelines()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def compare(prompt):
12
- imgs = generate_all(pipes, prompt)
13
- return imgs["sd_v1_5"], imgs["openjourney_v4"], imgs["ldm_256"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def show_leaderboard(prompt):
16
- scores = compute_sample_scores(None, prompt)
17
- plot_radar(scores)
18
- return "radar.png"
 
 
 
 
 
 
19
 
20
- with gr.Blocks() as demo:
21
- gr.Markdown("# 图像生成模型对比实验 (CPU模式)")
 
 
 
 
22
 
23
- with gr.Tab("Arena"):
24
- with gr.Row():
25
- prompt = gr.Textbox(label="请输入生成文本", scale=4)
26
- btn = gr.Button("生成图像", scale=1)
27
- with gr.Row():
28
- out1 = gr.Image(label="StableDiffusion v1.5")
29
- out2 = gr.Image(label="Openjourney v4")
30
- out3 = gr.Image(label="LDM 256")
31
- btn.click(compare, inputs=prompt, outputs=[out1, out2, out3])
32
 
33
- with gr.Tab("Leaderboard"):
34
- pm = gr.Textbox(label="统一 prompt 用于评价")
35
- lb_btn = gr.Button("显示 GRACE 雷达图")
36
- lb_img = gr.Image()
37
- lb_btn.click(show_leaderboard, inputs=pm, outputs=lb_img)
38
-
39
- with gr.Tab("Report"):
40
- markdown = open("report.md", "r", encoding="utf-8").read()
41
- gr.Markdown(markdown)
42
-
43
- demo.launch()
 
2
  from model_pipelines import load_pipelines, generate_all
3
  from grace_eval import compute_sample_scores, plot_radar
4
  import torch
5
+ import time
6
+ from functools import partial
7
 
8
+ # Force CPU and suppress warnings
9
+ torch.set_grad_enabled(False)
10
  torch.backends.cuda.is_available = lambda: False
11
 
12
+ class ModelLoader:
13
+ _instance = None
14
+
15
+ def __new__(cls):
16
+ if cls._instance is None:
17
+ cls._instance = super().__new__(cls)
18
+ cls._instance.models = None
19
+ return cls._instance
20
+
21
+ def load(self):
22
+ if self.models is None:
23
+ print("🔄 Initializing models...")
24
+ start = time.time()
25
+ self.models = load_pipelines()
26
+ print(f"✅ Models loaded in {time.time()-start:.1f}s")
27
+ return self.models
28
 
29
+ def create_interface():
30
+ with gr.Blocks(title="🖼️ AI Image Generator Comparison", theme=gr.themes.Soft()) as demo:
31
+ gr.Markdown("## 🏆 图像生成模型对比实验 (CPU模式)")
32
+
33
+ with gr.Tab("🆚 Arena"):
34
+ with gr.Row():
35
+ prompt = gr.Textbox(label="✨ 输入提示词", placeholder="描述您想生成的图像...")
36
+ with gr.Row():
37
+ generate_btn = gr.Button("🚀 生成图像", variant="primary")
38
+ with gr.Row():
39
+ outputs = [
40
+ gr.Image(label="Stable Diffusion v1.5", type="pil"),
41
+ gr.Image(label="Openjourney v4", type="pil"),
42
+ gr.Image(label="LDM 256", type="pil")
43
+ ]
44
+ generate_btn.click(
45
+ partial(generate_all, ModelLoader().load()),
46
+ inputs=prompt,
47
+ outputs=outputs
48
+ )
49
 
50
+ with gr.Tab("📊 Leaderboard"):
51
+ with gr.Column():
52
+ eval_prompt = gr.Textbox(label="评估用提示词")
53
+ eval_btn = gr.Button("生成雷达图")
54
+ radar_img = gr.Image(label="GRACE评估结果")
55
+ eval_btn.click(
56
+ lambda p: (plot_radar(compute_sample_scores(None, p)) or "radar.png"),
57
+ inputs=eval_prompt,
58
+ outputs=radar_img
59
+ )
60
 
61
+ with gr.Tab("📝 Report"):
62
+ try:
63
+ with open("report.md", "r", encoding="utf-8") as f:
64
+ gr.Markdown(f.read())
65
+ except:
66
+ gr.Markdown("## 实验报告\n报告加载失败")
67
 
68
+ return demo
 
 
 
 
 
 
 
 
69
 
70
+ if __name__ == "__main__":
71
+ create_interface().launch(
72
+ server_name="0.0.0.0",
73
+ server_port=7860,
74
+ show_error=True,
75
+ enable_queue=True
76
+ )