import torch import numpy as np from PIL import Image from diffusers import FluxInpaintPipeline from utils.florence import load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference import gradio as gr import supervision as sv # Load models DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") FLUX_PIPE = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to( DEVICE) FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) SAM_MODEL = load_sam_image_model(device=DEVICE) COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2'] COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS) BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) LABEL_ANNOTATOR = sv.LabelAnnotator( color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX, text_position=sv.Position.CENTER_OF_MASS, text_color=sv.Color.from_hex("#000000"), border_radius=5 ) MASK_ANNOTATOR = sv.MaskAnnotator( color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX ) def visualize_detections(image, detections): output_image = image.copy() output_image = MASK_ANNOTATOR.annotate(output_image, detections) output_image = BOX_ANNOTATOR.annotate(output_image, detections) output_image = LABEL_ANNOTATOR.annotate(output_image, detections) return output_image def detect_objects(image, text_prompt): # Use Florence for object detection _, result = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROCESSOR, device=DEVICE, image=image, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=text_prompt ) detections = sv.Detections.from_lmm( lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image.size ) # Use SAM to refine masks detections = run_sam_inference(SAM_MODEL, image, detections) return detections def inpaint_selected_objects(image, detections, selected_indices, inpaint_prompt): mask = np.zeros(image.size[::-1], dtype=np.uint8) for idx in selected_indices: mask |= detections.mask[idx] mask_image = Image.fromarray(mask * 255) result = FLUX_PIPE( prompt=inpaint_prompt, image=image, mask_image=mask_image, num_inference_steps=30, strength=0.85, ).images[0] return result def process_image(input_image, detection_prompt, inpaint_prompt, selected_objects): detections = detect_objects(input_image, detection_prompt) # Visualize detected objects detected_image = visualize_detections(input_image, detections) if selected_objects: selected_indices = [int(idx) for idx in selected_objects.split(',')] inpainted_image = inpaint_selected_objects(input_image, detections, selected_indices, inpaint_prompt) return detected_image, inpainted_image else: return detected_image, None # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Object Detection and Inpainting with FLUX, Florence, and SAM") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") detection_prompt = gr.Textbox(label="Detection Prompt", placeholder="Enter objects to detect") detect_button = gr.Button("Detect Objects") with gr.Column(): detected_image = gr.Image(type="pil", label="Detected Objects") selected_objects = gr.Textbox(label="Selected Objects", placeholder="Enter indices of objects to inpaint (comma-separated)") inpaint_prompt = gr.Textbox(label="Inpainting Prompt", placeholder="Describe what to inpaint") inpaint_button = gr.Button("Inpaint Selected Objects") output_image = gr.Image(type="pil", label="Inpainted Result") detect_button.click( fn=lambda img, prompt: process_image(img, prompt, "", "")[0], inputs=[input_image, detection_prompt], outputs=detected_image ) inpaint_button.click( fn=process_image, inputs=[input_image, detection_prompt, inpaint_prompt, selected_objects], outputs=[detected_image, output_image] ) demo.launch(debug=False, show_error=True)