File size: 2,710 Bytes
59da1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eabf1a3
 
59da1c6
 
eabf1a3
59da1c6
 
 
eabf1a3
59da1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eabf1a3
 
 
 
 
 
 
 
59da1c6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from qdhf_things import run_qdhf, many_pictures
from generate_examples import EXAMPLE_PROMPTS
import os
import io

# Get the absolute path to the examples directory
EXAMPLES_DIR = os.path.abspath("./examples")

def generate_images(prompt, init_pop, total_itrs):
    init_pop = int(init_pop)
    total_itrs = int(total_itrs)
    # Use placeholder if prompt is empty
    if not prompt.strip():
        prompt = "a duck crossing the street"
    archive_plots = []
    for archive, plt_fig in run_qdhf(prompt, init_pop, total_itrs):
        buf = io.BytesIO()
        plt_fig.savefig(buf, format='png')
        buf.seek(0)
        archive_plots.append(buf.getvalue())
    
    final_archive_plot = archive_plots[-1]
    generated_images = many_pictures(archive, prompt)
    
    # Save the final archive plot and generated images as temporary files
    temp_archive_file = "temp_archive_plot.png"
    temp_images_file = "temp_generated_images.png"
    
    with open(temp_archive_file, 'wb') as f:
        f.write(final_archive_plot)
    
    generated_images.savefig(temp_images_file)
    
    return temp_archive_file, temp_images_file

def show_example(prompt):
    index = EXAMPLE_PROMPTS.index(prompt)
    archive_plot_path = os.path.join(EXAMPLES_DIR, f"archive_{index}.mp4")
    images_path = os.path.join(EXAMPLES_DIR, f"archive_pics_{index}.png")
    return prompt, archive_plot_path, images_path

with gr.Blocks() as demo:
    gr.Markdown("# Quality Diversity through Human Feedback")
    gr.Markdown("[Paper](https://arxiv.org/abs/2310.12103) | [Project Website](https://liding.info/qdhf/)")
    
    with gr.Row():
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="a duck crossing the street")
            init_pop = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Initial Population")
            total_itrs = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Total Iterations")
            generate_button = gr.Button("Generate", variant="primary")
        
        with gr.Column(scale=2):
            archive_output = gr.Video(label="Archive Plot")
            images_output = gr.Image(label="Generated Pictures")
    
    generate_button.click(generate_images, 
                          inputs=[prompt_input, init_pop, total_itrs],
                          outputs=[archive_output, images_output])
    
    gr.Examples(
        examples=EXAMPLE_PROMPTS,
        inputs=prompt_input,
        outputs=[prompt_input, archive_output, images_output],
        fn=show_example,
        cache_examples=True,
        label="Example Prompts"
    )

if __name__ == "__main__":
    demo.launch()