tomrb's picture
initial yolov8to
ab854b9
raw
history blame
13.5 kB
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from ultralytics.utils import TQDM_BAR_FORMAT
class FastSAMPrompt:
def __init__(self, source, results, device='cuda') -> None:
self.device = device
self.results = results
self.source = source
# Import and assign clip
try:
import clip # for linear_assignment
except ImportError:
from ultralytics.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git')
import clip
self.clip = clip
@staticmethod
def _segment_image(image, bbox):
image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new('RGB', image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
@staticmethod
def _format_results(result, filter=0):
annotations = []
n = len(result.masks.data) if result.masks is not None else 0
for i in range(n):
mask = result.masks.data[i] == 1.0
if torch.sum(mask) >= filter:
annotation = {
'id': i,
'segmentation': mask.cpu().numpy(),
'bbox': result.boxes.data[i],
'score': result.boxes.conf[i]}
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
@staticmethod
def _get_bbox_from_mask(mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0])
x2, y2 = x1 + w, y1 + h
if len(contours) > 1:
for b in contours:
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
# 将多个bbox合并成一个
x1 = min(x1, x_t)
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
return [x1, y1, x2, y2]
def plot(self,
annotations,
output,
bbox=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
withContours=True):
n = len(annotations)
pbar = tqdm(annotations, total=n, bar_format=TQDM_BAR_FORMAT)
for ann in pbar:
result_name = os.path.basename(ann.path)
image = ann.orig_img
original_h, original_w = ann.orig_shape
# for macOS only
# plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if ann.masks is not None:
masks = ann.masks.data
if better_quality:
if isinstance(masks[0], torch.Tensor):
masks = np.array(masks.cpu())
for i, mask in enumerate(masks):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
self.fast_show_mask(
masks,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if withContours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(masks):
mask = mask.astype(np.uint8)
if not retina:
mask = cv2.resize(
mask,
(original_w, original_h),
interpolation=cv2.INTER_NEAREST,
)
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
plt.axis('off')
fig = plt.gcf()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
save_path = Path(output) / result_name
save_path.parent.mkdir(exist_ok=True, parents=True)
cv2.imwrite(str(save_path), img_array)
plt.close()
pbar.set_description('Saving {} to {}'.format(result_name, save_path))
@staticmethod
def fast_show_mask(
annotation,
ax,
random_color=False,
bbox=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
n, h, w = annotation.shape # batch, height, width
areas = np.sum(annotation, axis=(1, 2))
annotation = annotation[np.argsort(areas)]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((n, 1, 1, 3))
else:
color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
transparency = np.ones((n, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((h, w, 4))
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# Draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show)
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = self.clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
probs = 100.0 * image_features @ text_features.T
return probs[:, 0].softmax(dim=0)
def _crop_image(self, format_results):
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size
annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = []
not_crop = []
filter_id = []
for _, mask in enumerate(annotations):
if np.sum(mask['segmentation']) <= 100:
filter_id.append(_)
continue
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
cropped_images.append(bbox) # 保存裁剪的图片的bbox
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox):
if self.results[0].masks is not None:
assert (bbox[2] != 0 and bbox[3] != 0)
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self.results[0].masks.data
target_height, target_width = self.results[0].orig_shape
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
bbox[3] = min(round(bbox[3]), h)
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index = torch.argmax(IoUs)
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
return self.results
def point_prompt(self, points, pointlabel): # numpy 处理
if self.results[0].masks is not None:
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self._format_results(self.results[0], 0)
target_height, target_width = self.results[0].orig_shape
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
for i, annotation in enumerate(masks):
mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask -= mask
onemask = onemask >= 1
self.results[0].masks.data = torch.tensor(np.array([onemask]))
return self.results
def text_prompt(self, text):
if self.results[0].masks is not None:
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx))
self.results[0].masks.data = torch.tensor(np.array([ann['segmentation'] for ann in annotations]))
return self.results
def everything_prompt(self):
return self.results