Update app.py
Browse files
app.py
CHANGED
@@ -2,42 +2,75 @@ import gradio as gr
|
|
2 |
from model_pipelines import load_pipelines, generate_all
|
3 |
from grace_eval import compute_sample_scores, plot_radar
|
4 |
import torch
|
|
|
|
|
5 |
|
6 |
-
#
|
|
|
7 |
torch.backends.cuda.is_available = lambda: False
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
with gr.
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
with gr.Row():
|
25 |
-
prompt = gr.Textbox(label="请输入生成文本", scale=4)
|
26 |
-
btn = gr.Button("生成图像", scale=1)
|
27 |
-
with gr.Row():
|
28 |
-
out1 = gr.Image(label="StableDiffusion v1.5")
|
29 |
-
out2 = gr.Image(label="Openjourney v4")
|
30 |
-
out3 = gr.Image(label="LDM 256")
|
31 |
-
btn.click(compare, inputs=prompt, outputs=[out1, out2, out3])
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
markdown = open("report.md", "r", encoding="utf-8").read()
|
41 |
-
gr.Markdown(markdown)
|
42 |
-
|
43 |
-
demo.launch()
|
|
|
2 |
from model_pipelines import load_pipelines, generate_all
|
3 |
from grace_eval import compute_sample_scores, plot_radar
|
4 |
import torch
|
5 |
+
import time
|
6 |
+
from functools import partial
|
7 |
|
8 |
+
# Force CPU and suppress warnings
|
9 |
+
torch.set_grad_enabled(False)
|
10 |
torch.backends.cuda.is_available = lambda: False
|
11 |
|
12 |
+
class ModelLoader:
|
13 |
+
_instance = None
|
14 |
+
|
15 |
+
def __new__(cls):
|
16 |
+
if cls._instance is None:
|
17 |
+
cls._instance = super().__new__(cls)
|
18 |
+
cls._instance.models = None
|
19 |
+
return cls._instance
|
20 |
+
|
21 |
+
def load(self):
|
22 |
+
if self.models is None:
|
23 |
+
print("🔄 Initializing models...")
|
24 |
+
start = time.time()
|
25 |
+
self.models = load_pipelines()
|
26 |
+
print(f"✅ Models loaded in {time.time()-start:.1f}s")
|
27 |
+
return self.models
|
28 |
|
29 |
+
def create_interface():
|
30 |
+
with gr.Blocks(title="🖼️ AI Image Generator Comparison", theme=gr.themes.Soft()) as demo:
|
31 |
+
gr.Markdown("## 🏆 图像生成模型对比实验 (CPU模式)")
|
32 |
+
|
33 |
+
with gr.Tab("🆚 Arena"):
|
34 |
+
with gr.Row():
|
35 |
+
prompt = gr.Textbox(label="✨ 输入提示词", placeholder="描述您想生成的图像...")
|
36 |
+
with gr.Row():
|
37 |
+
generate_btn = gr.Button("🚀 生成图像", variant="primary")
|
38 |
+
with gr.Row():
|
39 |
+
outputs = [
|
40 |
+
gr.Image(label="Stable Diffusion v1.5", type="pil"),
|
41 |
+
gr.Image(label="Openjourney v4", type="pil"),
|
42 |
+
gr.Image(label="LDM 256", type="pil")
|
43 |
+
]
|
44 |
+
generate_btn.click(
|
45 |
+
partial(generate_all, ModelLoader().load()),
|
46 |
+
inputs=prompt,
|
47 |
+
outputs=outputs
|
48 |
+
)
|
49 |
|
50 |
+
with gr.Tab("📊 Leaderboard"):
|
51 |
+
with gr.Column():
|
52 |
+
eval_prompt = gr.Textbox(label="评估用提示词")
|
53 |
+
eval_btn = gr.Button("生成雷达图")
|
54 |
+
radar_img = gr.Image(label="GRACE评估结果")
|
55 |
+
eval_btn.click(
|
56 |
+
lambda p: (plot_radar(compute_sample_scores(None, p)) or "radar.png"),
|
57 |
+
inputs=eval_prompt,
|
58 |
+
outputs=radar_img
|
59 |
+
)
|
60 |
|
61 |
+
with gr.Tab("📝 Report"):
|
62 |
+
try:
|
63 |
+
with open("report.md", "r", encoding="utf-8") as f:
|
64 |
+
gr.Markdown(f.read())
|
65 |
+
except:
|
66 |
+
gr.Markdown("## 实验报告\n报告加载失败")
|
67 |
|
68 |
+
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
if __name__ == "__main__":
|
71 |
+
create_interface().launch(
|
72 |
+
server_name="0.0.0.0",
|
73 |
+
server_port=7860,
|
74 |
+
show_error=True,
|
75 |
+
enable_queue=True
|
76 |
+
)
|
|
|
|
|
|
|
|