dattarij's picture
Update app.py
4590ef9 verified
raw
history blame
4.69 kB
import os
import gradio as gr
from PIL import Image
import json
# Define global paths
BASE_PATH = "ContraCLIP/experiments/wip/"
EXPERIMENT_PATH = os.path.join(BASE_PATH, "ContraCLIP_stylegan2_ffhq1024-W-K21-D128-lss_beta_0.1-eps0.1_0.2-nonlinear_css_beta_0.5-contrastive_0.5-5000-expressions")
LATENT_CODES_DIR = os.path.join(EXPERIMENT_PATH, "results/stylegan2_ffhq1024-4/32_0.2_6.4")
SEMANTIC_DIPOLES_FILE = os.path.join(LATENT_CODES_DIR, "semantic_dipoles.json")
DEFAULT_IMAGE = "original_image.jpg"
# Load semantic dipoles
with open(SEMANTIC_DIPOLES_FILE, "r") as f:
semantic_dipoles = json.load(f)
# Transform semantic_dipoles into "A -> B" format
formatted_dipoles = [f"{pair[1]} -> {pair[0]}" for pair in semantic_dipoles]
# Helper to list all latent code folders
latent_code_folders = sorted(
[
folder
for folder in os.listdir(LATENT_CODES_DIR)
if os.path.isdir(os.path.join(LATENT_CODES_DIR, folder))
]
)
# Display predefined image paths based on semantic dipole index
def load_dipole_paths(latent_code):
latent_path = os.path.join(LATENT_CODES_DIR, latent_code, "paths_images")
paths = sorted(
[
f"path_{i:03d}"
for i in range(len(os.listdir(latent_path)))
]
)
return paths
def display_image(latent_code, formatted_dipole, frame_idx):
# Reverse-map "A -> B" format back to the index in semantic_dipoles
try:
index = formatted_dipoles.index(formatted_dipole)
except ValueError:
return f"Error: Semantic dipole '{formatted_dipole}' not found in the list."
path_dir = os.path.join(
LATENT_CODES_DIR, latent_code, "paths_images", f"path_{index:03d}"
)
frame_image_path = os.path.join(path_dir, f"{frame_idx:06d}.jpg")
if not os.path.exists(frame_image_path):
return f"Image not found: {frame_image_path}."
return Image.open(frame_image_path)
# Function to display GAN latent space interactive plot
def display_interactive_plot(latent_code):
# file_path = f"files/{LATENT_CODES_DIR}/{latent_code}/interactive_latent_space_{latent_code}.html"
file_path = f"/file/ContraCLIP/experiments/wip/ContraCLIP_stylegan2_ffhq1024-W-K21-D128-lss_beta_0.1-eps0.1_0.2-nonlinear_css_beta_0.5-contrastive_0.5-5000-expressions/results/stylegan2_ffhq1024-4/32_0.2_6.4/{latent_code}/interactive_latent_space_{latent_code}.html"
iframe_html = f'<iframe src="{file_path}" width="800" height="600" frameborder="0"></iframe>'
return iframe_html
# Gradio Interface
def build_interface():
with gr.Blocks() as demo:
gr.Markdown("# ContraCLIP-based Image Editing and Visualization Demo")
with gr.Row():
with gr.Column():
gr.Markdown("### Select Latent Code and Semantic Dipole")
latent_code_dropdown = gr.Dropdown(
latent_code_folders,
label="Latent Code",
value=latent_code_folders[0],
)
semantic_dipole_dropdown = gr.Dropdown(
formatted_dipoles,
label="Semantic Dipole",
value=formatted_dipoles[0], # Set default value
)
frame_slider = gr.Slider(
0, 32, step=1, label="Frame Index"
)
with gr.Column():
image_display = gr.Image(label="Image Preview")
html_display = gr.HTML(label="Interactive Latent Space")
# Update image based on latent code, semantic dipole, and frame index
def update_image(latent_code, semantic_dipole, frame_idx):
return display_image(latent_code, semantic_dipole, frame_idx)
# Update HTML display for the selected latent code
def update_html(latent_code):
return display_interactive_plot(latent_code)
# Link dropdowns and slider
frame_slider.change(
update_image,
[latent_code_dropdown, semantic_dipole_dropdown, frame_slider],
[image_display],
)
latent_code_dropdown.change(
update_html, [latent_code_dropdown], [html_display]
)
# Set up initial values
demo.load(
lambda: display_image(latent_code_folders[0], semantic_dipoles[0], 0),
inputs=[],
outputs=[image_display],
)
demo.load(
lambda: display_interactive_plot(latent_code_folders[0]),
inputs=[],
outputs=[html_display],
)
return demo
if __name__ == "__main__":
interface = build_interface()
interface.launch()