|
import os |
|
import shutil |
|
import subprocess |
|
import uuid |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
UPLOAD_DIR = "./sessions" |
|
RESULTS_DIR = "./results" |
|
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval" |
|
SAMPLE_DIR = "./sample_images" |
|
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
os.makedirs(CHECKPOINTS_DIR, exist_ok=True) |
|
os.makedirs(SAMPLE_DIR, exist_ok=True) |
|
|
|
from huggingface_hub import hf_hub_download |
|
from shutil import copyfile |
|
|
|
REPO_ID = "hasnafk/SingleImageReflectionRemoval" |
|
MODEL_FILE = "310_net_G.pth" |
|
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR) |
|
|
|
expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE) |
|
if not os.path.exists(expected_model_path): |
|
copyfile(model_path, expected_model_path) |
|
|
|
def generate_session_id(): |
|
return str(uuid.uuid4()) |
|
|
|
def randomize_file_name(original_name): |
|
extension = os.path.splitext(original_name)[1] |
|
new_name = f"{uuid.uuid4().hex}{extension}" |
|
return new_name |
|
|
|
def clear_session_files(session_id): |
|
session_dir = os.path.join(UPLOAD_DIR, session_id) |
|
if os.path.exists(session_dir): |
|
shutil.rmtree(session_dir) |
|
|
|
def reflection_removal(input_image, preprocess_type="resize_and_crop"): |
|
if preprocess_type not in ["resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"]: |
|
return "Invalid preprocessing type selected. Please choose a valid option." |
|
|
|
print("Preprocessing Type:", preprocess_type) |
|
print("Input Image:", input_image) |
|
|
|
session_id = generate_session_id() |
|
session_dir = os.path.join(UPLOAD_DIR, session_id) |
|
upload_dir = os.path.join(session_dir, "uploads") |
|
os.makedirs(upload_dir, exist_ok=True) |
|
|
|
if not input_image or not os.path.exists(input_image): |
|
return "No image was provided or file was cleared. Please upload a valid image." |
|
|
|
randomized_name = randomize_file_name(os.path.basename(input_image)) |
|
file_path = os.path.join(upload_dir, randomized_name) |
|
shutil.copy(input_image, file_path) |
|
|
|
input_filename = os.path.splitext(randomized_name)[0] |
|
|
|
cmd = [ |
|
"python", "test.py", |
|
"--dataroot", upload_dir, |
|
"--name", "SingleImageReflectionRemoval", |
|
"--model", "test", "--netG", "unet_256", |
|
"--direction", "AtoB", "--dataset_mode", "single", |
|
"--norm", "batch", "--epoch", "310", |
|
"--num_test", "1", |
|
"--gpu_ids", "-1", |
|
"--preprocess", preprocess_type |
|
] |
|
attempt = 0 |
|
while True: |
|
attempt += 1 |
|
try: |
|
subprocess.run(cmd, check=True) |
|
break |
|
except subprocess.CalledProcessError as e: |
|
cmd = [ |
|
"python", "test.py", |
|
"--dataroot", upload_dir, |
|
"--name", "SingleImageReflectionRemoval", |
|
"--model", "test", "--netG", "unet_256", |
|
"--direction", "AtoB", "--dataset_mode", "single", |
|
"--norm", "batch", "--epoch", "310", |
|
"--num_test", "1", |
|
"--gpu_ids", "-1", |
|
] |
|
if attempt > 2: |
|
return "No results found. Please try again with a different image." |
|
|
|
output_image = None |
|
for root, _, files in os.walk(RESULTS_DIR): |
|
for file in files: |
|
if file.startswith(input_filename) and file.endswith("_fake.png"): |
|
result_path = os.path.join(root, file) |
|
output_image = Image.open(result_path) |
|
|
|
if preprocess_type not in ["crop", "none"]: |
|
input_image = Image.open(input_image) |
|
output_image = output_image.resize(input_image.size) |
|
|
|
os.remove(result_path) |
|
elif file.startswith(input_filename) and file.endswith("_real.png"): |
|
real_path = os.path.join(root, file) |
|
os.remove(real_path) |
|
|
|
clear_session_files(session_id) |
|
|
|
if output_image: |
|
return output_image |
|
return "No results found." |
|
|
|
def use_sample_image(sample_image_name): |
|
sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name) |
|
if not os.path.exists(sample_image_path): |
|
return "Sample image not found." |
|
return sample_image_path |
|
|
|
sample_images = [ |
|
file for file in os.listdir(SAMPLE_DIR) |
|
if file.endswith((".jpg", ".jpeg", ".png")) |
|
] |
|
|
|
preprocess_options = [ |
|
"resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none" |
|
] |
|
|
|
iface = gr.Interface( |
|
fn=lambda input_image, preprocess_type: reflection_removal(input_image, preprocess_type or "resize_and_crop"), |
|
inputs=[ |
|
gr.Image(type="filepath", label="Upload Image (JPG/PNG)"), |
|
gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop") |
|
], |
|
outputs=gr.Image(label="Result after Reflection Removal"), |
|
examples=[ |
|
[os.path.join(SAMPLE_DIR, img), "resize_and_crop"] |
|
for img in sample_images |
|
], |
|
title="Reflection Remover with Pix2Pix", |
|
description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|