File size: 23,454 Bytes
f6228f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 |
# Ultralytics YOLO 🚀, AGPL-3.0 license
import json
from collections import defaultdict
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from ultralytics.utils.torch_utils import TORCHVISION_0_18
from .augment import (
Compose,
Format,
Instances,
LetterBox,
RandomLoadText,
classify_augmentations,
classify_transforms,
v8_transforms,
)
from .base import BaseDataset
from .utils import (
HELP_URL,
LOGGER,
get_hash,
img2label_paths,
load_dataset_cache_file,
save_dataset_cache_file,
verify_image,
verify_image_label,
)
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = "1.0.3"
class YOLODataset(BaseDataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
self.use_segments = task == "segment"
self.use_keypoints = task == "pose"
self.use_obb = task == "obb"
self.data = data
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(*args, **kwargs)
def cache_labels(self, path=Path("./labels.cache")):
"""
Cache dataset labels, check images and read shapes.
Args:
path (Path): Path where to save the cache file. Default is Path('./labels.cache').
Returns:
(dict): labels.
"""
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
)
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(
func=verify_image_label,
iterable=zip(
self.im_files,
self.label_files,
repeat(self.prefix),
repeat(self.use_keypoints),
repeat(len(self.data["names"])),
repeat(nkpt),
repeat(ndim),
),
)
pbar = TQDM(results, desc=desc, total=total)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
x["labels"].append(
{
"im_file": im_file,
"shape": shape,
"cls": lb[:, 0:1], # n, 1
"bboxes": lb[:, 1:], # n, 4
"segments": segments,
"keypoints": keypoint,
"normalized": True,
"bbox_format": "xywh",
}
)
if msg:
msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x
def get_labels(self):
"""Returns dictionary of labels for YOLO training."""
self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# Read cache
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels = cache["labels"]
if not labels:
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
self.im_files = [lb["im_file"] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
)
for lb in labels:
lb["segments"] = []
if len_cls == 0:
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
return labels
def build_transforms(self, hyp=None):
"""Builds and appends transforms to the list."""
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
transforms = v8_transforms(self, self.imgsz, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
transforms.append(
Format(
bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
return_obb=self.use_obb,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
)
)
return transforms
def close_mosaic(self, hyp):
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
hyp.mosaic = 0.0 # set mosaic ratio=0.0
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
self.transforms = self.build_transforms(hyp)
def update_labels_info(self, label):
"""
Custom your label format here.
Note:
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
Can also support classification and semantic segmentation by adding or removing dict keys there.
"""
bboxes = label.pop("bboxes")
segments = label.pop("segments", [])
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
# NOTE: do NOT resample oriented boxes
segment_resamples = 100 if self.use_obb else 1000
if len(segments) > 0:
# list[np.array(1000, 2)] * num_samples
# (N, 1000, 2)
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
else:
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label
@staticmethod
def collate_fn(batch):
"""Collates data samples into batches."""
new_batch = {}
keys = batch[0].keys()
values = list(zip(*[list(b.values()) for b in batch]))
for i, k in enumerate(keys):
value = values[i]
if k == "img":
value = torch.stack(value, 0)
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
value = torch.cat(value, 0)
new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"])
for i in range(len(new_batch["batch_idx"])):
new_batch["batch_idx"][i] += i # add target image index for build_targets()
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
return new_batch
class YOLOMultiModalDataset(YOLODataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes a dataset object for object detection tasks with optional specifications."""
super().__init__(*args, data=data, task=task, **kwargs)
def update_labels_info(self, label):
"""Add texts information for multi-modal model training."""
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`.
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
return labels
def build_transforms(self, hyp=None):
"""Enhances data transformations with optional text augmentation for multi-modal training."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
return transforms
class GroundingDataset(YOLODataset):
"""Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""
def __init__(self, *args, task="detect", json_file, **kwargs):
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
self.json_file = json_file
super().__init__(*args, task=task, data={}, **kwargs)
def get_img_files(self, img_path):
"""The image files would be read in `get_labels` function, return empty list here."""
return []
def get_labels(self):
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
labels = []
LOGGER.info("Loading annotation file...")
with open(self.json_file) as f:
annotations = json.load(f)
images = {f'{x["id"]:d}': x for x in annotations["images"]}
img_to_anns = defaultdict(list)
for ann in annotations["annotations"]:
img_to_anns[ann["image_id"]].append(ann)
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
im_file = Path(self.img_path) / f
if not im_file.exists():
continue
self.im_files.append(str(im_file))
bboxes = []
cat2id = {}
texts = []
for ann in anns:
if ann["iscrowd"]:
continue
box = np.array(ann["bbox"], dtype=np.float32)
box[:2] += box[2:] / 2
box[[0, 2]] /= float(w)
box[[1, 3]] /= float(h)
if box[2] <= 0 or box[3] <= 0:
continue
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
cls = cat2id[cat_name] # class
box = [cls] + box.tolist()
if box not in bboxes:
bboxes.append(box)
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
labels.append(
{
"im_file": im_file,
"shape": (h, w),
"cls": lb[:, 0:1], # n, 1
"bboxes": lb[:, 1:], # n, 4
"normalized": True,
"bbox_format": "xywh",
"texts": texts,
}
)
return labels
def build_transforms(self, hyp=None):
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
return transforms
class YOLOConcatDataset(ConcatDataset):
"""
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
"""
@staticmethod
def collate_fn(batch):
"""Collates data samples into batches."""
return YOLODataset.collate_fn(batch)
# TODO: support semantic segmentation
class SemanticDataset(BaseDataset):
"""
Semantic Segmentation Dataset.
This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
from the BaseDataset class.
Note:
This class is currently a placeholder and needs to be populated with methods and attributes for supporting
semantic segmentation tasks.
"""
def __init__(self):
"""Initialize a SemanticDataset object."""
super().__init__()
class ClassificationDataset:
"""
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
learning models, with optional image transformations and caching mechanisms to speed up training.
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
to ensure data integrity and consistency.
Attributes:
cache_ram (bool): Indicates if caching in RAM is enabled.
cache_disk (bool): Indicates if caching on disk is enabled.
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
torch_transforms (callable): PyTorch transforms to be applied to the images.
"""
def __init__(self, root, args, augment=False, prefix=""):
"""
Initialize YOLO object with root, image size, augmentations, and cache settings.
Args:
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
debugging. Default is an empty string.
"""
import torchvision # scope for faster 'import ultralytics'
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
else:
self.base = torchvision.datasets.ImageFolder(root=root)
self.samples = self.base.samples
self.root = self.base.root
# Initialize attributes
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
if self.cache_ram:
LOGGER.warning(
"WARNING ⚠️ Classification `cache_ram` training has known memory leak in "
"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`."
)
self.cache_ram = False
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
self.samples = self.verify_images() # filter out bad images
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
self.torch_transforms = (
classify_augmentations(
size=args.imgsz,
scale=scale,
hflip=args.fliplr,
vflip=args.flipud,
erasing=args.erasing,
auto_augment=args.auto_augment,
hsv_h=args.hsv_h,
hsv_s=args.hsv_s,
hsv_v=args.hsv_v,
)
if augment
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
)
def __getitem__(self, i):
"""Returns subset of data and targets corresponding to given indices."""
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
if self.cache_ram:
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
im = self.samples[i][3] = cv2.imread(f)
elif self.cache_disk:
if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
# Convert NumPy array to PIL image
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
sample = self.torch_transforms(im)
return {"img": sample, "cls": j}
def __len__(self) -> int:
"""Return the total number of samples in the dataset."""
return len(self.samples)
def verify_images(self):
"""Verify all images in dataset."""
desc = f"{self.prefix}Scanning {self.root}..."
path = Path(self.root).with_suffix(".cache") # *.cache file path
try:
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
if LOCAL_RANK in {-1, 0}:
d = f"{desc} {nf} images, {nc} corrupt"
TQDM(None, desc=d, total=n, initial=n)
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
return samples
except (FileNotFoundError, AssertionError, AttributeError):
# Run scan if *.cache retrieval failed
nf, nc, msgs, samples, x = 0, 0, [], [], {}
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
pbar = TQDM(results, desc=desc, total=len(self.samples))
for sample, nf_f, nc_f, msg in pbar:
if nf_f:
samples.append(sample)
if msg:
msgs.append(msg)
nf += nf_f
nc += nc_f
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples
|