SkalskiP's picture
Florence-2 + SAM2 + FLUX.1
7189c81
raw
history blame
7.79 kB
from typing import Tuple
import supervision as sv
import random
import numpy as np
import gradio as gr
import spaces
import torch
from PIL import Image, ImageFilter
from diffusers import FluxInpaintPipeline
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
MARKDOWN = """
# FLUX.1 Inpainting 🔥
Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for
creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
for taking it to the next level by enabling inpainting with the FLUX.
"""
MAX_SEED = np.iinfo(np.int32).max
IMAGE_SIZE = 1024
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cuda")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
FLUX_INPAINTING_PIPELINE = FluxInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
def resize_image_dimensions(
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = IMAGE_SIZE
) -> Tuple[int, int]:
width, height = original_resolution_wh
if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)
return new_width, new_height
@spaces.GPU(duration=150)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process(
input_image_editor: dict,
inpainting_prompt_text: str,
segmentation_prompt_text: str,
seed_slicer: int,
randomize_seed_checkbox: bool,
strength_slider: float,
num_inference_steps_slider: int,
progress=gr.Progress(track_tqdm=True)
):
if not inpainting_prompt_text:
gr.Info("Please enter a text prompt.")
return None, None
image = input_image_editor['background']
mask = input_image_editor['layers'][0]
if not image:
gr.Info("Please upload an image.")
return None, None
if not mask and not segmentation_prompt_text:
gr.Info("Please draw a mask or enter a segmentation prompt.")
return None, None
if mask and segmentation_prompt_text:
gr.Info("Both mask and segmentation prompt are provided. Please provide only "
"one.")
return None, None
width, height = resize_image_dimensions(original_resolution_wh=image.size)
image = image.resize((width, height), Image.LANCZOS)
if segmentation_prompt_text:
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=segmentation_prompt_text
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image, detections)
if len(detections) == 0:
gr.Info(f"{segmentation_prompt_text} prompt did not return any detections.")
return None, None
mask = Image.fromarray((detections.mask[0].astype(np.uint8)) * 255)
mask = mask.resize((width, height), Image.LANCZOS)
mask = mask.filter(ImageFilter.GaussianBlur(radius=10))
if randomize_seed_checkbox:
seed_slicer = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed_slicer)
result = FLUX_INPAINTING_PIPELINE(
prompt=inpainting_prompt_text,
image=image,
mask_image=mask,
width=width,
height=height,
strength=strength_slider,
generator=generator,
num_inference_steps=num_inference_steps_slider
).images[0]
print('INFERENCE DONE')
return result, mask
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image_editor_component = gr.ImageEditor(
label='Image',
type='pil',
sources=["upload", "webcam"],
image_mode='RGB',
layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
with gr.Row():
inpainting_prompt_text_component = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter inpainting prompt",
container=False,
)
submit_button_component = gr.Button(
value='Submit', variant='primary', scale=0)
with gr.Accordion("Advanced Settings", open=False):
segmentation_prompt_text_component = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter segmentation prompt",
container=False,
)
seed_slicer_component = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed_checkbox_component = gr.Checkbox(
label="Randomize seed", value=True)
with gr.Row():
strength_slider_component = gr.Slider(
label="Strength",
info="Indicates extent to transform the reference `image`. "
"Must be between 0 and 1. `image` is used as a starting "
"point and more noise is added the higher the `strength`.",
minimum=0,
maximum=1,
step=0.01,
value=0.85,
)
num_inference_steps_slider_component = gr.Slider(
label="Number of inference steps",
info="The number of denoising steps. More denoising steps "
"usually lead to a higher quality image at the",
minimum=1,
maximum=50,
step=1,
value=20,
)
with gr.Column():
output_image_component = gr.Image(
type='pil', image_mode='RGB', label='Generated image', format="png")
with gr.Accordion("Debug", open=False):
output_mask_component = gr.Image(
type='pil', image_mode='RGB', label='Input mask', format="png")
submit_button_component.click(
fn=process,
inputs=[
input_image_editor_component,
inpainting_prompt_text_component,
segmentation_prompt_text_component,
seed_slicer_component,
randomize_seed_checkbox_component,
strength_slider_component,
num_inference_steps_slider_component
],
outputs=[
output_image_component,
output_mask_component
]
)
demo.launch(debug=False, show_error=True)