Spaces:
Running
Running
import time | |
from typing import cast | |
from comfydeploy import ComfyDeploy | |
import os | |
import gradio as gr | |
from gradio.components.image_editor import EditorValue | |
from PIL import Image | |
import requests | |
import dotenv | |
from gradio_imageslider import ImageSlider | |
from io import BytesIO | |
import base64 | |
import numpy as np | |
dotenv.load_dotenv() | |
API_KEY = os.environ.get("API_KEY") | |
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "DEPLOYMENT_ID_NOT_SET") | |
if not API_KEY: | |
raise ValueError("Please set API_KEY in your environment variables") | |
if not DEPLOYMENT_ID or DEPLOYMENT_ID == "DEPLOYMENT_ID_NOT_SET": | |
raise ValueError("Please set DEPLOYMENT_ID in your environment variables") | |
client = ComfyDeploy(bearer_auth=API_KEY) | |
def get_base64_from_image(image: Image.Image) -> str: | |
buffered: BytesIO = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def process_image( | |
image: Image.Image | str | None, | |
mask: Image.Image | str | None, | |
user_data: dict, | |
progress: gr.Progress = gr.Progress(), | |
) -> Image.Image | None: | |
progress(0, desc="Preparing inputs...") | |
if image is None or mask is None: | |
return None | |
if isinstance(mask, str): | |
mask = Image.open(mask) | |
if isinstance(image, str): | |
image = Image.open(image) | |
image_base64 = get_base64_from_image(image) | |
mask_base64 = get_base64_from_image(mask) | |
# Prepare inputs | |
inputs: dict = { | |
"image": f"data:image/png;base64,{image_base64}", | |
"mask": f"data:image/png;base64,{mask_base64}", | |
# "run_metatada": str( | |
# { | |
# "source": "HF", | |
# "user": user_data, | |
# } | |
# ), | |
} | |
# Call ComfyDeploy API | |
try: | |
result = client.run.create( | |
request={"deployment_id": DEPLOYMENT_ID, "inputs": inputs} | |
) | |
if result and result.object: | |
run_id: str = result.object.run_id | |
progress(0, desc="Starting processing...") | |
# Wait for the result | |
while True: | |
run_result = client.run.get(run_id=run_id) | |
if not run_result.object: | |
continue | |
progress_value = ( | |
run_result.object.progress | |
if run_result.object.progress is not None | |
else 0 | |
) | |
status = ( | |
run_result.object.live_status | |
if run_result.object.live_status is not None | |
else "Cold starting..." | |
) | |
progress(progress_value, desc=f"Status: {status}") | |
if run_result.object.status == "success": | |
for output in run_result.object.outputs or []: | |
if output.data and output.data.images: | |
image_url: str = output.data.images[0].url | |
# Download and return both the original and processed images | |
response: requests.Response = requests.get(image_url) | |
processed_image: Image.Image = Image.open( | |
BytesIO(response.content) | |
) | |
return processed_image | |
return None | |
elif run_result.object.status == "failed": | |
print("Processing failed") | |
return None | |
time.sleep(1) # Wait for 1 second before checking the status again | |
except Exception as e: | |
print(f"Error: {e}") | |
return None | |
def make_example(background_path: str, mask_path: str) -> EditorValue: | |
example1_background = np.array(Image.open(background_path)) | |
example1_mask_only = np.array(Image.open(mask_path))[:, :, -1] | |
example1_layers = np.zeros( | |
(example1_background.shape[0], example1_background.shape[1], 4), dtype=np.uint8 | |
) | |
example1_layers[:, :, 3] = example1_mask_only | |
example1_composite = np.zeros( | |
(example1_background.shape[0], example1_background.shape[1], 4), dtype=np.uint8 | |
) | |
example1_composite[:, :, :3] = example1_background | |
example1_composite[:, :, 3] = np.where(example1_mask_only == 255, 0, 255) | |
return { | |
"background": example1_background, | |
"layers": [example1_layers], | |
"composite": example1_composite, | |
} | |
def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image: | |
if img.width <= min_side_length and img.height <= min_side_length: | |
return img | |
aspect_ratio = img.width / img.height | |
if img.width < img.height: | |
new_height = int(min_side_length / aspect_ratio) | |
return img.resize((min_side_length, new_height)) | |
new_width = int(min_side_length * aspect_ratio) | |
return img.resize((new_width, min_side_length)) | |
def get_profile(profile) -> dict: | |
return { | |
"username": profile.username, | |
"profile": profile.profile, | |
"name": profile.name, | |
} | |
async def process( | |
image_and_mask: EditorValue | None, | |
progress: gr.Progress = gr.Progress(), | |
profile: gr.OAuthProfile | None = None, | |
) -> tuple[Image.Image, Image.Image] | None: | |
if not image_and_mask: | |
gr.Info("Please upload an image and draw a mask") | |
return None | |
if profile is None: | |
gr.Info("Please log in to process the image.") | |
return None | |
user_data = get_profile(profile) | |
print("--------- RUN ----------") | |
print(user_data) | |
print("--------- RUN ----------") | |
image_np = image_and_mask["background"] | |
image_np = cast(np.ndarray, image_np) | |
# If the image is empty, return None | |
if np.sum(image_np) == 0: | |
gr.Info("Please upload an image") | |
return None | |
alpha_channel = image_and_mask["layers"][0] | |
alpha_channel = cast(np.ndarray, alpha_channel) | |
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
# if mask_np is empty, return None | |
if np.sum(mask_np) == 0: | |
gr.Info("Please mark the areas you want to remove") | |
return None | |
mask = Image.fromarray(mask_np) | |
mask = resize_image(mask) | |
image = Image.fromarray(image_np) | |
image = resize_image(image) | |
output = process_image( | |
image, # type: ignore | |
mask, # type: ignore | |
user_data, | |
progress, | |
) | |
if output is None: | |
gr.Info("Processing failed") | |
return None | |
progress(100, desc="Processing completed") | |
return image, output | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; text-align:center; flex-direction: column;"> | |
<h1 style="color: #333;">🧹 Room Cleaner</h1> | |
<div style="max-width: 800px; margin: 0 auto;"> | |
<p style="font-size: 16px;">Upload an image and use the pencil tool (✏️ icon at the bottom) to <b>mark the areas you want to remove</b>.</p> | |
<p style="font-size: 16px;"> | |
For best results, include the shadows and reflections of the objects you want to remove. | |
You can remove multiple objects at once. | |
If you forget to mask some parts of your object, it's likely that the model will reconstruct them. | |
</p> | |
<br> | |
<video width="640" height="360" controls style="margin: 0 auto; border-radius: 10px;"> | |
<source src="https://dropshare.blanchon.xyz/public/dropshare/room_cleaner_demo.mp4" type="video/mp4"> | |
</video> | |
<br> | |
<p style="font-size: 16px;">Finally, click on the <b>"Run"</b> button to process the image.</p> | |
<p style="font-size: 16px;">Wait for the processing to complete and compare the original and processed images using the slider.</p> | |
<p style="font-size: 16px;">⚠️ Note that the images are compressed to reduce the workloads of the demo. </p> | |
</div> | |
<div style="margin-top: 20px; display: flex; justify-content: center; gap: 10px;"> | |
<a href="https://x.com/JulienBlanchon"> | |
<img src="https://img.shields.io/badge/X-%23000000.svg?style=for-the-badge&logo=X&logoColor=white" alt="X Badge" style="border-radius: 3px;"/> | |
</a> | |
</div> | |
</div> | |
""") | |
with gr.Row(equal_height=False): | |
with gr.Column(): | |
# The image overflow, fix | |
image_and_mask = gr.ImageMask( | |
label="Image and Mask", | |
layers=False, | |
show_fullscreen_button=False, | |
sources=["upload"], | |
show_download_button=False, | |
interactive=True, | |
height="full", | |
width="full", | |
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), | |
transforms=[], | |
) | |
with gr.Column(): | |
image_slider = ImageSlider( | |
label="Result", | |
interactive=False, | |
) | |
login_button = gr.LoginButton(scale=8) | |
process_btn = gr.ClearButton( | |
value="Run", | |
variant="primary", | |
size="lg", | |
components=[image_slider], | |
) | |
# image_slider.change( | |
# fn=on_change_prompt, | |
# inputs=[ | |
# image_slider, | |
# ], | |
# outputs=[process_btn], | |
# api_name=False, | |
# ) | |
process_btn.click( | |
fn=lambda _: gr.update(interactive=False, value="Processing..."), | |
inputs=[], | |
outputs=[process_btn], | |
api_name=False, | |
).then( | |
fn=process, | |
inputs=[ | |
image_and_mask, | |
], | |
outputs=[image_slider], | |
api_name=False, | |
).then( | |
fn=lambda _: gr.update(interactive=True, value="Run"), | |
inputs=[], | |
outputs=[process_btn], | |
api_name=False, | |
) | |
example1 = make_example("./examples/ex1.jpg", "./examples/ex1_mask_only.png") | |
example2 = make_example("./examples/ex2.jpg", "./examples/ex2_mask_only.png") | |
example3 = make_example("./examples/ex3.jpg", "./examples/ex3_mask_only.png") | |
example4 = make_example("./examples/ex4.jpg", "./examples/ex4_mask_only.png") | |
examples = [ | |
[ | |
example1, | |
# ("./examples/ex1.jpg", "./examples/ex1_result.png") | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex1.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex1_results.png", | |
), | |
], | |
[ | |
example2, | |
# ("./examples/ex2.jpg", "./examples/ex2_result.png") | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex2.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex2_result.png", | |
), | |
], | |
[ | |
example3, | |
# ("./examples/ex3.jpg", "./examples/ex3_result.png") | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex3.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex3_result.png", | |
), | |
], | |
[ | |
example4, | |
# ("./examples/ex4.jpg", "./examples/ex4_result.png") | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex4.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex4_result.png", | |
), | |
], | |
] | |
# Update the gr.Examples call | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
image_and_mask, | |
image_slider, | |
], | |
api_name=False, | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
debug=False, | |
share=False, | |
show_api=False, | |
) | |