File size: 4,402 Bytes
9aecc37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import numpy as np
from PIL import Image
from diffusers import FluxInpaintPipeline
from utils.florence import load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
import gradio as gr
import supervision as sv

# Load models
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FLUX_PIPE = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(
    DEVICE)
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_MODEL = load_sam_image_model(device=DEVICE)

COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS)
BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LABEL_ANNOTATOR = sv.LabelAnnotator(
    color=COLOR_PALETTE,
    color_lookup=sv.ColorLookup.INDEX,
    text_position=sv.Position.CENTER_OF_MASS,
    text_color=sv.Color.from_hex("#000000"),
    border_radius=5
)
MASK_ANNOTATOR = sv.MaskAnnotator(
    color=COLOR_PALETTE,
    color_lookup=sv.ColorLookup.INDEX
)


def visualize_detections(image, detections):
    output_image = image.copy()
    output_image = MASK_ANNOTATOR.annotate(output_image, detections)
    output_image = BOX_ANNOTATOR.annotate(output_image, detections)
    output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
    return output_image


def detect_objects(image, text_prompt):
    # Use Florence for object detection
    _, result = run_florence_inference(
        model=FLORENCE_MODEL,
        processor=FLORENCE_PROCESSOR,
        device=DEVICE,
        image=image,
        task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
        text=text_prompt
    )
    detections = sv.Detections.from_lmm(
        lmm=sv.LMM.FLORENCE_2,
        result=result,
        resolution_wh=image.size
    )

    # Use SAM to refine masks
    detections = run_sam_inference(SAM_MODEL, image, detections)
    return detections


def inpaint_selected_objects(image, detections, selected_indices, inpaint_prompt):
    mask = np.zeros(image.size[::-1], dtype=np.uint8)
    for idx in selected_indices:
        mask |= detections.mask[idx]

    mask_image = Image.fromarray(mask * 255)

    result = FLUX_PIPE(
        prompt=inpaint_prompt,
        image=image,
        mask_image=mask_image,
        num_inference_steps=30,
        strength=0.85,
    ).images[0]

    return result


def process_image(input_image, detection_prompt, inpaint_prompt, selected_objects):
    detections = detect_objects(input_image, detection_prompt)

    # Visualize detected objects
    detected_image = visualize_detections(input_image, detections)

    if selected_objects:
        selected_indices = [int(idx) for idx in selected_objects.split(',')]
        inpainted_image = inpaint_selected_objects(input_image, detections, selected_indices, inpaint_prompt)
        return detected_image, inpainted_image
    else:
        return detected_image, None


# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Object Detection and Inpainting with FLUX, Florence, and SAM")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            detection_prompt = gr.Textbox(label="Detection Prompt", placeholder="Enter objects to detect")
            detect_button = gr.Button("Detect Objects")
        with gr.Column():
            detected_image = gr.Image(type="pil", label="Detected Objects")
            selected_objects = gr.Textbox(label="Selected Objects",
                                          placeholder="Enter indices of objects to inpaint (comma-separated)")
            inpaint_prompt = gr.Textbox(label="Inpainting Prompt", placeholder="Describe what to inpaint")
            inpaint_button = gr.Button("Inpaint Selected Objects")
    output_image = gr.Image(type="pil", label="Inpainted Result")

    detect_button.click(
        fn=lambda img, prompt: process_image(img, prompt, "", "")[0],
        inputs=[input_image, detection_prompt],
        outputs=detected_image
    )
    inpaint_button.click(
        fn=process_image,
        inputs=[input_image, detection_prompt, inpaint_prompt, selected_objects],
        outputs=[detected_image, output_image]
    )

demo.launch(debug=False, show_error=True)