tomrb's picture
initial yolov8to
ab854b9
raw
history blame
28.5 kB
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import hashlib
import json
import os
import random
import subprocess
import time
import zipfile
from multiprocessing.pool import ThreadPool
from pathlib import Path
from tarfile import is_tarfile
import cv2
import numpy as np
from PIL import Image, ImageOps
from tqdm import tqdm
from ultralytics.nn.autobackend import check_class_names
from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
yaml_load)
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
from ultralytics.utils.ops import segments2boxes
HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.'
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
def img2label_paths(img_paths):
"""Define label paths as a function of image paths."""
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
def get_hash(paths):
"""Returns a single hash value of a list of paths (files or dirs)."""
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.sha256(str(size).encode()) # hash sizes
h.update(''.join(paths).encode()) # hash paths
return h.hexdigest() # return hash
def exif_size(img: Image.Image):
"""Returns exif-corrected PIL size."""
s = img.size # (width, height)
if img.format == 'JPEG': # only support JPEG images
with contextlib.suppress(Exception):
exif = img.getexif()
if exif:
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
if rotation in [6, 8]: # rotation 270 or 90
s = s[1], s[0]
return s
def verify_image(args):
"""Verify one image."""
(im_file, cls), prefix = args
# Number (found, corrupt), message
nf, nc, msg = 0, 0, ''
try:
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
if im.format.lower() in ('jpg', 'jpeg'):
with open(im_file, 'rb') as f:
f.seek(-2, 2)
if f.read() != b'\xff\xd9': # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
nf = 1
except Exception as e:
nc = 1
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
return (im_file, cls), nf, nc, msg
def verify_image_label(args):
"""Verify one image-label pair."""
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
# Number (missing, found, empty, corrupt), message, segments, keypoints
nm, nf, ne, nc, msg, segments, keypoints, regression_vars = 0, 0, 0, 0, '', [], None, None
try:
# Verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
if im.format.lower() in ('jpg', 'jpeg'):
with open(im_file, 'rb') as f:
f.seek(-2, 2)
if f.read() != b'\xff\xd9': # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
# Verify labels
if os.path.isfile(lb_file):
nf = 1 # label found
with open(lb_file) as f:
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
classes = np.array([x[0] for x in lb], dtype=np.float32)
regression_vars = np.array([x[1:7] for x in lb], dtype=np.float32) # Extract regression variables
segments = [np.array(x[7:], dtype=np.float32).reshape(-1, 2) for x in lb]
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
lb = np.array(lb, dtype=np.float32)
nl = len(lb)
if nl:
if keypoint:
assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
assert (lb[:, 5::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert (lb[:, 6::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
else:
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
assert (lb[:, 1:] <= 1).all(), \
f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
# All labels
max_cls = int(lb[:, 0].max()) # max label count
assert max_cls <= num_cls, \
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
f'Possible class labels are 0-{num_cls - 1}'
_, i = np.unique(lb, axis=0, return_index=True)
if len(i) < nl: # duplicate row check
lb = lb[i] # remove duplicates
if segments:
segments = [segments[x] for x in i]
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
else:
ne = 1 # label empty
lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros(
(0, 5), dtype=np.float32)
else:
nm = 1 # label missing
lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
if keypoint:
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
if ndim == 2:
kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
lb = lb[:, :5]
return im_file, lb, shape, segments, regression_vars, keypoints, nm, nf, ne, nc, msg
except Exception as e:
nc = 1
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
return [None, None, None, None, None,None, nm, nf, ne, nc, msg]
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
"""
Args:
imgsz (tuple): The image size.
polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
color (int): color
downsample_ratio (int): downsample ratio
"""
mask = np.zeros(imgsz, dtype=np.uint8)
polygons = np.asarray(polygons, dtype=np.int32)
polygons = polygons.reshape((polygons.shape[0], -1, 2))
cv2.fillPoly(mask, polygons, color=color)
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
# NOTE: fillPoly first then resize is trying to keep the same way of loss calculation when mask-ratio=1.
return cv2.resize(mask, (nw, nh))
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
"""
Args:
imgsz (tuple): The image size.
polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
color (int): color
downsample_ratio (int): downsample ratio
"""
return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
"""Return a (640, 640) overlap mask."""
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
dtype=np.int32 if len(segments) > 255 else np.uint8)
areas = []
ms = []
for si in range(len(segments)):
mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
ms.append(mask)
areas.append(mask.sum())
areas = np.asarray(areas)
index = np.argsort(-areas)
ms = np.array(ms)[index]
for i in range(len(segments)):
mask = ms[i] * (i + 1)
masks = masks + mask
masks = np.clip(masks, a_min=0, a_max=i + 1)
return masks, index
def check_det_dataset(dataset, autodownload=True):
"""
Download, verify, and/or unzip a dataset if not found locally.
This function checks the availability of a specified dataset, and if not found, it has the option to download and
unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
resolves paths related to the dataset.
Args:
dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
Returns:
(dict): Parsed dataset information and paths.
"""
data = check_file(dataset)
# Download (optional)
extract_dir = ''
if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
extract_dir, autodownload = data.parent, False
# Read YAML (optional)
if isinstance(data, (str, Path)):
data = yaml_load(data, append_filename=True) # dictionary
# Checks
for k in 'train', 'val':
if k not in data:
if k == 'val' and 'validation' in data:
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
data['val'] = data.pop('validation') # replace 'validation' key with 'val' key
else:
raise SyntaxError(
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
if 'names' not in data and 'nc' not in data:
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
if 'names' not in data:
data['names'] = [f'class_{i}' for i in range(data['nc'])]
else:
data['nc'] = len(data['names'])
data['names'] = check_class_names(data['names'])
# Resolve paths
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
if not path.is_absolute():
path = (DATASETS_DIR / path).resolve()
data['path'] = path # download scripts
for k in 'train', 'val', 'test':
if data.get(k): # prepend path
if isinstance(data[k], str):
x = (path / data[k]).resolve()
if not x.exists() and data[k].startswith('../'):
x = (path / data[k][3:]).resolve()
data[k] = str(x)
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse YAML
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
name = clean_url(dataset) # dataset name with URL auth stripped
m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
if s and autodownload:
LOGGER.warning(m)
else:
m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
raise FileNotFoundError(m)
t = time.time()
r = None # success
if s.startswith('http') and s.endswith('.zip'): # URL
safe_download(url=s, dir=DATASETS_DIR, delete=True)
elif s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...')
r = os.system(s)
else: # python script
exec(s, {'yaml': data})
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
LOGGER.info(f'Dataset download {s}\n')
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
return data # dictionary
def check_cls_dataset(dataset, split=''):
"""
Checks a classification dataset such as Imagenet.
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
Args:
dataset (str | Path): The name of the dataset.
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns:
(dict): A dictionary containing the following keys:
- 'train' (Path): The directory path containing the training set of the dataset.
- 'val' (Path): The directory path containing the validation set of the dataset.
- 'test' (Path): The directory path containing the test set of the dataset.
- 'nc' (int): The number of classes in the dataset.
- 'names' (dict): A dictionary of class names in the dataset.
"""
dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
if not data_dir.is_dir():
LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
t = time.time()
if str(dataset) == 'imagenet':
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
download(url, dir=data_dir.parent)
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s)
train_set = data_dir / 'train'
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if (
data_dir / 'validation').exists() else None # data/test or data/val
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
if split == 'val' and not val_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
elif split == 'test' and not test_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names)))
# Print to console
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
prefix = f'{colorstr(f"{k}:")} {v}...'
if v is None:
LOGGER.info(prefix)
else:
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
nf = len(files) # number of files
nd = len({file.parent for file in files}) # number of directories
if nf == 0:
if k == 'train':
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
else:
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found')
elif nd != nc:
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}')
else:
LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
class HUBDatasetStats:
"""
A class for generating HUB dataset JSON and `-hub` dataset directory.
Args:
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
Example:
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
```python
from ultralytics.data.utils import HUBDatasetStats
stats = HUBDatasetStats('path/to/coco8.zip', task='detect') # detect dataset
stats = HUBDatasetStats('path/to/coco8-seg.zip', task='segment') # segment dataset
stats = HUBDatasetStats('path/to/coco8-pose.zip', task='pose') # pose dataset
stats = HUBDatasetStats('path/to/imagenet10.zip', task='classify') # classification dataset
stats.get_json(save=True)
stats.process_images()
```
"""
def __init__(self, path='coco128.yaml', task='detect', autodownload=False):
"""Initialize class."""
path = Path(path).resolve()
LOGGER.info(f'Starting HUB dataset checks for {path}....')
self.task = task # detect, segment, pose, classify
if self.task == 'classify':
unzip_dir = unzip_file(path)
data = check_cls_dataset(unzip_dir)
data['path'] = unzip_dir
else: # detect, segment, pose
zipped, data_dir, yaml_path = self._unzip(Path(path))
try:
# data = yaml_load(check_yaml(yaml_path)) # data dict
data = check_det_dataset(yaml_path, autodownload) # data dict
if zipped:
data['path'] = data_dir
except Exception as e:
raise Exception('error/HUB/dataset_stats/init') from e
self.hub_dir = Path(f'{data["path"]}-hub')
self.im_dir = self.hub_dir / 'images'
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
self.data = data
@staticmethod
def _find_yaml(dir):
"""Return data.yaml file."""
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
assert files, f"No *.yaml file found in '{dir.resolve()}'"
if len(files) > 1:
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
assert len(files) == 1, f"Expected 1 *.yaml file in '{dir.resolve()}', but found {len(files)}.\n{files}"
return files[0]
def _unzip(self, path):
"""Unzip data.zip."""
if not str(path).endswith('.zip'): # path is data.yaml
return False, None, path
unzip_dir = unzip_file(path, path=path.parent)
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
f'path/to/abc.zip MUST unzip to path/to/abc/'
return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
def _hub_ops(self, f):
"""Saves a compressed image for HUB previews."""
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
def get_json(self, save=False, verbose=False):
"""Return dataset JSON for Ultralytics HUB."""
def _round(labels):
"""Update labels to integer class and 4 decimal place floats."""
if self.task == 'detect':
coordinates = labels['bboxes']
elif self.task == 'segment':
coordinates = [x.flatten() for x in labels['segments']]
elif self.task == 'pose':
n = labels['keypoints'].shape[0]
coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
else:
raise ValueError('Undefined dataset task.')
zipped = zip(labels['cls'], coordinates)
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
for split in 'train', 'val', 'test':
self.stats[split] = None # predefine
path = self.data.get(split)
# Check split
if path is None: # no split
continue
files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
if not files: # no images
continue
# Get dataset statistics
if self.task == 'classify':
from torchvision.datasets import ImageFolder
dataset = ImageFolder(self.data[split])
x = np.zeros(len(dataset.classes)).astype(int)
for im in dataset.imgs:
x[im[1]] += 1
self.stats[split] = {
'instance_stats': {
'total': len(dataset),
'per_class': x.tolist()},
'image_stats': {
'total': len(dataset),
'unlabelled': 0,
'per_class': x.tolist()},
'labels': [{
Path(k).name: v} for k, v in dataset.imgs]}
else:
from ultralytics.data import YOLODataset
dataset = YOLODataset(img_path=self.data[split],
data=self.data,
use_segments=self.task == 'segment',
use_keypoints=self.task == 'pose')
x = np.array([
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
self.stats[split] = {
'instance_stats': {
'total': int(x.sum()),
'per_class': x.sum(0).tolist()},
'image_stats': {
'total': len(dataset),
'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()},
'labels': [{
Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
# Save, print and return
if save:
stats_path = self.hub_dir / 'stats.json'
LOGGER.info(f'Saving {stats_path.resolve()}...')
with open(stats_path, 'w') as f:
json.dump(self.stats, f) # save stats.json
if verbose:
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
return self.stats
def process_images(self):
"""Compress images for Ultralytics HUB."""
from ultralytics.data import YOLODataset # ClassificationDataset
for split in 'train', 'val', 'test':
if self.data.get(split) is None:
continue
dataset = YOLODataset(img_path=self.data[split], data=self.data)
with ThreadPool(NUM_THREADS) as pool:
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
pass
LOGGER.info(f'Done. All images saved to {self.im_dir}')
return self.im_dir
def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
"""
Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the
Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will
not be resized.
Args:
f (str): The path to the input image file.
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
quality (int, optional): The image compression quality as a percentage. Default is 50%.
Example:
```python
from pathlib import Path
from ultralytics.data.utils import compress_one_image
for f in Path('path/to/dataset').rglob('*.jpg'):
compress_one_image(f)
```
"""
try: # use PIL
im = Image.open(f)
r = max_dim / max(im.height, im.width) # ratio
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
except Exception as e: # use OpenCV
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
im = cv2.imread(f)
im_height, im_width = im.shape[:2]
r = max_dim / max(im_height, im_width) # ratio
if r < 1.0: # image too large
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
cv2.imwrite(str(f_new or f), im)
def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
"""
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
Args:
path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
Example:
```python
from ultralytics.data.utils import autosplit
autosplit()
```
"""
path = Path(path) # images dir
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
n = len(files) # number of files
random.seed(0) # for reproducibility
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
for x in txt:
if (path.parent / x).exists():
(path.parent / x).unlink() # remove existing
LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
for i, img in tqdm(zip(indices, files), total=n):
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path.parent / txt[i], 'a') as f:
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file