Spaces:
Sleeping
Sleeping
import time | |
from typing import cast | |
from comfydeploy import ComfyDeploy | |
import os | |
import gradio as gr | |
from gradio.components.image_editor import EditorValue | |
from PIL import Image | |
import requests | |
import dotenv | |
from gradio_imageslider import ImageSlider | |
from io import BytesIO | |
import base64 | |
import numpy as np | |
from loguru import logger | |
dotenv.load_dotenv() | |
API_KEY = os.environ.get("API_KEY") | |
CLEANER_DEPLOYMENT_ID = os.environ.get( | |
"CLEANER_DEPLOYMENT_ID", "CLEANER_DEPLOYMENT_ID_NOT_SET" | |
) | |
MASKER_DEPLOYMENT_ID = os.environ.get( | |
"MASKER_DEPLOYMENT_ID", "MASKER_DEPLOYMENT_ID_NOT_SET" | |
) | |
if not API_KEY: | |
raise ValueError("Please set API_KEY in your environment variables") | |
if ( | |
not CLEANER_DEPLOYMENT_ID | |
or CLEANER_DEPLOYMENT_ID == "CLEANER_DEPLOYMENT_ID_NOT_SET" | |
): | |
raise ValueError("Please set CLEANER_DEPLOYMENT_ID in your environment variables") | |
client = ComfyDeploy(bearer_auth=API_KEY) | |
def get_base64_from_image(image: Image.Image) -> str: | |
buffered: BytesIO = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def compute_mask( | |
image: Image.Image | str | None, progress: gr.Progress = gr.Progress() | |
) -> Image.Image | None: | |
progress(0, desc="Preparing inputs...") | |
if image is None: | |
return None | |
image = resize_image(image) | |
image_base64 = get_base64_from_image(image) | |
# Prepare inputs | |
inputs: dict = { | |
"input_image": f"data:image/png;base64,{image_base64}", | |
"dilation_1_iterations": 10, | |
"dilation_2_iterations": 15, | |
"mask_blur_amount": 0, | |
} | |
# Call ComfyDeploy API | |
try: | |
result = client.run.create( | |
request={"deployment_id": MASKER_DEPLOYMENT_ID, "inputs": inputs} | |
) | |
if result and result.object: | |
run_id: str = result.object.run_id | |
progress(0, desc="Starting processing...") | |
# Wait for the result | |
while True: | |
run_result = client.run.get(run_id=run_id) | |
if not run_result.object: | |
continue | |
progress_value = run_result.object.progress or 0 | |
status = run_result.object.live_status or "Cold starting..." | |
progress(progress_value, desc=f"Status: {status}") | |
if run_result.object.status == "success": | |
for output in run_result.object.outputs or []: | |
if output.data and output.data.images: | |
image_url: str = output.data.images[0].url | |
# Download and return the mask image | |
response: requests.Response = requests.get(image_url) | |
mask_image: Image.Image = Image.open( | |
BytesIO(response.content) | |
) | |
return mask_image | |
return None | |
elif run_result.object.status == "failed": | |
logger.debug("Processing failed") | |
return None | |
time.sleep(1) # Wait for 1 second before checking the status again | |
except Exception as e: | |
logger.debug(f"Error: {e}") | |
return None | |
def create_editor_value(image: Image.Image, mask: Image.Image) -> EditorValue: | |
# Convert image to numpy array | |
image_np = np.array(image) | |
# Resize mask to match image dimensions | |
mask_resized = mask.resize((image_np.shape[1], image_np.shape[0]), Image.NEAREST) | |
mask_np = np.array(mask_resized) | |
# Ensure mask is grayscale | |
if len(mask_np.shape) == 3: | |
mask_np = mask_np[:, :, -1] | |
# Create the layers array | |
layers = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8) | |
layers[:, :, 3] = mask_np | |
# Create the composite image | |
composite = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8) | |
composite[:, :, :3] = image_np | |
composite[:, :, 3] = np.where(mask_np == 255, 0, 255) | |
return { | |
"background": image_np, | |
"layers": [layers], | |
"composite": composite, | |
} | |
def run_masking( | |
image: np.ndarray | Image.Image | str | None, | |
progress: gr.Progress = gr.Progress(), | |
profile: gr.OAuthProfile | None = None, | |
) -> EditorValue | None: | |
if image is None: | |
return None | |
if profile is None: | |
gr.Info("Please log in to process the image.") | |
return None | |
# Convert np.ndarray to Image.Image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
elif isinstance(image, str): | |
image = Image.open(image) | |
mask = compute_mask(image, progress) | |
if mask is None: | |
return None | |
# Use the new create_editor_value function | |
return create_editor_value(image, mask) | |
def remove_objects( | |
image: Image.Image | str | None, | |
mask: Image.Image | str | None, | |
user_data: dict, | |
progress: gr.Progress = gr.Progress(), | |
) -> Image.Image | None: | |
progress(0, desc="Preparing inputs...") | |
if image is None or mask is None: | |
return None | |
if isinstance(mask, str): | |
mask = Image.open(mask) | |
if isinstance(image, str): | |
image = Image.open(image) | |
image_base64 = get_base64_from_image(image) | |
mask_base64 = get_base64_from_image(mask) | |
# Prepare inputs | |
inputs: dict = { | |
"image": f"data:image/png;base64,{image_base64}", | |
"mask": f"data:image/png;base64,{mask_base64}", | |
# "run_metatada": str( | |
# { | |
# "source": "HF", | |
# "user": user_data, | |
# } | |
# ), | |
} | |
# Call ComfyDeploy API | |
try: | |
result = client.run.create( | |
request={"deployment_id": CLEANER_DEPLOYMENT_ID, "inputs": inputs} | |
) | |
if result and result.object: | |
run_id: str = result.object.run_id | |
progress(0, desc="Starting processing...") | |
# Wait for the result | |
while True: | |
run_result = client.run.get(run_id=run_id) | |
if not run_result.object: | |
continue | |
progress_value = ( | |
run_result.object.progress | |
if run_result.object.progress is not None | |
else 0 | |
) | |
status = ( | |
run_result.object.live_status | |
if run_result.object.live_status is not None | |
else "Cold starting..." | |
) | |
progress(progress_value, desc=f"Status: {status}") | |
if run_result.object.status == "success": | |
for output in run_result.object.outputs or []: | |
if output.data and output.data.images: | |
image_url: str = output.data.images[0].url | |
# Download and return both the original and processed images | |
response: requests.Response = requests.get(image_url) | |
processed_image: Image.Image = Image.open( | |
BytesIO(response.content) | |
) | |
return processed_image | |
return None | |
elif run_result.object.status == "failed": | |
logger.debug("Processing failed") | |
return None | |
time.sleep(1) # Wait for 1 second before checking the status again | |
except Exception as e: | |
logger.debug(f"Error: {e}") | |
return None | |
def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image: | |
if img.width <= min_side_length and img.height <= min_side_length: | |
return img | |
aspect_ratio = img.width / img.height | |
if img.width < img.height: | |
new_height = int(min_side_length / aspect_ratio) | |
return img.resize((min_side_length, new_height)) | |
new_width = int(min_side_length * aspect_ratio) | |
return img.resize((new_width, min_side_length)) | |
def get_profile(profile) -> dict: | |
return { | |
"username": profile.username, | |
"profile": profile.profile, | |
"name": profile.name, | |
} | |
async def run_removal( | |
image_and_mask: EditorValue | None, | |
progress: gr.Progress = gr.Progress(), | |
profile: gr.OAuthProfile | None = None, | |
) -> tuple[Image.Image, Image.Image] | None: | |
if not image_and_mask: | |
gr.Info("Please upload an image and draw a mask") | |
return None | |
if profile is None: | |
gr.Info("Please log in to process the image.") | |
return None | |
user_data = get_profile(profile) | |
logger.debug("--------- RUN ----------") | |
logger.debug(user_data) | |
logger.debug("--------- RUN ----------") | |
image_np = image_and_mask["background"] | |
image_np = cast(np.ndarray, image_np) | |
# If the image is empty, return None | |
if np.sum(image_np) == 0: | |
gr.Info("Please upload an image") | |
return None | |
alpha_channel = image_and_mask["layers"][0] | |
alpha_channel = cast(np.ndarray, alpha_channel) | |
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
# if mask_np is empty, return None | |
if np.sum(mask_np) == 0: | |
gr.Info("Please mark the areas you want to remove") | |
return None | |
mask = Image.fromarray(mask_np) | |
mask = resize_image(mask) | |
image = Image.fromarray(image_np) | |
image = resize_image(image) | |
output = remove_objects( | |
image, # type: ignore | |
mask, # type: ignore | |
user_data, | |
progress, | |
) | |
if output is None: | |
gr.Info("Processing failed") | |
return None | |
progress(100, desc="Processing completed") | |
return image, output | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; text-align:center; flex-direction: column;"> | |
<h1 style="color: #333;">🧹 Room Cleaner</h1> | |
<div style="max-width: 800px; margin: 0 auto;"> | |
<p style="font-size: 16px;">Upload an image and use the pencil tool (✏️ icon at the bottom) to <b>mark the areas you want to remove</b>.</p> | |
<p style="font-size: 16px;"> | |
For best results, include the shadows and reflections of the objects you want to remove. | |
You can remove multiple objects at once. | |
If you forget to mask some parts of your object, it's likely that the model will reconstruct them. | |
</p> | |
<br> | |
<video width="640" height="360" controls style="margin: 0 auto; border-radius: 10px;"> | |
<source src="https://dropshare.blanchon.xyz/public/dropshare/room_cleaner_demo.mp4" type="video/mp4"> | |
</video> | |
<br> | |
<p style="font-size: 16px;">Finally, click on the <b>"Run"</b> button to process the image.</p> | |
<p style="font-size: 16px;">Wait for the processing to complete and compare the original and processed images using the slider.</p> | |
<p style="font-size: 16px;">⚠️ Note that the images are compressed to reduce the workloads of the demo. </p> | |
</div> | |
<div style="margin-top: 20px; display: flex; justify-content: center; gap: 10px;"> | |
<a href="https://x.com/JulienBlanchon"> | |
<img src="https://img.shields.io/badge/X-%23000000.svg?style=for-the-badge&logo=X&logoColor=white" alt="X Badge" style="border-radius: 3px;"/> | |
</a> | |
</div> | |
</div> | |
""") | |
login_button = gr.LoginButton(scale=8) | |
# ------ MASKING | |
with gr.Column(): | |
with gr.Row(equal_height=False): | |
# The image overflow, fix | |
input_image = gr.Image( | |
label="Input Image", | |
height="full", | |
width="full", | |
) | |
gr.HTML(""" | |
<h3 style="text-align: center;">Step 1: input image</h3> | |
<p style="text-align: center;">Upload an image of the room you want to clean.</p> | |
""") | |
with gr.Row(equal_height=False): | |
image_and_mask_auto = gr.ImageMask( | |
label="Image and Mask", | |
layers=False, | |
show_fullscreen_button=False, | |
sources=["upload"], | |
show_download_button=False, | |
interactive=True, | |
height="full", | |
width="full", | |
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), | |
transforms=[], | |
) | |
with gr.Column(): | |
gr.HTML(""" | |
<h3 style="text-align: center;">Step 2: Run masking</h3> | |
<p style="text-align: center;">Click get mask to get automatic masking and edit it after manually if needed.</p> | |
""") | |
compute_mask_btn = gr.ClearButton( | |
value="Get mask", | |
variant="primary", | |
size="lg", | |
components=[image_and_mask_auto], | |
) | |
compute_mask_btn.click( | |
fn=lambda _: gr.update(interactive=False, value="Processing..."), | |
inputs=[], | |
outputs=[compute_mask_btn], | |
api_name=False, | |
).then( | |
fn=run_masking, | |
inputs=[ | |
input_image, | |
], | |
outputs=[image_and_mask_auto], | |
api_name=False, | |
).then( | |
fn=lambda _: gr.update(interactive=True, value="Get mask"), | |
inputs=[], | |
outputs=[compute_mask_btn], | |
api_name=False, | |
) | |
# ------ REMOVAL | |
with gr.Row(equal_height=False): | |
image_slider = ImageSlider( | |
label="Result", | |
interactive=False, | |
) | |
with gr.Column(): | |
gr.HTML(""" | |
<h3 style="text-align: center;">Step 3: Run removal</h3> | |
<p style="text-align: center;">Click run to remove the objects from the image.</p> | |
""") | |
process_btn = gr.ClearButton( | |
value="Run", | |
variant="primary", | |
size="lg", | |
components=[image_slider], | |
) | |
process_btn.click( | |
fn=lambda _: gr.update(interactive=False, value="Processing..."), | |
inputs=[], | |
outputs=[process_btn], | |
api_name=False, | |
).then( | |
fn=run_removal, | |
inputs=[ | |
image_and_mask_auto, | |
], | |
outputs=[image_slider], | |
api_name=False, | |
).then( | |
fn=lambda _: gr.update(interactive=True, value="Run"), | |
inputs=[], | |
outputs=[process_btn], | |
api_name=False, | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
debug=False, | |
share=False, | |
show_api=False, | |
) | |