aryrk
[feat] resize output image to match input dimension
60c5500
import os
import shutil
import subprocess
import uuid
from PIL import Image
import gradio as gr
UPLOAD_DIR = "./sessions"
RESULTS_DIR = "./results"
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
SAMPLE_DIR = "./sample_images"
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)
from huggingface_hub import hf_hub_download
from shutil import copyfile
REPO_ID = "hasnafk/SingleImageReflectionRemoval"
MODEL_FILE = "310_net_G.pth"
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
if not os.path.exists(expected_model_path):
copyfile(model_path, expected_model_path)
def generate_session_id():
return str(uuid.uuid4())
def randomize_file_name(original_name):
extension = os.path.splitext(original_name)[1]
new_name = f"{uuid.uuid4().hex}{extension}"
return new_name
def clear_session_files(session_id):
session_dir = os.path.join(UPLOAD_DIR, session_id)
if os.path.exists(session_dir):
shutil.rmtree(session_dir)
def reflection_removal(input_image, preprocess_type="resize_and_crop"):
if preprocess_type not in ["resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"]:
return "Invalid preprocessing type selected. Please choose a valid option."
print("Preprocessing Type:", preprocess_type)
print("Input Image:", input_image)
session_id = generate_session_id()
session_dir = os.path.join(UPLOAD_DIR, session_id)
upload_dir = os.path.join(session_dir, "uploads")
os.makedirs(upload_dir, exist_ok=True)
if not input_image or not os.path.exists(input_image):
return "No image was provided or file was cleared. Please upload a valid image."
randomized_name = randomize_file_name(os.path.basename(input_image))
file_path = os.path.join(upload_dir, randomized_name)
shutil.copy(input_image, file_path)
input_filename = os.path.splitext(randomized_name)[0]
cmd = [
"python", "test.py",
"--dataroot", upload_dir,
"--name", "SingleImageReflectionRemoval",
"--model", "test", "--netG", "unet_256",
"--direction", "AtoB", "--dataset_mode", "single",
"--norm", "batch", "--epoch", "310",
"--num_test", "1",
"--gpu_ids", "-1",
"--preprocess", preprocess_type
]
attempt = 0
while True:
attempt += 1
try:
subprocess.run(cmd, check=True)
break
except subprocess.CalledProcessError as e:
cmd = [
"python", "test.py",
"--dataroot", upload_dir,
"--name", "SingleImageReflectionRemoval",
"--model", "test", "--netG", "unet_256",
"--direction", "AtoB", "--dataset_mode", "single",
"--norm", "batch", "--epoch", "310",
"--num_test", "1",
"--gpu_ids", "-1",
]
if attempt > 2:
return "No results found. Please try again with a different image."
output_image = None
for root, _, files in os.walk(RESULTS_DIR):
for file in files:
if file.startswith(input_filename) and file.endswith("_fake.png"):
result_path = os.path.join(root, file)
output_image = Image.open(result_path)
if preprocess_type not in ["crop", "none"]:
input_image = Image.open(input_image)
output_image = output_image.resize(input_image.size)
os.remove(result_path)
elif file.startswith(input_filename) and file.endswith("_real.png"):
real_path = os.path.join(root, file)
os.remove(real_path)
clear_session_files(session_id)
if output_image:
return output_image
return "No results found."
def use_sample_image(sample_image_name):
sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name)
if not os.path.exists(sample_image_path):
return "Sample image not found."
return sample_image_path
sample_images = [
file for file in os.listdir(SAMPLE_DIR)
if file.endswith((".jpg", ".jpeg", ".png"))
]
preprocess_options = [
"resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"
]
iface = gr.Interface(
fn=lambda input_image, preprocess_type: reflection_removal(input_image, preprocess_type or "resize_and_crop"),
inputs=[
gr.Image(type="filepath", label="Upload Image (JPG/PNG)"),
gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop")
],
outputs=gr.Image(label="Result after Reflection Removal"),
examples=[
[os.path.join(SAMPLE_DIR, img), "resize_and_crop"]
for img in sample_images
],
title="Reflection Remover with Pix2Pix",
description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.",
)
if __name__ == "__main__":
iface.launch()