Spaces:
Running
Running
# GSL | |
import os | |
import torch | |
import numpy as np | |
from PIL import Image, ImageChops, ImageEnhance | |
import cv2 | |
from simple_lama_inpainting import SimpleLama | |
from segment_anything import build_sam, SamPredictor | |
from GroundingDINO.groundingdino.util import box_ops | |
from GroundingDINO.groundingdino.util.slconfig import SLConfig | |
from GroundingDINO.groundingdino.util.utils import clean_state_dict | |
from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict | |
from huggingface_hub import hf_hub_download | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'): | |
cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename) | |
args = SLConfig.fromfile(cache_config_file) | |
args.device = device | |
model = build_model(args) | |
cache_file = hf_hub_download(repo_id=repo_id, filename=filename) | |
checkpoint = torch.load(cache_file, map_location=device) | |
model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) | |
model.eval() | |
return model | |
groundingdino_model = load_model_hf( | |
repo_id="ShilongLiu/GroundingDINO", | |
filename="groundingdino_swinb_cogcoor.pth", | |
ckpt_config_filename="GroundingDINO_SwinB.cfg.py", | |
device=device | |
) | |
sam_predictor = SamPredictor(build_sam(checkpoint='sam_vit_h_4b8939.pth').to(device)) | |
simple_lama = SimpleLama() | |
def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15): | |
boxes, logits, phrases = predict( | |
image=image, | |
model=model, | |
caption=text_prompt, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold | |
) | |
annotated_frame = annotate(image_source=image, boxes=boxes, logits=logits, phrases=phrases) | |
annotated_frame = annotated_frame[..., ::-1] # BGR to RGB | |
return annotated_frame, boxes, phrases | |
def segment(image, sam_model, boxes): | |
sam_model.set_image(image) | |
H, W, _ = image.shape | |
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) | |
transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2]) | |
masks, _, _ = sam_model.predict_torch( | |
point_coords=None, | |
point_labels=None, | |
boxes=transformed_boxes, | |
multimask_output=True, | |
) | |
return masks.cpu() | |
def draw_mask(mask, image, random_color=True): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) | |
else: | |
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
annotated_frame_pil = Image.fromarray(image).convert("RGBA") | |
mask_image_pil = Image.fromarray((mask_image.numpy() * 255).astype(np.uint8)).convert("RGBA") | |
return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil)) | |
def dilate_mask(mask, dilate_factor=15): | |
mask = mask.astype(np.uint8) | |
mask = cv2.dilate( | |
mask, | |
np.ones((dilate_factor, dilate_factor), np.uint8), | |
iterations=1 | |
) | |
return mask | |
def gsl_process_image(local_image_path): | |
# Load image | |
image_source, image = load_image(local_image_path) | |
# Detect insects | |
annotated_frame, detected_boxes, phrases = detect(image, model=groundingdino_model) | |
indices = [i for i, s in enumerate(phrases) if 'insect' in s] | |
# Segment insects | |
segmented_frame_masks = segment(image_source, sam_predictor, detected_boxes[indices]) | |
# Combine masks | |
final_mask = None | |
for i in range(len(segmented_frame_masks) - 1): | |
if final_mask is None: | |
final_mask = np.bitwise_or(segmented_frame_masks[i][0].cpu(), segmented_frame_masks[i + 1][0].cpu()) | |
else: | |
final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i + 1][0].cpu()) | |
# Draw mask | |
annotated_frame_with_mask = draw_mask(final_mask, image_source) | |
# Dilate mask | |
mask = final_mask.numpy() | |
mask = mask.astype(np.uint8) * 255 | |
mask = dilate_mask(mask) | |
dilated_image_mask_pil = Image.fromarray(mask) | |
# Inpainting | |
result = simple_lama(image_source, dilated_image_mask_pil) | |
# Difference and composite | |
diff = ImageChops.difference(result, Image.fromarray(image_source)) | |
threshold = 7 | |
diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1') | |
img3 = Image.new('RGB', Image.fromarray(image_source).size, (255, 236, 10)) | |
diff3 = Image.composite(Image.fromarray(image_source), img3, diff2) | |
return diff3 | |