import os import shutil import subprocess import uuid from PIL import Image import gradio as gr UPLOAD_DIR = "./sessions" RESULTS_DIR = "./results" CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval" SAMPLE_DIR = "./sample_images" os.makedirs(RESULTS_DIR, exist_ok=True) os.makedirs(CHECKPOINTS_DIR, exist_ok=True) os.makedirs(SAMPLE_DIR, exist_ok=True) from huggingface_hub import hf_hub_download from shutil import copyfile REPO_ID = "hasnafk/SingleImageReflectionRemoval" MODEL_FILE = "310_net_G.pth" model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR) expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE) if not os.path.exists(expected_model_path): copyfile(model_path, expected_model_path) def generate_session_id(): return str(uuid.uuid4()) def randomize_file_name(original_name): extension = os.path.splitext(original_name)[1] new_name = f"{uuid.uuid4().hex}{extension}" return new_name def clear_session_files(session_id): session_dir = os.path.join(UPLOAD_DIR, session_id) if os.path.exists(session_dir): shutil.rmtree(session_dir) def reflection_removal(input_image, preprocess_type="resize_and_crop"): if preprocess_type not in ["resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"]: return "Invalid preprocessing type selected. Please choose a valid option." print("Preprocessing Type:", preprocess_type) print("Input Image:", input_image) session_id = generate_session_id() session_dir = os.path.join(UPLOAD_DIR, session_id) upload_dir = os.path.join(session_dir, "uploads") os.makedirs(upload_dir, exist_ok=True) if not input_image or not os.path.exists(input_image): return "No image was provided or file was cleared. Please upload a valid image." randomized_name = randomize_file_name(os.path.basename(input_image)) file_path = os.path.join(upload_dir, randomized_name) shutil.copy(input_image, file_path) input_filename = os.path.splitext(randomized_name)[0] cmd = [ "python", "test.py", "--dataroot", upload_dir, "--name", "SingleImageReflectionRemoval", "--model", "test", "--netG", "unet_256", "--direction", "AtoB", "--dataset_mode", "single", "--norm", "batch", "--epoch", "310", "--num_test", "1", "--gpu_ids", "-1", "--preprocess", preprocess_type ] attempt = 0 while True: attempt += 1 try: subprocess.run(cmd, check=True) break except subprocess.CalledProcessError as e: cmd = [ "python", "test.py", "--dataroot", upload_dir, "--name", "SingleImageReflectionRemoval", "--model", "test", "--netG", "unet_256", "--direction", "AtoB", "--dataset_mode", "single", "--norm", "batch", "--epoch", "310", "--num_test", "1", "--gpu_ids", "-1", ] if attempt > 2: return "No results found. Please try again with a different image." output_image = None for root, _, files in os.walk(RESULTS_DIR): for file in files: if file.startswith(input_filename) and file.endswith("_fake.png"): result_path = os.path.join(root, file) output_image = Image.open(result_path) if preprocess_type not in ["crop", "none"]: input_image = Image.open(input_image) output_image = output_image.resize(input_image.size) os.remove(result_path) elif file.startswith(input_filename) and file.endswith("_real.png"): real_path = os.path.join(root, file) os.remove(real_path) clear_session_files(session_id) if output_image: return output_image return "No results found." def use_sample_image(sample_image_name): sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name) if not os.path.exists(sample_image_path): return "Sample image not found." return sample_image_path sample_images = [ file for file in os.listdir(SAMPLE_DIR) if file.endswith((".jpg", ".jpeg", ".png")) ] preprocess_options = [ "resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none" ] iface = gr.Interface( fn=lambda input_image, preprocess_type: reflection_removal(input_image, preprocess_type or "resize_and_crop"), inputs=[ gr.Image(type="filepath", label="Upload Image (JPG/PNG)"), gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop") ], outputs=gr.Image(label="Result after Reflection Removal"), examples=[ [os.path.join(SAMPLE_DIR, img), "resize_and_crop"] for img in sample_images ], title="Reflection Remover with Pix2Pix", description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.", ) if __name__ == "__main__": iface.launch()