File size: 2,627 Bytes
a963077 e878446 45f797c e878446 45f797c e878446 a963077 45f797c a963077 45f797c a963077 45f797c a963077 45f797c a963077 45f797c a963077 45f797c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
from model_pipelines import load_pipelines, generate_all
from grace_eval import compute_sample_scores, plot_radar
import torch
import time
from functools import partial
# Force CPU and suppress warnings
torch.set_grad_enabled(False)
torch.backends.cuda.is_available = lambda: False
class ModelLoader:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.models = None
return cls._instance
def load(self):
if self.models is None:
print("🔄 Initializing models...")
start = time.time()
self.models = load_pipelines()
print(f"✅ Models loaded in {time.time()-start:.1f}s")
return self.models
def create_interface():
with gr.Blocks(title="🖼️ AI Image Generator Comparison", theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🏆 图像生成模型对比实验 (CPU模式)")
with gr.Tab("🆚 Arena"):
with gr.Row():
prompt = gr.Textbox(label="✨ 输入提示词", placeholder="描述您想生成的图像...")
with gr.Row():
generate_btn = gr.Button("🚀 生成图像", variant="primary")
with gr.Row():
outputs = [
gr.Image(label="Stable Diffusion v1.5", type="pil"),
gr.Image(label="Openjourney v4", type="pil"),
gr.Image(label="LDM 256", type="pil")
]
generate_btn.click(
partial(generate_all, ModelLoader().load()),
inputs=prompt,
outputs=outputs
)
with gr.Tab("📊 Leaderboard"):
with gr.Column():
eval_prompt = gr.Textbox(label="评估用提示词")
eval_btn = gr.Button("生成雷达图")
radar_img = gr.Image(label="GRACE评估结果")
eval_btn.click(
lambda p: (plot_radar(compute_sample_scores(None, p)) or "radar.png"),
inputs=eval_prompt,
outputs=radar_img
)
with gr.Tab("📝 Report"):
try:
with open("report.md", "r", encoding="utf-8") as f:
gr.Markdown(f.read())
except:
gr.Markdown("## 实验报告\n报告加载失败")
return demo
if __name__ == "__main__":
create_interface().launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
enable_queue=True
) |