from turtle import title import gradio as gr from transformers import pipeline import numpy as np from PIL import Image import torch import cv2 from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") classes = list() def create_rgb_mask(mask): color = tuple(np.random.choice(range(0,256), size=3)) gray_3_channel = cv2.merge((mask, mask, mask)) gray_3_channel[mask==255] = color return gray_3_channel.astype(np.uint8) def detect_using_clip(image,prompts=[],threshould=0.4): predicted_masks = list() inputs = processor( text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt", ) with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation outputs = model(**inputs) preds = outputs.logits.unsqueeze(1) for i,prompt in enumerate(prompts): predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() predicted_image = np.where(predicted_image>threshould,255,0) predicted_masks.append(create_rgb_mask(predicted_image)) return predicted_masks def visualize_images(image,predicted_images,brightness=15,contrast=1.8): alpha = 0.7 image_resize = cv2.resize(image,(352,352)) resize_image_copy = image_resize.copy() for mask_image in predicted_images: resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10) return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness) def shot(brightness,contrast,image,labels_text): if "," in labels_text: prompts = labels_text.split(',') else: prompts = [labels_text] prompts = list(map(lambda x: x.strip(),prompts)) predicted_images = detect_using_clip(image,prompts=prompts) category_image = visualize_images(image=image,predicted_images=predicted_images,brightness=brightness,contrast=contrast) return category_image iface = gr.Interface(fn=shot, inputs = [ gr.Slider(5, 50, value=15, label="Brightness", info="Choose between 5 and 50"), gr.Slider(1, 5, value=1.5, label="Contrast", info="Choose between 1 and 5"), "image", "text" ], outputs = "image", description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )", title = "Zero-shot Image Segmentation with Prompt ", examples=[ [19,1.5,"images/seats.jpg","door,table,chairs"], [20,1.8,"images/vegetables.jpg","carrot,white radish,brinjal,basket,potato"], [17,2,"images/room2.jpg","door, plants, dog, coffe table, table lamp, carpet, door"] ], # allow_flagging=False, # analytics_enabled=False, ) iface.launch()