cliplama / app.py
antoyo123's picture
Update app.py
8edd020
raw
history blame
2.62 kB
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import matplotlib.pyplot as plt
import torch
import cv2
import os
os.system("wget https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt")
import paddlehub as hub
import gradio as gr
import torch
from PIL import Image, ImageOps
import numpy as np
import imageio
os.mkdir("data")
os.rename("best.ckpt", "models/best.ckpt")
os.mkdir("dataout")
# Load CLIPSeg model
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
# Load LAMA model
model = hub.Module(name='U2Net')
def process_image(image, prompt):
# Generate mask with CLIPSeg
inputs = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = clipseg_model(**inputs)
preds = outputs.logits
plt.imsave("mask.png", torch.sigmoid(preds))
mask_image = Image.open("mask.png").convert("RGB")
# Convert image to BGR format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Convert mask to grayscale format
mask_image = cv2.cvtColor(np.array(mask_image), cv2.COLOR_RGB2GRAY)
# Perform inpainting with LAMA
# input_dict = {"image": image, "mask": mask_image}
# imageio.imwrite("./data/data_mask.png", input_dict["mask"])
imageio.imwrite("./data/data_mask.png", mask_image)
os.system('python predict.py model.path=/home/user/app/ indir=/home/user/app/data/ outdir=/home/user/app/dataout/ device=cpu')
inpainted_image = "./dataout/data_mask.png"
# inpainted_image = Image.new('RGB', (image.shape[1], image.shape[0]), (0, 0, 0))
# inpainted_image = cv2.cvtColor(inpainted_image, cv2.COLOR_BGR2RGB)
# inpainted_image = Image.fromarray(inpainted_image)
return mask_image, inpainted_image
interface = gr.Interface(fn=process_image,
inputs=[gr.Image(type="pil"), gr.Textbox(label="Please describe what you want to identify")],
outputs=[gr.Image(type="pil"), gr.Image(type="filepath")],
title="Interactive demo: zero-shot image segmentation with CLIPSeg and inpainting with LAMA",
description="Demo for using CLIPSeg and LAMA to perform zero- and one-shot image segmentation and inpainting. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds.")
interface.launch(debug=True)