Spaces:
Paused
Paused
| import os | |
| import base64 | |
| import gradio as gr | |
| from PIL import Image | |
| from src.util import * | |
| from io import BytesIO | |
| from src.pipelines import * | |
| from threading import Thread | |
| from dash import Dash, dcc, html, Input, Output, no_update, callback | |
| app = Dash(__name__) | |
| app.layout = html.Div( | |
| className="container", | |
| children=[ | |
| dcc.Graph( | |
| id="graph", figure=fig, clear_on_unhover=True, style={"height": "90vh"} | |
| ), | |
| dcc.Tooltip(id="tooltip"), | |
| html.Div(id="word-emb-txt", style={"background-color": "white"}), | |
| html.Div(id="word-emb-vis"), | |
| html.Div( | |
| [ | |
| html.Button(id="btn-download-image", hidden=True), | |
| dcc.Download(id="download-image"), | |
| ] | |
| ), | |
| ], | |
| ) | |
| def display_hover(hoverData): | |
| if hoverData is None: | |
| return False, no_update, no_update, no_update, no_update, no_update | |
| hover_data = hoverData["points"][0] | |
| bbox = hover_data["bbox"] | |
| direction = "left" | |
| index = hover_data["pointNumber"] | |
| children = [ | |
| html.Img( | |
| src=images[index], | |
| style={"width": "250px"}, | |
| ), | |
| html.P( | |
| hover_data["text"], | |
| style={ | |
| "color": "black", | |
| "font-size": "20px", | |
| "text-align": "center", | |
| "background-color": "white", | |
| "margin": "5px", | |
| }, | |
| ), | |
| ] | |
| emb_children = [ | |
| html.Img( | |
| src=generate_word_emb_vis(hover_data["text"]), | |
| style={"width": "100%", "height": "25px"}, | |
| ), | |
| ] | |
| return True, bbox, children, direction, hover_data["text"], emb_children | |
| def download_image(clickData): | |
| if clickData is None: | |
| return no_update | |
| click_data = clickData["points"][0] | |
| index = click_data["pointNumber"] | |
| txt = click_data["text"] | |
| img_encoded = images[index] | |
| img_decoded = base64.b64decode(img_encoded.split(",")[1]) | |
| img = Image.open(BytesIO(img_decoded)) | |
| img.save(f"{txt}.png") | |
| return dcc.send_file(f"{txt}.png") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Stable Diffusion Demo") | |
| with gr.Tab("Latent Space"): | |
| with gr.TabItem("Beginner"): | |
| gr.Markdown("Generate images from text.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_beginner = gr.Textbox( | |
| lines=1, | |
| label="Prompt", | |
| value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
| ) | |
| with gr.Row(): | |
| seed_beginner = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_beginner = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| generate_images_button_beginner = gr.Button("Generate Image") | |
| with gr.Column(): | |
| images_output_beginner = gr.Image(label="Image") | |
| def generate_images_wrapper( | |
| prompt, seed, progress=gr.Progress() | |
| ): | |
| images, _ = display_poke_images( | |
| prompt, seed, num_inference_steps=8, poke=False, intermediate=False | |
| ) | |
| return images | |
| seed_beginner.change( | |
| fn=generate_seed_vis, inputs=[seed_beginner], outputs=[seed_vis_beginner] | |
| ) | |
| with gr.TabItem("Denoising"): | |
| gr.Markdown("Observe the intermediate images during denoising.") | |
| gr.HTML(read_html("html/denoising.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_denoise = gr.Textbox( | |
| lines=1, | |
| label="Prompt", | |
| value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
| ) | |
| num_inference_steps_denoise = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps", | |
| ) | |
| with gr.Row(): | |
| seed_denoise = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_denoise = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| generate_images_button_denoise = gr.Button("Generate Images") | |
| with gr.Column(): | |
| images_output_denoise = gr.Gallery(label="Images", selected_index=0, height=512) | |
| gif_denoise = gr.Image(label="GIF") | |
| zip_output_denoise = gr.File(label="Download ZIP") | |
| def generate_images_wrapper( | |
| prompt, seed, num_inference_steps, progress=gr.Progress() | |
| ): | |
| images, _ = display_poke_images( | |
| prompt, seed, num_inference_steps, poke=False, intermediate=True | |
| ) | |
| fname = "denoising" | |
| tab_config = { | |
| "Tab": "Denoising", | |
| "Prompt": prompt, | |
| "Number of Inference Steps": num_inference_steps, | |
| "Seed": seed, | |
| } | |
| export_as_zip(images, fname, tab_config) | |
| progress(1, desc="Exporting as gif") | |
| export_as_gif(images, filename="denoising.gif") | |
| return images, "outputs/denoising.gif", f"outputs/{fname}.zip" | |
| seed_denoise.change( | |
| fn=generate_seed_vis, inputs=[seed_denoise], outputs=[seed_vis_denoise] | |
| ) | |
| with gr.TabItem("Seeds"): | |
| gr.Markdown( | |
| "Understand how different starting points in latent space can lead to different images." | |
| ) | |
| gr.HTML(read_html("html/seeds.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_seed = gr.Textbox( | |
| lines=1, | |
| label="Prompt", | |
| value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
| ) | |
| num_images_seed = gr.Slider( | |
| minimum=1, maximum=100, step=1, value=5, label="Number of Seeds" | |
| ) | |
| num_inference_steps_seed = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps per Image", | |
| ) | |
| generate_images_button_seed = gr.Button("Generate Images") | |
| with gr.Column(): | |
| images_output_seed = gr.Gallery(label="Images", selected_index=0, height=512) | |
| zip_output_seed = gr.File(label="Download ZIP") | |
| generate_images_button_seed.click( | |
| fn=display_seed_images, | |
| inputs=[prompt_seed, num_inference_steps_seed, num_images_seed], | |
| outputs=[images_output_seed, zip_output_seed], | |
| ) | |
| with gr.TabItem("Perturbations"): | |
| gr.Markdown("Explore different perturbations from a point in latent space.") | |
| gr.HTML(read_html("html/perturbations.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_perturb = gr.Textbox( | |
| lines=1, | |
| label="Prompt", | |
| value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
| ) | |
| num_images_perturb = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| step=1, | |
| value=5, | |
| label="Number of Perturbations", | |
| ) | |
| perturbation_size_perturb = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.1, | |
| label="Perturbation Size", | |
| ) | |
| num_inference_steps_perturb = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps per Image", | |
| ) | |
| with gr.Row(): | |
| seed_perturb = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_perturb = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| generate_images_button_perturb = gr.Button("Generate Images") | |
| with gr.Column(): | |
| images_output_perturb = gr.Gallery(label="Image", selected_index=0, height=512) | |
| zip_output_perturb = gr.File(label="Download ZIP") | |
| generate_images_button_perturb.click( | |
| fn=display_perturb_images, | |
| inputs=[ | |
| prompt_perturb, | |
| seed_perturb, | |
| num_inference_steps_perturb, | |
| num_images_perturb, | |
| perturbation_size_perturb, | |
| ], | |
| outputs=[images_output_perturb, zip_output_perturb], | |
| ) | |
| seed_perturb.change( | |
| fn=generate_seed_vis, inputs=[seed_perturb], outputs=[seed_vis_perturb] | |
| ) | |
| with gr.TabItem("Circular"): | |
| gr.Markdown( | |
| "Generate a circular path in latent space and observe how the images vary along the path." | |
| ) | |
| gr.HTML(read_html("html/circular.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_circular = gr.Textbox( | |
| lines=1, | |
| label="Prompt", | |
| value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
| ) | |
| num_images_circular = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=5, | |
| label="Number of Steps around the Circle", | |
| ) | |
| with gr.Row(): | |
| start_degree_circular = gr.Slider( | |
| minimum=0, | |
| maximum=360, | |
| step=1, | |
| value=0, | |
| label="Start Angle", | |
| info="Enter the value in degrees", | |
| ) | |
| end_degree_circular = gr.Slider( | |
| minimum=0, | |
| maximum=360, | |
| step=1, | |
| value=360, | |
| label="End Angle", | |
| info="Enter the value in degrees", | |
| ) | |
| step_size_circular = gr.Textbox( | |
| label="Step Size", value=180 / 4 | |
| ) | |
| num_inference_steps_circular = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps per Image", | |
| ) | |
| with gr.Row(): | |
| seed_circular = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_circular = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| generate_images_button_circular = gr.Button("Generate Images") | |
| with gr.Column(): | |
| images_output_circular = gr.Gallery(label="Image", selected_index=0) | |
| gif_circular = gr.Image(label="GIF") | |
| zip_output_circular = gr.File(label="Download ZIP") | |
| num_images_circular.change( | |
| fn=calculate_step_size, | |
| inputs=[num_images_circular, start_degree_circular, end_degree_circular], | |
| outputs=[step_size_circular], | |
| ) | |
| start_degree_circular.change( | |
| fn=calculate_step_size, | |
| inputs=[num_images_circular, start_degree_circular, end_degree_circular], | |
| outputs=[step_size_circular], | |
| ) | |
| end_degree_circular.change( | |
| fn=calculate_step_size, | |
| inputs=[num_images_circular, start_degree_circular, end_degree_circular], | |
| outputs=[step_size_circular], | |
| ) | |
| generate_images_button_circular.click( | |
| fn=display_circular_images, | |
| inputs=[ | |
| prompt_circular, | |
| seed_circular, | |
| num_inference_steps_circular, | |
| num_images_circular + 1, | |
| start_degree_circular, | |
| end_degree_circular, | |
| ], | |
| outputs=[images_output_circular, gif_circular, zip_output_circular], | |
| ) | |
| seed_circular.change( | |
| fn=generate_seed_vis, inputs=[seed_circular], outputs=[seed_vis_circular] | |
| ) | |
| with gr.TabItem("Poke"): | |
| gr.Markdown("Perturb a region in the image and observe the effect.") | |
| gr.HTML(read_html("html/poke.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_poke = gr.Textbox( | |
| lines=1, | |
| label="Prompt", | |
| value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
| ) | |
| num_inference_steps_poke = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps per Image", | |
| ) | |
| with gr.Row(): | |
| seed_poke = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_poke = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| pokeX = gr.Slider( | |
| label="pokeX", | |
| minimum=0, | |
| maximum=64, | |
| step=1, | |
| value=32, | |
| info="X coordinate of poke center", | |
| ) | |
| pokeY = gr.Slider( | |
| label="pokeY", | |
| minimum=0, | |
| maximum=64, | |
| step=1, | |
| value=32, | |
| info="Y coordinate of poke center", | |
| ) | |
| pokeHeight = gr.Slider( | |
| label="pokeHeight", | |
| minimum=0, | |
| maximum=64, | |
| step=1, | |
| value=8, | |
| info="Height of the poke", | |
| ) | |
| pokeWidth = gr.Slider( | |
| label="pokeWidth", | |
| minimum=0, | |
| maximum=64, | |
| step=1, | |
| value=8, | |
| info="Width of the poke", | |
| ) | |
| generate_images_button_poke = gr.Button("Generate Images") | |
| with gr.Column(): | |
| original_images_output_poke = gr.Image( | |
| value=visualize_poke(32, 32, 8, 8)[0], label="Original Image" | |
| ) | |
| poked_images_output_poke = gr.Image( | |
| value=visualize_poke(32, 32, 8, 8)[1], label="Poked Image" | |
| ) | |
| zip_output_poke = gr.File(label="Download ZIP") | |
| pokeX.change( | |
| visualize_poke, | |
| inputs=[pokeX, pokeY, pokeHeight, pokeWidth], | |
| outputs=[original_images_output_poke, poked_images_output_poke], | |
| ) | |
| pokeY.change( | |
| visualize_poke, | |
| inputs=[pokeX, pokeY, pokeHeight, pokeWidth], | |
| outputs=[original_images_output_poke, poked_images_output_poke], | |
| ) | |
| pokeHeight.change( | |
| visualize_poke, | |
| inputs=[pokeX, pokeY, pokeHeight, pokeWidth], | |
| outputs=[original_images_output_poke, poked_images_output_poke], | |
| ) | |
| pokeWidth.change( | |
| visualize_poke, | |
| inputs=[pokeX, pokeY, pokeHeight, pokeWidth], | |
| outputs=[original_images_output_poke, poked_images_output_poke], | |
| ) | |
| seed_poke.change( | |
| fn=generate_seed_vis, inputs=[seed_poke], outputs=[seed_vis_poke] | |
| ) | |
| def generate_images_wrapper( | |
| prompt, | |
| seed, | |
| num_inference_steps, | |
| pokeX=pokeX, | |
| pokeY=pokeY, | |
| pokeHeight=pokeHeight, | |
| pokeWidth=pokeWidth, | |
| ): | |
| _, _ = display_poke_images( | |
| prompt, | |
| seed, | |
| num_inference_steps, | |
| poke=True, | |
| pokeX=pokeX, | |
| pokeY=pokeY, | |
| pokeHeight=pokeHeight, | |
| pokeWidth=pokeWidth, | |
| intermediate=False, | |
| ) | |
| images, modImages = visualize_poke(pokeX, pokeY, pokeHeight, pokeWidth) | |
| fname = "poke" | |
| tab_config = { | |
| "Tab": "Poke", | |
| "Prompt": prompt, | |
| "Number of Inference Steps per Image": num_inference_steps, | |
| "Seed": seed, | |
| "PokeX": pokeX, | |
| "PokeY": pokeY, | |
| "PokeHeight": pokeHeight, | |
| "PokeWidth": pokeWidth, | |
| } | |
| imgs_list = [] | |
| imgs_list.append((images, "Original Image")) | |
| imgs_list.append((modImages, "Poked Image")) | |
| export_as_zip(imgs_list, fname, tab_config) | |
| return images, modImages, f"outputs/{fname}.zip" | |
| with gr.TabItem("Guidance"): | |
| gr.Markdown("Observe the effect of different guidance scales.") | |
| gr.HTML(read_html("html/guidance.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_guidance = gr.Textbox( | |
| lines=1, | |
| label="Prompt", | |
| value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
| ) | |
| num_inference_steps_guidance = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps per Image", | |
| ) | |
| guidance_scale_values = gr.Textbox( | |
| lines=1, value="1, 8, 20, 30", label="Guidance Scale Values" | |
| ) | |
| with gr.Row(): | |
| seed_guidance = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_guidance = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| generate_images_button_guidance = gr.Button("Generate Images") | |
| with gr.Column(): | |
| images_output_guidance = gr.Gallery( | |
| label="Images", selected_index=0, | |
| height=512, | |
| ) | |
| zip_output_guidance = gr.File(label="Download ZIP") | |
| generate_images_button_guidance.click( | |
| fn=display_guidance_images, | |
| inputs=[ | |
| prompt_guidance, | |
| seed_guidance, | |
| num_inference_steps_guidance, | |
| guidance_scale_values, | |
| ], | |
| outputs=[images_output_guidance, zip_output_guidance], | |
| ) | |
| seed_guidance.change( | |
| fn=generate_seed_vis, inputs=[seed_guidance], outputs=[seed_vis_guidance] | |
| ) | |
| with gr.TabItem("Inpainting"): | |
| gr.Markdown("Inpaint the image based on the prompt.") | |
| gr.HTML(read_html("html/inpainting.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| uploaded_img_inpaint = gr.Sketchpad( | |
| sources="upload", brush=gr.Brush(colors=["#ffff00"]), type="pil", label="Upload" | |
| ) | |
| prompt_inpaint = gr.Textbox( | |
| lines=1, label="Prompt", value="sunglasses" | |
| ) | |
| num_inference_steps_inpaint = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps per Image", | |
| ) | |
| with gr.Row(): | |
| seed_inpaint = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_inpaint = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| inpaint_button = gr.Button("Inpaint") | |
| with gr.Column(): | |
| images_output_inpaint = gr.Image(label="Output") | |
| zip_output_inpaint = gr.File(label="Download ZIP") | |
| inpaint_button.click( | |
| fn=inpaint, | |
| inputs=[ | |
| uploaded_img_inpaint, | |
| num_inference_steps_inpaint, | |
| seed_inpaint, | |
| prompt_inpaint, | |
| ], | |
| outputs=[images_output_inpaint, zip_output_inpaint], | |
| ) | |
| seed_inpaint.change( | |
| fn=generate_seed_vis, inputs=[seed_inpaint], outputs=[seed_vis_inpaint] | |
| ) | |
| with gr.Tab("CLIP Space"): | |
| with gr.TabItem("Embeddings"): | |
| gr.Markdown( | |
| "Visualize text embedding space in 3D with input texts and output images based on the chosen axis." | |
| ) | |
| gr.HTML(read_html("html/embeddings.html")) | |
| with gr.Row(): | |
| output = gr.HTML( | |
| f""" | |
| <iframe id="html" src="{dash_tunnel}" style="width:100%; height:700px;"></iframe> | |
| """ | |
| ) | |
| with gr.Row(): | |
| word2add_rem = gr.Textbox(lines=1, label="Add/Remove word") | |
| word2change = gr.Textbox(lines=1, label="Change image for word") | |
| clear_words_button = gr.Button(value="Clear words") | |
| with gr.Accordion("Custom Semantic Dimensions", open=False): | |
| with gr.Row(): | |
| axis_name_1 = gr.Textbox(label="Axis name", value="gender") | |
| which_axis_1 = gr.Dropdown( | |
| choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], | |
| value=whichAxisMap["which_axis_1"], | |
| label="Axis direction", | |
| ) | |
| from_words_1 = gr.Textbox( | |
| lines=1, | |
| label="Positive", | |
| value="prince husband father son uncle", | |
| ) | |
| to_words_1 = gr.Textbox( | |
| lines=1, | |
| label="Negative", | |
| value="princess wife mother daughter aunt", | |
| ) | |
| submit_1 = gr.Button("Submit") | |
| with gr.Row(): | |
| axis_name_2 = gr.Textbox(label="Axis name", value="age") | |
| which_axis_2 = gr.Dropdown( | |
| choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], | |
| value=whichAxisMap["which_axis_2"], | |
| label="Axis direction", | |
| ) | |
| from_words_2 = gr.Textbox( | |
| lines=1, label="Positive", value="man woman king queen father" | |
| ) | |
| to_words_2 = gr.Textbox( | |
| lines=1, label="Negative", value="boy girl prince princess son" | |
| ) | |
| submit_2 = gr.Button("Submit") | |
| with gr.Row(): | |
| axis_name_3 = gr.Textbox(label="Axis name", value="residual") | |
| which_axis_3 = gr.Dropdown( | |
| choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], | |
| value=whichAxisMap["which_axis_3"], | |
| label="Axis direction", | |
| ) | |
| from_words_3 = gr.Textbox(lines=1, label="Positive") | |
| to_words_3 = gr.Textbox(lines=1, label="Negative") | |
| submit_3 = gr.Button("Submit") | |
| with gr.Row(): | |
| axis_name_4 = gr.Textbox(label="Axis name", value="number") | |
| which_axis_4 = gr.Dropdown( | |
| choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], | |
| value=whichAxisMap["which_axis_4"], | |
| label="Axis direction", | |
| ) | |
| from_words_4 = gr.Textbox( | |
| lines=1, | |
| label="Positive", | |
| value="boys girls cats puppies computers", | |
| ) | |
| to_words_4 = gr.Textbox( | |
| lines=1, label="Negative", value="boy girl cat puppy computer" | |
| ) | |
| submit_4 = gr.Button("Submit") | |
| with gr.Row(): | |
| axis_name_5 = gr.Textbox(label="Axis name", value="royalty") | |
| which_axis_5 = gr.Dropdown( | |
| choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], | |
| value=whichAxisMap["which_axis_5"], | |
| label="Axis direction", | |
| ) | |
| from_words_5 = gr.Textbox( | |
| lines=1, | |
| label="Positive", | |
| value="king queen prince princess duchess", | |
| ) | |
| to_words_5 = gr.Textbox( | |
| lines=1, label="Negative", value="man woman boy girl woman" | |
| ) | |
| submit_5 = gr.Button("Submit") | |
| with gr.Row(): | |
| axis_name_6 = gr.Textbox(label="Axis name") | |
| which_axis_6 = gr.Dropdown( | |
| choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], | |
| value=whichAxisMap["which_axis_6"], | |
| label="Axis direction", | |
| ) | |
| from_words_6 = gr.Textbox(lines=1, label="Positive") | |
| to_words_6 = gr.Textbox(lines=1, label="Negative") | |
| submit_6 = gr.Button("Submit") | |
| def add_rem_word_and_clear(words): | |
| return add_rem_word(words), "" | |
| def change_word_and_clear(word): | |
| return change_word(word), "" | |
| clear_words_button.click(fn=clear_words, outputs=[output]) | |
| def set_axis_wrapper(axis_name, which_axis, from_words, to_words): | |
| for ax in whichAxisMap: | |
| if whichAxisMap[ax] == which_axis: | |
| whichAxisMap[ax] = "---" | |
| whichAxisMap["which_axis_1"] = which_axis | |
| return ( | |
| set_axis(axis_name, which_axis, from_words, to_words), | |
| whichAxisMap["which_axis_2"], | |
| whichAxisMap["which_axis_3"], | |
| whichAxisMap["which_axis_4"], | |
| whichAxisMap["which_axis_5"], | |
| whichAxisMap["which_axis_6"], | |
| ) | |
| def set_axis_wrapper(axis_name, which_axis, from_words, to_words): | |
| for ax in whichAxisMap: | |
| if whichAxisMap[ax] == which_axis: | |
| whichAxisMap[ax] = "---" | |
| whichAxisMap["which_axis_2"] = which_axis | |
| return ( | |
| set_axis(axis_name, which_axis, from_words, to_words), | |
| whichAxisMap["which_axis_1"], | |
| whichAxisMap["which_axis_3"], | |
| whichAxisMap["which_axis_4"], | |
| whichAxisMap["which_axis_5"], | |
| whichAxisMap["which_axis_6"], | |
| ) | |
| def set_axis_wrapper(axis_name, which_axis, from_words, to_words): | |
| for ax in whichAxisMap: | |
| if whichAxisMap[ax] == which_axis: | |
| whichAxisMap[ax] = "---" | |
| whichAxisMap["which_axis_3"] = which_axis | |
| return ( | |
| set_axis(axis_name, which_axis, from_words, to_words), | |
| whichAxisMap["which_axis_1"], | |
| whichAxisMap["which_axis_2"], | |
| whichAxisMap["which_axis_4"], | |
| whichAxisMap["which_axis_5"], | |
| whichAxisMap["which_axis_6"], | |
| ) | |
| def set_axis_wrapper(axis_name, which_axis, from_words, to_words): | |
| for ax in whichAxisMap: | |
| if whichAxisMap[ax] == which_axis: | |
| whichAxisMap[ax] = "---" | |
| whichAxisMap["which_axis_4"] = which_axis | |
| return ( | |
| set_axis(axis_name, which_axis, from_words, to_words), | |
| whichAxisMap["which_axis_1"], | |
| whichAxisMap["which_axis_2"], | |
| whichAxisMap["which_axis_3"], | |
| whichAxisMap["which_axis_5"], | |
| whichAxisMap["which_axis_6"], | |
| ) | |
| def set_axis_wrapper(axis_name, which_axis, from_words, to_words): | |
| for ax in whichAxisMap: | |
| if whichAxisMap[ax] == which_axis: | |
| whichAxisMap[ax] = "---" | |
| whichAxisMap["which_axis_5"] = which_axis | |
| return ( | |
| set_axis(axis_name, which_axis, from_words, to_words), | |
| whichAxisMap["which_axis_1"], | |
| whichAxisMap["which_axis_2"], | |
| whichAxisMap["which_axis_3"], | |
| whichAxisMap["which_axis_4"], | |
| whichAxisMap["which_axis_6"], | |
| ) | |
| def set_axis_wrapper(axis_name, which_axis, from_words, to_words): | |
| for ax in whichAxisMap: | |
| if whichAxisMap[ax] == which_axis: | |
| whichAxisMap[ax] = "---" | |
| whichAxisMap["which_axis_6"] = which_axis | |
| return ( | |
| set_axis(axis_name, which_axis, from_words, to_words), | |
| whichAxisMap["which_axis_1"], | |
| whichAxisMap["which_axis_2"], | |
| whichAxisMap["which_axis_3"], | |
| whichAxisMap["which_axis_4"], | |
| whichAxisMap["which_axis_5"], | |
| ) | |
| with gr.TabItem("Interpolate"): | |
| gr.Markdown( | |
| "Interpolate between the first and the second prompt, and observe how the output changes." | |
| ) | |
| gr.HTML(read_html("html/interpolate.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| promptA = gr.Textbox( | |
| lines=1, | |
| label="First Prompt", | |
| value="Self-portrait oil painting, a beautiful man with golden hair, 8k", | |
| ) | |
| promptB = gr.Textbox( | |
| lines=1, | |
| label="Second Prompt", | |
| value="Self-portrait oil painting, a beautiful woman with golden hair, 8k", | |
| ) | |
| num_images_interpolate = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| step=1, | |
| value=5, | |
| label="Number of Interpolation Steps", | |
| ) | |
| num_inference_steps_interpolate = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps per Image", | |
| ) | |
| with gr.Row(): | |
| seed_interpolate = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_interpolate = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| generate_images_button_interpolate = gr.Button("Generate Images") | |
| with gr.Column(): | |
| images_output_interpolate = gr.Gallery( | |
| label="Interpolated Images", selected_index=0, | |
| height=512, | |
| ) | |
| gif_interpolate = gr.Image(label="GIF") | |
| zip_output_interpolate = gr.File(label="Download ZIP") | |
| generate_images_button_interpolate.click( | |
| fn=display_interpolate_images, | |
| inputs=[ | |
| seed_interpolate, | |
| promptA, | |
| promptB, | |
| num_inference_steps_interpolate, | |
| num_images_interpolate, | |
| ], | |
| outputs=[ | |
| images_output_interpolate, | |
| gif_interpolate, | |
| zip_output_interpolate, | |
| ], | |
| ) | |
| seed_interpolate.change( | |
| fn=generate_seed_vis, | |
| inputs=[seed_interpolate], | |
| outputs=[seed_vis_interpolate], | |
| ) | |
| with gr.TabItem("Negative"): | |
| gr.Markdown("Observe the effect of negative prompts.") | |
| gr.HTML(read_html("html/negative.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_negative = gr.Textbox( | |
| lines=1, | |
| label="Prompt", | |
| value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
| ) | |
| neg_prompt = gr.Textbox( | |
| lines=1, label="Negative Prompt", value="Yellow" | |
| ) | |
| num_inference_steps_negative = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| step=1, | |
| value=8, | |
| label="Number of Inference Steps per Image", | |
| ) | |
| with gr.Row(): | |
| seed_negative = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=14, label="Seed" | |
| ) | |
| seed_vis_negative = gr.Plot( | |
| value=generate_seed_vis(14), label="Seed" | |
| ) | |
| generate_images_button_negative = gr.Button("Generate Images") | |
| with gr.Column(): | |
| images_output_negative = gr.Image( | |
| label="Image without Negative Prompt" | |
| ) | |
| images_neg_output_negative = gr.Image( | |
| label="Image with Negative Prompt" | |
| ) | |
| zip_output_negative = gr.File(label="Download ZIP") | |
| seed_negative.change( | |
| fn=generate_seed_vis, inputs=[seed_negative], outputs=[seed_vis_negative] | |
| ) | |
| generate_images_button_negative.click( | |
| fn=display_negative_images, | |
| inputs=[ | |
| prompt_negative, | |
| seed_negative, | |
| num_inference_steps_negative, | |
| neg_prompt, | |
| ], | |
| outputs=[ | |
| images_output_negative, | |
| images_neg_output_negative, | |
| zip_output_negative, | |
| ], | |
| ) | |
| with gr.Tab("Credits"): | |
| gr.Markdown(""" | |
| Author: Adithya Kameswara Rao, Carnegie Mellon University. | |
| Advisor: David S. Touretzky, Carnegie Mellon University. | |
| This work was funded by a grant from NEOM Company, and by National Science Foundation award IIS-2112633. | |
| """) | |
| def run_dash(): | |
| app.run(host="127.0.0.1", port="8000") | |
| # def run_gradio(): | |
| # demo.queue() | |
| # _, _, public_url = demo.launch(share=True) | |
| # return public_url | |
| if __name__ == "__main__": | |
| thread = Thread(target=run_dash) | |
| thread.daemon = True | |
| thread.start() | |
| try: | |
| os.makedirs("outputs", exist_ok=True) | |
| demo.queue().launch(share=True) | |
| except KeyboardInterrupt: | |
| print("Server closed") | |