import os import gradio as gr from PIL import Image import json # Define global paths BASE_PATH = "ContraCLIP/experiments/wip/" EXPERIMENT_PATH = os.path.join(BASE_PATH, "ContraCLIP_stylegan2_ffhq1024-W-K21-D128-lss_beta_0.1-eps0.1_0.2-nonlinear_css_beta_0.5-contrastive_0.5-5000-expressions") LATENT_CODES_DIR = os.path.join(EXPERIMENT_PATH, "results/stylegan2_ffhq1024-4/32_0.2_6.4") SEMANTIC_DIPOLES_FILE = os.path.join(LATENT_CODES_DIR, "semantic_dipoles.json") DEFAULT_IMAGE = "original_image.jpg" # Load semantic dipoles with open(SEMANTIC_DIPOLES_FILE, "r") as f: semantic_dipoles = json.load(f) # Transform semantic_dipoles into "A -> B" format formatted_dipoles = [f"{pair[1]} -> {pair[0]}" for pair in semantic_dipoles] # Helper to list all latent code folders latent_code_folders = sorted( [ folder for folder in os.listdir(LATENT_CODES_DIR) if os.path.isdir(os.path.join(LATENT_CODES_DIR, folder)) ] ) # Display predefined image paths based on semantic dipole index def load_dipole_paths(latent_code): latent_path = os.path.join(LATENT_CODES_DIR, latent_code, "paths_images") paths = sorted( [ f"path_{i:03d}" for i in range(len(os.listdir(latent_path))) ] ) return paths def display_image(latent_code, formatted_dipole, frame_idx): # Reverse-map "A -> B" format back to the index in semantic_dipoles try: index = formatted_dipoles.index(formatted_dipole) except ValueError: return f"Error: Semantic dipole '{formatted_dipole}' not found in the list." path_dir = os.path.join( LATENT_CODES_DIR, latent_code, "paths_images", f"path_{index:03d}" ) frame_image_path = os.path.join(path_dir, f"{frame_idx:06d}.jpg") if not os.path.exists(frame_image_path): return f"Image not found: {frame_image_path}." return Image.open(frame_image_path) # Function to display GAN latent space interactive plot def display_interactive_plot(latent_code): # file_path = f"files/{LATENT_CODES_DIR}/{latent_code}/interactive_latent_space_{latent_code}.html" file_path = f"/file/ContraCLIP/experiments/wip/ContraCLIP_stylegan2_ffhq1024-W-K21-D128-lss_beta_0.1-eps0.1_0.2-nonlinear_css_beta_0.5-contrastive_0.5-5000-expressions/results/stylegan2_ffhq1024-4/32_0.2_6.4/{latent_code}/interactive_latent_space_{latent_code}.html" iframe_html = f'' return iframe_html # Gradio Interface def build_interface(): with gr.Blocks() as demo: gr.Markdown("# ContraCLIP-based Image Editing and Visualization Demo") with gr.Row(): with gr.Column(): gr.Markdown("### Select Latent Code and Semantic Dipole") latent_code_dropdown = gr.Dropdown( latent_code_folders, label="Latent Code", value=latent_code_folders[0], ) semantic_dipole_dropdown = gr.Dropdown( formatted_dipoles, label="Semantic Dipole", value=formatted_dipoles[0], # Set default value ) frame_slider = gr.Slider( 0, 32, step=1, label="Frame Index" ) with gr.Column(): image_display = gr.Image(label="Image Preview") html_display = gr.HTML(label="Interactive Latent Space") # Update image based on latent code, semantic dipole, and frame index def update_image(latent_code, semantic_dipole, frame_idx): return display_image(latent_code, semantic_dipole, frame_idx) # Update HTML display for the selected latent code def update_html(latent_code): return display_interactive_plot(latent_code) # Link dropdowns and slider frame_slider.change( update_image, [latent_code_dropdown, semantic_dipole_dropdown, frame_slider], [image_display], ) latent_code_dropdown.change( update_html, [latent_code_dropdown], [html_display] ) # Set up initial values demo.load( lambda: display_image(latent_code_folders[0], semantic_dipoles[0], 0), inputs=[], outputs=[image_display], ) demo.load( lambda: display_interactive_plot(latent_code_folders[0]), inputs=[], outputs=[html_display], ) return demo if __name__ == "__main__": interface = build_interface() interface.launch()