import os import gradio as gr from PIL import Image # Define global paths BASE_PATH = "disentangled-image-editing-final-project/ContraCLIP/experiments/wip" LATENT_CODES_DIR = os.path.join(BASE_PATH, "results/stylegan2_ffhq1024-4/32_0.2_6.4") SEMANTIC_DIPOLES_FILE = os.path.join(LATENT_CODES_DIR, "semantic_dipoles.json") DEFAULT_IMAGE = "original_image.jpg" # Load semantic dipoles with open(SEMANTIC_DIPOLES_FILE, "r") as f: semantic_dipoles = json.load(f) # Helper to list all latent code folders latent_code_folders = sorted( [ folder for folder in os.listdir(LATENT_CODES_DIR) if os.path.isdir(os.path.join(LATENT_CODES_DIR, folder)) ] ) # Display predefined image paths based on semantic dipole index def load_dipole_paths(latent_code): latent_path = os.path.join(LATENT_CODES_DIR, latent_code, "paths_images") paths = sorted( [ f"path_{i:03d}" for i in range(len(os.listdir(latent_path))) ] ) return paths # Function to display images def display_image(latent_code, semantic_dipole, frame_idx): index = semantic_dipoles.index(semantic_dipole) path_dir = os.path.join(LATENT_CODES_DIR, latent_code, "paths_images", f"path_{index:03d}") frame_image_path = os.path.join(path_dir, f"{frame_idx:06d}.jpg") if not os.path.exists(frame_image_path): return f"Image not found: {frame_image_path}" return Image.open(frame_image_path) # Function to display GAN latent space interactive plot def display_interactive_plot(latent_code): html_file = os.path.join(LATENT_CODES_DIR, latent_code, f"interactive_latent_space_{latent_code}.html") if not os.path.exists(html_file): return f"Interactive file not found: {html_file}" with open(html_file, "r") as file: return file.read() # Gradio Interface def build_interface(): with gr.Blocks() as demo: gr.Markdown("# ContraCLIP-based Image Editing and Visualization Demo") with gr.Row(): with gr.Column(): gr.Markdown("### Select Latent Code and Semantic Dipole") latent_code_dropdown = gr.Dropdown( latent_code_folders, label="Latent Code", value=latent_code_folders[0], ) semantic_dipole_dropdown = gr.Dropdown( semantic_dipoles, label="Semantic Dipole", value=semantic_dipoles[0], ) frame_slider = gr.Slider( 0, 32, step=1, label="Frame Index" ) with gr.Column(): image_display = gr.Image(label="Image Preview") html_display = gr.HTML(label="Interactive Latent Space") # Update image based on latent code, semantic dipole, and frame index def update_image(latent_code, semantic_dipole, frame_idx): return display_image(latent_code, semantic_dipole, frame_idx) # Update HTML display for the selected latent code def update_html(latent_code): return display_interactive_plot(latent_code) # Link dropdowns and slider frame_slider.change( update_image, [latent_code_dropdown, semantic_dipole_dropdown, frame_slider], [image_display], ) latent_code_dropdown.change( update_html, [latent_code_dropdown], [html_display] ) # Set up initial values demo.load( lambda: display_image(latent_code_folders[0], semantic_dipoles[0], 0), inputs=[], outputs=[image_display], ) demo.load( lambda: display_interactive_plot(latent_code_folders[0]), inputs=[], outputs=[html_display], ) return demo if __name__ == "__main__": interface = build_interface() interface.launch()