|
import os.path |
|
from typing import Any, Callable, List, Optional, Tuple |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision.datasets import VisionDataset |
|
|
|
|
|
class CocoDetection(VisionDataset): |
|
def __init__( |
|
self, |
|
root: str, |
|
annFile: str, |
|
class_names: Optional[List] = None, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
transforms: Optional[Callable] = None, |
|
) -> None: |
|
super().__init__(root, transforms, transform, target_transform) |
|
from pycocotools.coco import COCO |
|
|
|
self.coco = COCO(annFile) |
|
|
|
if class_names is not None: |
|
cat_ids = self._get_category_ids_from_name(category_names=class_names) |
|
self.ids = list( |
|
sorted((self._get_img_ids_for_category_ids(category_ids=cat_ids))) |
|
) |
|
|
|
else: |
|
cat_ids = self.coco.getCatIds() |
|
self.ids = list(sorted(self.coco.imgs.keys())) |
|
|
|
self.cat2idx = {cat_id: idx + 1 for idx, cat_id in enumerate(cat_ids)} |
|
self.cat2idx[0] = 0 |
|
|
|
def _load_image(self, id: int) -> Image.Image: |
|
path = self.coco.loadImgs(id)[0]["file_name"] |
|
return Image.open(os.path.join(self.root, path)).convert("RGB") |
|
|
|
def _load_target(self, id: int) -> List[Any]: |
|
return self.coco.loadAnns(self.coco.getAnnIds(id)) |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
id = self.ids[index] |
|
image = self._load_image(id) |
|
mask = self._load_target(id) |
|
|
|
mask = self._get_mask_in_channels(image, mask) |
|
|
|
if self.transform is not None: |
|
image = self.transform(image=np.array(image))["image"] |
|
|
|
if self.target_transform is not None: |
|
mask = self.target_transform(image=mask)["image"] |
|
|
|
return image, (mask != 0).int() |
|
|
|
def __len__(self) -> int: |
|
return len(self.ids) |
|
|
|
def _get_all_classes(self): |
|
catIDs = self.coco.getCatIds() |
|
return self.coco.loadCats(catIDs) |
|
|
|
def _get_category_info_from_ids(self, ids: list): |
|
all_cat = self._get_all_classes() |
|
return [category for category in all_cat if category["id"] in ids] |
|
|
|
def _get_category_ids_from_name(self, category_names: list): |
|
return self.coco.getCatIds(catNms=category_names) |
|
|
|
def _get_img_ids_for_category_ids(self, category_ids: list): |
|
img_ids = [] |
|
|
|
for catIds in category_ids: |
|
img_ids.extend(self.coco.getImgIds(catIds=catIds)) |
|
|
|
return img_ids |
|
|
|
def _get_img_ids_for_category_names(self, category_names: list): |
|
img_ids = [] |
|
category_ids = self._get_category_ids_from_name(category_names=class_names) |
|
|
|
for catIds in category_ids: |
|
img_ids.extend(self.coco.getImgIds(catIds=catIds)) |
|
|
|
return img_ids |
|
|
|
def _get_all_category_ids_in_img_id(self, img_id: int) -> List: |
|
target = self._load_target(img_id) |
|
return list({annotation["category_id"] for annotation in target}) |
|
|
|
def _get_mask_aggregated(self, image: Image, annotations: List) -> np.array: |
|
w, h = image.size |
|
mask = np.zeros((h, w)) |
|
|
|
for annotation in annotations: |
|
category_id = annotation["category_id"] |
|
|
|
if category_id in self.cat2idx: |
|
pixel_value = self.cat2idx[category_id] |
|
mask = np.maximum(self.coco.annToMask(annotation) * pixel_value, mask) |
|
|
|
return mask |
|
|
|
def _get_mask_in_channels(self, image: Image, annotations: List) -> np.array: |
|
w, h = image.size |
|
mask = np.zeros((len(self.cat2idx), h, w)) |
|
|
|
for annotation in annotations: |
|
category_id = annotation["category_id"] |
|
|
|
if category_id in self.cat2idx: |
|
pixel_value = self.cat2idx[category_id] |
|
mask[pixel_value] = np.maximum( |
|
self.coco.annToMask(annotation), mask[pixel_value] |
|
) |
|
|
|
|
|
mask = np.transpose(mask, (1, 2, 0)) |
|
return mask |
|
|
|
def _plot_image_and_mask(self, index): |
|
image, mask = self.__getitem__(index) |
|
|
|
|
|
fig, axs = plt.subplots(1, 2, figsize=(7, 3)) |
|
|
|
axs[0].imshow(image.permute(1, 2, 0)) |
|
axs[0].set_title("Image") |
|
|
|
axs[1].imshow(mask.sum(0, keepdim=True).permute(1, 2, 0)) |
|
axs[1].set_title("Mask") |
|
|
|
plt.show() |
|
|