Spaces:
Running
Running
File size: 3,341 Bytes
850cda3 3657d52 850cda3 3657d52 850cda3 3657d52 850cda3 3657d52 850cda3 3657d52 850cda3 3657d52 850cda3 3657d52 850cda3 3657d52 850cda3 3657d52 850cda3 3657d52 850cda3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
import torch
import numpy as np
from PIL import Image, ImageChops, ImageEnhance
import cv2
from simple_lama_inpainting import SimpleLama
from transformers import pipeline
from huggingface_hub import hf_hub_download
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_groundingdino_model(device='cpu'):
model = pipeline(model="IDEA-Research/grounding-dino-base", task="zero-shot-object-detection", device=device)
return model
groundingdino_model = load_groundingdino_model(device=device)
sam_predictor = None
simple_lama = SimpleLama()
def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
labels = [label if label.endswith('.') else label + '.' for label in text_prompt.split('.')]
results = model(image, candidate_labels=labels, threshold=box_threshold)
return results
def segment(image, sam_model, boxes):
sam_model.set_image(image)
H, W, _ = image.shape
boxes_xyxy = torch.Tensor(boxes) * torch.Tensor([W, H, W, H])
transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
masks, _, _ = sam_model.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=True,
)
return masks.cpu()
def draw_mask(mask, image, random_color=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
annotated_frame_pil = Image.fromarray(image).convert("RGBA")
mask_image_pil = Image.fromarray((mask_image.numpy() * 255).astype(np.uint8)).convert("RGBA")
return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
def dilate_mask(mask, dilate_factor=15):
mask = mask.astype(np.uint8)
mask = cv2.dilate(
mask,
np.ones((dilate_factor, dilate_factor), np.uint8),
iterations=1
)
return mask
def gsl_process_image(image):
image_source = Image.fromarray(image)
detected_boxes = detect(image_source, groundingdino_model)
boxes = [[d['box']['xmin'], d['box']['ymin'], d['box']['xmax'], d['box']['ymax']] for d in detected_boxes]
segmented_frame_masks = segment(image, sam_predictor, boxes)
final_mask = None
for i in range(len(segmented_frame_masks) - 1):
if final_mask is None:
final_mask = np.bitwise_or(segmented_frame_masks[i][0].cpu(), segmented_frame_masks[i + 1][0].cpu())
else:
final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i + 1][0].cpu())
annotated_frame_with_mask = draw_mask(final_mask, image)
mask = final_mask.numpy()
mask = mask.astype(np.uint8) * 255
mask = dilate_mask(mask)
dilated_image_mask_pil = Image.fromarray(mask)
result = simple_lama(image, dilated_image_mask_pil)
diff = ImageChops.difference(result, Image.fromarray(image))
threshold = 7
diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
img3 = Image.new('RGB', Image.fromarray(image).size, (255, 236, 10))
diff3 = Image.composite(Image.fromarray(image), img3, diff2)
return diff3
|