Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
import torch | |
import PIL | |
import gradio as gr | |
from PIL import Image, ImageDraw | |
import requests | |
# you can specify the revision tag if you don't want the timm dependency | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm") | |
def biggest_obj(res): | |
max_area = 0 | |
for i, bb in enumerate(res["boxes"]): | |
x1,y1,x2,y2 = list(map(int, bb.tolist())) | |
area = (abs(x2-x1)*abs(y1-y2)) | |
if area > max_area: | |
max_area = area | |
ind = i | |
coords = list(map(int, bb.tolist())) | |
cl = model.config.id2label[res["labels"][ind].item()] | |
return ind, coords, cl | |
def create_mask(im_shape:tuple, mask_zone:list): | |
mask = Image.new("L", im_shape, 0) | |
draw = ImageDraw.Draw(mask) | |
draw.rectangle(mask_zone, fill=255) | |
return mask | |
from diffusers import StableDiffusionInpaintPipeline | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
).to(device) | |
def predict(image, prompt): | |
image = image.convert("RGB").resize((512, 512)) | |
# DETR works | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
# convert outputs (bounding boxes and class logits) to COCO API | |
# let's only keep detections with score > 0.9 | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
# find the biggest bb on the image | |
ind, coords, cl = biggest_obj(results) | |
# mask image | |
mask_image = create_mask(image.size, coords) | |
images = pipe( | |
prompt=prompt, | |
image=image, | |
mask_image=mask_image, | |
guidance_scale=5, | |
generator=torch.Generator(device="cuda").manual_seed(0), | |
num_images_per_prompt=1, | |
).images | |
draw_on_image = ImageDraw.Draw(image) | |
# Define the rectangle coordinates (left-top, right-bottom) | |
rectangle_coordinates = coords | |
draw_on_image.rectangle(rectangle_coordinates, outline="red", width=2) | |
return images[0], image | |
examples = [["cats.png", "cat is smiling"], | |
["dog.jpg", "dog with big eyes"], | |
["dog1.jpg", "dog with big bone"], | |
["beaver.jpg", "big strong beaver"]] | |
gr.Interface( | |
predict, | |
title = 'Stable Diffusion In-Painting', | |
inputs=[ | |
gr.Image(type = 'pil'), | |
gr.Textbox(label = 'prompt') | |
], | |
outputs = [ | |
gr.Image(), | |
gr.Image(), | |
], | |
examples=examples, | |
).launch(debug=True, share=True) |