import gradio as gr from transformers import DetrImageProcessor, DetrForObjectDetection import torch import PIL import gradio as gr from PIL import Image, ImageDraw import requests # you can specify the revision tag if you don't want the timm dependency processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm") def biggest_obj(res): max_area = 0 for i, bb in enumerate(res["boxes"]): x1,y1,x2,y2 = list(map(int, bb.tolist())) area = (abs(x2-x1)*abs(y1-y2)) if area > max_area: max_area = area ind = i coords = list(map(int, bb.tolist())) cl = model.config.id2label[res["labels"][ind].item()] return ind, coords, cl def create_mask(im_shape:tuple, mask_zone:list): mask = Image.new("L", im_shape, 0) draw = ImageDraw.Draw(mask) draw.rectangle(mask_zone, fill=255) return mask from diffusers import StableDiffusionInpaintPipeline device = "cuda" if torch.cuda.is_available() else "cpu" pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, ).to(device) def predict(image, prompt): image = image.convert("RGB").resize((512, 512)) # DETR works inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) # convert outputs (bounding boxes and class logits) to COCO API # let's only keep detections with score > 0.9 target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] # find the biggest bb on the image ind, coords, cl = biggest_obj(results) # mask image mask_image = create_mask(image.size, coords) images = pipe( prompt=prompt, image=image, mask_image=mask_image, guidance_scale=5, generator=torch.Generator(device="cuda").manual_seed(0), num_images_per_prompt=1, ).images draw_on_image = ImageDraw.Draw(image) # Define the rectangle coordinates (left-top, right-bottom) rectangle_coordinates = coords draw_on_image.rectangle(rectangle_coordinates, outline="red", width=2) return images[0], image examples = [["cats.png", "cat is smiling"], ["dog.jpg", "dog with big eyes"], ["dog1.jpg", "dog with big bone"], ["beaver.jpg", "big strong beaver"]] gr.Interface( predict, title = 'Stable Diffusion In-Painting', inputs=[ gr.Image(type = 'pil'), gr.Textbox(label = 'prompt') ], outputs = [ gr.Image(), gr.Image(), ], examples=examples, ).launch(debug=True, share=True)