dattarij's picture
Update app.py
c837d24 verified
raw
history blame
3.97 kB
import os
import gradio as gr
from PIL import Image
# Define global paths
BASE_PATH = "disentangled-image-editing-final-project/ContraCLIP/experiments/wip"
LATENT_CODES_DIR = os.path.join(BASE_PATH, "results/stylegan2_ffhq1024-4/32_0.2_6.4")
SEMANTIC_DIPOLES_FILE = os.path.join(LATENT_CODES_DIR, "semantic_dipoles.json")
DEFAULT_IMAGE = "original_image.jpg"
# Load semantic dipoles
with open(SEMANTIC_DIPOLES_FILE, "r") as f:
semantic_dipoles = json.load(f)
# Helper to list all latent code folders
latent_code_folders = sorted(
[
folder
for folder in os.listdir(LATENT_CODES_DIR)
if os.path.isdir(os.path.join(LATENT_CODES_DIR, folder))
]
)
# Display predefined image paths based on semantic dipole index
def load_dipole_paths(latent_code):
latent_path = os.path.join(LATENT_CODES_DIR, latent_code, "paths_images")
paths = sorted(
[
f"path_{i:03d}"
for i in range(len(os.listdir(latent_path)))
]
)
return paths
# Function to display images
def display_image(latent_code, semantic_dipole, frame_idx):
index = semantic_dipoles.index(semantic_dipole)
path_dir = os.path.join(LATENT_CODES_DIR, latent_code, "paths_images", f"path_{index:03d}")
frame_image_path = os.path.join(path_dir, f"{frame_idx:06d}.jpg")
if not os.path.exists(frame_image_path):
return f"Image not found: {frame_image_path}"
return Image.open(frame_image_path)
# Function to display GAN latent space interactive plot
def display_interactive_plot(latent_code):
html_file = os.path.join(LATENT_CODES_DIR, latent_code, f"interactive_latent_space_{latent_code}.html")
if not os.path.exists(html_file):
return f"Interactive file not found: {html_file}"
with open(html_file, "r") as file:
return file.read()
# Gradio Interface
def build_interface():
with gr.Blocks() as demo:
gr.Markdown("# ContraCLIP-based Image Editing and Visualization Demo")
with gr.Row():
with gr.Column():
gr.Markdown("### Select Latent Code and Semantic Dipole")
latent_code_dropdown = gr.Dropdown(
latent_code_folders,
label="Latent Code",
value=latent_code_folders[0],
)
semantic_dipole_dropdown = gr.Dropdown(
semantic_dipoles,
label="Semantic Dipole",
value=semantic_dipoles[0],
)
frame_slider = gr.Slider(
0, 32, step=1, label="Frame Index"
)
with gr.Column():
image_display = gr.Image(label="Image Preview")
html_display = gr.HTML(label="Interactive Latent Space")
# Update image based on latent code, semantic dipole, and frame index
def update_image(latent_code, semantic_dipole, frame_idx):
return display_image(latent_code, semantic_dipole, frame_idx)
# Update HTML display for the selected latent code
def update_html(latent_code):
return display_interactive_plot(latent_code)
# Link dropdowns and slider
frame_slider.change(
update_image,
[latent_code_dropdown, semantic_dipole_dropdown, frame_slider],
[image_display],
)
latent_code_dropdown.change(
update_html, [latent_code_dropdown], [html_display]
)
# Set up initial values
demo.load(
lambda: display_image(latent_code_folders[0], semantic_dipoles[0], 0),
inputs=[],
outputs=[image_display],
)
demo.load(
lambda: display_interactive_plot(latent_code_folders[0]),
inputs=[],
outputs=[html_display],
)
return demo
if __name__ == "__main__":
interface = build_interface()
interface.launch()