Spaces:
Running
Running
from typing import cast | |
from comfydeploy import ComfyDeploy | |
import asyncio | |
import os | |
import gradio as gr | |
from gradio.components.image_editor import EditorValue | |
from PIL import Image | |
import requests | |
import dotenv | |
from gradio_imageslider import ImageSlider | |
from io import BytesIO | |
import base64 | |
import glob | |
import numpy as np | |
dotenv.load_dotenv() | |
API_KEY = os.environ.get("API_KEY") | |
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "DEPLOYMENT_ID_NOT_SET") | |
if not API_KEY: | |
raise ValueError( | |
"Please set API_KEY and DEPLOYMENT_ID in your environment variables" | |
) | |
if DEPLOYMENT_ID == "DEPLOYMENT_ID_NOT_SET": | |
raise ValueError("Please set DEPLOYMENT_ID in your environment variables") | |
client = ComfyDeploy(bearer_auth=API_KEY) | |
def get_base64_from_image(image: Image.Image) -> str: | |
buffered: BytesIO = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
async def process_image( | |
image: Image.Image | str | None, | |
mask: Image.Image | str | None, | |
progress: gr.Progress = gr.Progress(), | |
) -> Image.Image | None: | |
progress(0, desc="Starting...") | |
if image is None or mask is None: | |
return None | |
if isinstance(mask, str): | |
mask = Image.open(mask) | |
if isinstance(image, str): | |
image = Image.open(image) | |
image_base64 = get_base64_from_image(image) | |
mask_base64 = get_base64_from_image(mask) | |
# Prepare inputs | |
inputs: dict = { | |
"image": f"data:image/png;base64,{image_base64}", | |
"mask": f"data:image/png;base64,{mask_base64}", | |
} | |
# Call ComfyDeploy API | |
try: | |
result = client.run.create( | |
request={"deployment_id": DEPLOYMENT_ID, "inputs": inputs} | |
) | |
if result and result.object: | |
run_id: str = result.object.run_id | |
progress(0, desc="Starting processing...") | |
# Wait for the result | |
while True: | |
run_result = client.run.get(run_id=run_id) | |
if not run_result.object: | |
continue | |
progress_value = ( | |
run_result.object.progress | |
if run_result.object.progress is not None | |
else 0 | |
) | |
status = ( | |
run_result.object.live_status | |
if run_result.object.live_status is not None | |
else "Cold starting..." | |
) | |
progress(progress_value, desc=f"Status: {status}") | |
if run_result.object.status == "success": | |
for output in run_result.object.outputs or []: | |
if output.data and output.data.images: | |
image_url: str = output.data.images[0].url | |
# Download and return both the original and processed images | |
response: requests.Response = requests.get(image_url) | |
processed_image: Image.Image = Image.open( | |
BytesIO(response.content) | |
) | |
return processed_image | |
return None | |
elif run_result.object.status == "failed": | |
print("Processing failed") | |
return None | |
await asyncio.sleep(2) # Wait for 2 seconds before checking again | |
except Exception as e: | |
print(f"Error: {e}") | |
return None | |
def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image: | |
if image.width <= shortest_side and image.height <= shortest_side: | |
return image | |
if image.width < image.height: | |
return image.resize( | |
size=(shortest_side, int(shortest_side * image.height / image.width)) | |
) | |
return image.resize( | |
size=(int(shortest_side * image.width / image.height), shortest_side) | |
) | |
async def run_async( | |
image_and_mask: EditorValue | None, | |
progress: gr.Progress = gr.Progress(), | |
) -> tuple[Image.Image, Image.Image] | None: | |
if not image_and_mask: | |
return None | |
alpha_channel = image_and_mask["layers"][0] | |
alpha_channel = cast(np.ndarray, alpha_channel) | |
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
image_np = image_and_mask["background"] | |
image_np = cast(np.ndarray, image_np) | |
# Save mask to ./masks.png | |
mask = Image.fromarray(mask_np) | |
mask = resize(mask) | |
# mask.save("mask.png") | |
# Save image to ./images.png | |
image = Image.fromarray(image_np) | |
image = resize(image) | |
# image.save("image.png") | |
output = await process_image( | |
image, # type: ignore | |
mask, # type: ignore | |
progress, | |
) | |
if output is None: | |
return None | |
return output, image | |
def run_sync(*args): | |
return asyncio.run(run_async(*args)) | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# 🧹 Room Cleaner | |
Upload an image and and use pen tool (pencil icon at the bottom) to mark the areas you want to remove. | |
Click on the "Run" button to process the image and remove the marked areas. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# The image overflow, fix | |
image_and_mask = gr.ImageMask( | |
label="Input Image and Mask", | |
layers=False, | |
show_fullscreen_button=False, | |
sources=["upload"], | |
show_download_button=False, | |
interactive=True, | |
height="full", | |
width="full", | |
) | |
with gr.Column(): | |
image_slider = ImageSlider( | |
label="Compare Original and Processed", | |
interactive=False, | |
) | |
process_btn = gr.ClearButton( | |
value="Run", | |
variant="primary", | |
size="lg", | |
components=[image_slider], | |
) | |
process_btn.click( | |
fn=run_sync, | |
inputs=[ | |
image_and_mask, | |
], | |
outputs=[image_slider], | |
api_name=False, | |
) | |
# Build examples | |
images_examples = glob.glob("examples/*") | |
mask_examples = [img.replace("inputs", "masks") for img in images_examples] | |
output_examples = [img.replace("inputs", "outputs") for img in images_examples] | |
# examples = [ | |
# [ | |
# img, | |
# mask, | |
# (img, out), | |
# ] | |
# for img, mask, out in zip(images_examples, mask_examples, output_examples) | |
# ] | |
examples = [ | |
[ | |
{ | |
"background": "./examples/ex1.jpg", | |
"layers": [], | |
"composite": "./examples/ex1_mask.png", | |
}, | |
# ("./examples/ex1.jpg", "./examples/ex1_result.png"), | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex1.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex1_result.png", | |
), | |
], | |
[ | |
{ | |
"background": "./examples/ex2.jpg", | |
"layers": [], | |
"composite": "./examples/ex2_mask.png", | |
}, | |
# ("./examples/ex2.jpg", "./examples/ex2_result.png"), | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex2.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex2_result.png", | |
), | |
], | |
[ | |
{ | |
"background": "./examples/ex3.jpg", | |
"layers": [], | |
"composite": "./examples/ex3_mask.png", | |
}, | |
# ("./examples/ex3.jpg", "./examples/ex3_result.png"), | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex3.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex3_result.png", | |
), | |
], | |
[ | |
{ | |
"background": "./examples/ex4.jpg", | |
"layers": [], | |
"composite": "./examples/ex4_mask.png", | |
}, | |
# ("./examples/ex4.jpg", "./examples/ex4_result.png"), | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex4.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex4_result.png", | |
), | |
], | |
] | |
# Update the gr.Examples call | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
image_and_mask, | |
image_slider, | |
], | |
api_name=False, | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, share=True) | |