|
import gradio as gr |
|
from model_pipelines import load_pipelines, generate_all |
|
from grace_eval import compute_sample_scores, plot_radar |
|
|
|
pipes = load_pipelines() |
|
|
|
def compare(prompt): |
|
imgs = generate_all(pipes, prompt) |
|
return imgs["sd_v1_5"], imgs["openjourney_v4"], imgs["ldm_256"] |
|
|
|
def show_leaderboard(prompt): |
|
scores = compute_sample_scores(None, prompt) |
|
plot_radar(scores) |
|
return "radar.png" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 图像生成模型对比实验") |
|
|
|
with gr.Tab("Arena"): |
|
prompt = gr.Textbox(label="请输入生成文本") |
|
btn = gr.Button("生成图像") |
|
out1 = gr.Image(label="StableDiffusion v1.5") |
|
out2 = gr.Image(label="Openjourney v4") |
|
out3 = gr.Image(label="LDM 256") |
|
btn.click(compare, inputs=prompt, outputs=[out1, out2, out3]) |
|
|
|
with gr.Tab("Leaderboard"): |
|
pm = gr.Textbox(label="统一 prompt 用于评价") |
|
lb_btn = gr.Button("显示 GRACE 雷达图") |
|
lb_img = gr.Image() |
|
lb_btn.click(show_leaderboard, inputs=pm, outputs=lb_img) |
|
|
|
with gr.Tab("Report"): |
|
markdown = open("report.md", "r", encoding="utf-8").read() |
|
gr.Markdown(markdown) |
|
|
|
demo.launch() |