Spaces:
Running
Running
File size: 3,774 Bytes
6dc32ee 9fcd716 850cda3 6e9fa1f 850cda3 cfccf84 3657d52 850cda3 3657d52 81b2e04 850cda3 9fcd716 cfccf84 3657d52 9fcd716 850cda3 3657d52 850cda3 3657d52 850cda3 746e19a 3657d52 9fcd716 6dc32ee 9fcd716 6dc32ee 3657d52 850cda3 3657d52 850cda3 5e1955d 850cda3 5e1955d 850cda3 5e1955d 850cda3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# GSL
import os
import spaces
import torch
import numpy as np
from PIL import Image, ImageChops, ImageEnhance
import cv2
from simple_lama_inpainting import SimpleLama
from segment_anything import build_sam, SamPredictor
from transformers import pipeline
from huggingface_hub import hf_hub_download
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_groundingdino_model(device='cpu'):
model = pipeline(model="IDEA-Research/grounding-dino-base", task="zero-shot-object-detection", device=device)
return model
def load_sam_model(checkpoint_path, device='cpu'):
sam_model = build_sam(checkpoint=checkpoint_path).to(device)
return SamPredictor(sam_model)
groundingdino_model = load_groundingdino_model(device=device)
sam_predictor = load_sam_model(checkpoint_path="models/sam_vit_h_4b8939.pth", device=device)
simple_lama = SimpleLama()
def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
labels = [label if label.endswith('.') else label + '.' for label in text_prompt.split('.')]
results = model(image, candidate_labels=labels, threshold=box_threshold)
return results
def segment(image, sam_model, boxes):
sam_model.set_image(image)
H, W, _ = image.shape
boxes_xyxy = torch.Tensor(boxes) * torch.Tensor([W, H, W, H])
transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
masks, _, _ = sam_model.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=True,
)
return masks.cpu()
def draw_mask(mask, image, random_color=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
annotated_frame_pil = Image.fromarray(image).convert("RGBA")
mask_image_pil = Image.fromarray((mask_image.numpy() * 255).astype(np.uint8)).convert("RGBA")
return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
def dilate_mask(mask, dilate_factor=15):
mask = mask.astype(np.uint8)
mask = cv2.dilate(
mask,
np.ones((dilate_factor, dilate_factor), np.uint8),
iterations=1
)
return mask
@spaces.GPU
def gsl_process_image(image):
# numpy array
if not isinstance(image, np.ndarray):
image = np.array(image)
# load image as a PIL
image_pil = Image.fromarray(image)
detected_boxes = detect(image_pil, groundingdino_model)
boxes = [[d['box']['xmin'], d['box']['ymin'], d['box']['xmax'], d['box']['ymax']] for d in detected_boxes]
segmented_frame_masks = segment(image, sam_predictor, boxes)
final_mask = None
for i in range(len(segmented_frame_masks) - 1):
if final_mask is None:
final_mask = np.bitwise_or(segmented_frame_masks[i][0].cpu(), segmented_frame_masks[i + 1][0].cpu())
else:
final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i + 1][0].cpu())
annotated_frame_with_mask = draw_mask(final_mask, image)
mask = final_mask.numpy()
mask = mask.astype(np.uint8) * 255
mask = dilate_mask(mask)
dilated_image_mask_pil = Image.fromarray(mask) # test
result = simple_lama(image, dilated_image_mask_pil)
diff = ImageChops.difference(result, Image.fromarray(image))
threshold = 7
diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
img3 = Image.new('RGB', Image.fromarray(image).size, (255, 236, 10))
diff3 = Image.composite(Image.fromarray(image), img3, diff2)
return diff3
|