import cv2 import random from typing import Tuple, Optional import gradio as gr import numpy as np import spaces import torch from PIL import Image, ImageFilter from diffusers import FluxInpaintPipeline from gradio_client import Client, handle_file MARKDOWN = """ # FLUX.1 Inpainting 🔥 Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos) for taking it to the next level by enabling inpainting with the FLUX. """ MAX_SEED = np.iinfo(np.int32).max IMAGE_SIZE = 1024 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" PIPE = FluxInpaintPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE) def calculate_image_dimensions_for_flux( original_resolution_wh: Tuple[int, int], maximum_dimension: int = IMAGE_SIZE ) -> Tuple[int, int]: width, height = original_resolution_wh if width > height: scaling_factor = maximum_dimension / width else: scaling_factor = maximum_dimension / height new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) new_width = new_width - (new_width % 32) new_height = new_height - (new_height % 32) return new_width, new_height def is_mask_empty(image: Image.Image) -> bool: gray_img = image.convert("L") pixels = list(gray_img.getdata()) return all(pixel == 0 for pixel in pixels) def process_mask( mask: Image.Image, mask_inflation: Optional[int] = None, mask_blur: Optional[int] = None ) -> Image.Image: """ Inflates and blurs the white regions of a mask. Args: mask (Image.Image): The input mask image. mask_inflation (Optional[int]): The number of pixels to inflate the mask by. mask_blur (Optional[int]): The radius of the Gaussian blur to apply. Returns: Image.Image: The processed mask with inflated and/or blurred regions. """ if mask_inflation and mask_inflation > 0: mask_array = np.array(mask) kernel = np.ones((mask_inflation, mask_inflation), np.uint8) mask_array = cv2.dilate(mask_array, kernel, iterations=1) mask = Image.fromarray(mask_array) if mask_blur and mask_blur > 0: mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur)) return mask def set_client_for_session(request: gr.Request): x_ip_token = request.headers['x-ip-token'] return Client("SkalskiP/florence-sam-masking", headers={"X-IP-Token": x_ip_token}) @spaces.GPU(duration=100) def run_flux( image: Image.Image, mask: Image.Image, prompt: str, seed_slicer: int, randomize_seed_checkbox: bool, strength_slider: float, num_inference_steps_slider: int, resolution_wh: Tuple[int, int], ) -> Image.Image: width, height = resolution_wh if randomize_seed_checkbox: seed_slicer = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed_slicer) return PIPE( prompt=prompt, image=image, mask_image=mask, width=width, height=height, strength=strength_slider, generator=generator, num_inference_steps=num_inference_steps_slider ).images[0] def process( client, input_image_editor: dict, inpainting_prompt_text: str, masking_prompt_text: str, mask_inflation_slider: int, mask_blur_slider: int, seed_slicer: int, randomize_seed_checkbox: bool, strength_slider: float, num_inference_steps_slider: int, progress=gr.Progress(track_tqdm=True) ): if not inpainting_prompt_text: gr.Info("Please enter a text prompt.") return None, None image_path = input_image_editor['background'] mask_path = input_image_editor['layers'][0] image = Image.open(image_path) mask = Image.open(mask_path) if not image: gr.Info("Please upload an image.") return None, None if is_mask_empty(mask) and not masking_prompt_text: gr.Info("Please draw a mask or enter a masking prompt.") return None, None if not is_mask_empty(mask) and masking_prompt_text: gr.Info("Both mask and masking prompt are provided. Please provide only one.") return None, None if is_mask_empty(mask): mask = client.predict( image_input=handle_file(image_path), text_input=masking_prompt_text, api_name="/process_image") mask = Image.open(mask) width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size) image = image.resize((width, height), Image.LANCZOS) mask = mask.resize((width, height), Image.LANCZOS) mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider) image = run_flux( image=image, mask=mask, prompt=inpainting_prompt_text, seed_slicer=seed_slicer, randomize_seed_checkbox=randomize_seed_checkbox, strength_slider=strength_slider, num_inference_steps_slider=num_inference_steps_slider, resolution_wh=(width, height) ) return image, mask with gr.Blocks() as demo: client_component = gr.State() gr.Markdown(MARKDOWN) with gr.Row(): with gr.Column(): input_image_editor_component = gr.ImageEditor( label='Image', type='filepath', sources=["upload", "webcam"], image_mode='RGB', layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")) with gr.Row(): inpainting_prompt_text_component = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter text to generate inpainting", container=False, ) submit_button_component = gr.Button( value='Submit', variant='primary', scale=0) with gr.Accordion("Advanced Settings", open=False): masking_prompt_text_component = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter text to generate masking", container=False, ) with gr.Row(): mask_inflation_slider_component = gr.Slider( label="Mask inflation", info="Adjusts the amount of mask edge expansion before " "inpainting.", minimum=0, maximum=20, step=1, value=5, ) mask_blur_slider_component = gr.Slider( label="Mask blur", info="Controls the intensity of the Gaussian blur applied to " "the mask edges.", minimum=0, maximum=20, step=1, value=5, ) seed_slicer_component = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed_checkbox_component = gr.Checkbox( label="Randomize seed", value=True) with gr.Row(): strength_slider_component = gr.Slider( label="Strength", info="Indicates extent to transform the reference `image`. " "Must be between 0 and 1. `image` is used as a starting " "point and more noise is added the higher the `strength`.", minimum=0, maximum=1, step=0.01, value=0.85, ) num_inference_steps_slider_component = gr.Slider( label="Number of inference steps", info="The number of denoising steps. More denoising steps " "usually lead to a higher quality image at the", minimum=1, maximum=50, step=1, value=20, ) with gr.Column(): output_image_component = gr.Image( type='pil', image_mode='RGB', label='Generated image', format="png") with gr.Accordion("Debug", open=False): output_mask_component = gr.Image( type='pil', image_mode='RGB', label='Input mask', format="png") submit_button_component.click( fn=process, inputs=[ client_component, input_image_editor_component, inpainting_prompt_text_component, masking_prompt_text_component, mask_inflation_slider_component, mask_blur_slider_component, seed_slicer_component, randomize_seed_checkbox_component, strength_slider_component, num_inference_steps_slider_component ], outputs=[ output_image_component, output_mask_component ] ) demo.load(set_client_for_session, None, client_component) demo.launch(debug=False, show_error=True)