from typing import Tuple import os import requests import random import numpy as np import gradio as gr import spaces import torch from PIL import Image, ImageFilter from diffusers import FluxInpaintPipeline from gradio_client import Client, handle_file MARKDOWN = """ # FLUX.1 Inpainting 🔥 Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos) for taking it to the next level by enabling inpainting with the FLUX. """ MAX_SEED = np.iinfo(np.int32).max IMAGE_SIZE = 1024 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HF_TOKEN = os.environ.get("HF_TOKEN", None) client = Client("SkalskiP/florence-sam-masking", hf_token=HF_TOKEN) def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image: image = image.convert("RGBA") data = image.getdata() new_data = [] for item in data: avg = sum(item[:3]) / 3 if avg < threshold: new_data.append((0, 0, 0, 0)) else: new_data.append(item) image.putdata(new_data) return image EXAMPLES = [ [ { "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw), "layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2.png", stream=True).raw))], "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw), }, "little lion", None, 42, False, 0.85, 30 ], [ { "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw), "layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-3.png", stream=True).raw))], "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-3.png", stream=True).raw), }, "tattoos", None, 42, False, 0.85, 30 ] ] pipe = FluxInpaintPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE) def resize_image_dimensions( original_resolution_wh: Tuple[int, int], maximum_dimension: int = IMAGE_SIZE ) -> Tuple[int, int]: width, height = original_resolution_wh if width > height: scaling_factor = maximum_dimension / width else: scaling_factor = maximum_dimension / height new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) new_width = new_width - (new_width % 32) new_height = new_height - (new_height % 32) return new_width, new_height def is_image_empty(image: Image.Image) -> bool: gray_img = image.convert("L") pixels = list(gray_img.getdata()) return all(pixel == 0 for pixel in pixels) @spaces.GPU(duration=100) def process( input_image_editor: dict, inpainting_prompt_text: str, masking_prompt_text: str, seed_slicer: int, randomize_seed_checkbox: bool, strength_slider: float, num_inference_steps_slider: int, progress=gr.Progress(track_tqdm=True) ): if not inpainting_prompt_text: gr.Info("Please enter a text prompt.") return None, None image_path = input_image_editor['background'] mask_path = input_image_editor['layers'][0] image = Image.open(image_path) mask = Image.open(mask_path) if not image: gr.Info("Please upload an image.") return None, None if is_image_empty(mask) and not masking_prompt_text: gr.Info("Please draw a mask or enter a masking prompt.") return None, None if not is_image_empty(mask) and masking_prompt_text: gr.Info("Both mask and masking prompt are provided. Please provide only one.") return None, None if is_image_empty(mask): mask = client.predict( image_input=handle_file(image_path), text_input=masking_prompt_text, api_name="/process_image") mask = Image.open(mask) mask = mask.filter(ImageFilter.GaussianBlur(radius=5)) width, height = resize_image_dimensions(original_resolution_wh=image.size) resized_image = image.resize((width, height), Image.LANCZOS) resized_mask = mask.resize((width, height), Image.LANCZOS) if randomize_seed_checkbox: seed_slicer = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed_slicer) result = pipe( prompt=inpainting_prompt_text, image=resized_image, mask_image=resized_mask, width=width, height=height, strength=strength_slider, generator=generator, num_inference_steps=num_inference_steps_slider ).images[0] print('INFERENCE DONE') return result, resized_mask with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): with gr.Column(): input_image_editor_component = gr.ImageEditor( label='Image', type='filepath', sources=["upload", "webcam"], image_mode='RGB', layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")) with gr.Row(): inpainting_prompt_text_component = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter text to generate inpainting", container=False, ) submit_button_component = gr.Button( value='Submit', variant='primary', scale=0) with gr.Accordion("Advanced Settings", open=False): masking_prompt_text_component = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter text to generate masking", container=False, ) seed_slicer_component = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed_checkbox_component = gr.Checkbox( label="Randomize seed", value=True) with gr.Row(): strength_slider_component = gr.Slider( label="Strength", info="Indicates extent to transform the reference `image`. " "Must be between 0 and 1. `image` is used as a starting " "point and more noise is added the higher the `strength`.", minimum=0, maximum=1, step=0.01, value=0.85, ) num_inference_steps_slider_component = gr.Slider( label="Number of inference steps", info="The number of denoising steps. More denoising steps " "usually lead to a higher quality image at the", minimum=1, maximum=50, step=1, value=20, ) with gr.Column(): output_image_component = gr.Image( type='pil', image_mode='RGB', label='Generated image', format="png") with gr.Accordion("Debug", open=False): output_mask_component = gr.Image( type='pil', image_mode='RGB', label='Input mask', format="png") with gr.Row(): gr.Examples( fn=process, examples=EXAMPLES, inputs=[ input_image_editor_component, inpainting_prompt_text_component, masking_prompt_text_component, seed_slicer_component, randomize_seed_checkbox_component, strength_slider_component, num_inference_steps_slider_component ], outputs=[ output_image_component, output_mask_component ], run_on_click=True, cache_examples=True ) submit_button_component.click( fn=process, inputs=[ input_image_editor_component, inpainting_prompt_text_component, masking_prompt_text_component, seed_slicer_component, randomize_seed_checkbox_component, strength_slider_component, num_inference_steps_slider_component ], outputs=[ output_image_component, output_mask_component ] ) demo.launch(debug=False, show_error=True)