StKirill's picture
Update app.py
f4d7772
raw
history blame
2.75 kB
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
import PIL
import gradio as gr
from PIL import Image, ImageDraw
import requests
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
def biggest_obj(res):
max_area = 0
for i, bb in enumerate(res["boxes"]):
x1,y1,x2,y2 = list(map(int, bb.tolist()))
area = (abs(x2-x1)*abs(y1-y2))
if area > max_area:
max_area = area
ind = i
coords = list(map(int, bb.tolist()))
cl = model.config.id2label[res["labels"][ind].item()]
return ind, coords, cl
def create_mask(im_shape:tuple, mask_zone:list):
mask = Image.new("L", im_shape, 0)
draw = ImageDraw.Draw(mask)
draw.rectangle(mask_zone, fill=255)
return mask
from diffusers import StableDiffusionInpaintPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
).to(device)
def predict(image, prompt):
image = image.convert("RGB").resize((512, 512))
# DETR works
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# find the biggest bb on the image
ind, coords, cl = biggest_obj(results)
# mask image
mask_image = create_mask(image.size, coords)
images = pipe(
prompt=prompt,
image=image,
mask_image=mask_image,
guidance_scale=5,
generator=torch.Generator(device="cuda").manual_seed(0),
num_images_per_prompt=1,
).images
draw_on_image = ImageDraw.Draw(image)
# Define the rectangle coordinates (left-top, right-bottom)
rectangle_coordinates = coords
draw_on_image.rectangle(rectangle_coordinates, outline="red", width=2)
return images[0], image
examples = [["cats.png", "cat is smiling"],
["dog.jpg", "dog with big eyes"],
["dog1.jpg", "dog with big bone"],
["beaver.jpg", "big strong beaver"]]
gr.Interface(
predict,
title = 'Stable Diffusion In-Painting',
inputs=[
gr.Image(type = 'pil'),
gr.Textbox(label = 'prompt')
],
outputs = [
gr.Image(),
gr.Image(),
],
examples=examples,
).launch(debug=True, share=True)