llzzyy233 commited on
Commit
80914e2
·
verified ·
1 Parent(s): 4c9215f

提交项目文件

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +71 -0
  2. detect-best.pt +3 -0
  3. example1.jpg +0 -0
  4. example2.jpg +0 -0
  5. example3.jpg +0 -0
  6. requirements.txt +1 -0
  7. ultralytics/__init__.py +13 -0
  8. ultralytics/__pycache__/__init__.cpython-310.pyc +0 -0
  9. ultralytics/__pycache__/__init__.cpython-39.pyc +0 -0
  10. ultralytics/cfg/__init__.py +441 -0
  11. ultralytics/cfg/__pycache__/__init__.cpython-310.pyc +0 -0
  12. ultralytics/cfg/__pycache__/__init__.cpython-39.pyc +0 -0
  13. ultralytics/cfg/default.yaml +114 -0
  14. ultralytics/cfg/models/v8/yolov8.yaml +46 -0
  15. ultralytics/cfg/models/v8/yolov8_ECA.yaml +50 -0
  16. ultralytics/cfg/models/v8/yolov8_GAM.yaml +50 -0
  17. ultralytics/cfg/models/v8/yolov8_ResBlock_CBAM.yaml +50 -0
  18. ultralytics/cfg/models/v8/yolov8_SA.yaml +50 -0
  19. ultralytics/cfg/trackers/botsort.yaml +18 -0
  20. ultralytics/cfg/trackers/bytetrack.yaml +11 -0
  21. ultralytics/data/__init__.py +8 -0
  22. ultralytics/data/__pycache__/__init__.cpython-310.pyc +0 -0
  23. ultralytics/data/__pycache__/__init__.cpython-39.pyc +0 -0
  24. ultralytics/data/__pycache__/augment.cpython-310.pyc +0 -0
  25. ultralytics/data/__pycache__/augment.cpython-39.pyc +0 -0
  26. ultralytics/data/__pycache__/base.cpython-310.pyc +0 -0
  27. ultralytics/data/__pycache__/base.cpython-39.pyc +0 -0
  28. ultralytics/data/__pycache__/build.cpython-310.pyc +0 -0
  29. ultralytics/data/__pycache__/build.cpython-39.pyc +0 -0
  30. ultralytics/data/__pycache__/dataset.cpython-310.pyc +0 -0
  31. ultralytics/data/__pycache__/dataset.cpython-39.pyc +0 -0
  32. ultralytics/data/__pycache__/loaders.cpython-310.pyc +0 -0
  33. ultralytics/data/__pycache__/loaders.cpython-39.pyc +0 -0
  34. ultralytics/data/__pycache__/utils.cpython-310.pyc +0 -0
  35. ultralytics/data/__pycache__/utils.cpython-39.pyc +0 -0
  36. ultralytics/data/annotator.py +39 -0
  37. ultralytics/data/augment.py +906 -0
  38. ultralytics/data/base.py +287 -0
  39. ultralytics/data/build.py +170 -0
  40. ultralytics/data/converter.py +230 -0
  41. ultralytics/data/dataloaders/__init__.py +0 -0
  42. ultralytics/data/dataset.py +275 -0
  43. ultralytics/data/loaders.py +407 -0
  44. ultralytics/data/scripts/download_weights.sh +18 -0
  45. ultralytics/data/scripts/get_coco.sh +60 -0
  46. ultralytics/data/scripts/get_coco128.sh +17 -0
  47. ultralytics/data/scripts/get_imagenet.sh +51 -0
  48. ultralytics/data/utils.py +557 -0
  49. ultralytics/engine/__init__.py +0 -0
  50. ultralytics/engine/__pycache__/__init__.cpython-310.pyc +0 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from ultralytics import YOLO
5
+ import matplotlib.pyplot as plt
6
+ import io
7
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
+ model = YOLO('detect-best.pt')
9
+
10
+ def predict(img, conf, iou):
11
+ results = model.predict(img, conf=conf, iou=iou)
12
+ name = results[0].names
13
+ cls = results[0].boxes.cls
14
+ crazing = 0
15
+ inclusion = 0
16
+ patches = 0
17
+ pitted_surface = 0
18
+ rolled_inscale = 0
19
+ scratches = 0
20
+ for i in cls:
21
+ if i == 0:
22
+ crazing += 1
23
+ elif i == 1:
24
+ inclusion += 1
25
+ elif i == 2:
26
+ patches += 1
27
+ elif i == 3:
28
+ pitted_surface += 1
29
+ elif i == 4:
30
+ rolled_inscale += 1
31
+ elif i == 5:
32
+ scratches += 1
33
+ # 绘制柱状图
34
+ fig, ax = plt.subplots()
35
+ categories = ['crazing','inclusion', 'patches' ,'pitted_surface', 'rolled_inscale' ,'scratches']
36
+ counts = [crazing,inclusion, patches ,pitted_surface, rolled_inscale ,scratches]
37
+ ax.bar(categories, counts)
38
+ ax.set_title('Category-Count')
39
+ plt.ylim(0,5)
40
+ plt.xticks(rotation=45, ha="right")
41
+ ax.set_xlabel('Category')
42
+ ax.set_ylabel('Count')
43
+ # 将图表保存为字节流
44
+ buf = io.BytesIO()
45
+ canvas = FigureCanvas(fig)
46
+ canvas.print_png(buf)
47
+ plt.close(fig) # 关闭图形,释放资源
48
+
49
+ # 将字节流转换为PIL Image
50
+ image_png = Image.open(buf)
51
+ # 绘制并返回结果图片和类别计数图表
52
+
53
+ for i, r in enumerate(results):
54
+ # Plot results image
55
+ im_bgr = r.plot() # BGR-order numpy array
56
+ im_rgb = Image.fromarray(im_bgr[..., ::-1]) # RGB-order PIL image
57
+
58
+ # Show results to screen (in supported environments)
59
+ return im_rgb, image_png
60
+
61
+
62
+ base_conf, base_iou = 0.25, 0.45
63
+ title = "基于改进YOLOv8算法的工业瑕疵辅助检测系统"
64
+ des = "鼠标点击上传图片即可检测缺陷,可通过鼠标调整预测置信度,还可点击网页最下方示例图片进行预测"
65
+ interface = gr.Interface(
66
+ inputs=['image', gr.Slider(maximum=1, minimum=0, value=base_conf), gr.Slider(maximum=1, minimum=0, value=base_iou)],
67
+ outputs=["image", 'image'], fn=predict, title=title, description=des,
68
+ examples=[["example1.jpg", base_conf, base_iou],
69
+ ["example2.jpg", base_conf, base_iou],
70
+ ["example3.jpg", base_conf, base_iou]])
71
+ interface.launch()
detect-best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62b790537841a3f4a29d3cf6c3a7effcea9000cdf769e87829e8feee0f39b383
3
+ size 8385200
example1.jpg ADDED
example2.jpg ADDED
example3.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ultralytics
ultralytics/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ __version__ = '8.0.147'
4
+
5
+ from ultralytics.hub import start
6
+ from ultralytics.models import RTDETR, SAM, YOLO
7
+ from ultralytics.models.fastsam import FastSAM
8
+ from ultralytics.models.nas import NAS
9
+ from ultralytics.utils import SETTINGS as settings
10
+ from ultralytics.utils.checks import check_yolo as checks
11
+ from ultralytics.utils.downloads import download
12
+
13
+ __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start', 'settings' # allow simpler import
ultralytics/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (674 Bytes). View file
 
ultralytics/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (681 Bytes). View file
 
ultralytics/cfg/__init__.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import contextlib
4
+ import re
5
+ import shutil
6
+ import sys
7
+ from difflib import get_close_matches
8
+ from pathlib import Path
9
+ from types import SimpleNamespace
10
+ from typing import Dict, List, Union
11
+
12
+ from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, SETTINGS, SETTINGS_YAML,
13
+ IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
14
+ yaml_print)
15
+
16
+ # Define valid tasks and modes
17
+ MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
18
+ TASKS = 'detect', 'segment', 'classify', 'pose'
19
+ TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
20
+ TASK2MODEL = {
21
+ 'detect': 'yolov8n.pt',
22
+ 'segment': 'yolov8n-seg.pt',
23
+ 'classify': 'yolov8n-cls.pt',
24
+ 'pose': 'yolov8n-pose.pt'}
25
+ TASK2METRIC = {
26
+ 'detect': 'metrics/mAP50-95(B)',
27
+ 'segment': 'metrics/mAP50-95(M)',
28
+ 'classify': 'metrics/accuracy_top1',
29
+ 'pose': 'metrics/mAP50-95(P)'}
30
+
31
+ CLI_HELP_MSG = \
32
+ f"""
33
+ Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
34
+
35
+ yolo TASK MODE ARGS
36
+
37
+ Where TASK (optional) is one of {TASKS}
38
+ MODE (required) is one of {MODES}
39
+ ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
40
+ See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
41
+
42
+ 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
43
+ yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
44
+
45
+ 2. Predict a YouTube video using a pretrained segmentation model at image size 320:
46
+ yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
47
+
48
+ 3. Val a pretrained detection model at batch-size 1 and image size 640:
49
+ yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
50
+
51
+ 4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
52
+ yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
53
+
54
+ 5. Run special commands:
55
+ yolo help
56
+ yolo checks
57
+ yolo version
58
+ yolo settings
59
+ yolo copy-cfg
60
+ yolo cfg
61
+
62
+ Docs: https://docs.ultralytics.com
63
+ Community: https://community.ultralytics.com
64
+ GitHub: https://github.com/ultralytics/ultralytics
65
+ """
66
+
67
+ # Define keys for arg type checks
68
+ CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
69
+ CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
70
+ 'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
71
+ 'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
72
+ CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
73
+ 'line_width', 'workspace', 'nbs', 'save_period')
74
+ CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
75
+ 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
76
+ 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
77
+ 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
78
+
79
+
80
+ def cfg2dict(cfg):
81
+ """
82
+ Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
83
+
84
+ Args:
85
+ cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
86
+
87
+ Returns:
88
+ cfg (dict): Configuration object in dictionary format.
89
+ """
90
+ if isinstance(cfg, (str, Path)):
91
+ cfg = yaml_load(cfg) # load dict
92
+ elif isinstance(cfg, SimpleNamespace):
93
+ cfg = vars(cfg) # convert to dict
94
+ return cfg
95
+
96
+
97
+ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
98
+ """
99
+ Load and merge configuration data from a file or dictionary.
100
+
101
+ Args:
102
+ cfg (str | Path | Dict | SimpleNamespace): Configuration data.
103
+ overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
104
+
105
+ Returns:
106
+ (SimpleNamespace): Training arguments namespace.
107
+ """
108
+ cfg = cfg2dict(cfg)
109
+
110
+ # Merge overrides
111
+ if overrides:
112
+ overrides = cfg2dict(overrides)
113
+ check_dict_alignment(cfg, overrides)
114
+ cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
115
+
116
+ # Special handling for numeric project/name
117
+ for k in 'project', 'name':
118
+ if k in cfg and isinstance(cfg[k], (int, float)):
119
+ cfg[k] = str(cfg[k])
120
+ if cfg.get('name') == 'model': # assign model to 'name' arg
121
+ cfg['name'] = cfg.get('model', '').split('.')[0]
122
+ LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
123
+
124
+ # Type and Value checks
125
+ for k, v in cfg.items():
126
+ if v is not None: # None values may be from optional args
127
+ if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
128
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
129
+ f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
130
+ elif k in CFG_FRACTION_KEYS:
131
+ if not isinstance(v, (int, float)):
132
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
133
+ f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
134
+ if not (0.0 <= v <= 1.0):
135
+ raise ValueError(f"'{k}={v}' is an invalid value. "
136
+ f"Valid '{k}' values are between 0.0 and 1.0.")
137
+ elif k in CFG_INT_KEYS and not isinstance(v, int):
138
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
139
+ f"'{k}' must be an int (i.e. '{k}=8')")
140
+ elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
141
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
142
+ f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
143
+
144
+ # Return instance
145
+ return IterableSimpleNamespace(**cfg)
146
+
147
+
148
+ def _handle_deprecation(custom):
149
+ """Hardcoded function to handle deprecated config keys"""
150
+
151
+ for key in custom.copy().keys():
152
+ if key == 'hide_labels':
153
+ deprecation_warn(key, 'show_labels')
154
+ custom['show_labels'] = custom.pop('hide_labels') == 'False'
155
+ if key == 'hide_conf':
156
+ deprecation_warn(key, 'show_conf')
157
+ custom['show_conf'] = custom.pop('hide_conf') == 'False'
158
+ if key == 'line_thickness':
159
+ deprecation_warn(key, 'line_width')
160
+ custom['line_width'] = custom.pop('line_thickness')
161
+
162
+ return custom
163
+
164
+
165
+ def check_dict_alignment(base: Dict, custom: Dict, e=None):
166
+ """
167
+ This function checks for any mismatched keys between a custom configuration list and a base configuration list.
168
+ If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
169
+
170
+ Args:
171
+ custom (dict): a dictionary of custom configuration options
172
+ base (dict): a dictionary of base configuration options
173
+ """
174
+ custom = _handle_deprecation(custom)
175
+ base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
176
+ mismatched = [k for k in custom_keys if k not in base_keys]
177
+ if mismatched:
178
+ string = ''
179
+ for x in mismatched:
180
+ matches = get_close_matches(x, base_keys) # key list
181
+ matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches]
182
+ match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
183
+ string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
184
+ raise SyntaxError(string + CLI_HELP_MSG) from e
185
+
186
+
187
+ def merge_equals_args(args: List[str]) -> List[str]:
188
+ """
189
+ Merges arguments around isolated '=' args in a list of strings.
190
+ The function considers cases where the first argument ends with '=' or the second starts with '=',
191
+ as well as when the middle one is an equals sign.
192
+
193
+ Args:
194
+ args (List[str]): A list of strings where each element is an argument.
195
+
196
+ Returns:
197
+ List[str]: A list of strings where the arguments around isolated '=' are merged.
198
+ """
199
+ new_args = []
200
+ for i, arg in enumerate(args):
201
+ if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
202
+ new_args[-1] += f'={args[i + 1]}'
203
+ del args[i + 1]
204
+ elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
205
+ new_args.append(f'{arg}{args[i + 1]}')
206
+ del args[i + 1]
207
+ elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
208
+ new_args[-1] += arg
209
+ else:
210
+ new_args.append(arg)
211
+ return new_args
212
+
213
+
214
+ def handle_yolo_hub(args: List[str]) -> None:
215
+ """
216
+ Handle Ultralytics HUB command-line interface (CLI) commands.
217
+
218
+ This function processes Ultralytics HUB CLI commands such as login and logout.
219
+ It should be called when executing a script with arguments related to HUB authentication.
220
+
221
+ Args:
222
+ args (List[str]): A list of command line arguments
223
+
224
+ Example:
225
+ ```python
226
+ python my_script.py hub login your_api_key
227
+ ```
228
+ """
229
+ from ultralytics import hub
230
+
231
+ if args[0] == 'login':
232
+ key = args[1] if len(args) > 1 else ''
233
+ # Log in to Ultralytics HUB using the provided API key
234
+ hub.login(key)
235
+ elif args[0] == 'logout':
236
+ # Log out from Ultralytics HUB
237
+ hub.logout()
238
+
239
+
240
+ def handle_yolo_settings(args: List[str]) -> None:
241
+ """
242
+ Handle YOLO settings command-line interface (CLI) commands.
243
+
244
+ This function processes YOLO settings CLI commands such as reset.
245
+ It should be called when executing a script with arguments related to YOLO settings management.
246
+
247
+ Args:
248
+ args (List[str]): A list of command line arguments for YOLO settings management.
249
+
250
+ Example:
251
+ ```python
252
+ python my_script.py yolo settings reset
253
+ ```
254
+ """
255
+ if any(args):
256
+ if args[0] == 'reset':
257
+ SETTINGS_YAML.unlink() # delete the settings file
258
+ SETTINGS.reset() # create new settings
259
+ LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
260
+ else: # save a new setting
261
+ new = dict(parse_key_value_pair(a) for a in args)
262
+ check_dict_alignment(SETTINGS, new)
263
+ SETTINGS.update(new)
264
+
265
+ yaml_print(SETTINGS_YAML) # print the current settings
266
+
267
+
268
+ def parse_key_value_pair(pair):
269
+ """Parse one 'key=value' pair and return key and value."""
270
+ re.sub(r' *= *', '=', pair) # remove spaces around equals sign
271
+ k, v = pair.split('=', 1) # split on first '=' sign
272
+ assert v, f"missing '{k}' value"
273
+ return k, smart_value(v)
274
+
275
+
276
+ def smart_value(v):
277
+ """Convert a string to an underlying type such as int, float, bool, etc."""
278
+ if v.lower() == 'none':
279
+ return None
280
+ elif v.lower() == 'true':
281
+ return True
282
+ elif v.lower() == 'false':
283
+ return False
284
+ else:
285
+ with contextlib.suppress(Exception):
286
+ return eval(v)
287
+ return v
288
+
289
+
290
+ def entrypoint(debug=''):
291
+ """
292
+ This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
293
+ to the package.
294
+
295
+ This function allows for:
296
+ - passing mandatory YOLO args as a list of strings
297
+ - specifying the task to be performed, either 'detect', 'segment' or 'classify'
298
+ - specifying the mode, either 'train', 'val', 'test', or 'predict'
299
+ - running special modes like 'checks'
300
+ - passing overrides to the package's configuration
301
+
302
+ It uses the package's default cfg and initializes it using the passed overrides.
303
+ Then it calls the CLI function with the composed cfg
304
+ """
305
+ args = (debug.split(' ') if debug else sys.argv)[1:]
306
+ if not args: # no arguments passed
307
+ LOGGER.info(CLI_HELP_MSG)
308
+ return
309
+
310
+ special = {
311
+ 'help': lambda: LOGGER.info(CLI_HELP_MSG),
312
+ 'checks': checks.check_yolo,
313
+ 'version': lambda: LOGGER.info(__version__),
314
+ 'settings': lambda: handle_yolo_settings(args[1:]),
315
+ 'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
316
+ 'hub': lambda: handle_yolo_hub(args[1:]),
317
+ 'login': lambda: handle_yolo_hub(args),
318
+ 'copy-cfg': copy_default_cfg}
319
+ full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
320
+
321
+ # Define common mis-uses of special commands, i.e. -h, -help, --help
322
+ special.update({k[0]: v for k, v in special.items()}) # singular
323
+ special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
324
+ special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
325
+
326
+ overrides = {} # basic overrides, i.e. imgsz=320
327
+ for a in merge_equals_args(args): # merge spaces around '=' sign
328
+ if a.startswith('--'):
329
+ LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
330
+ a = a[2:]
331
+ if a.endswith(','):
332
+ LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
333
+ a = a[:-1]
334
+ if '=' in a:
335
+ try:
336
+ k, v = parse_key_value_pair(a)
337
+ if k == 'cfg': # custom.yaml passed
338
+ LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
339
+ overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
340
+ else:
341
+ overrides[k] = v
342
+ except (NameError, SyntaxError, ValueError, AssertionError) as e:
343
+ check_dict_alignment(full_args_dict, {a: ''}, e)
344
+
345
+ elif a in TASKS:
346
+ overrides['task'] = a
347
+ elif a in MODES:
348
+ overrides['mode'] = a
349
+ elif a.lower() in special:
350
+ special[a.lower()]()
351
+ return
352
+ elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
353
+ overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
354
+ elif a in DEFAULT_CFG_DICT:
355
+ raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
356
+ f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
357
+ else:
358
+ check_dict_alignment(full_args_dict, {a: ''})
359
+
360
+ # Check keys
361
+ check_dict_alignment(full_args_dict, overrides)
362
+
363
+ # Mode
364
+ mode = overrides.get('mode')
365
+ if mode is None:
366
+ mode = DEFAULT_CFG.mode or 'predict'
367
+ LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
368
+ elif mode not in MODES:
369
+ if mode not in ('checks', checks):
370
+ raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
371
+ LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
372
+ checks.check_yolo()
373
+ return
374
+
375
+ # Task
376
+ task = overrides.pop('task', None)
377
+ if task:
378
+ if task not in TASKS:
379
+ raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
380
+ if 'model' not in overrides:
381
+ overrides['model'] = TASK2MODEL[task]
382
+
383
+ # Model
384
+ model = overrides.pop('model', DEFAULT_CFG.model)
385
+ if model is None:
386
+ model = 'yolov8n.pt'
387
+ LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
388
+ overrides['model'] = model
389
+ if 'rtdetr' in model.lower(): # guess architecture
390
+ from ultralytics import RTDETR
391
+ model = RTDETR(model) # no task argument
392
+ elif 'fastsam' in model.lower():
393
+ from ultralytics import FastSAM
394
+ model = FastSAM(model)
395
+ elif 'sam' in model.lower():
396
+ from ultralytics import SAM
397
+ model = SAM(model)
398
+ else:
399
+ from ultralytics import YOLO
400
+ model = YOLO(model, task=task)
401
+ if isinstance(overrides.get('pretrained'), str):
402
+ model.load(overrides['pretrained'])
403
+
404
+ # Task Update
405
+ if task != model.task:
406
+ if task:
407
+ LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
408
+ f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
409
+ task = model.task
410
+
411
+ # Mode
412
+ if mode in ('predict', 'track') and 'source' not in overrides:
413
+ overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
414
+ else 'https://ultralytics.com/images/bus.jpg'
415
+ LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
416
+ elif mode in ('train', 'val'):
417
+ if 'data' not in overrides:
418
+ overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
419
+ LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
420
+ elif mode == 'export':
421
+ if 'format' not in overrides:
422
+ overrides['format'] = DEFAULT_CFG.format or 'torchscript'
423
+ LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
424
+
425
+ # Run command in python
426
+ # getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
427
+ getattr(model, mode)(**overrides) # default args from model
428
+
429
+
430
+ # Special modes --------------------------------------------------------------------------------------------------------
431
+ def copy_default_cfg():
432
+ """Copy and create a new default configuration file with '_copy' appended to its name."""
433
+ new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
434
+ shutil.copy2(DEFAULT_CFG_PATH, new_file)
435
+ LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
436
+ f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
437
+
438
+
439
+ if __name__ == '__main__':
440
+ # Example Usage: entrypoint(debug='yolo predict model=yolov8n.pt')
441
+ entrypoint(debug='')
ultralytics/cfg/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (16.3 kB). View file
 
ultralytics/cfg/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (16.3 kB). View file
 
ultralytics/cfg/default.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Default training settings and hyperparameters for medium-augmentation COCO training
3
+
4
+ task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
5
+ mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
6
+
7
+ # Train settings -------------------------------------------------------------------------------------------------------
8
+ model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
9
+ data: # (str, optional) path to data file, i.e. coco128.yaml
10
+ epochs: 100 # (int) number of epochs to train for
11
+ patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
12
+ batch: -1 # (int) number of images per batch (-1 for AutoBatch)
13
+ imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
14
+ save: True # (bool) save train checkpoints and predict results
15
+ save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
16
+ cache: False # (bool) True/ram, disk or False. Use cache for data loading
17
+ device: cpu # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
18
+ workers: 2 # (int) number of worker threads for data loading (per RANK if DDP)
19
+ project: # (str, optional) project name
20
+ name: # (str, optional) experiment name, results saved to 'project/name' directory
21
+ exist_ok: True # (bool) whether to overwrite existing experiment
22
+ pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
23
+ optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
24
+ verbose: True # (bool) whether to print verbose output
25
+ seed: 0 # (int) random seed for reproducibility
26
+ deterministic: True # (bool) whether to enable deterministic mode
27
+ single_cls: False # (bool) train multi-class data as single-class
28
+ rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
29
+ cos_lr: False # (bool) use cosine learning rate scheduler
30
+ close_mosaic: 10 # (int) disable mosaic augmentation for final epochs
31
+ resume: False # (bool) resume training from last checkpoint
32
+ amp: False # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
33
+ fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
34
+ profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
35
+ # Segmentation
36
+ overlap_mask: True # (bool) masks should overlap during training (segment train only)
37
+ mask_ratio: 4 # (int) mask downsample ratio (segment train only)
38
+ # Classification
39
+ dropout: 0.0 # (float) use dropout regularization (classify train only)
40
+
41
+ # Val/Test settings ----------------------------------------------------------------------------------------------------
42
+ val: True # (bool) validate/test during training
43
+ split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
44
+ save_json: True # (bool) save results to JSON file
45
+ save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
46
+ conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
47
+ iou: 0.7 # (float) intersection over union (IoU) threshold for NMS
48
+ max_det: 300 # (int) maximum number of detections per image
49
+ half: False # (bool) use half precision (FP16)
50
+ dnn: False # (bool) use OpenCV DNN for ONNX inference
51
+ plots: True # (bool) save plots during train/val
52
+
53
+ # Prediction settings --------------------------------------------------------------------------------------------------
54
+ source: # (str, optional) source directory for images or videos
55
+ show: False # (bool) show results if possible
56
+ save_txt: False # (bool) save results as .txt file
57
+ save_conf: False # (bool) save results with confidence scores
58
+ save_crop: False # (bool) save cropped images with results
59
+ show_labels: True # (bool) show object labels in plots
60
+ show_conf: True # (bool) show object confidence scores in plots
61
+ vid_stride: 1 # (int) video frame-rate stride
62
+ line_width: # (int, optional) line width of the bounding boxes, auto if missing
63
+ visualize: False # (bool) visualize model features
64
+ augment: False # (bool) apply image augmentation to prediction sources
65
+ agnostic_nms: False # (bool) class-agnostic NMS
66
+ classes: # (int | list[int], optional) filter results by class, i.e. class=0, or class=[0,2,3]
67
+ retina_masks: False # (bool) use high-resolution segmentation masks
68
+ boxes: True # (bool) Show boxes in segmentation predictions
69
+
70
+ # Export settings ------------------------------------------------------------------------------------------------------
71
+ format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
72
+ keras: False # (bool) use Kera=s
73
+ optimize: False # (bool) TorchScript: optimize for mobile
74
+ int8: False # (bool) CoreML/TF INT8 quantization
75
+ dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
76
+ simplify: False # (bool) ONNX: simplify model
77
+ opset: # (int, optional) ONNX: opset version
78
+ workspace: 4 # (int) TensorRT: workspace size (GB)
79
+ nms: False # (bool) CoreML: add NMS
80
+
81
+ # Hyperparameters ------------------------------------------------------------------------------------------------------
82
+ lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
83
+ lrf: 0.01 # (float) final learning rate (lr0 * lrf)
84
+ momentum: 0.937 # (float) SGD momentum/Adam beta1
85
+ weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
86
+ warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
87
+ warmup_momentum: 0.8 # (float) warmup initial momentum
88
+ warmup_bias_lr: 0.1 # (float) warmup initial bias lr
89
+ box: 7.5 # (float) box loss gain
90
+ cls: 0.5 # (float) cls loss gain (scale with pixels)
91
+ dfl: 1.5 # (float) dfl loss gain
92
+ pose: 12.0 # (float) pose loss gain
93
+ kobj: 1.0 # (float) keypoint obj loss gain
94
+ label_smoothing: 0.0 # (float) label smoothing (fraction)
95
+ nbs: 64 # (int) nominal batch size
96
+ hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
97
+ hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
98
+ hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
99
+ degrees: 0.0 # (float) image rotation (+/- deg)
100
+ translate: 0.1 # (float) image translation (+/- fraction)
101
+ scale: 0.5 # (float) image scale (+/- gain)
102
+ shear: 0.0 # (float) image shear (+/- deg)
103
+ perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
104
+ flipud: 0.0 # (float) image flip up-down (probability)
105
+ fliplr: 0.5 # (float) image flip left-right (probability)
106
+ mosaic: 1.0 # (float) image mosaic (probability)
107
+ mixup: 0.0 # (float) image mixup (probability)
108
+ copy_paste: 0.0 # (float) segment copy-paste (probability)
109
+
110
+ # Custom config.yaml ---------------------------------------------------------------------------------------------------
111
+ cfg: # (str, optional) for overriding defaults.yaml
112
+ save_dir: ./runs/train1 # 自己设置路径
113
+ # Tracker settings ------------------------------------------------------------------------------------------------------
114
+ tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]
ultralytics/cfg/models/v8/yolov8.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 1 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+
34
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37
+
38
+ - [-1, 1, Conv, [256, 3, 2]]
39
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
40
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41
+
42
+ - [-1, 1, Conv, [512, 3, 2]]
43
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
44
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45
+
46
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/models/v8/yolov8_ECA.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 9 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+ - [-1, 1, ECAAttention, [512]]
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
38
+ - [-1, 1, ECAAttention, [256]]
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
42
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
43
+ - [-1, 1, ECAAttention, [512]]
44
+
45
+ - [-1, 1, Conv, [512, 3, 2]]
46
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
47
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
48
+ - [-1, 1, ECAAttention, [1024]]
49
+
50
+ - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/models/v8/yolov8_GAM.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 9 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+ - [-1, 1, GAM_Attention, [512,512]]
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
38
+ - [-1, 1, GAM_Attention, [256,256]]
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
42
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
43
+ - [-1, 1, GAM_Attention, [512,512]]
44
+
45
+ - [-1, 1, Conv, [512, 3, 2]]
46
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
47
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
48
+ - [-1, 1, GAM_Attention, [1024,1024]]
49
+
50
+ - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/models/v8/yolov8_ResBlock_CBAM.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 9 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+ - [-1, 1, ResBlock_CBAM, [512]]
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
38
+ - [-1, 1, ResBlock_CBAM, [256]]
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
42
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
43
+ - [-1, 1, ResBlock_CBAM, [512]]
44
+
45
+ - [-1, 1, Conv, [512, 3, 2]]
46
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
47
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
48
+ - [-1, 1, ResBlock_CBAM, [1024]]
49
+
50
+ - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/models/v8/yolov8_SA.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 9 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+ - [-1, 1, ShuffleAttention, [512]]
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
38
+ - [-1, 1, ShuffleAttention, [256]]
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
42
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
43
+ - [-1, 1, ShuffleAttention, [512]]
44
+
45
+ - [-1, 1, Conv, [512, 3, 2]]
46
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
47
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
48
+ - [-1, 1, ShuffleAttention, [1024]]
49
+
50
+ - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/trackers/botsort.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT
3
+
4
+ tracker_type: botsort # tracker type, ['botsort', 'bytetrack']
5
+ track_high_thresh: 0.5 # threshold for the first association
6
+ track_low_thresh: 0.1 # threshold for the second association
7
+ new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8
+ track_buffer: 30 # buffer to calculate the time when to remove tracks
9
+ match_thresh: 0.8 # threshold for matching tracks
10
+ # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11
+ # mot20: False # for tracker evaluation(not used for now)
12
+
13
+ # BoT-SORT settings
14
+ cmc_method: sparseOptFlow # method of global motion compensation
15
+ # ReID model related thresh (not supported yet)
16
+ proximity_thresh: 0.5
17
+ appearance_thresh: 0.25
18
+ with_reid: False
ultralytics/cfg/trackers/bytetrack.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack
3
+
4
+ tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack']
5
+ track_high_thresh: 0.5 # threshold for the first association
6
+ track_low_thresh: 0.1 # threshold for the second association
7
+ new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8
+ track_buffer: 30 # buffer to calculate the time when to remove tracks
9
+ match_thresh: 0.8 # threshold for matching tracks
10
+ # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11
+ # mot20: False # for tracker evaluation(not used for now)
ultralytics/data/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .base import BaseDataset
4
+ from .build import build_dataloader, build_yolo_dataset, load_inference_source
5
+ from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
6
+
7
+ __all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
8
+ 'build_dataloader', 'load_inference_source')
ultralytics/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (466 Bytes). View file
 
ultralytics/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (473 Bytes). View file
 
ultralytics/data/__pycache__/augment.cpython-310.pyc ADDED
Binary file (31.5 kB). View file
 
ultralytics/data/__pycache__/augment.cpython-39.pyc ADDED
Binary file (31.6 kB). View file
 
ultralytics/data/__pycache__/base.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
ultralytics/data/__pycache__/base.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
ultralytics/data/__pycache__/build.cpython-310.pyc ADDED
Binary file (6.33 kB). View file
 
ultralytics/data/__pycache__/build.cpython-39.pyc ADDED
Binary file (6.2 kB). View file
 
ultralytics/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
ultralytics/data/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
ultralytics/data/__pycache__/loaders.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
ultralytics/data/__pycache__/loaders.cpython-39.pyc ADDED
Binary file (15.7 kB). View file
 
ultralytics/data/__pycache__/utils.cpython-310.pyc ADDED
Binary file (24.2 kB). View file
 
ultralytics/data/__pycache__/utils.cpython-39.pyc ADDED
Binary file (24.1 kB). View file
 
ultralytics/data/annotator.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from ultralytics import SAM, YOLO
4
+
5
+
6
+ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
7
+ """
8
+ Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
9
+ Args:
10
+ data (str): Path to a folder containing images to be annotated.
11
+ det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'.
12
+ sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'.
13
+ device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
14
+ output_dir (str | None | optional): Directory to save the annotated results.
15
+ Defaults to a 'labels' folder in the same directory as 'data'.
16
+ """
17
+ det_model = YOLO(det_model)
18
+ sam_model = SAM(sam_model)
19
+
20
+ if not output_dir:
21
+ output_dir = Path(str(data)).parent / 'labels'
22
+ Path(output_dir).mkdir(exist_ok=True, parents=True)
23
+
24
+ det_results = det_model(data, stream=True, device=device)
25
+
26
+ for result in det_results:
27
+ boxes = result.boxes.xyxy # Boxes object for bbox outputs
28
+ class_ids = result.boxes.cls.int().tolist() # noqa
29
+ if len(class_ids):
30
+ sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
31
+ segments = sam_results[0].masks.xyn # noqa
32
+
33
+ with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
34
+ for i in range(len(segments)):
35
+ s = segments[i]
36
+ if len(s) == 0:
37
+ continue
38
+ segment = map(str, segments[i].reshape(-1).tolist())
39
+ f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
ultralytics/data/augment.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import math
4
+ import random
5
+ from copy import deepcopy
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ import torchvision.transforms as T
11
+
12
+ from ultralytics.utils import LOGGER, colorstr
13
+ from ultralytics.utils.checks import check_version
14
+ from ultralytics.utils.instance import Instances
15
+ from ultralytics.utils.metrics import bbox_ioa
16
+ from ultralytics.utils.ops import segment2box
17
+
18
+ from .utils import polygons2masks, polygons2masks_overlap
19
+
20
+ POSE_FLIPLR_INDEX = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
21
+
22
+
23
+ # TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
24
+ class BaseTransform:
25
+
26
+ def __init__(self) -> None:
27
+ pass
28
+
29
+ def apply_image(self, labels):
30
+ """Applies image transformation to labels."""
31
+ pass
32
+
33
+ def apply_instances(self, labels):
34
+ """Applies transformations to input 'labels' and returns object instances."""
35
+ pass
36
+
37
+ def apply_semantic(self, labels):
38
+ """Applies semantic segmentation to an image."""
39
+ pass
40
+
41
+ def __call__(self, labels):
42
+ """Applies label transformations to an image, instances and semantic masks."""
43
+ self.apply_image(labels)
44
+ self.apply_instances(labels)
45
+ self.apply_semantic(labels)
46
+
47
+
48
+ class Compose:
49
+
50
+ def __init__(self, transforms):
51
+ """Initializes the Compose object with a list of transforms."""
52
+ self.transforms = transforms
53
+
54
+ def __call__(self, data):
55
+ """Applies a series of transformations to input data."""
56
+ for t in self.transforms:
57
+ data = t(data)
58
+ return data
59
+
60
+ def append(self, transform):
61
+ """Appends a new transform to the existing list of transforms."""
62
+ self.transforms.append(transform)
63
+
64
+ def tolist(self):
65
+ """Converts list of transforms to a standard Python list."""
66
+ return self.transforms
67
+
68
+ def __repr__(self):
69
+ """Return string representation of object."""
70
+ format_string = f'{self.__class__.__name__}('
71
+ for t in self.transforms:
72
+ format_string += '\n'
73
+ format_string += f' {t}'
74
+ format_string += '\n)'
75
+ return format_string
76
+
77
+
78
+ class BaseMixTransform:
79
+ """This implementation is from mmyolo."""
80
+
81
+ def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
82
+ self.dataset = dataset
83
+ self.pre_transform = pre_transform
84
+ self.p = p
85
+
86
+ def __call__(self, labels):
87
+ """Applies pre-processing transforms and mixup/mosaic transforms to labels data."""
88
+ if random.uniform(0, 1) > self.p:
89
+ return labels
90
+
91
+ # Get index of one or three other images
92
+ indexes = self.get_indexes()
93
+ if isinstance(indexes, int):
94
+ indexes = [indexes]
95
+
96
+ # Get images information will be used for Mosaic or MixUp
97
+ mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
98
+
99
+ if self.pre_transform is not None:
100
+ for i, data in enumerate(mix_labels):
101
+ mix_labels[i] = self.pre_transform(data)
102
+ labels['mix_labels'] = mix_labels
103
+
104
+ # Mosaic or MixUp
105
+ labels = self._mix_transform(labels)
106
+ labels.pop('mix_labels', None)
107
+ return labels
108
+
109
+ def _mix_transform(self, labels):
110
+ """Applies MixUp or Mosaic augmentation to the label dictionary."""
111
+ raise NotImplementedError
112
+
113
+ def get_indexes(self):
114
+ """Gets a list of shuffled indexes for mosaic augmentation."""
115
+ raise NotImplementedError
116
+
117
+
118
+ class Mosaic(BaseMixTransform):
119
+ """
120
+ Mosaic augmentation.
121
+
122
+ This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
123
+ The augmentation is applied to a dataset with a given probability.
124
+
125
+ Attributes:
126
+ dataset: The dataset on which the mosaic augmentation is applied.
127
+ imgsz (int, optional): Image size (height and width) after mosaic pipeline of a single image. Default to 640.
128
+ p (float, optional): Probability of applying the mosaic augmentation. Must be in the range 0-1. Default to 1.0.
129
+ n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
130
+ """
131
+
132
+ def __init__(self, dataset, imgsz=640, p=1.0, n=4):
133
+ """Initializes the object with a dataset, image size, probability, and border."""
134
+ assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
135
+ assert n in (4, 9), 'grid must be equal to 4 or 9.'
136
+ super().__init__(dataset=dataset, p=p)
137
+ self.dataset = dataset
138
+ self.imgsz = imgsz
139
+ self.border = (-imgsz // 2, -imgsz // 2) # width, height
140
+ self.n = n
141
+
142
+ def get_indexes(self, buffer=True):
143
+ """Return a list of random indexes from the dataset."""
144
+ if buffer: # select images from buffer
145
+ return random.choices(list(self.dataset.buffer), k=self.n - 1)
146
+ else: # select any images
147
+ return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
148
+
149
+ def _mix_transform(self, labels):
150
+ """Apply mixup transformation to the input image and labels."""
151
+ assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
152
+ assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
153
+ return self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
154
+
155
+ def _mosaic4(self, labels):
156
+ """Create a 2x2 image mosaic."""
157
+ mosaic_labels = []
158
+ s = self.imgsz
159
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
160
+ for i in range(4):
161
+ labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
162
+ # Load image
163
+ img = labels_patch['img']
164
+ h, w = labels_patch.pop('resized_shape')
165
+
166
+ # Place img in img4
167
+ if i == 0: # top left
168
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
169
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
170
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
171
+ elif i == 1: # top right
172
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
173
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
174
+ elif i == 2: # bottom left
175
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
176
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
177
+ elif i == 3: # bottom right
178
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
179
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
180
+
181
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
182
+ padw = x1a - x1b
183
+ padh = y1a - y1b
184
+
185
+ labels_patch = self._update_labels(labels_patch, padw, padh)
186
+ mosaic_labels.append(labels_patch)
187
+ final_labels = self._cat_labels(mosaic_labels)
188
+ final_labels['img'] = img4
189
+ return final_labels
190
+
191
+ def _mosaic9(self, labels):
192
+ """Create a 3x3 image mosaic."""
193
+ mosaic_labels = []
194
+ s = self.imgsz
195
+ hp, wp = -1, -1 # height, width previous
196
+ for i in range(9):
197
+ labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
198
+ # Load image
199
+ img = labels_patch['img']
200
+ h, w = labels_patch.pop('resized_shape')
201
+
202
+ # Place img in img9
203
+ if i == 0: # center
204
+ img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
205
+ h0, w0 = h, w
206
+ c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
207
+ elif i == 1: # top
208
+ c = s, s - h, s + w, s
209
+ elif i == 2: # top right
210
+ c = s + wp, s - h, s + wp + w, s
211
+ elif i == 3: # right
212
+ c = s + w0, s, s + w0 + w, s + h
213
+ elif i == 4: # bottom right
214
+ c = s + w0, s + hp, s + w0 + w, s + hp + h
215
+ elif i == 5: # bottom
216
+ c = s + w0 - w, s + h0, s + w0, s + h0 + h
217
+ elif i == 6: # bottom left
218
+ c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
219
+ elif i == 7: # left
220
+ c = s - w, s + h0 - h, s, s + h0
221
+ elif i == 8: # top left
222
+ c = s - w, s + h0 - hp - h, s, s + h0 - hp
223
+
224
+ padw, padh = c[:2]
225
+ x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
226
+
227
+ # Image
228
+ img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
229
+ hp, wp = h, w # height, width previous for next iteration
230
+
231
+ # Labels assuming imgsz*2 mosaic size
232
+ labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])
233
+ mosaic_labels.append(labels_patch)
234
+ final_labels = self._cat_labels(mosaic_labels)
235
+
236
+ final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
237
+ return final_labels
238
+
239
+ @staticmethod
240
+ def _update_labels(labels, padw, padh):
241
+ """Update labels."""
242
+ nh, nw = labels['img'].shape[:2]
243
+ labels['instances'].convert_bbox(format='xyxy')
244
+ labels['instances'].denormalize(nw, nh)
245
+ labels['instances'].add_padding(padw, padh)
246
+ return labels
247
+
248
+ def _cat_labels(self, mosaic_labels):
249
+ """Return labels with mosaic border instances clipped."""
250
+ if len(mosaic_labels) == 0:
251
+ return {}
252
+ cls = []
253
+ instances = []
254
+ imgsz = self.imgsz * 2 # mosaic imgsz
255
+ for labels in mosaic_labels:
256
+ cls.append(labels['cls'])
257
+ instances.append(labels['instances'])
258
+ final_labels = {
259
+ 'im_file': mosaic_labels[0]['im_file'],
260
+ 'ori_shape': mosaic_labels[0]['ori_shape'],
261
+ 'resized_shape': (imgsz, imgsz),
262
+ 'cls': np.concatenate(cls, 0),
263
+ 'instances': Instances.concatenate(instances, axis=0),
264
+ 'mosaic_border': self.border} # final_labels
265
+ final_labels['instances'].clip(imgsz, imgsz)
266
+ good = final_labels['instances'].remove_zero_area_boxes()
267
+ final_labels['cls'] = final_labels['cls'][good]
268
+ return final_labels
269
+
270
+
271
+ class MixUp(BaseMixTransform):
272
+
273
+ def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
274
+ super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
275
+
276
+ def get_indexes(self):
277
+ """Get a random index from the dataset."""
278
+ return random.randint(0, len(self.dataset) - 1)
279
+
280
+ def _mix_transform(self, labels):
281
+ """Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf."""
282
+ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
283
+ labels2 = labels['mix_labels'][0]
284
+ labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
285
+ labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
286
+ labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
287
+ return labels
288
+
289
+
290
+ class RandomPerspective:
291
+
292
+ def __init__(self,
293
+ degrees=0.0,
294
+ translate=0.1,
295
+ scale=0.5,
296
+ shear=0.0,
297
+ perspective=0.0,
298
+ border=(0, 0),
299
+ pre_transform=None):
300
+ self.degrees = degrees
301
+ self.translate = translate
302
+ self.scale = scale
303
+ self.shear = shear
304
+ self.perspective = perspective
305
+ # Mosaic border
306
+ self.border = border
307
+ self.pre_transform = pre_transform
308
+
309
+ def affine_transform(self, img, border):
310
+ """Center."""
311
+ C = np.eye(3, dtype=np.float32)
312
+
313
+ C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
314
+ C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
315
+
316
+ # Perspective
317
+ P = np.eye(3, dtype=np.float32)
318
+ P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
319
+ P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
320
+
321
+ # Rotation and Scale
322
+ R = np.eye(3, dtype=np.float32)
323
+ a = random.uniform(-self.degrees, self.degrees)
324
+ # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
325
+ s = random.uniform(1 - self.scale, 1 + self.scale)
326
+ # s = 2 ** random.uniform(-scale, scale)
327
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
328
+
329
+ # Shear
330
+ S = np.eye(3, dtype=np.float32)
331
+ S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
332
+ S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
333
+
334
+ # Translation
335
+ T = np.eye(3, dtype=np.float32)
336
+ T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
337
+ T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
338
+
339
+ # Combined rotation matrix
340
+ M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
341
+ # Affine image
342
+ if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
343
+ if self.perspective:
344
+ img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
345
+ else: # affine
346
+ img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
347
+ return img, M, s
348
+
349
+ def apply_bboxes(self, bboxes, M):
350
+ """
351
+ Apply affine to bboxes only.
352
+
353
+ Args:
354
+ bboxes (ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
355
+ M (ndarray): affine matrix.
356
+
357
+ Returns:
358
+ new_bboxes (ndarray): bboxes after affine, [num_bboxes, 4].
359
+ """
360
+ n = len(bboxes)
361
+ if n == 0:
362
+ return bboxes
363
+
364
+ xy = np.ones((n * 4, 3), dtype=bboxes.dtype)
365
+ xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
366
+ xy = xy @ M.T # transform
367
+ xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
368
+
369
+ # Create new boxes
370
+ x = xy[:, [0, 2, 4, 6]]
371
+ y = xy[:, [1, 3, 5, 7]]
372
+ return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
373
+
374
+ def apply_segments(self, segments, M):
375
+ """
376
+ Apply affine to segments and generate new bboxes from segments.
377
+
378
+ Args:
379
+ segments (ndarray): list of segments, [num_samples, 500, 2].
380
+ M (ndarray): affine matrix.
381
+
382
+ Returns:
383
+ new_segments (ndarray): list of segments after affine, [num_samples, 500, 2].
384
+ new_bboxes (ndarray): bboxes after affine, [N, 4].
385
+ """
386
+ n, num = segments.shape[:2]
387
+ if n == 0:
388
+ return [], segments
389
+
390
+ xy = np.ones((n * num, 3), dtype=segments.dtype)
391
+ segments = segments.reshape(-1, 2)
392
+ xy[:, :2] = segments
393
+ xy = xy @ M.T # transform
394
+ xy = xy[:, :2] / xy[:, 2:3]
395
+ segments = xy.reshape(n, -1, 2)
396
+ bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
397
+ return bboxes, segments
398
+
399
+ def apply_keypoints(self, keypoints, M):
400
+ """
401
+ Apply affine to keypoints.
402
+
403
+ Args:
404
+ keypoints (ndarray): keypoints, [N, 17, 3].
405
+ M (ndarray): affine matrix.
406
+
407
+ Return:
408
+ new_keypoints (ndarray): keypoints after affine, [N, 17, 3].
409
+ """
410
+ n, nkpt = keypoints.shape[:2]
411
+ if n == 0:
412
+ return keypoints
413
+ xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype)
414
+ visible = keypoints[..., 2].reshape(n * nkpt, 1)
415
+ xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2)
416
+ xy = xy @ M.T # transform
417
+ xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine
418
+ out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1])
419
+ visible[out_mask] = 0
420
+ return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
421
+
422
+ def __call__(self, labels):
423
+ """
424
+ Affine images and targets.
425
+
426
+ Args:
427
+ labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
428
+ """
429
+ if self.pre_transform and 'mosaic_border' not in labels:
430
+ labels = self.pre_transform(labels)
431
+ labels.pop('ratio_pad', None) # do not need ratio pad
432
+
433
+ img = labels['img']
434
+ cls = labels['cls']
435
+ instances = labels.pop('instances')
436
+ # Make sure the coord formats are right
437
+ instances.convert_bbox(format='xyxy')
438
+ instances.denormalize(*img.shape[:2][::-1])
439
+
440
+ border = labels.pop('mosaic_border', self.border)
441
+ self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
442
+ # M is affine matrix
443
+ # scale for func:`box_candidates`
444
+ img, M, scale = self.affine_transform(img, border)
445
+
446
+ bboxes = self.apply_bboxes(instances.bboxes, M)
447
+
448
+ segments = instances.segments
449
+ keypoints = instances.keypoints
450
+ # Update bboxes if there are segments.
451
+ if len(segments):
452
+ bboxes, segments = self.apply_segments(segments, M)
453
+
454
+ if keypoints is not None:
455
+ keypoints = self.apply_keypoints(keypoints, M)
456
+ new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
457
+ # Clip
458
+ new_instances.clip(*self.size)
459
+
460
+ # Filter instances
461
+ instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
462
+ # Make the bboxes have the same scale with new_bboxes
463
+ i = self.box_candidates(box1=instances.bboxes.T,
464
+ box2=new_instances.bboxes.T,
465
+ area_thr=0.01 if len(segments) else 0.10)
466
+ labels['instances'] = new_instances[i]
467
+ labels['cls'] = cls[i]
468
+ labels['img'] = img
469
+ labels['resized_shape'] = img.shape[:2]
470
+ return labels
471
+
472
+ def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
473
+ # Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
474
+ w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
475
+ w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
476
+ ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
477
+ return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
478
+
479
+
480
+ class RandomHSV:
481
+
482
+ def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
483
+ self.hgain = hgain
484
+ self.sgain = sgain
485
+ self.vgain = vgain
486
+
487
+ def __call__(self, labels):
488
+ """Applies random horizontal or vertical flip to an image with a given probability."""
489
+ img = labels['img']
490
+ if self.hgain or self.sgain or self.vgain:
491
+ r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
492
+ hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
493
+ dtype = img.dtype # uint8
494
+
495
+ x = np.arange(0, 256, dtype=r.dtype)
496
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
497
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
498
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
499
+
500
+ im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
501
+ cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
502
+ return labels
503
+
504
+
505
+ class RandomFlip:
506
+
507
+ def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
508
+ assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
509
+ assert 0 <= p <= 1.0
510
+
511
+ self.p = p
512
+ self.direction = direction
513
+ self.flip_idx = flip_idx
514
+
515
+ def __call__(self, labels):
516
+ """Resize image and padding for detection, instance segmentation, pose."""
517
+ img = labels['img']
518
+ instances = labels.pop('instances')
519
+ instances.convert_bbox(format='xywh')
520
+ h, w = img.shape[:2]
521
+ h = 1 if instances.normalized else h
522
+ w = 1 if instances.normalized else w
523
+
524
+ # Flip up-down
525
+ if self.direction == 'vertical' and random.random() < self.p:
526
+ img = np.flipud(img)
527
+ instances.flipud(h)
528
+ if self.direction == 'horizontal' and random.random() < self.p:
529
+ img = np.fliplr(img)
530
+ instances.fliplr(w)
531
+ # For keypoints
532
+ if self.flip_idx is not None and instances.keypoints is not None:
533
+ instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
534
+ labels['img'] = np.ascontiguousarray(img)
535
+ labels['instances'] = instances
536
+ return labels
537
+
538
+
539
+ class LetterBox:
540
+ """Resize image and padding for detection, instance segmentation, pose."""
541
+
542
+ def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
543
+ """Initialize LetterBox object with specific parameters."""
544
+ self.new_shape = new_shape
545
+ self.auto = auto
546
+ self.scaleFill = scaleFill
547
+ self.scaleup = scaleup
548
+ self.stride = stride
549
+ self.center = center # Put the image in the middle or top-left
550
+
551
+ def __call__(self, labels=None, image=None):
552
+ """Return updated labels and image with added border."""
553
+ if labels is None:
554
+ labels = {}
555
+ img = labels.get('img') if image is None else image
556
+ shape = img.shape[:2] # current shape [height, width]
557
+ new_shape = labels.pop('rect_shape', self.new_shape)
558
+ if isinstance(new_shape, int):
559
+ new_shape = (new_shape, new_shape)
560
+
561
+ # Scale ratio (new / old)
562
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
563
+ if not self.scaleup: # only scale down, do not scale up (for better val mAP)
564
+ r = min(r, 1.0)
565
+
566
+ # Compute padding
567
+ ratio = r, r # width, height ratios
568
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
569
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
570
+ if self.auto: # minimum rectangle
571
+ dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
572
+ elif self.scaleFill: # stretch
573
+ dw, dh = 0.0, 0.0
574
+ new_unpad = (new_shape[1], new_shape[0])
575
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
576
+
577
+ if self.center:
578
+ dw /= 2 # divide padding into 2 sides
579
+ dh /= 2
580
+ if labels.get('ratio_pad'):
581
+ labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
582
+
583
+ if shape[::-1] != new_unpad: # resize
584
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
585
+ top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
586
+ left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
587
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
588
+ value=(114, 114, 114)) # add border
589
+
590
+ if len(labels):
591
+ labels = self._update_labels(labels, ratio, dw, dh)
592
+ labels['img'] = img
593
+ labels['resized_shape'] = new_shape
594
+ return labels
595
+ else:
596
+ return img
597
+
598
+ def _update_labels(self, labels, ratio, padw, padh):
599
+ """Update labels."""
600
+ labels['instances'].convert_bbox(format='xyxy')
601
+ labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
602
+ labels['instances'].scale(*ratio)
603
+ labels['instances'].add_padding(padw, padh)
604
+ return labels
605
+
606
+
607
+ class CopyPaste:
608
+
609
+ def __init__(self, p=0.5) -> None:
610
+ self.p = p
611
+
612
+ def __call__(self, labels):
613
+ """Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)."""
614
+ im = labels['img']
615
+ cls = labels['cls']
616
+ h, w = im.shape[:2]
617
+ instances = labels.pop('instances')
618
+ instances.convert_bbox(format='xyxy')
619
+ instances.denormalize(w, h)
620
+ if self.p and len(instances.segments):
621
+ n = len(instances)
622
+ _, w, _ = im.shape # height, width, channels
623
+ im_new = np.zeros(im.shape, np.uint8)
624
+
625
+ # Calculate ioa first then select indexes randomly
626
+ ins_flip = deepcopy(instances)
627
+ ins_flip.fliplr(w)
628
+
629
+ ioa = bbox_ioa(ins_flip.bboxes, instances.bboxes) # intersection over area, (N, M)
630
+ indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
631
+ n = len(indexes)
632
+ for j in random.sample(list(indexes), k=round(self.p * n)):
633
+ cls = np.concatenate((cls, cls[[j]]), axis=0)
634
+ instances = Instances.concatenate((instances, ins_flip[[j]]), axis=0)
635
+ cv2.drawContours(im_new, instances.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
636
+
637
+ result = cv2.flip(im, 1) # augment segments (flip left-right)
638
+ i = cv2.flip(im_new, 1).astype(bool)
639
+ im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
640
+
641
+ labels['img'] = im
642
+ labels['cls'] = cls
643
+ labels['instances'] = instances
644
+ return labels
645
+
646
+
647
+ class Albumentations:
648
+ """YOLOv8 Albumentations class (optional, only used if package is installed)"""
649
+
650
+ def __init__(self, p=1.0):
651
+ """Initialize the transform object for YOLO bbox formatted params."""
652
+ self.p = p
653
+ self.transform = None
654
+ prefix = colorstr('albumentations: ')
655
+ try:
656
+ import albumentations as A
657
+
658
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
659
+
660
+ T = [
661
+ A.Blur(p=0.01),
662
+ A.MedianBlur(p=0.01),
663
+ A.ToGray(p=0.01),
664
+ A.CLAHE(p=0.01),
665
+ A.RandomBrightnessContrast(p=0.0),
666
+ A.RandomGamma(p=0.0),
667
+ A.ImageCompression(quality_lower=75, p=0.0)] # transforms
668
+ self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
669
+
670
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
671
+ except ImportError: # package not installed, skip
672
+ pass
673
+ except Exception as e:
674
+ LOGGER.info(f'{prefix}{e}')
675
+
676
+ def __call__(self, labels):
677
+ """Generates object detections and returns a dictionary with detection results."""
678
+ im = labels['img']
679
+ cls = labels['cls']
680
+ if len(cls):
681
+ labels['instances'].convert_bbox('xywh')
682
+ labels['instances'].normalize(*im.shape[:2][::-1])
683
+ bboxes = labels['instances'].bboxes
684
+ # TODO: add supports of segments and keypoints
685
+ if self.transform and random.random() < self.p:
686
+ new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
687
+ if len(new['class_labels']) > 0: # skip update if no bbox in new im
688
+ labels['img'] = new['image']
689
+ labels['cls'] = np.array(new['class_labels'])
690
+ bboxes = np.array(new['bboxes'], dtype=np.float32)
691
+ labels['instances'].update(bboxes=bboxes)
692
+ return labels
693
+
694
+
695
+ # TODO: technically this is not an augmentation, maybe we should put this to another files
696
+ class Format:
697
+
698
+ def __init__(self,
699
+ bbox_format='xywh',
700
+ normalize=True,
701
+ return_mask=False,
702
+ return_keypoint=False,
703
+ mask_ratio=4,
704
+ mask_overlap=True,
705
+ batch_idx=True):
706
+ self.bbox_format = bbox_format
707
+ self.normalize = normalize
708
+ self.return_mask = return_mask # set False when training detection only
709
+ self.return_keypoint = return_keypoint
710
+ self.mask_ratio = mask_ratio
711
+ self.mask_overlap = mask_overlap
712
+ self.batch_idx = batch_idx # keep the batch indexes
713
+
714
+ def __call__(self, labels):
715
+ """Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
716
+ img = labels.pop('img')
717
+ h, w = img.shape[:2]
718
+ cls = labels.pop('cls')
719
+ instances = labels.pop('instances')
720
+ instances.convert_bbox(format=self.bbox_format)
721
+ instances.denormalize(w, h)
722
+ nl = len(instances)
723
+
724
+ if self.return_mask:
725
+ if nl:
726
+ masks, instances, cls = self._format_segments(instances, cls, w, h)
727
+ masks = torch.from_numpy(masks)
728
+ else:
729
+ masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
730
+ img.shape[1] // self.mask_ratio)
731
+ labels['masks'] = masks
732
+ if self.normalize:
733
+ instances.normalize(w, h)
734
+ labels['img'] = self._format_img(img)
735
+ labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
736
+ labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
737
+ if self.return_keypoint:
738
+ labels['keypoints'] = torch.from_numpy(instances.keypoints)
739
+ # Then we can use collate_fn
740
+ if self.batch_idx:
741
+ labels['batch_idx'] = torch.zeros(nl)
742
+ return labels
743
+
744
+ def _format_img(self, img):
745
+ """Format the image for YOLOv5 from Numpy array to PyTorch tensor."""
746
+ if len(img.shape) < 3:
747
+ img = np.expand_dims(img, -1)
748
+ img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
749
+ img = torch.from_numpy(img)
750
+ return img
751
+
752
+ def _format_segments(self, instances, cls, w, h):
753
+ """convert polygon points to bitmap."""
754
+ segments = instances.segments
755
+ if self.mask_overlap:
756
+ masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
757
+ masks = masks[None] # (640, 640) -> (1, 640, 640)
758
+ instances = instances[sorted_idx]
759
+ cls = cls[sorted_idx]
760
+ else:
761
+ masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
762
+
763
+ return masks, instances, cls
764
+
765
+
766
+ def v8_transforms(dataset, imgsz, hyp, stretch=False):
767
+ """Convert images to a size suitable for YOLOv8 training."""
768
+ pre_transform = Compose([
769
+ Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
770
+ CopyPaste(p=hyp.copy_paste),
771
+ RandomPerspective(
772
+ degrees=hyp.degrees,
773
+ translate=hyp.translate,
774
+ scale=hyp.scale,
775
+ shear=hyp.shear,
776
+ perspective=hyp.perspective,
777
+ pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
778
+ )])
779
+ flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
780
+ if dataset.use_keypoints:
781
+ kpt_shape = dataset.data.get('kpt_shape', None)
782
+ if len(flip_idx) == 0 and hyp.fliplr > 0.0:
783
+ hyp.fliplr = 0.0
784
+ LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
785
+ elif flip_idx and (len(flip_idx) != kpt_shape[0]):
786
+ raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
787
+
788
+ return Compose([
789
+ pre_transform,
790
+ MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
791
+ Albumentations(p=1.0),
792
+ RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
793
+ RandomFlip(direction='vertical', p=hyp.flipud),
794
+ RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
795
+
796
+
797
+ # Classification augmentations -----------------------------------------------------------------------------------------
798
+ def classify_transforms(size=224, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD
799
+ # Transforms to apply if albumentations not installed
800
+ if not isinstance(size, int):
801
+ raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
802
+ if any(mean) or any(std):
803
+ return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(mean, std, inplace=True)])
804
+ else:
805
+ return T.Compose([CenterCrop(size), ToTensor()])
806
+
807
+
808
+ def hsv2colorjitter(h, s, v):
809
+ """Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
810
+ return v, v, s, h
811
+
812
+
813
+ def classify_albumentations(
814
+ augment=True,
815
+ size=224,
816
+ scale=(0.08, 1.0),
817
+ hflip=0.5,
818
+ vflip=0.0,
819
+ hsv_h=0.015, # image HSV-Hue augmentation (fraction)
820
+ hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
821
+ hsv_v=0.4, # image HSV-Value augmentation (fraction)
822
+ mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
823
+ std=(1.0, 1.0, 1.0), # IMAGENET_STD
824
+ auto_aug=False,
825
+ ):
826
+ """YOLOv8 classification Albumentations (optional, only used if package is installed)."""
827
+ prefix = colorstr('albumentations: ')
828
+ try:
829
+ import albumentations as A
830
+ from albumentations.pytorch import ToTensorV2
831
+
832
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
833
+ if augment: # Resize and crop
834
+ T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
835
+ if auto_aug:
836
+ # TODO: implement AugMix, AutoAug & RandAug in albumentations
837
+ LOGGER.info(f'{prefix}auto augmentations are currently not supported')
838
+ else:
839
+ if hflip > 0:
840
+ T += [A.HorizontalFlip(p=hflip)]
841
+ if vflip > 0:
842
+ T += [A.VerticalFlip(p=vflip)]
843
+ if any((hsv_h, hsv_s, hsv_v)):
844
+ T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
845
+ else: # Use fixed crop for eval set (reproducibility)
846
+ T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
847
+ T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
848
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
849
+ return A.Compose(T)
850
+
851
+ except ImportError: # package not installed, skip
852
+ pass
853
+ except Exception as e:
854
+ LOGGER.info(f'{prefix}{e}')
855
+
856
+
857
+ class ClassifyLetterBox:
858
+ """YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])"""
859
+
860
+ def __init__(self, size=(640, 640), auto=False, stride=32):
861
+ """Resizes image and crops it to center with max dimensions 'h' and 'w'."""
862
+ super().__init__()
863
+ self.h, self.w = (size, size) if isinstance(size, int) else size
864
+ self.auto = auto # pass max size integer, automatically solve for short side using stride
865
+ self.stride = stride # used with auto
866
+
867
+ def __call__(self, im): # im = np.array HWC
868
+ imh, imw = im.shape[:2]
869
+ r = min(self.h / imh, self.w / imw) # ratio of new/old
870
+ h, w = round(imh * r), round(imw * r) # resized image
871
+ hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
872
+ top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
873
+ im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
874
+ im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
875
+ return im_out
876
+
877
+
878
+ class CenterCrop:
879
+ """YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])"""
880
+
881
+ def __init__(self, size=640):
882
+ """Converts an image from numpy array to PyTorch tensor."""
883
+ super().__init__()
884
+ self.h, self.w = (size, size) if isinstance(size, int) else size
885
+
886
+ def __call__(self, im): # im = np.array HWC
887
+ imh, imw = im.shape[:2]
888
+ m = min(imh, imw) # min dimension
889
+ top, left = (imh - m) // 2, (imw - m) // 2
890
+ return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
891
+
892
+
893
+ class ToTensor:
894
+ """YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])."""
895
+
896
+ def __init__(self, half=False):
897
+ """Initialize YOLOv8 ToTensor object with optional half-precision support."""
898
+ super().__init__()
899
+ self.half = half
900
+
901
+ def __call__(self, im): # im = np.array HWC in BGR order
902
+ im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
903
+ im = torch.from_numpy(im) # to torch
904
+ im = im.half() if self.half else im.float() # uint8 to fp16/32
905
+ im /= 255.0 # 0-255 to 0.0-1.0
906
+ return im
ultralytics/data/base.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import glob
4
+ import math
5
+ import os
6
+ import random
7
+ from copy import deepcopy
8
+ from multiprocessing.pool import ThreadPool
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import psutil
15
+ from torch.utils.data import Dataset
16
+ from tqdm import tqdm
17
+
18
+ from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
19
+
20
+ from .utils import HELP_URL, IMG_FORMATS
21
+
22
+
23
+ class BaseDataset(Dataset):
24
+ """
25
+ Base dataset class for loading and processing image data.
26
+
27
+ Args:
28
+ img_path (str): Path to the folder containing images.
29
+ imgsz (int, optional): Image size. Defaults to 640.
30
+ cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
31
+ augment (bool, optional): If True, data augmentation is applied. Defaults to True.
32
+ hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
33
+ prefix (str, optional): Prefix to print in log messages. Defaults to ''.
34
+ rect (bool, optional): If True, rectangular training is used. Defaults to False.
35
+ batch_size (int, optional): Size of batches. Defaults to None.
36
+ stride (int, optional): Stride. Defaults to 32.
37
+ pad (float, optional): Padding. Defaults to 0.0.
38
+ single_cls (bool, optional): If True, single class training is used. Defaults to False.
39
+ classes (list): List of included classes. Default is None.
40
+ fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
41
+
42
+ Attributes:
43
+ im_files (list): List of image file paths.
44
+ labels (list): List of label data dictionaries.
45
+ ni (int): Number of images in the dataset.
46
+ ims (list): List of loaded images.
47
+ npy_files (list): List of numpy file paths.
48
+ transforms (callable): Image transformation function.
49
+ """
50
+
51
+ def __init__(self,
52
+ img_path,
53
+ imgsz=640,
54
+ cache=False,
55
+ augment=True,
56
+ hyp=DEFAULT_CFG,
57
+ prefix='',
58
+ rect=False,
59
+ batch_size=16,
60
+ stride=32,
61
+ pad=0.5,
62
+ single_cls=False,
63
+ classes=None,
64
+ fraction=1.0):
65
+ super().__init__()
66
+ self.img_path = img_path
67
+ self.imgsz = imgsz
68
+ self.augment = augment
69
+ self.single_cls = single_cls
70
+ self.prefix = prefix
71
+ self.fraction = fraction
72
+ self.im_files = self.get_img_files(self.img_path)
73
+ self.labels = self.get_labels()
74
+ self.update_labels(include_class=classes) # single_cls and include_class
75
+ self.ni = len(self.labels) # number of images
76
+ self.rect = rect
77
+ self.batch_size = batch_size
78
+ self.stride = stride
79
+ self.pad = pad
80
+ if self.rect:
81
+ assert self.batch_size is not None
82
+ self.set_rectangle()
83
+
84
+ # Buffer thread for mosaic images
85
+ self.buffer = [] # buffer size = batch size
86
+ self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
87
+
88
+ # Cache stuff
89
+ if cache == 'ram' and not self.check_cache_ram():
90
+ cache = False
91
+ self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
92
+ self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
93
+ if cache:
94
+ self.cache_images(cache)
95
+
96
+ # Transforms
97
+ self.transforms = self.build_transforms(hyp=hyp)
98
+
99
+ def get_img_files(self, img_path):
100
+ """Read image files."""
101
+ try:
102
+ f = [] # image files
103
+ for p in img_path if isinstance(img_path, list) else [img_path]:
104
+ p = Path(p) # os-agnostic
105
+ if p.is_dir(): # dir
106
+ f += glob.glob(str(p / '**' / '*.*'), recursive=True)
107
+ # F = list(p.rglob('*.*')) # pathlib
108
+ elif p.is_file(): # file
109
+ with open(p) as t:
110
+ t = t.read().strip().splitlines()
111
+ parent = str(p.parent) + os.sep
112
+ f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
113
+ # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
114
+ else:
115
+ raise FileNotFoundError(f'{self.prefix}{p} does not exist')
116
+ im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
117
+ # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
118
+ assert im_files, f'{self.prefix}No images found'
119
+ except Exception as e:
120
+ raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
121
+ if self.fraction < 1:
122
+ im_files = im_files[:round(len(im_files) * self.fraction)]
123
+ return im_files
124
+
125
+ def update_labels(self, include_class: Optional[list]):
126
+ """include_class, filter labels to include only these classes (optional)."""
127
+ include_class_array = np.array(include_class).reshape(1, -1)
128
+ for i in range(len(self.labels)):
129
+ if include_class is not None:
130
+ cls = self.labels[i]['cls']
131
+ bboxes = self.labels[i]['bboxes']
132
+ segments = self.labels[i]['segments']
133
+ keypoints = self.labels[i]['keypoints']
134
+ j = (cls == include_class_array).any(1)
135
+ self.labels[i]['cls'] = cls[j]
136
+ self.labels[i]['bboxes'] = bboxes[j]
137
+ if segments:
138
+ self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
139
+ if keypoints is not None:
140
+ self.labels[i]['keypoints'] = keypoints[j]
141
+ if self.single_cls:
142
+ self.labels[i]['cls'][:, 0] = 0
143
+
144
+ def load_image(self, i):
145
+ """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
146
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
147
+ if im is None: # not cached in RAM
148
+ if fn.exists(): # load npy
149
+ im = np.load(fn)
150
+ else: # read image
151
+ im = cv2.imread(f) # BGR
152
+ if im is None:
153
+ raise FileNotFoundError(f'Image Not Found {f}')
154
+ h0, w0 = im.shape[:2] # orig hw
155
+ r = self.imgsz / max(h0, w0) # ratio
156
+ if r != 1: # if sizes are not equal
157
+ interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
158
+ im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
159
+ interpolation=interp)
160
+
161
+ # Add to buffer if training with augmentations
162
+ if self.augment:
163
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
164
+ self.buffer.append(i)
165
+ if len(self.buffer) >= self.max_buffer_length:
166
+ j = self.buffer.pop(0)
167
+ self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
168
+
169
+ return im, (h0, w0), im.shape[:2]
170
+
171
+ return self.ims[i], self.im_hw0[i], self.im_hw[i]
172
+
173
+ def cache_images(self, cache):
174
+ """Cache images to memory or disk."""
175
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
176
+ fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
177
+ with ThreadPool(NUM_THREADS) as pool:
178
+ results = pool.imap(fcn, range(self.ni))
179
+ pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
180
+ for i, x in pbar:
181
+ if cache == 'disk':
182
+ b += self.npy_files[i].stat().st_size
183
+ else: # 'ram'
184
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
185
+ b += self.ims[i].nbytes
186
+ pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
187
+ pbar.close()
188
+
189
+ def cache_images_to_disk(self, i):
190
+ """Saves an image as an *.npy file for faster loading."""
191
+ f = self.npy_files[i]
192
+ if not f.exists():
193
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
194
+
195
+ def check_cache_ram(self, safety_margin=0.5):
196
+ """Check image caching requirements vs available memory."""
197
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
198
+ n = min(self.ni, 30) # extrapolate from 30 random images
199
+ for _ in range(n):
200
+ im = cv2.imread(random.choice(self.im_files)) # sample image
201
+ ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
202
+ b += im.nbytes * ratio ** 2
203
+ mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
204
+ mem = psutil.virtual_memory()
205
+ cache = mem_required < mem.available # to cache or not to cache, that is the question
206
+ if not cache:
207
+ LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
208
+ f'with {int(safety_margin * 100)}% safety margin but only '
209
+ f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
210
+ f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
211
+ return cache
212
+
213
+ def set_rectangle(self):
214
+ """Sets the shape of bounding boxes for YOLO detections as rectangles."""
215
+ bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
216
+ nb = bi[-1] + 1 # number of batches
217
+
218
+ s = np.array([x.pop('shape') for x in self.labels]) # hw
219
+ ar = s[:, 0] / s[:, 1] # aspect ratio
220
+ irect = ar.argsort()
221
+ self.im_files = [self.im_files[i] for i in irect]
222
+ self.labels = [self.labels[i] for i in irect]
223
+ ar = ar[irect]
224
+
225
+ # Set training image shapes
226
+ shapes = [[1, 1]] * nb
227
+ for i in range(nb):
228
+ ari = ar[bi == i]
229
+ mini, maxi = ari.min(), ari.max()
230
+ if maxi < 1:
231
+ shapes[i] = [maxi, 1]
232
+ elif mini > 1:
233
+ shapes[i] = [1, 1 / mini]
234
+
235
+ self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
236
+ self.batch = bi # batch index of image
237
+
238
+ def __getitem__(self, index):
239
+ """Returns transformed label information for given index."""
240
+ return self.transforms(self.get_image_and_label(index))
241
+
242
+ def get_image_and_label(self, index):
243
+ """Get and return label information from the dataset."""
244
+ label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
245
+ label.pop('shape', None) # shape is for rect, remove it
246
+ label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
247
+ label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
248
+ label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
249
+ if self.rect:
250
+ label['rect_shape'] = self.batch_shapes[self.batch[index]]
251
+ return self.update_labels_info(label)
252
+
253
+ def __len__(self):
254
+ """Returns the length of the labels list for the dataset."""
255
+ return len(self.labels)
256
+
257
+ def update_labels_info(self, label):
258
+ """custom your label format here."""
259
+ return label
260
+
261
+ def build_transforms(self, hyp=None):
262
+ """Users can custom augmentations here
263
+ like:
264
+ if self.augment:
265
+ # Training transforms
266
+ return Compose([])
267
+ else:
268
+ # Val transforms
269
+ return Compose([])
270
+ """
271
+ raise NotImplementedError
272
+
273
+ def get_labels(self):
274
+ """Users can custom their own format here.
275
+ Make sure your output is a list with each element like below:
276
+ dict(
277
+ im_file=im_file,
278
+ shape=shape, # format: (height, width)
279
+ cls=cls,
280
+ bboxes=bboxes, # xywh
281
+ segments=segments, # xy
282
+ keypoints=keypoints, # xy
283
+ normalized=True, # or False
284
+ bbox_format="xyxy", # or xywh, ltwh
285
+ )
286
+ """
287
+ raise NotImplementedError
ultralytics/data/build.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import dataloader, distributed
11
+
12
+ from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
13
+ SourceTypes, autocast_list)
14
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
15
+ from ultralytics.utils import RANK, colorstr
16
+ from ultralytics.utils.checks import check_file
17
+
18
+ from .dataset import YOLODataset
19
+ from .utils import PIN_MEMORY
20
+
21
+
22
+ class InfiniteDataLoader(dataloader.DataLoader):
23
+ """Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ """Dataloader that infinitely recycles workers, inherits from DataLoader."""
27
+ super().__init__(*args, **kwargs)
28
+ object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
29
+ self.iterator = super().__iter__()
30
+
31
+ def __len__(self):
32
+ """Returns the length of the batch sampler's sampler."""
33
+ return len(self.batch_sampler.sampler)
34
+
35
+ def __iter__(self):
36
+ """Creates a sampler that repeats indefinitely."""
37
+ for _ in range(len(self)):
38
+ yield next(self.iterator)
39
+
40
+ def reset(self):
41
+ """Reset iterator.
42
+ This is useful when we want to modify settings of dataset while training.
43
+ """
44
+ self.iterator = self._get_iterator()
45
+
46
+
47
+ class _RepeatSampler:
48
+ """
49
+ Sampler that repeats forever.
50
+
51
+ Args:
52
+ sampler (Dataset.sampler): The sampler to repeat.
53
+ """
54
+
55
+ def __init__(self, sampler):
56
+ """Initializes an object that repeats a given sampler indefinitely."""
57
+ self.sampler = sampler
58
+
59
+ def __iter__(self):
60
+ """Iterates over the 'sampler' and yields its contents."""
61
+ while True:
62
+ yield from iter(self.sampler)
63
+
64
+
65
+ def seed_worker(worker_id): # noqa
66
+ """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
67
+ worker_seed = torch.initial_seed() % 2 ** 32
68
+ np.random.seed(worker_seed)
69
+ random.seed(worker_seed)
70
+
71
+
72
+ def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
73
+ """Build YOLO Dataset"""
74
+ return YOLODataset(
75
+ img_path=img_path,
76
+ imgsz=cfg.imgsz,
77
+ batch_size=batch,
78
+ augment=mode == 'train', # augmentation
79
+ hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
80
+ rect=cfg.rect or rect, # rectangular batches
81
+ cache=cfg.cache or None,
82
+ single_cls=cfg.single_cls or False,
83
+ stride=int(stride),
84
+ pad=0.0 if mode == 'train' else 0.5,
85
+ prefix=colorstr(f'{mode}: '),
86
+ use_segments=cfg.task == 'segment',
87
+ use_keypoints=cfg.task == 'pose',
88
+ classes=cfg.classes,
89
+ data=data,
90
+ fraction=cfg.fraction if mode == 'train' else 1.0)
91
+
92
+
93
+ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
94
+ """Return an InfiniteDataLoader or DataLoader for training or validation set."""
95
+ batch = min(batch, len(dataset))
96
+ nd = torch.cuda.device_count() # number of CUDA devices
97
+ nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
98
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
99
+ generator = torch.Generator()
100
+ generator.manual_seed(6148914691236517205 + RANK)
101
+ return InfiniteDataLoader(dataset=dataset,
102
+ batch_size=batch,
103
+ shuffle=shuffle and sampler is None,
104
+ num_workers=nw,
105
+ sampler=sampler,
106
+ pin_memory=PIN_MEMORY,
107
+ collate_fn=getattr(dataset, 'collate_fn', None),
108
+ worker_init_fn=seed_worker,
109
+ generator=generator)
110
+
111
+
112
+ def check_source(source):
113
+ """Check source type and return corresponding flag values."""
114
+ webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
115
+ if isinstance(source, (str, int, Path)): # int for local usb camera
116
+ source = str(source)
117
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
118
+ is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
119
+ webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
120
+ screenshot = source.lower() == 'screen'
121
+ if is_url and is_file:
122
+ source = check_file(source) # download
123
+ elif isinstance(source, tuple(LOADERS)):
124
+ in_memory = True
125
+ elif isinstance(source, (list, tuple)):
126
+ source = autocast_list(source) # convert all list elements to PIL or np arrays
127
+ from_img = True
128
+ elif isinstance(source, (Image.Image, np.ndarray)):
129
+ from_img = True
130
+ elif isinstance(source, torch.Tensor):
131
+ tensor = True
132
+ else:
133
+ raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
134
+
135
+ return source, webcam, screenshot, from_img, in_memory, tensor
136
+
137
+
138
+ def load_inference_source(source=None, imgsz=640, vid_stride=1):
139
+ """
140
+ Loads an inference source for object detection and applies necessary transformations.
141
+
142
+ Args:
143
+ source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
144
+ imgsz (int, optional): The size of the image for inference. Default is 640.
145
+ vid_stride (int, optional): The frame interval for video sources. Default is 1.
146
+
147
+ Returns:
148
+ dataset (Dataset): A dataset object for the specified input source.
149
+ """
150
+ source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
151
+ source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
152
+
153
+ # Dataloader
154
+ if tensor:
155
+ dataset = LoadTensor(source)
156
+ elif in_memory:
157
+ dataset = source
158
+ elif webcam:
159
+ dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride)
160
+ elif screenshot:
161
+ dataset = LoadScreenshots(source, imgsz=imgsz)
162
+ elif from_img:
163
+ dataset = LoadPilAndNumpy(source, imgsz=imgsz)
164
+ else:
165
+ dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
166
+
167
+ # Attach source types to the dataset
168
+ setattr(dataset, 'source_type', source_type)
169
+
170
+ return dataset
ultralytics/data/converter.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from ultralytics.utils.checks import check_requirements
10
+ from ultralytics.utils.files import make_dirs
11
+
12
+
13
+ def coco91_to_coco80_class():
14
+ """Converts 91-index COCO class IDs to 80-index COCO class IDs.
15
+
16
+ Returns:
17
+ (list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
18
+ corresponding 91-index class ID.
19
+
20
+ """
21
+ return [
22
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
23
+ None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
24
+ 51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
25
+ None, 73, 74, 75, 76, 77, 78, 79, None]
26
+
27
+
28
+ def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keypoints=False, cls91to80=True):
29
+ """Converts COCO dataset annotations to a format suitable for training YOLOv5 models.
30
+
31
+ Args:
32
+ labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
33
+ use_segments (bool, optional): Whether to include segmentation masks in the output.
34
+ use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
35
+ cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
36
+
37
+ Raises:
38
+ FileNotFoundError: If the labels_dir path does not exist.
39
+
40
+ Example Usage:
41
+ convert_coco(labels_dir='../coco/annotations/', use_segments=True, use_keypoints=True, cls91to80=True)
42
+
43
+ Output:
44
+ Generates output files in the specified output directory.
45
+ """
46
+
47
+ save_dir = make_dirs('yolo_labels') # output directory
48
+ coco80 = coco91_to_coco80_class()
49
+
50
+ # Import json
51
+ for json_file in sorted(Path(labels_dir).resolve().glob('*.json')):
52
+ fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name
53
+ fn.mkdir(parents=True, exist_ok=True)
54
+ with open(json_file) as f:
55
+ data = json.load(f)
56
+
57
+ # Create image dict
58
+ images = {f'{x["id"]:d}': x for x in data['images']}
59
+ # Create image-annotations dict
60
+ imgToAnns = defaultdict(list)
61
+ for ann in data['annotations']:
62
+ imgToAnns[ann['image_id']].append(ann)
63
+
64
+ # Write labels file
65
+ for img_id, anns in tqdm(imgToAnns.items(), desc=f'Annotations {json_file}'):
66
+ img = images[f'{img_id:d}']
67
+ h, w, f = img['height'], img['width'], img['file_name']
68
+
69
+ bboxes = []
70
+ segments = []
71
+ keypoints = []
72
+ for ann in anns:
73
+ if ann['iscrowd']:
74
+ continue
75
+ # The COCO box format is [top left x, top left y, width, height]
76
+ box = np.array(ann['bbox'], dtype=np.float64)
77
+ box[:2] += box[2:] / 2 # xy top-left corner to center
78
+ box[[0, 2]] /= w # normalize x
79
+ box[[1, 3]] /= h # normalize y
80
+ if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
81
+ continue
82
+
83
+ cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class
84
+ box = [cls] + box.tolist()
85
+ if box not in bboxes:
86
+ bboxes.append(box)
87
+ if use_segments and ann.get('segmentation') is not None:
88
+ if len(ann['segmentation']) == 0:
89
+ segments.append([])
90
+ continue
91
+ if isinstance(ann['segmentation'], dict):
92
+ ann['segmentation'] = rle2polygon(ann['segmentation'])
93
+ if len(ann['segmentation']) > 1:
94
+ s = merge_multi_segment(ann['segmentation'])
95
+ s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
96
+ else:
97
+ s = [j for i in ann['segmentation'] for j in i] # all segments concatenated
98
+ s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
99
+ s = [cls] + s
100
+ if s not in segments:
101
+ segments.append(s)
102
+ if use_keypoints and ann.get('keypoints') is not None:
103
+ k = (np.array(ann['keypoints']).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
104
+ k = box + k
105
+ keypoints.append(k)
106
+
107
+ # Write
108
+ with open((fn / f).with_suffix('.txt'), 'a') as file:
109
+ for i in range(len(bboxes)):
110
+ if use_keypoints:
111
+ line = *(keypoints[i]), # cls, box, keypoints
112
+ else:
113
+ line = *(segments[i]
114
+ if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
115
+ file.write(('%g ' * len(line)).rstrip() % line + '\n')
116
+
117
+
118
+ def rle2polygon(segmentation):
119
+ """
120
+ Convert Run-Length Encoding (RLE) mask to polygon coordinates.
121
+
122
+ Args:
123
+ segmentation (dict, list): RLE mask representation of the object segmentation.
124
+
125
+ Returns:
126
+ (list): A list of lists representing the polygon coordinates for each contour.
127
+
128
+ Note:
129
+ Requires the 'pycocotools' package to be installed.
130
+ """
131
+ check_requirements('pycocotools')
132
+ from pycocotools import mask
133
+
134
+ m = mask.decode(segmentation)
135
+ m[m > 0] = 255
136
+ contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
137
+ polygons = []
138
+ for contour in contours:
139
+ epsilon = 0.001 * cv2.arcLength(contour, True)
140
+ contour_approx = cv2.approxPolyDP(contour, epsilon, True)
141
+ polygon = contour_approx.flatten().tolist()
142
+ polygons.append(polygon)
143
+ return polygons
144
+
145
+
146
+ def min_index(arr1, arr2):
147
+ """
148
+ Find a pair of indexes with the shortest distance between two arrays of 2D points.
149
+
150
+ Args:
151
+ arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points.
152
+ arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points.
153
+
154
+ Returns:
155
+ (tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively.
156
+ """
157
+ dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1)
158
+ return np.unravel_index(np.argmin(dis, axis=None), dis.shape)
159
+
160
+
161
+ def merge_multi_segment(segments):
162
+ """
163
+ Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.
164
+ This function connects these coordinates with a thin line to merge all segments into one.
165
+
166
+ Args:
167
+ segments (List[List]): Original segmentations in COCO's JSON file.
168
+ Each element is a list of coordinates, like [segmentation1, segmentation2,...].
169
+
170
+ Returns:
171
+ s (List[np.ndarray]): A list of connected segments represented as NumPy arrays.
172
+ """
173
+ s = []
174
+ segments = [np.array(i).reshape(-1, 2) for i in segments]
175
+ idx_list = [[] for _ in range(len(segments))]
176
+
177
+ # record the indexes with min distance between each segment
178
+ for i in range(1, len(segments)):
179
+ idx1, idx2 = min_index(segments[i - 1], segments[i])
180
+ idx_list[i - 1].append(idx1)
181
+ idx_list[i].append(idx2)
182
+
183
+ # use two round to connect all the segments
184
+ for k in range(2):
185
+ # forward connection
186
+ if k == 0:
187
+ for i, idx in enumerate(idx_list):
188
+ # middle segments have two indexes
189
+ # reverse the index of middle segments
190
+ if len(idx) == 2 and idx[0] > idx[1]:
191
+ idx = idx[::-1]
192
+ segments[i] = segments[i][::-1, :]
193
+
194
+ segments[i] = np.roll(segments[i], -idx[0], axis=0)
195
+ segments[i] = np.concatenate([segments[i], segments[i][:1]])
196
+ # deal with the first segment and the last one
197
+ if i in [0, len(idx_list) - 1]:
198
+ s.append(segments[i])
199
+ else:
200
+ idx = [0, idx[1] - idx[0]]
201
+ s.append(segments[i][idx[0]:idx[1] + 1])
202
+
203
+ else:
204
+ for i in range(len(idx_list) - 1, -1, -1):
205
+ if i not in [0, len(idx_list) - 1]:
206
+ idx = idx_list[i]
207
+ nidx = abs(idx[1] - idx[0])
208
+ s.append(segments[i][nidx:])
209
+ return s
210
+
211
+
212
+ def delete_dsstore(path='../datasets'):
213
+ """Delete Apple .DS_Store files in the specified directory and its subdirectories."""
214
+ from pathlib import Path
215
+
216
+ files = list(Path(path).rglob('.DS_store'))
217
+ print(files)
218
+ for f in files:
219
+ f.unlink()
220
+
221
+
222
+ if __name__ == '__main__':
223
+ source = 'COCO'
224
+
225
+ if source == 'COCO':
226
+ convert_coco(
227
+ '../datasets/coco/annotations', # directory with *.json
228
+ use_segments=False,
229
+ use_keypoints=True,
230
+ cls91to80=False)
ultralytics/data/dataloaders/__init__.py ADDED
File without changes
ultralytics/data/dataset.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from itertools import repeat
4
+ from multiprocessing.pool import ThreadPool
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ from tqdm import tqdm
12
+
13
+ from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
14
+
15
+ from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
16
+ from .base import BaseDataset
17
+ from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
18
+
19
+
20
+ class YOLODataset(BaseDataset):
21
+ """
22
+ Dataset class for loading object detection and/or segmentation labels in YOLO format.
23
+
24
+ Args:
25
+ data (dict, optional): A dataset YAML dictionary. Defaults to None.
26
+ use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
27
+ use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
28
+
29
+ Returns:
30
+ (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
31
+ """
32
+ cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
33
+ rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
34
+
35
+ def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
36
+ self.use_segments = use_segments
37
+ self.use_keypoints = use_keypoints
38
+ self.data = data
39
+ assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def cache_labels(self, path=Path('./labels.cache')):
43
+ """Cache dataset labels, check images and read shapes.
44
+ Args:
45
+ path (Path): path where to save the cache file (default: Path('./labels.cache')).
46
+ Returns:
47
+ (dict): labels.
48
+ """
49
+ x = {'labels': []}
50
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
51
+ desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
52
+ total = len(self.im_files)
53
+ nkpt, ndim = self.data.get('kpt_shape', (0, 0))
54
+ if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
55
+ raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
56
+ "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
57
+ with ThreadPool(NUM_THREADS) as pool:
58
+ results = pool.imap(func=verify_image_label,
59
+ iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
60
+ repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
61
+ repeat(ndim)))
62
+ pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
63
+ for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
64
+ nm += nm_f
65
+ nf += nf_f
66
+ ne += ne_f
67
+ nc += nc_f
68
+ if im_file:
69
+ x['labels'].append(
70
+ dict(
71
+ im_file=im_file,
72
+ shape=shape,
73
+ cls=lb[:, 0:1], # n, 1
74
+ bboxes=lb[:, 1:], # n, 4
75
+ segments=segments,
76
+ keypoints=keypoint,
77
+ normalized=True,
78
+ bbox_format='xywh'))
79
+ if msg:
80
+ msgs.append(msg)
81
+ pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
82
+ pbar.close()
83
+
84
+ if msgs:
85
+ LOGGER.info('\n'.join(msgs))
86
+ if nf == 0:
87
+ LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
88
+ x['hash'] = get_hash(self.label_files + self.im_files)
89
+ x['results'] = nf, nm, ne, nc, len(self.im_files)
90
+ x['msgs'] = msgs # warnings
91
+ x['version'] = self.cache_version # cache version
92
+ if is_dir_writeable(path.parent):
93
+ if path.exists():
94
+ path.unlink() # remove *.cache file if exists
95
+ np.save(str(path), x) # save cache for next time
96
+ path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
97
+ LOGGER.info(f'{self.prefix}New cache created: {path}')
98
+ else:
99
+ LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
100
+ return x
101
+
102
+ def get_labels(self):
103
+ """Returns dictionary of labels for YOLO training."""
104
+ self.label_files = img2label_paths(self.im_files)
105
+ cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
106
+ try:
107
+ import gc
108
+ gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
109
+ cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
110
+ gc.enable()
111
+ assert cache['version'] == self.cache_version # matches current version
112
+ assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
113
+ except (FileNotFoundError, AssertionError, AttributeError):
114
+ cache, exists = self.cache_labels(cache_path), False # run cache ops
115
+
116
+ # Display cache
117
+ nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
118
+ if exists and LOCAL_RANK in (-1, 0):
119
+ d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
120
+ tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
121
+ if cache['msgs']:
122
+ LOGGER.info('\n'.join(cache['msgs'])) # display warnings
123
+ if nf == 0: # number of labels found
124
+ raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
125
+
126
+ # Read cache
127
+ [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
128
+ labels = cache['labels']
129
+ self.im_files = [lb['im_file'] for lb in labels] # update im_files
130
+
131
+ # Check if the dataset is all boxes or all segments
132
+ lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
133
+ len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
134
+ if len_segments and len_boxes != len_segments:
135
+ LOGGER.warning(
136
+ f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
137
+ f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
138
+ 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
139
+ for lb in labels:
140
+ lb['segments'] = []
141
+ if len_cls == 0:
142
+ raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
143
+ return labels
144
+
145
+ # TODO: use hyp config to set all these augmentations
146
+ def build_transforms(self, hyp=None):
147
+ """Builds and appends transforms to the list."""
148
+ if self.augment:
149
+ hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
150
+ hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
151
+ transforms = v8_transforms(self, self.imgsz, hyp)
152
+ else:
153
+ transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
154
+ transforms.append(
155
+ Format(bbox_format='xywh',
156
+ normalize=True,
157
+ return_mask=self.use_segments,
158
+ return_keypoint=self.use_keypoints,
159
+ batch_idx=True,
160
+ mask_ratio=hyp.mask_ratio,
161
+ mask_overlap=hyp.overlap_mask))
162
+ return transforms
163
+
164
+ def close_mosaic(self, hyp):
165
+ """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
166
+ hyp.mosaic = 0.0 # set mosaic ratio=0.0
167
+ hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
168
+ hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
169
+ self.transforms = self.build_transforms(hyp)
170
+
171
+ def update_labels_info(self, label):
172
+ """custom your label format here."""
173
+ # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
174
+ # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
175
+ bboxes = label.pop('bboxes')
176
+ segments = label.pop('segments')
177
+ keypoints = label.pop('keypoints', None)
178
+ bbox_format = label.pop('bbox_format')
179
+ normalized = label.pop('normalized')
180
+ label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
181
+ return label
182
+
183
+ @staticmethod
184
+ def collate_fn(batch):
185
+ """Collates data samples into batches."""
186
+ new_batch = {}
187
+ keys = batch[0].keys()
188
+ values = list(zip(*[list(b.values()) for b in batch]))
189
+ for i, k in enumerate(keys):
190
+ value = values[i]
191
+ if k == 'img':
192
+ value = torch.stack(value, 0)
193
+ if k in ['masks', 'keypoints', 'bboxes', 'cls']:
194
+ value = torch.cat(value, 0)
195
+ new_batch[k] = value
196
+ new_batch['batch_idx'] = list(new_batch['batch_idx'])
197
+ for i in range(len(new_batch['batch_idx'])):
198
+ new_batch['batch_idx'][i] += i # add target image index for build_targets()
199
+ new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
200
+ return new_batch
201
+
202
+
203
+ # Classification dataloaders -------------------------------------------------------------------------------------------
204
+ class ClassificationDataset(torchvision.datasets.ImageFolder):
205
+ """
206
+ YOLO Classification Dataset.
207
+
208
+ Args:
209
+ root (str): Dataset path.
210
+
211
+ Attributes:
212
+ cache_ram (bool): True if images should be cached in RAM, False otherwise.
213
+ cache_disk (bool): True if images should be cached on disk, False otherwise.
214
+ samples (list): List of samples containing file, index, npy, and im.
215
+ torch_transforms (callable): torchvision transforms applied to the dataset.
216
+ album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
217
+ """
218
+
219
+ def __init__(self, root, args, augment=False, cache=False):
220
+ """
221
+ Initialize YOLO object with root, image size, augmentations, and cache settings.
222
+
223
+ Args:
224
+ root (str): Dataset path.
225
+ args (Namespace): Argument parser containing dataset related settings.
226
+ augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
227
+ cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
228
+ """
229
+ super().__init__(root=root)
230
+ if augment and args.fraction < 1.0: # reduce training fraction
231
+ self.samples = self.samples[:round(len(self.samples) * args.fraction)]
232
+ self.cache_ram = cache is True or cache == 'ram'
233
+ self.cache_disk = cache == 'disk'
234
+ self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
235
+ self.torch_transforms = classify_transforms(args.imgsz)
236
+ self.album_transforms = classify_albumentations(
237
+ augment=augment,
238
+ size=args.imgsz,
239
+ scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
240
+ hflip=args.fliplr,
241
+ vflip=args.flipud,
242
+ hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
243
+ hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
244
+ hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
245
+ mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
246
+ std=(1.0, 1.0, 1.0), # IMAGENET_STD
247
+ auto_aug=False) if augment else None
248
+
249
+ def __getitem__(self, i):
250
+ """Returns subset of data and targets corresponding to given indices."""
251
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
252
+ if self.cache_ram and im is None:
253
+ im = self.samples[i][3] = cv2.imread(f)
254
+ elif self.cache_disk:
255
+ if not fn.exists(): # load npy
256
+ np.save(fn.as_posix(), cv2.imread(f))
257
+ im = np.load(fn)
258
+ else: # read image
259
+ im = cv2.imread(f) # BGR
260
+ if self.album_transforms:
261
+ sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
262
+ else:
263
+ sample = self.torch_transforms(im)
264
+ return {'img': sample, 'cls': j}
265
+
266
+ def __len__(self) -> int:
267
+ return len(self.samples)
268
+
269
+
270
+ # TODO: support semantic segmentation
271
+ class SemanticDataset(BaseDataset):
272
+
273
+ def __init__(self):
274
+ """Initialize a SemanticDataset object."""
275
+ super().__init__()
ultralytics/data/loaders.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import glob
4
+ import math
5
+ import os
6
+ import time
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from threading import Thread
10
+ from urllib.parse import urlparse
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import requests
15
+ import torch
16
+ from PIL import Image
17
+
18
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
19
+ from ultralytics.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
20
+ from ultralytics.utils.checks import check_requirements
21
+
22
+
23
+ @dataclass
24
+ class SourceTypes:
25
+ webcam: bool = False
26
+ screenshot: bool = False
27
+ from_img: bool = False
28
+ tensor: bool = False
29
+
30
+
31
+ class LoadStreams:
32
+ """YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
33
+
34
+ def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
35
+ """Initialize instance variables and check for consistent input stream shapes."""
36
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
37
+ self.mode = 'stream'
38
+ self.imgsz = imgsz
39
+ self.vid_stride = vid_stride # video frame-rate stride
40
+ sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
41
+ n = len(sources)
42
+ self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
43
+ self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n
44
+ for i, s in enumerate(sources): # index, source
45
+ # Start thread to read frames from video stream
46
+ st = f'{i + 1}/{n}: {s}... '
47
+ if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
48
+ # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
49
+ s = get_best_youtube_url(s)
50
+ s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
51
+ if s == 0 and (is_colab() or is_kaggle()):
52
+ raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
53
+ "Try running 'source=0' in a local environment.")
54
+ cap = cv2.VideoCapture(s)
55
+ if not cap.isOpened():
56
+ raise ConnectionError(f'{st}Failed to open {s}')
57
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
58
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
59
+ fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
60
+ self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
61
+ self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
62
+
63
+ success, im = cap.read() # guarantee first frame
64
+ if not success or im is None:
65
+ raise ConnectionError(f'{st}Failed to read images from {s}')
66
+ self.imgs[i].append(im)
67
+ self.shape[i] = im.shape
68
+ self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
69
+ LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
70
+ self.threads[i].start()
71
+ LOGGER.info('') # newline
72
+
73
+ # Check for common shapes
74
+ self.bs = self.__len__()
75
+
76
+ def update(self, i, cap, stream):
77
+ """Read stream `i` frames in daemon thread."""
78
+ n, f = 0, self.frames[i] # frame number, frame array
79
+ while cap.isOpened() and n < f:
80
+ # Only read a new frame if the buffer is empty
81
+ if not self.imgs[i]:
82
+ n += 1
83
+ cap.grab() # .read() = .grab() followed by .retrieve()
84
+ if n % self.vid_stride == 0:
85
+ success, im = cap.retrieve()
86
+ if success:
87
+ self.imgs[i].append(im) # add image to buffer
88
+ else:
89
+ LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
90
+ self.imgs[i].append(np.zeros(self.shape[i]))
91
+ cap.open(stream) # re-open stream if signal was lost
92
+ else:
93
+ time.sleep(0.01) # wait until the buffer is empty
94
+
95
+ def __iter__(self):
96
+ """Iterates through YOLO image feed and re-opens unresponsive streams."""
97
+ self.count = -1
98
+ return self
99
+
100
+ def __next__(self):
101
+ """Returns source paths, transformed and original images for processing."""
102
+ self.count += 1
103
+
104
+ # Wait until a frame is available in each buffer
105
+ while not all(self.imgs):
106
+ if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
107
+ cv2.destroyAllWindows()
108
+ raise StopIteration
109
+ time.sleep(1 / min(self.fps))
110
+
111
+ # Get and remove the next frame from imgs buffer
112
+ return self.sources, [x.pop(0) for x in self.imgs], None, ''
113
+
114
+ def __len__(self):
115
+ """Return the length of the sources object."""
116
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
117
+
118
+
119
+ class LoadScreenshots:
120
+ """YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
121
+
122
+ def __init__(self, source, imgsz=640):
123
+ """source = [screen_number left top width height] (pixels)."""
124
+ check_requirements('mss')
125
+ import mss # noqa
126
+
127
+ source, *params = source.split()
128
+ self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
129
+ if len(params) == 1:
130
+ self.screen = int(params[0])
131
+ elif len(params) == 4:
132
+ left, top, width, height = (int(x) for x in params)
133
+ elif len(params) == 5:
134
+ self.screen, left, top, width, height = (int(x) for x in params)
135
+ self.imgsz = imgsz
136
+ self.mode = 'stream'
137
+ self.frame = 0
138
+ self.sct = mss.mss()
139
+ self.bs = 1
140
+
141
+ # Parse monitor shape
142
+ monitor = self.sct.monitors[self.screen]
143
+ self.top = monitor['top'] if top is None else (monitor['top'] + top)
144
+ self.left = monitor['left'] if left is None else (monitor['left'] + left)
145
+ self.width = width or monitor['width']
146
+ self.height = height or monitor['height']
147
+ self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
148
+
149
+ def __iter__(self):
150
+ """Returns an iterator of the object."""
151
+ return self
152
+
153
+ def __next__(self):
154
+ """mss screen capture: get raw pixels from the screen as np array."""
155
+ im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
156
+ s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
157
+
158
+ self.frame += 1
159
+ return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
160
+
161
+
162
+ class LoadImages:
163
+ """YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
164
+
165
+ def __init__(self, path, imgsz=640, vid_stride=1):
166
+ """Initialize the Dataloader and raise FileNotFoundError if file not found."""
167
+ parent = None
168
+ if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
169
+ parent = Path(path).parent
170
+ path = Path(path).read_text().rsplit()
171
+ files = []
172
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
173
+ a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
174
+ if '*' in a:
175
+ files.extend(sorted(glob.glob(a, recursive=True))) # glob
176
+ elif os.path.isdir(a):
177
+ files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir
178
+ elif os.path.isfile(a):
179
+ files.append(a) # files (absolute or relative to CWD)
180
+ elif parent and (parent / p).is_file():
181
+ files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
182
+ else:
183
+ raise FileNotFoundError(f'{p} does not exist')
184
+
185
+ images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
186
+ videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
187
+ ni, nv = len(images), len(videos)
188
+
189
+ self.imgsz = imgsz
190
+ self.files = images + videos
191
+ self.nf = ni + nv # number of files
192
+ self.video_flag = [False] * ni + [True] * nv
193
+ self.mode = 'image'
194
+ self.vid_stride = vid_stride # video frame-rate stride
195
+ self.bs = 1
196
+ if any(videos):
197
+ self.orientation = None # rotation degrees
198
+ self._new_video(videos[0]) # new video
199
+ else:
200
+ self.cap = None
201
+ if self.nf == 0:
202
+ raise FileNotFoundError(f'No images or videos found in {p}. '
203
+ f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
204
+
205
+ def __iter__(self):
206
+ """Returns an iterator object for VideoStream or ImageFolder."""
207
+ self.count = 0
208
+ return self
209
+
210
+ def __next__(self):
211
+ """Return next image, path and metadata from dataset."""
212
+ if self.count == self.nf:
213
+ raise StopIteration
214
+ path = self.files[self.count]
215
+
216
+ if self.video_flag[self.count]:
217
+ # Read video
218
+ self.mode = 'video'
219
+ for _ in range(self.vid_stride):
220
+ self.cap.grab()
221
+ success, im0 = self.cap.retrieve()
222
+ while not success:
223
+ self.count += 1
224
+ self.cap.release()
225
+ if self.count == self.nf: # last video
226
+ raise StopIteration
227
+ path = self.files[self.count]
228
+ self._new_video(path)
229
+ success, im0 = self.cap.read()
230
+
231
+ self.frame += 1
232
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
233
+ s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
234
+
235
+ else:
236
+ # Read image
237
+ self.count += 1
238
+ im0 = cv2.imread(path) # BGR
239
+ if im0 is None:
240
+ raise FileNotFoundError(f'Image Not Found {path}')
241
+ s = f'image {self.count}/{self.nf} {path}: '
242
+
243
+ return [path], [im0], self.cap, s
244
+
245
+ def _new_video(self, path):
246
+ """Create a new video capture object."""
247
+ self.frame = 0
248
+ self.cap = cv2.VideoCapture(path)
249
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
250
+ if hasattr(cv2, 'CAP_PROP_ORIENTATION_META'): # cv2<4.6.0 compatibility
251
+ self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
252
+ # Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
253
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
254
+
255
+ def _cv2_rotate(self, im):
256
+ """Rotate a cv2 video manually."""
257
+ if self.orientation == 0:
258
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
259
+ elif self.orientation == 180:
260
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
261
+ elif self.orientation == 90:
262
+ return cv2.rotate(im, cv2.ROTATE_180)
263
+ return im
264
+
265
+ def __len__(self):
266
+ """Returns the number of files in the object."""
267
+ return self.nf # number of files
268
+
269
+
270
+ class LoadPilAndNumpy:
271
+
272
+ def __init__(self, im0, imgsz=640):
273
+ """Initialize PIL and Numpy Dataloader."""
274
+ if not isinstance(im0, list):
275
+ im0 = [im0]
276
+ self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
277
+ self.im0 = [self._single_check(im) for im in im0]
278
+ self.imgsz = imgsz
279
+ self.mode = 'image'
280
+ # Generate fake paths
281
+ self.bs = len(self.im0)
282
+
283
+ @staticmethod
284
+ def _single_check(im):
285
+ """Validate and format an image to numpy array."""
286
+ assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
287
+ if isinstance(im, Image.Image):
288
+ if im.mode != 'RGB':
289
+ im = im.convert('RGB')
290
+ im = np.asarray(im)[:, :, ::-1]
291
+ im = np.ascontiguousarray(im) # contiguous
292
+ return im
293
+
294
+ def __len__(self):
295
+ """Returns the length of the 'im0' attribute."""
296
+ return len(self.im0)
297
+
298
+ def __next__(self):
299
+ """Returns batch paths, images, processed images, None, ''."""
300
+ if self.count == 1: # loop only once as it's batch inference
301
+ raise StopIteration
302
+ self.count += 1
303
+ return self.paths, self.im0, None, ''
304
+
305
+ def __iter__(self):
306
+ """Enables iteration for class LoadPilAndNumpy."""
307
+ self.count = 0
308
+ return self
309
+
310
+
311
+ class LoadTensor:
312
+
313
+ def __init__(self, im0) -> None:
314
+ self.im0 = self._single_check(im0)
315
+ self.bs = self.im0.shape[0]
316
+ self.mode = 'image'
317
+ self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
318
+
319
+ @staticmethod
320
+ def _single_check(im, stride=32):
321
+ """Validate and format an image to torch.Tensor."""
322
+ s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
323
+ f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
324
+ if len(im.shape) != 4:
325
+ if len(im.shape) != 3:
326
+ raise ValueError(s)
327
+ LOGGER.warning(s)
328
+ im = im.unsqueeze(0)
329
+ if im.shape[2] % stride or im.shape[3] % stride:
330
+ raise ValueError(s)
331
+ if im.max() > 1.0:
332
+ LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
333
+ f'Dividing input by 255.')
334
+ im = im.float() / 255.0
335
+
336
+ return im
337
+
338
+ def __iter__(self):
339
+ """Returns an iterator object."""
340
+ self.count = 0
341
+ return self
342
+
343
+ def __next__(self):
344
+ """Return next item in the iterator."""
345
+ if self.count == 1:
346
+ raise StopIteration
347
+ self.count += 1
348
+ return self.paths, self.im0, None, ''
349
+
350
+ def __len__(self):
351
+ """Returns the batch size."""
352
+ return self.bs
353
+
354
+
355
+ def autocast_list(source):
356
+ """
357
+ Merges a list of source of different types into a list of numpy arrays or PIL images
358
+ """
359
+ files = []
360
+ for im in source:
361
+ if isinstance(im, (str, Path)): # filename or uri
362
+ files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
363
+ elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
364
+ files.append(im)
365
+ else:
366
+ raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
367
+ f'See https://docs.ultralytics.com/modes/predict for supported source types.')
368
+
369
+ return files
370
+
371
+
372
+ LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
373
+
374
+
375
+ def get_best_youtube_url(url, use_pafy=True):
376
+ """
377
+ Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
378
+
379
+ This function uses the pafy or yt_dlp library to extract the video info from YouTube. It then finds the highest
380
+ quality MP4 format that has video codec but no audio codec, and returns the URL of this video stream.
381
+
382
+ Args:
383
+ url (str): The URL of the YouTube video.
384
+ use_pafy (bool): Use the pafy package, default=True, otherwise use yt_dlp package.
385
+
386
+ Returns:
387
+ (str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
388
+ """
389
+ if use_pafy:
390
+ check_requirements(('pafy', 'youtube_dl==2020.12.2'))
391
+ import pafy # noqa
392
+ return pafy.new(url).getbest(preftype='mp4').url
393
+ else:
394
+ check_requirements('yt-dlp')
395
+ import yt_dlp
396
+ with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
397
+ info_dict = ydl.extract_info(url, download=False) # extract info
398
+ for f in info_dict.get('formats', None):
399
+ if f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
400
+ return f.get('url', None)
401
+
402
+
403
+ if __name__ == '__main__':
404
+ img = cv2.imread(str(ROOT / 'assets/bus.jpg'))
405
+ dataset = LoadPilAndNumpy(im0=img)
406
+ for d in dataset:
407
+ print(d[0])
ultralytics/data/scripts/download_weights.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download latest models from https://github.com/ultralytics/assets/releases
4
+ # Example usage: bash ultralytics/data/scripts/download_weights.sh
5
+ # parent
6
+ # └── weights
7
+ # ├── yolov8n.pt ← downloads here
8
+ # ├── yolov8s.pt
9
+ # └── ...
10
+
11
+ python - <<EOF
12
+ from ultralytics.utils.downloads import attempt_download_asset
13
+
14
+ assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose')]
15
+ for x in assets:
16
+ attempt_download_asset(f'weights/{x}')
17
+
18
+ EOF
ultralytics/data/scripts/get_coco.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download COCO 2017 dataset http://cocodataset.org
4
+ # Example usage: bash data/scripts/get_coco.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── coco ← downloads here
9
+
10
+ # Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments
11
+ if [ "$#" -gt 0 ]; then
12
+ for opt in "$@"; do
13
+ case "${opt}" in
14
+ --train) train=true ;;
15
+ --val) val=true ;;
16
+ --test) test=true ;;
17
+ --segments) segments=true ;;
18
+ --sama) sama=true ;;
19
+ esac
20
+ done
21
+ else
22
+ train=true
23
+ val=true
24
+ test=false
25
+ segments=false
26
+ sama=false
27
+ fi
28
+
29
+ # Download/unzip labels
30
+ d='../datasets' # unzip directory
31
+ url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
32
+ if [ "$segments" == "true" ]; then
33
+ f='coco2017labels-segments.zip' # 169 MB
34
+ elif [ "$sama" == "true" ]; then
35
+ f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/
36
+ else
37
+ f='coco2017labels.zip' # 46 MB
38
+ fi
39
+ echo 'Downloading' $url$f ' ...'
40
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
41
+
42
+ # Download/unzip images
43
+ d='../datasets/coco/images' # unzip directory
44
+ url=http://images.cocodataset.org/zips/
45
+ if [ "$train" == "true" ]; then
46
+ f='train2017.zip' # 19G, 118k images
47
+ echo 'Downloading' $url$f '...'
48
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
49
+ fi
50
+ if [ "$val" == "true" ]; then
51
+ f='val2017.zip' # 1G, 5k images
52
+ echo 'Downloading' $url$f '...'
53
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
54
+ fi
55
+ if [ "$test" == "true" ]; then
56
+ f='test2017.zip' # 7G, 41k images (optional)
57
+ echo 'Downloading' $url$f '...'
58
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
59
+ fi
60
+ wait # finish background tasks
ultralytics/data/scripts/get_coco128.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017)
4
+ # Example usage: bash data/scripts/get_coco128.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── coco128 ← downloads here
9
+
10
+ # Download/unzip images and labels
11
+ d='../datasets' # unzip directory
12
+ url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
13
+ f='coco128.zip' # or 'coco128-segments.zip', 68 MB
14
+ echo 'Downloading' $url$f ' ...'
15
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
16
+
17
+ wait # finish background tasks
ultralytics/data/scripts/get_imagenet.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download ILSVRC2012 ImageNet dataset https://image-net.org
4
+ # Example usage: bash data/scripts/get_imagenet.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── imagenet ← downloads here
9
+
10
+ # Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val
11
+ if [ "$#" -gt 0 ]; then
12
+ for opt in "$@"; do
13
+ case "${opt}" in
14
+ --train) train=true ;;
15
+ --val) val=true ;;
16
+ esac
17
+ done
18
+ else
19
+ train=true
20
+ val=true
21
+ fi
22
+
23
+ # Make dir
24
+ d='../datasets/imagenet' # unzip directory
25
+ mkdir -p $d && cd $d
26
+
27
+ # Download/unzip train
28
+ if [ "$train" == "true" ]; then
29
+ wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images
30
+ mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
31
+ tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
32
+ find . -name "*.tar" | while read NAME; do
33
+ mkdir -p "${NAME%.tar}"
34
+ tar -xf "${NAME}" -C "${NAME%.tar}"
35
+ rm -f "${NAME}"
36
+ done
37
+ cd ..
38
+ fi
39
+
40
+ # Download/unzip val
41
+ if [ "$val" == "true" ]; then
42
+ wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images
43
+ mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar
44
+ wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs
45
+ fi
46
+
47
+ # Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail)
48
+ # rm train/n04266014/n04266014_10835.JPEG
49
+
50
+ # TFRecords (optional)
51
+ # wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt
ultralytics/data/utils.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import contextlib
4
+ import hashlib
5
+ import json
6
+ import os
7
+ import random
8
+ import subprocess
9
+ import time
10
+ import zipfile
11
+ from multiprocessing.pool import ThreadPool
12
+ from pathlib import Path
13
+ from tarfile import is_tarfile
14
+
15
+ import cv2
16
+ import numpy as np
17
+ from PIL import ExifTags, Image, ImageOps
18
+ from tqdm import tqdm
19
+
20
+ from ultralytics.nn.autobackend import check_class_names
21
+ from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
22
+ yaml_load)
23
+ from ultralytics.utils.checks import check_file, check_font, is_ascii
24
+ from ultralytics.utils.downloads import download, safe_download, unzip_file
25
+ from ultralytics.utils.ops import segments2boxes
26
+
27
+ HELP_URL = 'See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data'
28
+ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
29
+ VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
30
+ PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
31
+ IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
32
+ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
33
+
34
+ # Get orientation exif tag
35
+ for orientation in ExifTags.TAGS.keys():
36
+ if ExifTags.TAGS[orientation] == 'Orientation':
37
+ break
38
+
39
+
40
+ def img2label_paths(img_paths):
41
+ """Define label paths as a function of image paths."""
42
+ sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
43
+ return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
44
+
45
+
46
+ def get_hash(paths):
47
+ """Returns a single hash value of a list of paths (files or dirs)."""
48
+ size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
49
+ h = hashlib.sha256(str(size).encode()) # hash sizes
50
+ h.update(''.join(paths).encode()) # hash paths
51
+ return h.hexdigest() # return hash
52
+
53
+
54
+ def exif_size(img):
55
+ """Returns exif-corrected PIL size."""
56
+ s = img.size # (width, height)
57
+ with contextlib.suppress(Exception):
58
+ rotation = dict(img._getexif().items())[orientation]
59
+ if rotation in [6, 8]: # rotation 270 or 90
60
+ s = (s[1], s[0])
61
+ return s
62
+
63
+
64
+ def verify_image_label(args):
65
+ """Verify one image-label pair."""
66
+ im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
67
+ # Number (missing, found, empty, corrupt), message, segments, keypoints
68
+ nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
69
+ try:
70
+ # Verify images
71
+ im = Image.open(im_file)
72
+ im.verify() # PIL verify
73
+ shape = exif_size(im) # image size
74
+ shape = (shape[1], shape[0]) # hw
75
+ assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
76
+ assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
77
+ if im.format.lower() in ('jpg', 'jpeg'):
78
+ with open(im_file, 'rb') as f:
79
+ f.seek(-2, 2)
80
+ if f.read() != b'\xff\xd9': # corrupt JPEG
81
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
82
+ msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
83
+
84
+ # Verify labels
85
+ if os.path.isfile(lb_file):
86
+ nf = 1 # label found
87
+ with open(lb_file) as f:
88
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
89
+ if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
90
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
91
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
92
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
93
+ lb = np.array(lb, dtype=np.float32)
94
+ nl = len(lb)
95
+ if nl:
96
+ if keypoint:
97
+ assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
98
+ assert (lb[:, 5::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
99
+ assert (lb[:, 6::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
100
+ else:
101
+ assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
102
+ assert (lb[:, 1:] <= 1).all(), \
103
+ f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
104
+ assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
105
+ # All labels
106
+ max_cls = int(lb[:, 0].max()) # max label count
107
+ assert max_cls <= num_cls, \
108
+ f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
109
+ f'Possible class labels are 0-{num_cls - 1}'
110
+ _, i = np.unique(lb, axis=0, return_index=True)
111
+ if len(i) < nl: # duplicate row check
112
+ lb = lb[i] # remove duplicates
113
+ if segments:
114
+ segments = [segments[x] for x in i]
115
+ msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
116
+ else:
117
+ ne = 1 # label empty
118
+ lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros(
119
+ (0, 5), dtype=np.float32)
120
+ else:
121
+ nm = 1 # label missing
122
+ lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
123
+ if keypoint:
124
+ keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
125
+ if ndim == 2:
126
+ kpt_mask = np.ones(keypoints.shape[:2], dtype=np.float32)
127
+ kpt_mask = np.where(keypoints[..., 0] < 0, 0.0, kpt_mask)
128
+ kpt_mask = np.where(keypoints[..., 1] < 0, 0.0, kpt_mask)
129
+ keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
130
+ lb = lb[:, :5]
131
+ return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
132
+ except Exception as e:
133
+ nc = 1
134
+ msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
135
+ return [None, None, None, None, None, nm, nf, ne, nc, msg]
136
+
137
+
138
+ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
139
+ """
140
+ Args:
141
+ imgsz (tuple): The image size.
142
+ polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
143
+ color (int): color
144
+ downsample_ratio (int): downsample ratio
145
+ """
146
+ mask = np.zeros(imgsz, dtype=np.uint8)
147
+ polygons = np.asarray(polygons)
148
+ polygons = polygons.astype(np.int32)
149
+ shape = polygons.shape
150
+ polygons = polygons.reshape(shape[0], -1, 2)
151
+ cv2.fillPoly(mask, polygons, color=color)
152
+ nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
153
+ # NOTE: fillPoly firstly then resize is trying the keep the same way
154
+ # of loss calculation when mask-ratio=1.
155
+ mask = cv2.resize(mask, (nw, nh))
156
+ return mask
157
+
158
+
159
+ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
160
+ """
161
+ Args:
162
+ imgsz (tuple): The image size.
163
+ polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
164
+ color (int): color
165
+ downsample_ratio (int): downsample ratio
166
+ """
167
+ masks = []
168
+ for si in range(len(polygons)):
169
+ mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio)
170
+ masks.append(mask)
171
+ return np.array(masks)
172
+
173
+
174
+ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
175
+ """Return a (640, 640) overlap mask."""
176
+ masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
177
+ dtype=np.int32 if len(segments) > 255 else np.uint8)
178
+ areas = []
179
+ ms = []
180
+ for si in range(len(segments)):
181
+ mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
182
+ ms.append(mask)
183
+ areas.append(mask.sum())
184
+ areas = np.asarray(areas)
185
+ index = np.argsort(-areas)
186
+ ms = np.array(ms)[index]
187
+ for i in range(len(segments)):
188
+ mask = ms[i] * (i + 1)
189
+ masks = masks + mask
190
+ masks = np.clip(masks, a_min=0, a_max=i + 1)
191
+ return masks, index
192
+
193
+
194
+ def check_det_dataset(dataset, autodownload=True):
195
+ """Download, check and/or unzip dataset if not found locally."""
196
+ data = check_file(dataset)
197
+
198
+ # Download (optional)
199
+ extract_dir = ''
200
+ if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
201
+ new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
202
+ data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
203
+ extract_dir, autodownload = data.parent, False
204
+
205
+ # Read yaml (optional)
206
+ if isinstance(data, (str, Path)):
207
+ data = yaml_load(data, append_filename=True) # dictionary
208
+
209
+ # Checks
210
+ for k in 'train', 'val':
211
+ if k not in data:
212
+ raise SyntaxError(
213
+ emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
214
+ if 'names' not in data and 'nc' not in data:
215
+ raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
216
+ if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
217
+ raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
218
+ if 'names' not in data:
219
+ data['names'] = [f'class_{i}' for i in range(data['nc'])]
220
+ else:
221
+ data['nc'] = len(data['names'])
222
+
223
+ data['names'] = check_class_names(data['names'])
224
+
225
+ # Resolve paths
226
+ path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
227
+
228
+ if not path.is_absolute():
229
+ path = (DATASETS_DIR / path).resolve()
230
+ data['path'] = path # download scripts
231
+ for k in 'train', 'val', 'test':
232
+ if data.get(k): # prepend path
233
+ if isinstance(data[k], str):
234
+ x = (path / data[k]).resolve()
235
+ if not x.exists() and data[k].startswith('../'):
236
+ x = (path / data[k][3:]).resolve()
237
+ data[k] = str(x)
238
+ else:
239
+ data[k] = [str((path / x).resolve()) for x in data[k]]
240
+
241
+ # Parse yaml
242
+ train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
243
+ if val:
244
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
245
+ if not all(x.exists() for x in val):
246
+ name = clean_url(dataset) # dataset name with URL auth stripped
247
+ m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
248
+ if s and autodownload:
249
+ LOGGER.warning(m)
250
+ else:
251
+ m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
252
+ raise FileNotFoundError(m)
253
+ t = time.time()
254
+ if s.startswith('http') and s.endswith('.zip'): # URL
255
+ safe_download(url=s, dir=DATASETS_DIR, delete=True)
256
+ r = None # success
257
+ elif s.startswith('bash '): # bash script
258
+ LOGGER.info(f'Running {s} ...')
259
+ r = os.system(s)
260
+ else: # python script
261
+ r = exec(s, {'yaml': data}) # return None
262
+ dt = f'({round(time.time() - t, 1)}s)'
263
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
264
+ LOGGER.info(f'Dataset download {s}\n')
265
+ check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
266
+
267
+ return data # dictionary
268
+
269
+
270
+ def check_cls_dataset(dataset: str, split=''):
271
+ """
272
+ Checks a classification dataset such as Imagenet.
273
+
274
+ This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
275
+ If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
276
+
277
+ Args:
278
+ dataset (str): The name of the dataset.
279
+ split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
280
+
281
+ Returns:
282
+ (dict): A dictionary containing the following keys:
283
+ - 'train' (Path): The directory path containing the training set of the dataset.
284
+ - 'val' (Path): The directory path containing the validation set of the dataset.
285
+ - 'test' (Path): The directory path containing the test set of the dataset.
286
+ - 'nc' (int): The number of classes in the dataset.
287
+ - 'names' (dict): A dictionary of class names in the dataset.
288
+
289
+ Raises:
290
+ FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
291
+ """
292
+
293
+ dataset = Path(dataset)
294
+ data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
295
+ if not data_dir.is_dir():
296
+ LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
297
+ t = time.time()
298
+ if str(dataset) == 'imagenet':
299
+ subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
300
+ else:
301
+ url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
302
+ download(url, dir=data_dir.parent)
303
+ s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
304
+ LOGGER.info(s)
305
+ train_set = data_dir / 'train'
306
+ val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
307
+ test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
308
+ if split == 'val' and not val_set:
309
+ LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
310
+ elif split == 'test' and not test_set:
311
+ LOGGER.info("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
312
+
313
+ nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
314
+ names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
315
+ names = dict(enumerate(sorted(names)))
316
+ return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
317
+
318
+
319
+ class HUBDatasetStats():
320
+ """
321
+ A class for generating HUB dataset JSON and `-hub` dataset directory.
322
+
323
+ Args:
324
+ path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
325
+ task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
326
+ autodownload (bool): Attempt to download dataset if not found locally. Default is False.
327
+
328
+ Usage
329
+ from ultralytics.data.utils import HUBDatasetStats
330
+ stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
331
+ stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
332
+ stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
333
+ stats.get_json(save=False)
334
+ stats.process_images()
335
+ """
336
+
337
+ def __init__(self, path='coco128.yaml', task='detect', autodownload=False):
338
+ """Initialize class."""
339
+ LOGGER.info(f'Starting HUB dataset checks for {path}....')
340
+ zipped, data_dir, yaml_path = self._unzip(Path(path))
341
+ try:
342
+ # data = yaml_load(check_yaml(yaml_path)) # data dict
343
+ data = check_det_dataset(yaml_path, autodownload) # data dict
344
+ if zipped:
345
+ data['path'] = data_dir
346
+ except Exception as e:
347
+ raise Exception('error/HUB/dataset_stats/yaml_load') from e
348
+
349
+ self.hub_dir = Path(str(data['path']) + '-hub')
350
+ self.im_dir = self.hub_dir / 'images'
351
+ self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
352
+ self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
353
+ self.data = data
354
+ self.task = task # detect, segment, pose, classify
355
+
356
+ @staticmethod
357
+ def _find_yaml(dir):
358
+ """Return data.yaml file."""
359
+ files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
360
+ assert files, f'No *.yaml file found in {dir}'
361
+ if len(files) > 1:
362
+ files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
363
+ assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
364
+ assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
365
+ return files[0]
366
+
367
+ def _unzip(self, path):
368
+ """Unzip data.zip."""
369
+ if not str(path).endswith('.zip'): # path is data.yaml
370
+ return False, None, path
371
+ unzip_dir = unzip_file(path, path=path.parent)
372
+ assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
373
+ f'path/to/abc.zip MUST unzip to path/to/abc/'
374
+ return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
375
+
376
+ def _hub_ops(self, f):
377
+ """Saves a compressed image for HUB previews."""
378
+ compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
379
+
380
+ def get_json(self, save=False, verbose=False):
381
+ """Return dataset JSON for Ultralytics HUB."""
382
+ from ultralytics.data import YOLODataset # ClassificationDataset
383
+
384
+ def _round(labels):
385
+ """Update labels to integer class and 4 decimal place floats."""
386
+ if self.task == 'detect':
387
+ coordinates = labels['bboxes']
388
+ elif self.task == 'segment':
389
+ coordinates = [x.flatten() for x in labels['segments']]
390
+ elif self.task == 'pose':
391
+ n = labels['keypoints'].shape[0]
392
+ coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
393
+ else:
394
+ raise ValueError('Undefined dataset task.')
395
+ zipped = zip(labels['cls'], coordinates)
396
+ return [[int(c), *(round(float(x), 4) for x in points)] for c, points in zipped]
397
+
398
+ for split in 'train', 'val', 'test':
399
+ if self.data.get(split) is None:
400
+ self.stats[split] = None # i.e. no test set
401
+ continue
402
+
403
+ dataset = YOLODataset(img_path=self.data[split],
404
+ data=self.data,
405
+ use_segments=self.task == 'segment',
406
+ use_keypoints=self.task == 'pose')
407
+ x = np.array([
408
+ np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
409
+ for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
410
+ self.stats[split] = {
411
+ 'instance_stats': {
412
+ 'total': int(x.sum()),
413
+ 'per_class': x.sum(0).tolist()},
414
+ 'image_stats': {
415
+ 'total': len(dataset),
416
+ 'unlabelled': int(np.all(x == 0, 1).sum()),
417
+ 'per_class': (x > 0).sum(0).tolist()},
418
+ 'labels': [{
419
+ Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
420
+
421
+ # Save, print and return
422
+ if save:
423
+ stats_path = self.hub_dir / 'stats.json'
424
+ LOGGER.info(f'Saving {stats_path.resolve()}...')
425
+ with open(stats_path, 'w') as f:
426
+ json.dump(self.stats, f) # save stats.json
427
+ if verbose:
428
+ LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
429
+ return self.stats
430
+
431
+ def process_images(self):
432
+ """Compress images for Ultralytics HUB."""
433
+ from ultralytics.data import YOLODataset # ClassificationDataset
434
+
435
+ for split in 'train', 'val', 'test':
436
+ if self.data.get(split) is None:
437
+ continue
438
+ dataset = YOLODataset(img_path=self.data[split], data=self.data)
439
+ with ThreadPool(NUM_THREADS) as pool:
440
+ for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
441
+ pass
442
+ LOGGER.info(f'Done. All images saved to {self.im_dir}')
443
+ return self.im_dir
444
+
445
+
446
+ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
447
+ """
448
+ Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the
449
+ Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will
450
+ not be resized.
451
+
452
+ Args:
453
+ f (str): The path to the input image file.
454
+ f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
455
+ max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
456
+ quality (int, optional): The image compression quality as a percentage. Default is 50%.
457
+
458
+ Usage:
459
+ from pathlib import Path
460
+ from ultralytics.data.utils import compress_one_image
461
+ for f in Path('/Users/glennjocher/Downloads/dataset').rglob('*.jpg'):
462
+ compress_one_image(f)
463
+ """
464
+ try: # use PIL
465
+ im = Image.open(f)
466
+ r = max_dim / max(im.height, im.width) # ratio
467
+ if r < 1.0: # image too large
468
+ im = im.resize((int(im.width * r), int(im.height * r)))
469
+ im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
470
+ except Exception as e: # use OpenCV
471
+ LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
472
+ im = cv2.imread(f)
473
+ im_height, im_width = im.shape[:2]
474
+ r = max_dim / max(im_height, im_width) # ratio
475
+ if r < 1.0: # image too large
476
+ im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
477
+ cv2.imwrite(str(f_new or f), im)
478
+
479
+
480
+ def delete_dsstore(path):
481
+ """
482
+ Deletes all ".DS_store" files under a specified directory.
483
+
484
+ Args:
485
+ path (str, optional): The directory path where the ".DS_store" files should be deleted.
486
+
487
+ Usage:
488
+ from ultralytics.data.utils import delete_dsstore
489
+ delete_dsstore('/Users/glennjocher/Downloads/dataset')
490
+
491
+ Note:
492
+ ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
493
+ are hidden system files and can cause issues when transferring files between different operating systems.
494
+ """
495
+ # Delete Apple .DS_store files
496
+ files = list(Path(path).rglob('.DS_store'))
497
+ LOGGER.info(f'Deleting *.DS_store files: {files}')
498
+ for f in files:
499
+ f.unlink()
500
+
501
+
502
+ def zip_directory(dir, use_zipfile_library=True):
503
+ """
504
+ Zips a directory and saves the archive to the specified output path.
505
+
506
+ Args:
507
+ dir (str): The path to the directory to be zipped.
508
+ use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
509
+
510
+ Usage:
511
+ from ultralytics.data.utils import zip_directory
512
+ zip_directory('/Users/glennjocher/Downloads/playground')
513
+
514
+ zip -r coco8-pose.zip coco8-pose
515
+ """
516
+ delete_dsstore(dir)
517
+ if use_zipfile_library:
518
+ dir = Path(dir)
519
+ with zipfile.ZipFile(dir.with_suffix('.zip'), 'w', zipfile.ZIP_DEFLATED) as zip_file:
520
+ for file_path in dir.glob('**/*'):
521
+ if file_path.is_file():
522
+ zip_file.write(file_path, file_path.relative_to(dir))
523
+ else:
524
+ import shutil
525
+ shutil.make_archive(dir, 'zip', dir)
526
+
527
+
528
+ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
529
+ """
530
+ Autosplit a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
531
+
532
+ Args:
533
+ path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco128/images'.
534
+ weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
535
+ annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
536
+
537
+ Usage:
538
+ from utils.dataloaders import autosplit
539
+ autosplit()
540
+ """
541
+
542
+ path = Path(path) # images dir
543
+ files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
544
+ n = len(files) # number of files
545
+ random.seed(0) # for reproducibility
546
+ indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
547
+
548
+ txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
549
+ for x in txt:
550
+ if (path.parent / x).exists():
551
+ (path.parent / x).unlink() # remove existing
552
+
553
+ LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
554
+ for i, img in tqdm(zip(indices, files), total=n):
555
+ if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
556
+ with open(path.parent / txt[i], 'a') as f:
557
+ f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
ultralytics/engine/__init__.py ADDED
File without changes
ultralytics/engine/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (168 Bytes). View file