# Ultralytics YOLO 🚀, AGPL-3.0 license import hashlib import json import os import random import subprocess import time import zipfile from multiprocessing.pool import ThreadPool from pathlib import Path from tarfile import is_tarfile import cv2 import numpy as np from PIL import Image, ImageOps from ultralytics.nn.autobackend import check_class_names from ultralytics.utils import ( DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_FILE, TQDM, clean_url, colorstr, emojis, is_dir_writeable, yaml_load, yaml_save, ) from ultralytics.utils.checks import check_file, check_font, is_ascii from ultralytics.utils.downloads import download, safe_download, unzip_file from ultralytics.utils.ops import segments2boxes HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance." IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" def img2label_paths(img_paths): """Define label paths as a function of image paths.""" sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] def get_hash(paths): """Returns a single hash value of a list of paths (files or dirs).""" size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes h = hashlib.sha256(str(size).encode()) # hash sizes h.update("".join(paths).encode()) # hash paths return h.hexdigest() # return hash def exif_size(img: Image.Image): """Returns exif-corrected PIL size.""" s = img.size # (width, height) if img.format == "JPEG": # only support JPEG images try: exif = img.getexif() if exif: rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274 if rotation in {6, 8}: # rotation 270 or 90 s = s[1], s[0] except: # noqa E722 pass return s def verify_image(args): """Verify one image.""" (im_file, cls), prefix = args # Number (found, corrupt), message nf, nc, msg = 0, 0, "" try: im = Image.open(im_file) im.verify() # PIL verify shape = exif_size(im) # image size shape = (shape[1], shape[0]) # hw assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}" if im.format.lower() in {"jpg", "jpeg"}: with open(im_file, "rb") as f: f.seek(-2, 2) if f.read() != b"\xff\xd9": # corrupt JPEG ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" nf = 1 except Exception as e: nc = 1 msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" return (im_file, cls), nf, nc, msg def verify_image_label(args): """Verify one image-label pair.""" im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args # Number (missing, found, empty, corrupt), message, segments, keypoints nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None try: # Verify images im = Image.open(im_file) im.verify() # PIL verify shape = exif_size(im) # image size shape = (shape[1], shape[0]) # hw assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}" if im.format.lower() in {"jpg", "jpeg"}: with open(im_file, "rb") as f: f.seek(-2, 2) if f.read() != b"\xff\xd9": # corrupt JPEG ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" # Verify labels if os.path.isfile(lb_file): nf = 1 # label found with open(lb_file) as f: lb = [x.split() for x in f.read().strip().splitlines() if len(x)] if any(len(x) > 6 for x in lb) and (not keypoint): # is segment classes = np.array([x[0] for x in lb], dtype=np.float32) segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) lb = np.array(lb, dtype=np.float32) nl = len(lb) if nl: if keypoint: assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" points = lb[:, 5:].reshape(-1, ndim)[:, :2] else: assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" points = lb[:, 1:] assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}" assert lb.min() >= 0, f"negative label values {lb[lb < 0]}" # All labels max_cls = lb[:, 0].max() # max label count assert max_cls <= num_cls, ( f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " f"Possible class labels are 0-{num_cls - 1}" ) _, i = np.unique(lb, axis=0, return_index=True) if len(i) < nl: # duplicate row check lb = lb[i] # remove duplicates if segments: segments = [segments[x] for x in i] msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" else: ne = 1 # label empty lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32) else: nm = 1 # label missing lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32) if keypoint: keypoints = lb[:, 5:].reshape(-1, nkpt, ndim) if ndim == 2: kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32) keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3) lb = lb[:, :5] return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg except Exception as e: nc = 1 msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" return [None, None, None, None, None, nm, nf, ne, nc, msg] def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1): """ Convert a list of polygons to a binary mask of the specified image size. Args: imgsz (tuple): The size of the image as (height, width). polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where N is the number of polygons, and M is the number of points such that M % 2 = 0. color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1. downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1. Returns: (np.ndarray): A binary mask of the specified image size with the polygons filled in. """ mask = np.zeros(imgsz, dtype=np.uint8) polygons = np.asarray(polygons, dtype=np.int32) polygons = polygons.reshape((polygons.shape[0], -1, 2)) cv2.fillPoly(mask, polygons, color=color) nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio) # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1 return cv2.resize(mask, (nw, nh)) def polygons2masks(imgsz, polygons, color, downsample_ratio=1): """ Convert a list of polygons to a set of binary masks of the specified image size. Args: imgsz (tuple): The size of the image as (height, width). polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where N is the number of polygons, and M is the number of points such that M % 2 = 0. color (int): The color value to fill in the polygons on the masks. downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1. Returns: (np.ndarray): A set of binary masks of the specified image size with the polygons filled in. """ return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons]) def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): """Return a (640, 640) overlap mask.""" masks = np.zeros( (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), dtype=np.int32 if len(segments) > 255 else np.uint8, ) areas = [] ms = [] for si in range(len(segments)): mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1) ms.append(mask.astype(masks.dtype)) areas.append(mask.sum()) areas = np.asarray(areas) index = np.argsort(-areas) ms = np.array(ms)[index] for i in range(len(segments)): mask = ms[i] * (i + 1) masks = masks + mask masks = np.clip(masks, a_min=0, a_max=i + 1) return masks, index def find_dataset_yaml(path: Path) -> Path: """ Find and return the YAML file associated with a Detect, Segment or Pose dataset. This function searches for a YAML file at the root level of the provided directory first, and if not found, it performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError is raised if no YAML file is found or if multiple YAML files are found. Args: path (Path): The directory path to search for the YAML file. Returns: (Path): The path of the found YAML file. """ files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive assert files, f"No YAML file found in '{path.resolve()}'" if len(files) > 1: files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}" return files[0] def check_det_dataset(dataset, autodownload=True): """ Download, verify, and/or unzip a dataset if not found locally. This function checks the availability of a specified dataset, and if not found, it has the option to download and unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also resolves paths related to the dataset. Args: dataset (str): Path to the dataset or dataset descriptor (like a YAML file). autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True. Returns: (dict): Parsed dataset information and paths. """ file = check_file(dataset) # Download (optional) extract_dir = "" if zipfile.is_zipfile(file) or is_tarfile(file): new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) file = find_dataset_yaml(DATASETS_DIR / new_dir) extract_dir, autodownload = file.parent, False # Read YAML data = yaml_load(file, append_filename=True) # dictionary # Checks for k in "train", "val": if k not in data: if k != "val" or "validation" not in data: raise SyntaxError( emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.") ) LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") data["val"] = data.pop("validation") # replace 'validation' key with 'val' key if "names" not in data and "nc" not in data: raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs.")) if "names" in data and "nc" in data and len(data["names"]) != data["nc"]: raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match.")) if "names" not in data: data["names"] = [f"class_{i}" for i in range(data["nc"])] else: data["nc"] = len(data["names"]) data["names"] = check_class_names(data["names"]) # Resolve paths path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root if not path.is_absolute(): path = (DATASETS_DIR / path).resolve() # Set paths data["path"] = path # download scripts for k in "train", "val", "test", "minival": if data.get(k): # prepend path if isinstance(data[k], str): x = (path / data[k]).resolve() if not x.exists() and data[k].startswith("../"): x = (path / data[k][3:]).resolve() data[k] = str(x) else: data[k] = [str((path / x).resolve()) for x in data[k]] # Parse YAML val, s = (data.get(x) for x in ("val", "download")) if val: val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path if not all(x.exists() for x in val): name = clean_url(dataset) # dataset name with URL auth stripped m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'" if s and autodownload: LOGGER.warning(m) else: m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'" raise FileNotFoundError(m) t = time.time() r = None # success if s.startswith("http") and s.endswith(".zip"): # URL safe_download(url=s, dir=DATASETS_DIR, delete=True) elif s.startswith("bash "): # bash script LOGGER.info(f"Running {s} ...") r = os.system(s) else: # python script exec(s, {"yaml": data}) dt = f"({round(time.time() - t, 1)}s)" s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌" LOGGER.info(f"Dataset download {s}\n") check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts return data # dictionary def check_cls_dataset(dataset, split=""): """ Checks a classification dataset such as Imagenet. This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally. Args: dataset (str | Path): The name of the dataset. split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''. Returns: (dict): A dictionary containing the following keys: - 'train' (Path): The directory path containing the training set of the dataset. - 'val' (Path): The directory path containing the validation set of the dataset. - 'test' (Path): The directory path containing the test set of the dataset. - 'nc' (int): The number of classes in the dataset. - 'names' (dict): A dictionary of class names in the dataset. """ # Download (optional if dataset=https://file.zip is passed directly) if str(dataset).startswith(("http:/", "https:/")): dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False) elif Path(dataset).suffix in {".zip", ".tar", ".gz"}: file = check_file(dataset) dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) dataset = Path(dataset) data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve() if not data_dir.is_dir(): LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...") t = time.time() if str(dataset) == "imagenet": subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) else: url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip" download(url, dir=data_dir.parent) s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" LOGGER.info(s) train_set = data_dir / "train" val_set = ( data_dir / "val" if (data_dir / "val").exists() else data_dir / "validation" if (data_dir / "validation").exists() else None ) # data/test or data/val test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test if split == "val" and not val_set: LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") elif split == "test" and not test_set: LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.") nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list names = dict(enumerate(sorted(names))) # Print to console for k, v in {"train": train_set, "val": val_set, "test": test_set}.items(): prefix = f'{colorstr(f"{k}:")} {v}...' if v is None: LOGGER.info(prefix) else: files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS] nf = len(files) # number of files nd = len({file.parent for file in files}) # number of directories if nf == 0: if k == "train": raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ ")) else: LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found") elif nd != nc: LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}") else: LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ") return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names} class HUBDatasetStats: """ A class for generating HUB dataset JSON and `-hub` dataset directory. Args: path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'. task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'. autodownload (bool): Attempt to download dataset if not found locally. Default is False. Example: Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip. ```python from ultralytics.data.utils import HUBDatasetStats stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset stats = HUBDatasetStats("path/to/coco8-seg.zip", task="segment") # segment dataset stats = HUBDatasetStats("path/to/coco8-pose.zip", task="pose") # pose dataset stats = HUBDatasetStats("path/to/dota8.zip", task="obb") # OBB dataset stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset stats.get_json(save=True) stats.process_images() ``` """ def __init__(self, path="coco8.yaml", task="detect", autodownload=False): """Initialize class.""" path = Path(path).resolve() LOGGER.info(f"Starting HUB dataset checks for {path}....") self.task = task # detect, segment, pose, classify, obb if self.task == "classify": unzip_dir = unzip_file(path) data = check_cls_dataset(unzip_dir) data["path"] = unzip_dir else: # detect, segment, pose, obb _, data_dir, yaml_path = self._unzip(Path(path)) try: # Load YAML with checks data = yaml_load(yaml_path) data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets yaml_save(yaml_path, data) data = check_det_dataset(yaml_path, autodownload) # dict data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute) except Exception as e: raise Exception("error/HUB/dataset_stats/init") from e self.hub_dir = Path(f'{data["path"]}-hub') self.im_dir = self.hub_dir / "images" self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary self.data = data @staticmethod def _unzip(path): """Unzip data.zip.""" if not str(path).endswith(".zip"): # path is data.yaml return False, None, path unzip_dir = unzip_file(path, path=path.parent) assert unzip_dir.is_dir(), ( f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/" ) return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path def _hub_ops(self, f): """Saves a compressed image for HUB previews.""" compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub def get_json(self, save=False, verbose=False): """Return dataset JSON for Ultralytics HUB.""" def _round(labels): """Update labels to integer class and 4 decimal place floats.""" if self.task == "detect": coordinates = labels["bboxes"] elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy coordinates = [x.flatten() for x in labels["segments"]] elif self.task == "pose": n, nk, nd = labels["keypoints"].shape coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1) else: raise ValueError(f"Undefined dataset task={self.task}.") zipped = zip(labels["cls"], coordinates) return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped] for split in "train", "val", "test": self.stats[split] = None # predefine path = self.data.get(split) # Check split if path is None: # no split continue files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split if not files: # no images continue # Get dataset statistics if self.task == "classify": from torchvision.datasets import ImageFolder dataset = ImageFolder(self.data[split]) x = np.zeros(len(dataset.classes)).astype(int) for im in dataset.imgs: x[im[1]] += 1 self.stats[split] = { "instance_stats": {"total": len(dataset), "per_class": x.tolist()}, "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()}, "labels": [{Path(k).name: v} for k, v in dataset.imgs], } else: from ultralytics.data import YOLODataset dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task) x = np.array( [ np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"]) for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics") ] ) # shape(128x80) self.stats[split] = { "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()}, "image_stats": { "total": len(dataset), "unlabelled": int(np.all(x == 0, 1).sum()), "per_class": (x > 0).sum(0).tolist(), }, "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)], } # Save, print and return if save: self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/ stats_path = self.hub_dir / "stats.json" LOGGER.info(f"Saving {stats_path.resolve()}...") with open(stats_path, "w") as f: json.dump(self.stats, f) # save stats.json if verbose: LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) return self.stats def process_images(self): """Compress images for Ultralytics HUB.""" from ultralytics.data import YOLODataset # ClassificationDataset self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/ for split in "train", "val", "test": if self.data.get(split) is None: continue dataset = YOLODataset(img_path=self.data[split], data=self.data) with ThreadPool(NUM_THREADS) as pool: for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"): pass LOGGER.info(f"Done. All images saved to {self.im_dir}") return self.im_dir def compress_one_image(f, f_new=None, max_dim=1920, quality=50): """ Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be resized. Args: f (str): The path to the input image file. f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten. max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels. quality (int, optional): The image compression quality as a percentage. Default is 50%. Example: ```python from pathlib import Path from ultralytics.data.utils import compress_one_image for f in Path("path/to/dataset").rglob("*.jpg"): compress_one_image(f) ``` """ try: # use PIL im = Image.open(f) r = max_dim / max(im.height, im.width) # ratio if r < 1.0: # image too large im = im.resize((int(im.width * r), int(im.height * r))) im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save except Exception as e: # use OpenCV LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}") im = cv2.imread(f) im_height, im_width = im.shape[:2] r = max_dim / max(im_height, im_width) # ratio if r < 1.0: # image too large im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA) cv2.imwrite(str(f_new or f), im) def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False): """ Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files. Args: path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'. weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0). annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False. Example: ```python from ultralytics.data.utils import autosplit autosplit() ``` """ path = Path(path) # images dir files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only n = len(files) # number of files random.seed(0) # for reproducibility indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files for x in txt: if (path.parent / x).exists(): (path.parent / x).unlink() # remove existing LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only) for i, img in TQDM(zip(indices, files), total=n): if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label with open(path.parent / txt[i], "a") as f: f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file def load_dataset_cache_file(path): """Load an Ultralytics *.cache dictionary from path.""" import gc gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 cache = np.load(str(path), allow_pickle=True).item() # load dict gc.enable() return cache def save_dataset_cache_file(prefix, path, x, version): """Save an Ultralytics dataset *.cache dictionary x to path.""" x["version"] = version # add cache version if is_dir_writeable(path.parent): if path.exists(): path.unlink() # remove *.cache file if exists np.save(str(path), x) # save cache for next time path.with_suffix(".cache.npy").rename(path) # remove .npy suffix LOGGER.info(f"{prefix}New cache created: {path}") else: LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")