pg56714 commited on
Commit
d28c8e3
·
verified ·
1 Parent(s): 45a771e

Delete efficientvit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. efficientvit/__init__.py +0 -0
  2. efficientvit/__pycache__/__init__.cpython-310.pyc +0 -0
  3. efficientvit/__pycache__/sam_model_zoo.cpython-310.pyc +0 -0
  4. efficientvit/apps/__init__.py +0 -0
  5. efficientvit/apps/__pycache__/__init__.cpython-310.pyc +0 -0
  6. efficientvit/apps/data_provider/__init__.py +0 -7
  7. efficientvit/apps/data_provider/__pycache__/__init__.cpython-310.pyc +0 -0
  8. efficientvit/apps/data_provider/__pycache__/base.cpython-310.pyc +0 -0
  9. efficientvit/apps/data_provider/augment/__init__.py +0 -6
  10. efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-310.pyc +0 -0
  11. efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-310.pyc +0 -0
  12. efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-310.pyc +0 -0
  13. efficientvit/apps/data_provider/augment/bbox.py +0 -30
  14. efficientvit/apps/data_provider/augment/color_aug.py +0 -84
  15. efficientvit/apps/data_provider/base.py +0 -223
  16. efficientvit/apps/data_provider/random_resolution/__init__.py +0 -7
  17. efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-310.pyc +0 -0
  18. efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-310.pyc +0 -0
  19. efficientvit/apps/data_provider/random_resolution/_data_loader.py +0 -1603
  20. efficientvit/apps/data_provider/random_resolution/_data_worker.py +0 -378
  21. efficientvit/apps/data_provider/random_resolution/controller.py +0 -94
  22. efficientvit/apps/setup.py +0 -141
  23. efficientvit/apps/trainer/__init__.py +0 -6
  24. efficientvit/apps/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  25. efficientvit/apps/trainer/__pycache__/base.cpython-310.pyc +0 -0
  26. efficientvit/apps/trainer/__pycache__/run_config.cpython-310.pyc +0 -0
  27. efficientvit/apps/trainer/base.py +0 -297
  28. efficientvit/apps/trainer/run_config.py +0 -121
  29. efficientvit/apps/utils/__init__.py +0 -12
  30. efficientvit/apps/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  31. efficientvit/apps/utils/__pycache__/dist.cpython-310.pyc +0 -0
  32. efficientvit/apps/utils/__pycache__/ema.cpython-310.pyc +0 -0
  33. efficientvit/apps/utils/__pycache__/export.cpython-310.pyc +0 -0
  34. efficientvit/apps/utils/__pycache__/init.cpython-310.pyc +0 -0
  35. efficientvit/apps/utils/__pycache__/lr.cpython-310.pyc +0 -0
  36. efficientvit/apps/utils/__pycache__/metric.cpython-310.pyc +0 -0
  37. efficientvit/apps/utils/__pycache__/misc.cpython-310.pyc +0 -0
  38. efficientvit/apps/utils/__pycache__/opt.cpython-310.pyc +0 -0
  39. efficientvit/apps/utils/dist.py +0 -73
  40. efficientvit/apps/utils/ema.py +0 -50
  41. efficientvit/apps/utils/export.py +0 -47
  42. efficientvit/apps/utils/init.py +0 -68
  43. efficientvit/apps/utils/lr.py +0 -48
  44. efficientvit/apps/utils/metric.py +0 -37
  45. efficientvit/apps/utils/misc.py +0 -111
  46. efficientvit/apps/utils/opt.py +0 -31
  47. efficientvit/models/__init__.py +0 -0
  48. efficientvit/models/__pycache__/__init__.cpython-310.pyc +0 -0
  49. efficientvit/models/efficientvit/__init__.py +0 -8
  50. efficientvit/models/efficientvit/__pycache__/__init__.cpython-310.pyc +0 -0
efficientvit/__init__.py DELETED
File without changes
efficientvit/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (168 Bytes)
 
efficientvit/__pycache__/sam_model_zoo.cpython-310.pyc DELETED
Binary file (1.48 kB)
 
efficientvit/apps/__init__.py DELETED
File without changes
efficientvit/apps/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (173 Bytes)
 
efficientvit/apps/data_provider/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- from .augment import *
6
- from .base import *
7
- from .random_resolution import *
 
 
 
 
 
 
 
 
efficientvit/apps/data_provider/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (256 Bytes)
 
efficientvit/apps/data_provider/__pycache__/base.cpython-310.pyc DELETED
Binary file (6.39 kB)
 
efficientvit/apps/data_provider/augment/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- from .bbox import *
6
- from .color_aug import *
 
 
 
 
 
 
 
efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (237 Bytes)
 
efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-310.pyc DELETED
Binary file (800 Bytes)
 
efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-310.pyc DELETED
Binary file (3.15 kB)
 
efficientvit/apps/data_provider/augment/bbox.py DELETED
@@ -1,30 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import numpy as np
6
-
7
- __all__ = ["rand_bbox"]
8
-
9
-
10
- def rand_bbox(
11
- h: int,
12
- w: int,
13
- lam: float,
14
- rand_func: callable = np.random.uniform,
15
- ) -> tuple[int, int, int, int]:
16
- """randomly sample bbox, used in cutmix"""
17
- cut_rat = np.sqrt(1.0 - lam)
18
- cut_w = w * cut_rat
19
- cut_h = h * cut_rat
20
-
21
- # uniform
22
- cx = rand_func(0, w)
23
- cy = rand_func(0, h)
24
-
25
- bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
26
- bby1 = int(np.clip(cy - cut_h / 2, 0, h))
27
- bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
28
- bby2 = int(np.clip(cy + cut_h / 2, 0, h))
29
-
30
- return bbx1, bby1, bbx2, bby2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/data_provider/augment/color_aug.py DELETED
@@ -1,84 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import numpy as np
6
- import torchvision.transforms as transforms
7
- from PIL import Image
8
- from timm.data.auto_augment import rand_augment_transform
9
-
10
- __all__ = ["ColorAug", "RandAug"]
11
-
12
-
13
- class ImageAug:
14
- def aug_image(self, image: Image.Image) -> Image.Image:
15
- raise NotImplementedError
16
-
17
- def __call__(
18
- self, feed_dict: dict or np.ndarray or Image.Image
19
- ) -> dict or np.ndarray or Image.Image:
20
- if isinstance(feed_dict, dict):
21
- output_dict = feed_dict
22
- image = feed_dict[self.key]
23
- else:
24
- output_dict = None
25
- image = feed_dict
26
- is_ndarray = isinstance(image, np.ndarray)
27
- if is_ndarray:
28
- image = Image.fromarray(image)
29
-
30
- image = self.aug_image(image)
31
-
32
- if is_ndarray:
33
- image = np.array(image)
34
-
35
- if output_dict is None:
36
- return image
37
- else:
38
- output_dict[self.key] = image
39
- return output_dict
40
-
41
-
42
- class ColorAug(transforms.ColorJitter, ImageAug):
43
- def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
44
- super().__init__(
45
- brightness=brightness,
46
- contrast=contrast,
47
- saturation=saturation,
48
- hue=hue,
49
- )
50
- self.key = key
51
-
52
- def aug_image(self, image: Image.Image) -> Image.Image:
53
- return transforms.ColorJitter.forward(self, image)
54
-
55
- def forward(
56
- self, feed_dict: dict or np.ndarray or Image.Image
57
- ) -> dict or np.ndarray or Image.Image:
58
- return ImageAug.__call__(self, feed_dict)
59
-
60
-
61
- class RandAug(ImageAug):
62
- def __init__(
63
- self, config: dict[str, any], mean: tuple[float, float, float], key="data"
64
- ):
65
- n = config.get("n", 2)
66
- m = config.get("m", 9)
67
- mstd = config.get("mstd", 1.0)
68
- inc = config.get("inc", 1)
69
- tpct = config.get("tpct", 0.45)
70
- config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
71
-
72
- aa_params = dict(
73
- translate_pct=tpct,
74
- img_mean=tuple([min(255, round(255 * x)) for x in mean]),
75
- interpolation=Image.BICUBIC,
76
- )
77
- self.aug_op = rand_augment_transform(config_str, aa_params)
78
- self.key = key
79
-
80
- def aug_image(self, image: Image.Image) -> Image.Image:
81
- return self.aug_op(image)
82
-
83
- def __repr__(self):
84
- return self.aug_op.__repr__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/data_provider/base.py DELETED
@@ -1,223 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import copy
6
- import warnings
7
-
8
- import torch.utils.data
9
- from torch.utils.data.distributed import DistributedSampler
10
-
11
- from efficientvit.apps.data_provider.random_resolution import RRSController
12
- from efficientvit.models.utils import val2tuple
13
-
14
- __all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
15
-
16
-
17
- def parse_image_size(size: int or str) -> tuple[int, int]:
18
- if isinstance(size, str):
19
- size = [int(val) for val in size.split("-")]
20
- return size[0], size[1]
21
- else:
22
- return val2tuple(size, 2)
23
-
24
-
25
- def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
26
- g = torch.Generator()
27
- g.manual_seed(seed) # set random seed before sampling validation set
28
- rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
29
-
30
- dropped_indexes = rand_indexes[:drop_size]
31
- remaining_indexes = rand_indexes[drop_size:]
32
-
33
- dropped_dataset = copy.deepcopy(dataset)
34
- for key in keys:
35
- setattr(
36
- dropped_dataset,
37
- key,
38
- [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
39
- )
40
- setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
41
- return dataset, dropped_dataset
42
-
43
-
44
- class DataProvider:
45
- data_keys = ("samples",)
46
- mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
47
- SUB_SEED = 937162211 # random seed for sampling subset
48
- VALID_SEED = 2147483647 # random seed for the validation set
49
-
50
- name: str
51
-
52
- def __init__(
53
- self,
54
- train_batch_size: int,
55
- test_batch_size: int or None,
56
- valid_size: int or float or None,
57
- n_worker: int,
58
- image_size: int or list[int] or str or list[str],
59
- num_replicas: int or None = None,
60
- rank: int or None = None,
61
- train_ratio: float or None = None,
62
- drop_last: bool = False,
63
- ):
64
- warnings.filterwarnings("ignore")
65
- super().__init__()
66
-
67
- # batch_size & valid_size
68
- self.train_batch_size = train_batch_size
69
- self.test_batch_size = test_batch_size or self.train_batch_size
70
- self.valid_size = valid_size
71
-
72
- # image size
73
- if isinstance(image_size, list):
74
- self.image_size = [parse_image_size(size) for size in image_size]
75
- self.image_size.sort() # e.g., 160 -> 224
76
- RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
77
- self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
78
- else:
79
- self.image_size = parse_image_size(image_size)
80
- RRSController.IMAGE_SIZE_LIST = [self.image_size]
81
- self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
82
-
83
- # distributed configs
84
- self.num_replicas = num_replicas
85
- self.rank = rank
86
-
87
- # build datasets
88
- train_dataset, val_dataset, test_dataset = self.build_datasets()
89
-
90
- if train_ratio is not None and train_ratio < 1.0:
91
- assert 0 < train_ratio < 1
92
- _, train_dataset = random_drop_data(
93
- train_dataset,
94
- int(train_ratio * len(train_dataset)),
95
- self.SUB_SEED,
96
- self.data_keys,
97
- )
98
-
99
- # build data loader
100
- self.train = self.build_dataloader(
101
- train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
102
- )
103
- self.valid = self.build_dataloader(
104
- val_dataset, test_batch_size, n_worker, drop_last=False, train=False
105
- )
106
- self.test = self.build_dataloader(
107
- test_dataset, test_batch_size, n_worker, drop_last=False, train=False
108
- )
109
- if self.valid is None:
110
- self.valid = self.test
111
- self.sub_train = None
112
-
113
- @property
114
- def data_shape(self) -> tuple[int, ...]:
115
- return 3, self.active_image_size[0], self.active_image_size[1]
116
-
117
- def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
118
- raise NotImplementedError
119
-
120
- def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
121
- raise NotImplementedError
122
-
123
- def build_datasets(self) -> tuple[any, any, any]:
124
- raise NotImplementedError
125
-
126
- def build_dataloader(
127
- self,
128
- dataset: any or None,
129
- batch_size: int,
130
- n_worker: int,
131
- drop_last: bool,
132
- train: bool,
133
- ):
134
- if dataset is None:
135
- return None
136
- if isinstance(self.image_size, list) and train:
137
- from efficientvit.apps.data_provider.random_resolution._data_loader import \
138
- RRSDataLoader
139
-
140
- dataloader_class = RRSDataLoader
141
- else:
142
- dataloader_class = torch.utils.data.DataLoader
143
- if self.num_replicas is None:
144
- return dataloader_class(
145
- dataset=dataset,
146
- batch_size=batch_size,
147
- shuffle=True,
148
- num_workers=n_worker,
149
- pin_memory=True,
150
- drop_last=drop_last,
151
- )
152
- else:
153
- sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
154
- return dataloader_class(
155
- dataset=dataset,
156
- batch_size=batch_size,
157
- sampler=sampler,
158
- num_workers=n_worker,
159
- pin_memory=True,
160
- drop_last=drop_last,
161
- )
162
-
163
- def set_epoch(self, epoch: int) -> None:
164
- RRSController.set_epoch(epoch, len(self.train))
165
- if isinstance(self.train.sampler, DistributedSampler):
166
- self.train.sampler.set_epoch(epoch)
167
-
168
- def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
169
- self.active_image_size = val2tuple(new_size, 2)
170
- new_transform = self.build_valid_transform(self.active_image_size)
171
- # change the transform of the valid and test set
172
- self.valid.dataset.transform = self.test.dataset.transform = new_transform
173
-
174
- def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
175
- if self.valid_size is not None:
176
- if 0 < self.valid_size < 1:
177
- valid_size = int(self.valid_size * len(train_dataset))
178
- else:
179
- assert self.valid_size >= 1
180
- valid_size = int(self.valid_size)
181
- train_dataset, val_dataset = random_drop_data(
182
- train_dataset,
183
- valid_size,
184
- self.VALID_SEED,
185
- self.data_keys,
186
- )
187
- val_dataset.transform = valid_transform
188
- else:
189
- val_dataset = None
190
- return train_dataset, val_dataset
191
-
192
- def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
193
- # used for resetting BN running statistics
194
- if self.sub_train is None:
195
- self.sub_train = {}
196
- if self.active_image_size in self.sub_train:
197
- return self.sub_train[self.active_image_size]
198
-
199
- # construct dataset and dataloader
200
- train_dataset = copy.deepcopy(self.train.dataset)
201
- if n_samples < len(train_dataset):
202
- _, train_dataset = random_drop_data(
203
- train_dataset,
204
- n_samples,
205
- self.SUB_SEED,
206
- self.data_keys,
207
- )
208
- RRSController.ACTIVE_SIZE = self.active_image_size
209
- train_dataset.transform = self.build_train_transform(
210
- image_size=self.active_image_size
211
- )
212
- data_loader = self.build_dataloader(
213
- train_dataset, batch_size, self.train.num_workers, True, False
214
- )
215
-
216
- # pre-fetch data
217
- self.sub_train[self.active_image_size] = [
218
- data
219
- for data in data_loader
220
- for _ in range(max(1, n_samples // len(train_dataset)))
221
- ]
222
-
223
- return self.sub_train[self.active_image_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/data_provider/random_resolution/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- """Random resolution data loader compatible with multi-processing and distributed training.
2
-
3
- Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
4
- at the training time, resolution sampling is controlled by RRSController
5
- """
6
-
7
- from .controller import *
 
 
 
 
 
 
 
 
efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (489 Bytes)
 
efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-310.pyc DELETED
Binary file (3.31 kB)
 
efficientvit/apps/data_provider/random_resolution/_data_loader.py DELETED
@@ -1,1603 +0,0 @@
1
- r"""This file is based on torch/utils/data/data_loader.py
2
-
3
- Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
4
-
5
- To support these two classes, in `./_utils` we define many utility methods and
6
- functions to be run in multiprocessing. E.g., the data loading worker loop is
7
- in `./_utils/worker.py`.
8
- """
9
-
10
- import functools
11
- import itertools
12
- import logging
13
- import multiprocessing as python_multiprocessing
14
- import os
15
- import queue
16
- import threading
17
- import warnings
18
- from typing import (Any, Callable, Generic, Iterable, List, Optional, Sequence,
19
- TypeVar, Union)
20
-
21
- import torch
22
- import torch.distributed as dist
23
- import torch.multiprocessing as multiprocessing
24
- import torch.utils.data.graph_settings
25
- from torch._utils import ExceptionWrapper
26
- from torch.utils.data import (BatchSampler, Dataset, IterableDataset,
27
- IterDataPipe, MapDataPipe, RandomSampler,
28
- Sampler, SequentialSampler, _utils)
29
- from torch.utils.data.datapipes.datapipe import (
30
- _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper)
31
-
32
- from ._data_worker import _worker_loop
33
-
34
- __all__ = ["RRSDataLoader"]
35
-
36
- T_co = TypeVar("T_co", covariant=True)
37
- T = TypeVar("T")
38
- _worker_init_fn_t = Callable[[int], None]
39
-
40
- # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
41
- # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
42
- # See https://github.com/python/mypy/issues/3737.
43
- _collate_fn_t = Callable[[List[T]], Any]
44
-
45
-
46
- # These functions used to be defined in this file. However, it was moved to
47
- # _utils/collate.py. Although it is rather hard to access this from user land
48
- # (one has to explicitly directly `import torch.utils.data.dataloader`), there
49
- # probably is user code out there using it. This aliasing maintains BC in this
50
- # aspect.
51
- default_collate: _collate_fn_t = _utils.collate.default_collate
52
- default_convert = _utils.collate.default_convert
53
-
54
- get_worker_info = _utils.worker.get_worker_info
55
-
56
- logger = logging.getLogger(__name__)
57
-
58
-
59
- class _DatasetKind:
60
- Map = 0
61
- Iterable = 1
62
-
63
- @staticmethod
64
- def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
65
- if kind == _DatasetKind.Map:
66
- return _utils.fetch._MapDatasetFetcher(
67
- dataset, auto_collation, collate_fn, drop_last
68
- )
69
- else:
70
- return _utils.fetch._IterableDatasetFetcher(
71
- dataset, auto_collation, collate_fn, drop_last
72
- )
73
-
74
-
75
- class _InfiniteConstantSampler(Sampler):
76
- r"""Analogous to ``itertools.repeat(None, None)``.
77
- Used as sampler for :class:`~torch.utils.data.IterableDataset`.
78
-
79
- Args:
80
- data_source (Dataset): dataset to sample from
81
- """
82
-
83
- def __init__(self):
84
- super().__init__(None)
85
-
86
- def __iter__(self):
87
- while True:
88
- yield None
89
-
90
-
91
- def _get_distributed_settings():
92
- if dist.is_available() and dist.is_initialized():
93
- return dist.get_world_size(), dist.get_rank()
94
- else:
95
- return 1, 0
96
-
97
-
98
- def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
99
- global_worker_id = worker_id
100
- info = torch.utils.data.get_worker_info()
101
- assert info is not None
102
- total_workers = info.num_workers
103
- datapipe = info.dataset
104
- assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
105
- # To distribute elements across distributed process evenly, we should shard data on distributed
106
- # processes first then shard on worker processes
107
- total_workers *= world_size
108
- global_worker_id = global_worker_id * world_size + rank_id
109
- # For BC, use default SHARDING_PRIORITIES
110
- torch.utils.data.graph_settings.apply_sharding(
111
- datapipe, total_workers, global_worker_id
112
- )
113
- if worker_init_fn is not None:
114
- worker_init_fn(worker_id)
115
-
116
-
117
- def _share_dist_seed(generator, pg):
118
- _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
119
- if isinstance(pg, dist.ProcessGroup):
120
- dist.broadcast(_shared_seed, src=0, group=pg)
121
- return _shared_seed.item()
122
-
123
-
124
- class RRSDataLoader(Generic[T_co]):
125
- r"""
126
- Data loader. Combines a dataset and a sampler, and provides an iterable over
127
- the given dataset.
128
-
129
- The :class:`~torch.utils.data.DataLoader` supports both map-style and
130
- iterable-style datasets with single- or multi-process loading, customizing
131
- loading order and optional automatic batching (collation) and memory pinning.
132
-
133
- See :py:mod:`torch.utils.data` documentation page for more details.
134
-
135
- Args:
136
- dataset (Dataset): dataset from which to load the data.
137
- batch_size (int, optional): how many samples per batch to load
138
- (default: ``1``).
139
- shuffle (bool, optional): set to ``True`` to have the data reshuffled
140
- at every epoch (default: ``False``).
141
- sampler (Sampler or Iterable, optional): defines the strategy to draw
142
- samples from the dataset. Can be any ``Iterable`` with ``__len__``
143
- implemented. If specified, :attr:`shuffle` must not be specified.
144
- batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
145
- returns a batch of indices at a time. Mutually exclusive with
146
- :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
147
- and :attr:`drop_last`.
148
- num_workers (int, optional): how many subprocesses to use for data
149
- loading. ``0`` means that the data will be loaded in the main process.
150
- (default: ``0``)
151
- collate_fn (Callable, optional): merges a list of samples to form a
152
- mini-batch of Tensor(s). Used when using batched loading from a
153
- map-style dataset.
154
- pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
155
- into device/CUDA pinned memory before returning them. If your data elements
156
- are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
157
- see the example below.
158
- drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
159
- if the dataset size is not divisible by the batch size. If ``False`` and
160
- the size of dataset is not divisible by the batch size, then the last batch
161
- will be smaller. (default: ``False``)
162
- timeout (numeric, optional): if positive, the timeout value for collecting a batch
163
- from workers. Should always be non-negative. (default: ``0``)
164
- worker_init_fn (Callable, optional): If not ``None``, this will be called on each
165
- worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
166
- input, after seeding and before data loading. (default: ``None``)
167
- generator (torch.Generator, optional): If not ``None``, this RNG will be used
168
- by RandomSampler to generate random indexes and multiprocessing to generate
169
- `base_seed` for workers. (default: ``None``)
170
- prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
171
- in advance by each worker. ``2`` means there will be a total of
172
- 2 * num_workers batches prefetched across all workers. (default value depends
173
- on the set value for num_workers. If value of num_workers=0 default is ``None``.
174
- Otherwise if value of num_workers>0 default is ``2``).
175
- persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
176
- the worker processes after a dataset has been consumed once. This allows to
177
- maintain the workers `Dataset` instances alive. (default: ``False``)
178
- pin_memory_device (str, optional): the data loader will copy Tensors
179
- into device pinned memory before returning them if pin_memory is set to true.
180
-
181
-
182
- .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
183
- cannot be an unpicklable object, e.g., a lambda function. See
184
- :ref:`multiprocessing-best-practices` on more details related
185
- to multiprocessing in PyTorch.
186
-
187
- .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
188
- When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
189
- it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
190
- rounding depending on :attr:`drop_last`, regardless of multi-process loading
191
- configurations. This represents the best guess PyTorch can make because PyTorch
192
- trusts user :attr:`dataset` code in correctly handling multi-process
193
- loading to avoid duplicate data.
194
-
195
- However, if sharding results in multiple workers having incomplete last batches,
196
- this estimate can still be inaccurate, because (1) an otherwise complete batch can
197
- be broken into multiple ones and (2) more than one batch worth of samples can be
198
- dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
199
- cases in general.
200
-
201
- See `Dataset Types`_ for more details on these two types of datasets and how
202
- :class:`~torch.utils.data.IterableDataset` interacts with
203
- `Multi-process data loading`_.
204
-
205
- .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
206
- :ref:`data-loading-randomness` notes for random seed related questions.
207
- """
208
-
209
- dataset: Dataset[T_co]
210
- batch_size: Optional[int]
211
- num_workers: int
212
- pin_memory: bool
213
- drop_last: bool
214
- timeout: float
215
- sampler: Union[Sampler, Iterable]
216
- pin_memory_device: str
217
- prefetch_factor: Optional[int]
218
- _iterator: Optional["_BaseDataLoaderIter"]
219
- __initialized = False
220
-
221
- def __init__(
222
- self,
223
- dataset: Dataset[T_co],
224
- batch_size: Optional[int] = 1,
225
- shuffle: Optional[bool] = None,
226
- sampler: Union[Sampler, Iterable, None] = None,
227
- batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
228
- num_workers: int = 0,
229
- collate_fn: Optional[_collate_fn_t] = None,
230
- pin_memory: bool = False,
231
- drop_last: bool = False,
232
- timeout: float = 0,
233
- worker_init_fn: Optional[_worker_init_fn_t] = None,
234
- multiprocessing_context=None,
235
- generator=None,
236
- *,
237
- prefetch_factor: Optional[int] = None,
238
- persistent_workers: bool = False,
239
- pin_memory_device: str = ""
240
- ):
241
- torch._C._log_api_usage_once("python.data_loader")
242
-
243
- if num_workers < 0:
244
- raise ValueError(
245
- "num_workers option should be non-negative; "
246
- "use num_workers=0 to disable multiprocessing."
247
- )
248
-
249
- if timeout < 0:
250
- raise ValueError("timeout option should be non-negative")
251
-
252
- if num_workers == 0 and prefetch_factor is not None:
253
- raise ValueError(
254
- "prefetch_factor option could only be specified in multiprocessing."
255
- "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
256
- )
257
- elif num_workers > 0 and prefetch_factor is None:
258
- prefetch_factor = 2
259
- elif prefetch_factor is not None and prefetch_factor < 0:
260
- raise ValueError("prefetch_factor option should be non-negative")
261
-
262
- if persistent_workers and num_workers == 0:
263
- raise ValueError("persistent_workers option needs num_workers > 0")
264
-
265
- self.dataset = dataset
266
- self.num_workers = num_workers
267
- self.prefetch_factor = prefetch_factor
268
- self.pin_memory = pin_memory
269
- self.pin_memory_device = pin_memory_device
270
- self.timeout = timeout
271
- self.worker_init_fn = worker_init_fn
272
- self.multiprocessing_context = multiprocessing_context
273
-
274
- # Adds forward compatibilities so classic DataLoader can work with DataPipes:
275
- # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
276
- if isinstance(self.dataset, IterDataPipe):
277
- self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
278
- elif isinstance(self.dataset, MapDataPipe):
279
- self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
280
-
281
- # Arg-check dataset related before checking samplers because we want to
282
- # tell users that iterable-style datasets are incompatible with custom
283
- # samplers first, so that they don't learn that this combo doesn't work
284
- # after spending time fixing the custom sampler errors.
285
- if isinstance(dataset, IterableDataset):
286
- self._dataset_kind = _DatasetKind.Iterable
287
- # NOTE [ Custom Samplers and IterableDataset ]
288
- #
289
- # `IterableDataset` does not support custom `batch_sampler` or
290
- # `sampler` since the key is irrelevant (unless we support
291
- # generator-style dataset one day...).
292
- #
293
- # For `sampler`, we always create a dummy sampler. This is an
294
- # infinite sampler even when the dataset may have an implemented
295
- # finite `__len__` because in multi-process data loading, naive
296
- # settings will return duplicated data (which may be desired), and
297
- # thus using a sampler with length matching that of dataset will
298
- # cause data lost (you may have duplicates of the first couple
299
- # batches, but never see anything afterwards). Therefore,
300
- # `Iterabledataset` always uses an infinite sampler, an instance of
301
- # `_InfiniteConstantSampler` defined above.
302
- #
303
- # A custom `batch_sampler` essentially only controls the batch size.
304
- # However, it is unclear how useful it would be since an iterable-style
305
- # dataset can handle that within itself. Moreover, it is pointless
306
- # in multi-process data loading as the assignment order of batches
307
- # to workers is an implementation detail so users can not control
308
- # how to batchify each worker's iterable. Thus, we disable this
309
- # option. If this turns out to be useful in future, we can re-enable
310
- # this, and support custom samplers that specify the assignments to
311
- # specific workers.
312
- if isinstance(dataset, IterDataPipe):
313
- if shuffle is not None:
314
- dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
315
- dataset, shuffle=shuffle
316
- )
317
- # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
318
- elif shuffle not in {False, None}:
319
- raise ValueError(
320
- "DataLoader with IterableDataset: expected unspecified "
321
- "shuffle option, but got shuffle={}".format(shuffle)
322
- )
323
-
324
- if sampler is not None:
325
- # See NOTE [ Custom Samplers and IterableDataset ]
326
- raise ValueError(
327
- "DataLoader with IterableDataset: expected unspecified "
328
- "sampler option, but got sampler={}".format(sampler)
329
- )
330
- elif batch_sampler is not None:
331
- # See NOTE [ Custom Samplers and IterableDataset ]
332
- raise ValueError(
333
- "DataLoader with IterableDataset: expected unspecified "
334
- "batch_sampler option, but got batch_sampler={}".format(
335
- batch_sampler
336
- )
337
- )
338
- else:
339
- shuffle = bool(shuffle)
340
- self._dataset_kind = _DatasetKind.Map
341
-
342
- if sampler is not None and shuffle:
343
- raise ValueError("sampler option is mutually exclusive with " "shuffle")
344
-
345
- if batch_sampler is not None:
346
- # auto_collation with custom batch_sampler
347
- if batch_size != 1 or shuffle or sampler is not None or drop_last:
348
- raise ValueError(
349
- "batch_sampler option is mutually exclusive "
350
- "with batch_size, shuffle, sampler, and "
351
- "drop_last"
352
- )
353
- batch_size = None
354
- drop_last = False
355
- elif batch_size is None:
356
- # no auto_collation
357
- if drop_last:
358
- raise ValueError(
359
- "batch_size=None option disables auto-batching "
360
- "and is mutually exclusive with drop_last"
361
- )
362
-
363
- if sampler is None: # give default samplers
364
- if self._dataset_kind == _DatasetKind.Iterable:
365
- # See NOTE [ Custom Samplers and IterableDataset ]
366
- sampler = _InfiniteConstantSampler()
367
- else: # map-style
368
- if shuffle:
369
- sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
370
- else:
371
- sampler = SequentialSampler(dataset) # type: ignore[arg-type]
372
-
373
- if batch_size is not None and batch_sampler is None:
374
- # auto_collation without custom batch_sampler
375
- batch_sampler = BatchSampler(sampler, batch_size, drop_last)
376
-
377
- self.batch_size = batch_size
378
- self.drop_last = drop_last
379
- self.sampler = sampler
380
- self.batch_sampler = batch_sampler
381
- self.generator = generator
382
-
383
- if collate_fn is None:
384
- if self._auto_collation:
385
- collate_fn = _utils.collate.default_collate
386
- else:
387
- collate_fn = _utils.collate.default_convert
388
-
389
- self.collate_fn = collate_fn
390
- self.persistent_workers = persistent_workers
391
-
392
- self.__initialized = True
393
- self._IterableDataset_len_called = (
394
- None # See NOTE [ IterableDataset and __len__ ]
395
- )
396
-
397
- self._iterator = None
398
-
399
- self.check_worker_number_rationality()
400
-
401
- torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
402
-
403
- def _get_iterator(self) -> "_BaseDataLoaderIter":
404
- if self.num_workers == 0:
405
- return _SingleProcessDataLoaderIter(self)
406
- else:
407
- self.check_worker_number_rationality()
408
- return _MultiProcessingDataLoaderIter(self)
409
-
410
- @property
411
- def multiprocessing_context(self):
412
- return self.__multiprocessing_context
413
-
414
- @multiprocessing_context.setter
415
- def multiprocessing_context(self, multiprocessing_context):
416
- if multiprocessing_context is not None:
417
- if self.num_workers > 0:
418
- if isinstance(multiprocessing_context, str):
419
- valid_start_methods = multiprocessing.get_all_start_methods()
420
- if multiprocessing_context not in valid_start_methods:
421
- raise ValueError(
422
- (
423
- "multiprocessing_context option "
424
- "should specify a valid start method in {!r}, but got "
425
- "multiprocessing_context={!r}"
426
- ).format(valid_start_methods, multiprocessing_context)
427
- )
428
- multiprocessing_context = multiprocessing.get_context(
429
- multiprocessing_context
430
- )
431
-
432
- if not isinstance(
433
- multiprocessing_context, python_multiprocessing.context.BaseContext
434
- ):
435
- raise TypeError(
436
- (
437
- "multiprocessing_context option should be a valid context "
438
- "object or a string specifying the start method, but got "
439
- "multiprocessing_context={}"
440
- ).format(multiprocessing_context)
441
- )
442
- else:
443
- raise ValueError(
444
- (
445
- "multiprocessing_context can only be used with "
446
- "multi-process loading (num_workers > 0), but got "
447
- "num_workers={}"
448
- ).format(self.num_workers)
449
- )
450
-
451
- self.__multiprocessing_context = multiprocessing_context
452
-
453
- def __setattr__(self, attr, val):
454
- if self.__initialized and attr in (
455
- "batch_size",
456
- "batch_sampler",
457
- "sampler",
458
- "drop_last",
459
- "dataset",
460
- "persistent_workers",
461
- ):
462
- raise ValueError(
463
- "{} attribute should not be set after {} is "
464
- "initialized".format(attr, self.__class__.__name__)
465
- )
466
-
467
- super().__setattr__(attr, val)
468
-
469
- # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
470
- # since '_BaseDataLoaderIter' references 'DataLoader'.
471
- def __iter__(self) -> "_BaseDataLoaderIter":
472
- # When using a single worker the returned iterator should be
473
- # created everytime to avoid reseting its state
474
- # However, in the case of a multiple workers iterator
475
- # the iterator is only created once in the lifetime of the
476
- # DataLoader object so that workers can be reused
477
- if self.persistent_workers and self.num_workers > 0:
478
- if self._iterator is None:
479
- self._iterator = self._get_iterator()
480
- else:
481
- self._iterator._reset(self)
482
- return self._iterator
483
- else:
484
- return self._get_iterator()
485
-
486
- @property
487
- def _auto_collation(self):
488
- return self.batch_sampler is not None
489
-
490
- @property
491
- def _index_sampler(self):
492
- # The actual sampler used for generating indices for `_DatasetFetcher`
493
- # (see _utils/fetch.py) to read data at each time. This would be
494
- # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
495
- # We can't change `.sampler` and `.batch_sampler` attributes for BC
496
- # reasons.
497
- if self._auto_collation:
498
- return self.batch_sampler
499
- else:
500
- return self.sampler
501
-
502
- def __len__(self) -> int:
503
- if self._dataset_kind == _DatasetKind.Iterable:
504
- # NOTE [ IterableDataset and __len__ ]
505
- #
506
- # For `IterableDataset`, `__len__` could be inaccurate when one naively
507
- # does multi-processing data loading, since the samples will be duplicated.
508
- # However, no real use case should be actually using that behavior, so
509
- # it should count as a user error. We should generally trust user
510
- # code to do the proper thing (e.g., configure each replica differently
511
- # in `__iter__`), and give us the correct `__len__` if they choose to
512
- # implement it (this will still throw if the dataset does not implement
513
- # a `__len__`).
514
- #
515
- # To provide a further warning, we track if `__len__` was called on the
516
- # `DataLoader`, save the returned value in `self._len_called`, and warn
517
- # if the iterator ends up yielding more than this number of samples.
518
-
519
- # Cannot statically verify that dataset is Sized
520
- length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
521
- if (
522
- self.batch_size is not None
523
- ): # IterableDataset doesn't allow custom sampler or batch_sampler
524
- from math import ceil
525
-
526
- if self.drop_last:
527
- length = length // self.batch_size
528
- else:
529
- length = ceil(length / self.batch_size)
530
- return length
531
- else:
532
- return len(self._index_sampler)
533
-
534
- def check_worker_number_rationality(self):
535
- # This function check whether the dataloader's worker number is rational based on
536
- # current system's resource. Current rule is that if the number of workers this
537
- # Dataloader will create is bigger than the number of logical cpus that is allowed to
538
- # use, than we will pop up a warning to let user pay attention.
539
- #
540
- # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
541
- # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
542
- # DataLoader process can use half of them which is 32, then the rational max number of
543
- # worker that initiated from this process is 32.
544
- # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
545
- # So the warning message is triggered to notify the user to lower the worker number if
546
- # necessary.
547
- #
548
- #
549
- # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
550
- # available (available in most of Linux system, but not OSX and Windows).
551
- # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
552
- # it doesn't repect cpuset.
553
- # We don't take threading into account since each worker process is single threaded
554
- # at this time.
555
- #
556
- # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
557
- # other than `torch.set_num_threads` to 1 in the worker process, if the passing
558
- # in functions use 3rd party modules that rely on those threading flags to determine
559
- # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
560
- # set those flags correctly.
561
- def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
562
-
563
- suggested_max_worker_msg = (
564
- (
565
- (
566
- "Our suggested max number of worker in current system is {}{}, which is smaller "
567
- "than what this DataLoader is going to create."
568
- ).format(
569
- num_worker_suggest,
570
- (
571
- ""
572
- if cpuset_checked
573
- else " (`cpuset` is not taken into account)"
574
- ),
575
- )
576
- )
577
- if num_worker_suggest is not None
578
- else (
579
- "DataLoader is not able to compute a suggested max number of worker in current system."
580
- )
581
- )
582
-
583
- warn_msg = (
584
- "This DataLoader will create {} worker processes in total. {} "
585
- "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
586
- "lower the worker number to avoid potential slowness/freeze if necessary."
587
- ).format(num_worker_created, suggested_max_worker_msg)
588
- return warn_msg
589
-
590
- if not self.num_workers or self.num_workers == 0:
591
- return
592
-
593
- # try to compute a suggested max number of worker based on system's resource
594
- max_num_worker_suggest = None
595
- cpuset_checked = False
596
- if hasattr(os, "sched_getaffinity"):
597
- try:
598
- max_num_worker_suggest = len(os.sched_getaffinity(0))
599
- cpuset_checked = True
600
- except Exception:
601
- pass
602
- if max_num_worker_suggest is None:
603
- # os.cpu_count() could return Optional[int]
604
- # get cpu count first and check None in order to satify mypy check
605
- cpu_count = os.cpu_count()
606
- if cpu_count is not None:
607
- max_num_worker_suggest = cpu_count
608
-
609
- if max_num_worker_suggest is None:
610
- warnings.warn(
611
- _create_warning_msg(
612
- max_num_worker_suggest, self.num_workers, cpuset_checked
613
- )
614
- )
615
- return
616
-
617
- if self.num_workers > max_num_worker_suggest:
618
- warnings.warn(
619
- _create_warning_msg(
620
- max_num_worker_suggest, self.num_workers, cpuset_checked
621
- )
622
- )
623
-
624
-
625
- class _BaseDataLoaderIter:
626
- def __init__(self, loader: RRSDataLoader) -> None:
627
- self._dataset = loader.dataset
628
- self._shared_seed = None
629
- self._pg = None
630
- if isinstance(self._dataset, IterDataPipe):
631
- if dist.is_available() and dist.is_initialized():
632
- self._pg = dist.new_group(backend="gloo")
633
- self._shared_seed = _share_dist_seed(loader.generator, self._pg)
634
- shared_rng = torch.Generator()
635
- shared_rng.manual_seed(self._shared_seed)
636
- self._dataset = torch.utils.data.graph_settings.apply_random_seed(
637
- self._dataset, shared_rng
638
- )
639
- self._dataset_kind = loader._dataset_kind
640
- self._IterableDataset_len_called = loader._IterableDataset_len_called
641
- self._auto_collation = loader._auto_collation
642
- self._drop_last = loader.drop_last
643
- self._index_sampler = loader._index_sampler
644
- self._num_workers = loader.num_workers
645
- ws, rank = _get_distributed_settings()
646
- self._world_size = ws
647
- self._rank = rank
648
- # for other backends, pin_memory_device need to set. if not set
649
- # default behaviour is CUDA device. if pin_memory_device is selected
650
- # and pin_memory is not set, the default behaviour false.
651
- if len(loader.pin_memory_device) == 0:
652
- self._pin_memory = loader.pin_memory and torch.cuda.is_available()
653
- self._pin_memory_device = None
654
- else:
655
- if not loader.pin_memory:
656
- warn_msg = (
657
- "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
658
- "please set pin_memory to true, if you need to use the device pin memory"
659
- )
660
- warnings.warn(warn_msg)
661
-
662
- self._pin_memory = loader.pin_memory
663
- self._pin_memory_device = loader.pin_memory_device
664
- self._timeout = loader.timeout
665
- self._collate_fn = loader.collate_fn
666
- self._sampler_iter = iter(self._index_sampler)
667
- self._base_seed = (
668
- torch.empty((), dtype=torch.int64)
669
- .random_(generator=loader.generator)
670
- .item()
671
- )
672
- self._persistent_workers = loader.persistent_workers
673
- self._num_yielded = 0
674
- self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
675
- self.__class__.__name__
676
- )
677
-
678
- def __iter__(self) -> "_BaseDataLoaderIter":
679
- return self
680
-
681
- def _reset(self, loader, first_iter=False):
682
- self._sampler_iter = iter(self._index_sampler)
683
- self._num_yielded = 0
684
- self._IterableDataset_len_called = loader._IterableDataset_len_called
685
- if isinstance(self._dataset, IterDataPipe):
686
- self._shared_seed = _share_dist_seed(loader.generator, self._pg)
687
- shared_rng = torch.Generator()
688
- shared_rng.manual_seed(self._shared_seed)
689
- self._dataset = torch.utils.data.graph_settings.apply_random_seed(
690
- self._dataset, shared_rng
691
- )
692
-
693
- def _next_index(self):
694
- return next(self._sampler_iter) # may raise StopIteration
695
-
696
- def _next_data(self):
697
- raise NotImplementedError
698
-
699
- def __next__(self) -> Any:
700
- with torch.autograd.profiler.record_function(self._profile_name):
701
- if self._sampler_iter is None:
702
- # TODO(https://github.com/pytorch/pytorch/issues/76750)
703
- self._reset() # type: ignore[call-arg]
704
- data = self._next_data()
705
- self._num_yielded += 1
706
- if (
707
- self._dataset_kind == _DatasetKind.Iterable
708
- and self._IterableDataset_len_called is not None
709
- and self._num_yielded > self._IterableDataset_len_called
710
- ):
711
- warn_msg = (
712
- "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
713
- "samples have been fetched. "
714
- ).format(
715
- self._dataset, self._IterableDataset_len_called, self._num_yielded
716
- )
717
- if self._num_workers > 0:
718
- warn_msg += (
719
- "For multiprocessing data-loading, this could be caused by not properly configuring the "
720
- "IterableDataset replica at each worker. Please see "
721
- "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
722
- )
723
- warnings.warn(warn_msg)
724
- return data
725
-
726
- def __len__(self) -> int:
727
- return len(self._index_sampler)
728
-
729
- def __getstate__(self):
730
- # TODO: add limited pickling support for sharing an iterator
731
- # across multiple threads for HOGWILD.
732
- # Probably the best way to do this is by moving the sample pushing
733
- # to a separate thread and then just sharing the data queue
734
- # but signalling the end is tricky without a non-blocking API
735
- raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
736
-
737
-
738
- class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
739
- def __init__(self, loader):
740
- super().__init__(loader)
741
- assert self._timeout == 0
742
- assert self._num_workers == 0
743
-
744
- # Adds forward compatibilities so classic DataLoader can work with DataPipes:
745
- # Taking care of distributed sharding
746
- if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
747
- # For BC, use default SHARDING_PRIORITIES
748
- torch.utils.data.graph_settings.apply_sharding(
749
- self._dataset, self._world_size, self._rank
750
- )
751
-
752
- self._dataset_fetcher = _DatasetKind.create_fetcher(
753
- self._dataset_kind,
754
- self._dataset,
755
- self._auto_collation,
756
- self._collate_fn,
757
- self._drop_last,
758
- )
759
-
760
- def _next_data(self):
761
- index = self._next_index() # may raise StopIteration
762
- data = self._dataset_fetcher.fetch(index) # may raise StopIteration
763
- if self._pin_memory:
764
- data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
765
- return data
766
-
767
-
768
- class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
769
- r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
770
-
771
- # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
772
- #
773
- # Preliminary:
774
- #
775
- # Our data model looks like this (queues are indicated with curly brackets):
776
- #
777
- # main process ||
778
- # | ||
779
- # {index_queue} ||
780
- # | ||
781
- # worker processes || DATA
782
- # | ||
783
- # {worker_result_queue} || FLOW
784
- # | ||
785
- # pin_memory_thread of main process || DIRECTION
786
- # | ||
787
- # {data_queue} ||
788
- # | ||
789
- # data output \/
790
- #
791
- # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
792
- # `pin_memory=False`.
793
- #
794
- #
795
- # Terminating multiprocessing logic requires very careful design. In
796
- # particular, we need to make sure that
797
- #
798
- # 1. The iterator gracefully exits the workers when its last reference is
799
- # gone or it is depleted.
800
- #
801
- # In this case, the workers should be gracefully exited because the
802
- # main process may still need to continue to run, and we want cleaning
803
- # up code in the workers to be executed (e.g., releasing GPU memory).
804
- # Naturally, we implement the shutdown logic in `__del__` of
805
- # DataLoaderIterator.
806
- #
807
- # We delay the discussion on the logic in this case until later.
808
- #
809
- # 2. The iterator exits the workers when the loader process and/or worker
810
- # processes exits normally or with error.
811
- #
812
- # We set all workers and `pin_memory_thread` to have `daemon=True`.
813
- #
814
- # You may ask, why can't we make the workers non-daemonic, and
815
- # gracefully exit using the same logic as we have in `__del__` when the
816
- # iterator gets deleted (see 1 above)?
817
- #
818
- # First of all, `__del__` is **not** guaranteed to be called when
819
- # interpreter exits. Even if it is called, by the time it executes,
820
- # many Python core library resources may alreay be freed, and even
821
- # simple things like acquiring an internal lock of a queue may hang.
822
- # Therefore, in this case, we actually need to prevent `__del__` from
823
- # being executed, and rely on the automatic termination of daemonic
824
- # children.
825
- #
826
- # Thus, we register an `atexit` hook that sets a global flag
827
- # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
828
- # reverse order of registration, we are guaranteed that this flag is
829
- # set before library resources we use are freed (which, at least in
830
- # CPython, is done via an `atexit` handler defined in
831
- # `multiprocessing/util.py`
832
- # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
833
- # registered when an object requiring this mechanism is first
834
- # created, e.g., `mp.Queue`
835
- # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
836
- # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
837
- # )
838
- #
839
- # So in `__del__`, we check if `_utils.python_exit_status` is set or
840
- # `None` (freed), and perform no-op if so.
841
- #
842
- # However, simply letting library clean-up codes run can also be bad,
843
- # because such codes (i.e., `multiprocessing.util._exit_function()`)
844
- # include join putting threads for `mp.Queue`, which can be blocking.
845
- # Hence, the main process putting threads are called with
846
- # `cancel_join_thread` at creation. See later section
847
- # [ 3b. A process won't hang when putting into a queue; ]
848
- # for more details.
849
- #
850
- # Here are two example cases where library clean-up codes can run
851
- # before `__del__` is called:
852
- #
853
- # 1. If we hold onto a reference to the iterator, it more often
854
- # than not tries to do `multiprocessing` library cleaning before
855
- # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
856
- # and thus prevents our cleaning-up code to run first.
857
- #
858
- # 2. A similar issue araises when a `DataLoader` is used in a subprocess.
859
- # When a process ends, it shuts the all its daemonic children
860
- # down with a SIGTERM (instead of joining them without a timeout).
861
- # Simiarly for threads, but by a different mechanism. This fact,
862
- # together with a few implementation details of multiprocessing, forces
863
- # us to make workers daemonic. All of our problems arise when a
864
- # DataLoader is used in a subprocess, and are caused by multiprocessing
865
- # code which looks more or less like this:
866
- #
867
- # try:
868
- # your_function_using_a_dataloader()
869
- # finally:
870
- # multiprocessing.util._exit_function()
871
- #
872
- # The joining/termination mentioned above happens inside
873
- # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
874
- # throws, the stack trace stored in the exception will prevent the
875
- # frame which uses `DataLoaderIter` to be freed. If the frame has any
876
- # reference to the `DataLoaderIter` (e.g., in a method of the iter),
877
- # its `__del__`, which starts the shutdown procedure, will not be
878
- # called. That, in turn, means that workers aren't notified. Attempting
879
- # to join in `_exit_function` will then result in a hang.
880
- #
881
- # For context, `_exit_function` is also registered as an `atexit` call.
882
- # So it is unclear to me (@ssnl) why this is needed in a finally block.
883
- # The code dates back to 2008 and there is no comment on the original
884
- # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
885
- # the finally block and the `atexit` registration) that explains this.
886
- #
887
- #
888
- # Finally, another choice is to just shutdown workers with logic in 1
889
- # above whenever we see an error in `next`. This isn't ideal because
890
- # a. It prevents users from using try-catch to resume data loading.
891
- # b. It doesn't prevent hanging if users have references to the
892
- # iterator.
893
- #
894
- # 3. All processes exit if any of them die unexpectedly by fatal signals.
895
- #
896
- # As shown above, the workers are set as daemonic children of the main
897
- # process. However, automatic cleaning-up of such child processes only
898
- # happens if the parent process exits gracefully (e.g., not via fatal
899
- # signals like SIGKILL). So we must ensure that each process will exit
900
- # even the process that should send/receive data to/from it were
901
- # killed, i.e.,
902
- #
903
- # a. A process won't hang when getting from a queue.
904
- #
905
- # Even with carefully designed data dependencies (i.e., a `put()`
906
- # always corresponding to a `get()`), hanging on `get()` can still
907
- # happen when data in queue is corrupted (e.g., due to
908
- # `cancel_join_thread` or unexpected exit).
909
- #
910
- # For child exit, we set a timeout whenever we try to get data
911
- # from `data_queue`, and check the workers' status on each timeout
912
- # and error.
913
- # See `_DataLoaderiter._get_batch()` and
914
- # `_DataLoaderiter._try_get_data()` for details.
915
- #
916
- # Additionally, for child exit on non-Windows platforms, we also
917
- # register a SIGCHLD handler (which is supported on Windows) on
918
- # the main process, which checks if any of the workers fail in the
919
- # (Python) handler. This is more efficient and faster in detecting
920
- # worker failures, compared to only using the above mechanism.
921
- # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
922
- #
923
- # For `.get()` calls where the sender(s) is not the workers, we
924
- # guard them with timeouts, and check the status of the sender
925
- # when timeout happens:
926
- # + in the workers, the `_utils.worker.ManagerWatchdog` class
927
- # checks the status of the main process.
928
- # + if `pin_memory=True`, when getting from `pin_memory_thread`,
929
- # check `pin_memory_thread` status periodically until `.get()`
930
- # returns or see that `pin_memory_thread` died.
931
- #
932
- # b. A process won't hang when putting into a queue;
933
- #
934
- # We use `mp.Queue` which has a separate background thread to put
935
- # objects from an unbounded buffer array. The background thread is
936
- # daemonic and usually automatically joined when the process
937
- # *exits*.
938
- #
939
- # In case that the receiver has ended abruptly while
940
- # reading from the pipe, the join will hang forever. The usual
941
- # solution for this in Python is calling `q.cancel_join_thread`,
942
- # which prevents automatically joining it when finalizing
943
- # (exiting).
944
- #
945
- # Nonetheless, `cancel_join_thread` must only be called when the
946
- # queue is **not** going to be read from or write into by another
947
- # process, because it may hold onto a lock or leave corrupted data
948
- # in the queue, leading other readers/writers to hang.
949
- #
950
- # Hence,
951
- # + For worker processes, we only do so (for their output
952
- # queues, i.e., `worker_result_queue`) before exiting.
953
- # + For `pin_memory_thread`, its output queue `data_queue` is a
954
- # `queue.Queue` that does blocking `put` if the queue is full.
955
- # So there is no above problem, but as a result, in
956
- # `_pin_memory_loop`, we do need to wrap the `put` in a loop
957
- # that breaks not only upon success, but also when the main
958
- # process stops reading, i.e., is shutting down.
959
- # + For loader process, we `cancel_join_thread()` for all
960
- # `_index_queues` because the whole purpose of workers and
961
- # `pin_memory_thread` is to serve the loader process. If
962
- # loader process is already exiting, we don't really care if
963
- # the queues are corrupted.
964
- #
965
- #
966
- # Now let's get back to 1:
967
- # how we gracefully exit the workers when the last reference to the
968
- # iterator is gone.
969
- #
970
- # To achieve this, we implement the following logic along with the design
971
- # choices mentioned above:
972
- #
973
- # `workers_done_event`:
974
- # A `multiprocessing.Event` shared among the main process and all worker
975
- # processes. This is used to signal the workers that the iterator is
976
- # shutting down. After it is set, they will not send processed data to
977
- # queues anymore, and only wait for the final `None` before exiting.
978
- # `done_event` isn't strictly needed. I.e., we can just check for `None`
979
- # from the input queue, but it allows us to skip wasting resources
980
- # processing data if we are already shutting down.
981
- #
982
- # `pin_memory_thread_done_event`:
983
- # A `threading.Event` for a similar purpose to that of
984
- # `workers_done_event`, but is for the `pin_memory_thread`. The reason
985
- # that separate events are needed is that `pin_memory_thread` reads from
986
- # the output queue of the workers. But the workers, upon seeing that
987
- # `workers_done_event` is set, only wants to see the final `None`, and is
988
- # not required to flush all data in the output queue (e.g., it may call
989
- # `cancel_join_thread` on that queue if its `IterableDataset` iterator
990
- # happens to exhaust coincidentally, which is out of the control of the
991
- # main process). Thus, since we will exit `pin_memory_thread` before the
992
- # workers (see below), two separete events are used.
993
- #
994
- # NOTE: In short, the protocol is that the main process will set these
995
- # `done_event`s and then the corresponding processes/threads a `None`,
996
- # and that they may exit at any time after receiving the `None`.
997
- #
998
- # NOTE: Using `None` as the final signal is valid, since normal data will
999
- # always be a 2-tuple with the 1st element being the index of the data
1000
- # transferred (different from dataset index/key), and the 2nd being
1001
- # either the dataset key or the data sample (depending on which part
1002
- # of the data model the queue is at).
1003
- #
1004
- # [ worker processes ]
1005
- # While loader process is alive:
1006
- # Get from `index_queue`.
1007
- # If get anything else,
1008
- # Check `workers_done_event`.
1009
- # If set, continue to next iteration
1010
- # i.e., keep getting until see the `None`, then exit.
1011
- # Otherwise, process data:
1012
- # If is fetching from an `IterableDataset` and the iterator
1013
- # is exhausted, send an `_IterableDatasetStopIteration`
1014
- # object to signal iteration end. The main process, upon
1015
- # receiving such an object, will send `None` to this
1016
- # worker and not use the corresponding `index_queue`
1017
- # anymore.
1018
- # If timed out,
1019
- # No matter `workers_done_event` is set (still need to see `None`)
1020
- # or not, must continue to next iteration.
1021
- # (outside loop)
1022
- # If `workers_done_event` is set, (this can be False with `IterableDataset`)
1023
- # `data_queue.cancel_join_thread()`. (Everything is ending here:
1024
- # main process won't read from it;
1025
- # other workers will also call
1026
- # `cancel_join_thread`.)
1027
- #
1028
- # [ pin_memory_thread ]
1029
- # # No need to check main thread. If this thread is alive, the main loader
1030
- # # thread must be alive, because this thread is set as daemonic.
1031
- # While `pin_memory_thread_done_event` is not set:
1032
- # Get from `index_queue`.
1033
- # If timed out, continue to get in the next iteration.
1034
- # Otherwise, process data.
1035
- # While `pin_memory_thread_done_event` is not set:
1036
- # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
1037
- # If timed out, continue to put in the next iteration.
1038
- # Otherwise, break, i.e., continuing to the out loop.
1039
- #
1040
- # NOTE: we don't check the status of the main thread because
1041
- # 1. if the process is killed by fatal signal, `pin_memory_thread`
1042
- # ends.
1043
- # 2. in other cases, either the cleaning-up in __del__ or the
1044
- # automatic exit of daemonic thread will take care of it.
1045
- # This won't busy-wait either because `.get(timeout)` does not
1046
- # busy-wait.
1047
- #
1048
- # [ main process ]
1049
- # In the DataLoader Iter's `__del__`
1050
- # b. Exit `pin_memory_thread`
1051
- # i. Set `pin_memory_thread_done_event`.
1052
- # ii Put `None` in `worker_result_queue`.
1053
- # iii. Join the `pin_memory_thread`.
1054
- # iv. `worker_result_queue.cancel_join_thread()`.
1055
- #
1056
- # c. Exit the workers.
1057
- # i. Set `workers_done_event`.
1058
- # ii. Put `None` in each worker's `index_queue`.
1059
- # iii. Join the workers.
1060
- # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
1061
- #
1062
- # NOTE: (c) is better placed after (b) because it may leave corrupted
1063
- # data in `worker_result_queue`, which `pin_memory_thread`
1064
- # reads from, in which case the `pin_memory_thread` can only
1065
- # happen at timeing out, which is slow. Nonetheless, same thing
1066
- # happens if a worker is killed by signal at unfortunate times,
1067
- # but in other cases, we are better off having a non-corrupted
1068
- # `worker_result_queue` for `pin_memory_thread`.
1069
- #
1070
- # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
1071
- # can be omitted
1072
- #
1073
- # NB: `done_event`s isn't strictly needed. E.g., we can just check for
1074
- # `None` from `index_queue`, but it allows us to skip wasting resources
1075
- # processing indices already in `index_queue` if we are already shutting
1076
- # down.
1077
-
1078
- def __init__(self, loader):
1079
- super().__init__(loader)
1080
-
1081
- self._prefetch_factor = loader.prefetch_factor
1082
-
1083
- assert self._num_workers > 0
1084
- assert self._prefetch_factor > 0
1085
-
1086
- if loader.multiprocessing_context is None:
1087
- multiprocessing_context = multiprocessing
1088
- else:
1089
- multiprocessing_context = loader.multiprocessing_context
1090
-
1091
- self._worker_init_fn = loader.worker_init_fn
1092
-
1093
- # Adds forward compatibilities so classic DataLoader can work with DataPipes:
1094
- # Additional worker init function will take care of sharding in MP and Distributed
1095
- if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
1096
- self._worker_init_fn = functools.partial(
1097
- _sharding_worker_init_fn,
1098
- self._worker_init_fn,
1099
- self._world_size,
1100
- self._rank,
1101
- )
1102
-
1103
- # No certainty which module multiprocessing_context is
1104
- self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1105
- self._worker_pids_set = False
1106
- self._shutdown = False
1107
- self._workers_done_event = multiprocessing_context.Event()
1108
-
1109
- self._index_queues = []
1110
- self._workers = []
1111
- for i in range(self._num_workers):
1112
- # No certainty which module multiprocessing_context is
1113
- index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1114
- # Need to `cancel_join_thread` here!
1115
- # See sections (2) and (3b) above.
1116
- index_queue.cancel_join_thread()
1117
- w = multiprocessing_context.Process(
1118
- target=_worker_loop,
1119
- args=(
1120
- self._dataset_kind,
1121
- self._dataset,
1122
- index_queue,
1123
- self._worker_result_queue,
1124
- self._workers_done_event,
1125
- self._auto_collation,
1126
- self._collate_fn,
1127
- self._drop_last,
1128
- self._base_seed,
1129
- self._worker_init_fn,
1130
- i,
1131
- self._num_workers,
1132
- self._persistent_workers,
1133
- self._shared_seed,
1134
- ),
1135
- )
1136
- w.daemon = True
1137
- # NB: Process.start() actually take some time as it needs to
1138
- # start a process and pass the arguments over via a pipe.
1139
- # Therefore, we only add a worker to self._workers list after
1140
- # it started, so that we do not call .join() if program dies
1141
- # before it starts, and __del__ tries to join but will get:
1142
- # AssertionError: can only join a started process.
1143
- w.start()
1144
- self._index_queues.append(index_queue)
1145
- self._workers.append(w)
1146
-
1147
- if self._pin_memory:
1148
- self._pin_memory_thread_done_event = threading.Event()
1149
-
1150
- # Queue is not type-annotated
1151
- self._data_queue = queue.Queue() # type: ignore[var-annotated]
1152
- if self._pin_memory_device == "xpu":
1153
- current_device = torch.xpu.current_device() # type: ignore[attr-defined]
1154
- else:
1155
- current_device = torch.cuda.current_device() # choose cuda for default
1156
- pin_memory_thread = threading.Thread(
1157
- target=_utils.pin_memory._pin_memory_loop,
1158
- args=(
1159
- self._worker_result_queue,
1160
- self._data_queue,
1161
- current_device,
1162
- self._pin_memory_thread_done_event,
1163
- self._pin_memory_device,
1164
- ),
1165
- )
1166
- pin_memory_thread.daemon = True
1167
- pin_memory_thread.start()
1168
- # Similar to workers (see comment above), we only register
1169
- # pin_memory_thread once it is started.
1170
- self._pin_memory_thread = pin_memory_thread
1171
- else:
1172
- self._data_queue = self._worker_result_queue
1173
-
1174
- # In some rare cases, persistent workers (daemonic processes)
1175
- # would be terminated before `__del__` of iterator is invoked
1176
- # when main process exits
1177
- # It would cause failure when pin_memory_thread tries to read
1178
- # corrupted data from worker_result_queue
1179
- # atexit is used to shutdown thread and child processes in the
1180
- # right sequence before main process exits
1181
- if self._persistent_workers and self._pin_memory:
1182
- import atexit
1183
-
1184
- for w in self._workers:
1185
- atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
1186
-
1187
- # .pid can be None only before process is spawned (not the case, so ignore)
1188
- _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
1189
- _utils.signal_handling._set_SIGCHLD_handler()
1190
- self._worker_pids_set = True
1191
- self._reset(loader, first_iter=True)
1192
-
1193
- def _reset(self, loader, first_iter=False):
1194
- super()._reset(loader, first_iter)
1195
- self._send_idx = 0 # idx of the next task to be sent to workers
1196
- self._rcvd_idx = 0 # idx of the next task to be returned in __next__
1197
- # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
1198
- # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
1199
- # \ (worker_id, data) if data is already fetched (out-of-order)
1200
- self._task_info = {}
1201
- self._tasks_outstanding = (
1202
- 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
1203
- )
1204
- # A list of booleans representing whether each worker still has work to
1205
- # do, i.e., not having exhausted its iterable dataset object. It always
1206
- # contains all `True`s if not using an iterable-style dataset
1207
- # (i.e., if kind != Iterable).
1208
- # Not that this indicates that a worker still has work to do *for this epoch*.
1209
- # It does not mean that a worker is dead. In case of `_persistent_workers`,
1210
- # the worker will be reset to available in the next epoch.
1211
- self._workers_status = [True for i in range(self._num_workers)]
1212
- # Reset the worker queue cycle so it resumes next epoch at worker 0
1213
- self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
1214
- # We resume the prefetching in case it was enabled
1215
- if not first_iter:
1216
- for idx in range(self._num_workers):
1217
- self._index_queues[idx].put(
1218
- _utils.worker._ResumeIteration(self._shared_seed)
1219
- )
1220
- resume_iteration_cnt = self._num_workers
1221
- while resume_iteration_cnt > 0:
1222
- return_idx, return_data = self._get_data()
1223
- if isinstance(return_idx, _utils.worker._ResumeIteration):
1224
- assert return_data is None
1225
- resume_iteration_cnt -= 1
1226
- # prime the prefetch loop
1227
- for _ in range(self._prefetch_factor * self._num_workers):
1228
- self._try_put_index()
1229
-
1230
- def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
1231
- # Tries to fetch data from `self._data_queue` once for a given timeout.
1232
- # This can also be used as inner loop of fetching without timeout, with
1233
- # the sender status as the loop condition.
1234
- #
1235
- # This raises a `RuntimeError` if any worker died expectedly. This error
1236
- # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
1237
- # (only for non-Windows platforms), or the manual check below on errors
1238
- # and timeouts.
1239
- #
1240
- # Returns a 2-tuple:
1241
- # (bool: whether successfully get data, any: data if successful else None)
1242
- try:
1243
- data = self._data_queue.get(timeout=timeout)
1244
- return (True, data)
1245
- except Exception as e:
1246
- # At timeout and error, we manually check whether any worker has
1247
- # failed. Note that this is the only mechanism for Windows to detect
1248
- # worker failures.
1249
- failed_workers = []
1250
- for worker_id, w in enumerate(self._workers):
1251
- if self._workers_status[worker_id] and not w.is_alive():
1252
- failed_workers.append(w)
1253
- self._mark_worker_as_unavailable(worker_id)
1254
- if len(failed_workers) > 0:
1255
- pids_str = ", ".join(str(w.pid) for w in failed_workers)
1256
- raise RuntimeError(
1257
- "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
1258
- ) from e
1259
- if isinstance(e, queue.Empty):
1260
- return (False, None)
1261
- import errno
1262
- import tempfile
1263
-
1264
- try:
1265
- # Raise an exception if we are this close to the FDs limit.
1266
- # Apparently, trying to open only one file is not a sufficient
1267
- # test.
1268
- # See NOTE [ DataLoader on Linux and open files limit ]
1269
- fds_limit_margin = 10
1270
- fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
1271
- except OSError as e:
1272
- if e.errno == errno.EMFILE:
1273
- raise RuntimeError(
1274
- "Too many open files. Communication with the"
1275
- " workers is no longer possible. Please increase the"
1276
- " limit using `ulimit -n` in the shell or change the"
1277
- " sharing strategy by calling"
1278
- " `torch.multiprocessing.set_sharing_strategy('file_system')`"
1279
- " at the beginning of your code"
1280
- ) from None
1281
- raise
1282
-
1283
- # NOTE [ DataLoader on Linux and open files limit ]
1284
- #
1285
- # On Linux when DataLoader is used with multiprocessing we pass the data between
1286
- # the root process and the workers through SHM files. We remove those files from
1287
- # the filesystem as soon as they are created and keep them alive by
1288
- # passing around their file descriptors through AF_UNIX sockets. (See
1289
- # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
1290
- # the wiki (https://github.com/pytorch/pytorch/wiki).)
1291
- #
1292
- # This sometimes leads us to exceeding the open files limit. When that happens,
1293
- # and the offending file descriptor is coming over a socket, the `socket` Python
1294
- # package silently strips the file descriptor from the message, setting only the
1295
- # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
1296
- # it _indicates that some control data were discarded due to lack of space in
1297
- # the buffer for ancillary data_). This might reflect the C implementation of
1298
- # AF_UNIX sockets.
1299
- #
1300
- # This behaviour can be reproduced with the script and instructions at the
1301
- # bottom of this note.
1302
- #
1303
- # When that happens, the standard Python `multiprocessing` (and not
1304
- # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
1305
- #
1306
- # Sometimes, instead of the FD being stripped, you may get an `OSError:
1307
- # Too many open files`, both in the script below and in DataLoader. However,
1308
- # this is rare and seems to be nondeterministic.
1309
- #
1310
- #
1311
- # #!/usr/bin/env python3
1312
- # import sys
1313
- # import socket
1314
- # import os
1315
- # import array
1316
- # import shutil
1317
- # import socket
1318
- #
1319
- #
1320
- # if len(sys.argv) != 4:
1321
- # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
1322
- # sys.exit(1)
1323
- #
1324
- # if __name__ == '__main__':
1325
- # dirname = sys.argv[1]
1326
- # sock_path = dirname + "/sock"
1327
- # iterations = int(sys.argv[2])
1328
- # def dummy_path(i):
1329
- # return dirname + "/" + str(i) + ".dummy"
1330
- #
1331
- #
1332
- # if sys.argv[3] == 'send':
1333
- # while not os.path.exists(sock_path):
1334
- # pass
1335
- # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1336
- # client.connect(sock_path)
1337
- # for i in range(iterations):
1338
- # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
1339
- # ancdata = array.array('i', [fd])
1340
- # msg = bytes([i % 256])
1341
- # print("Sending fd ", fd, " (iteration #", i, ")")
1342
- # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
1343
- #
1344
- #
1345
- # else:
1346
- # assert sys.argv[3] == 'recv'
1347
- #
1348
- # if os.path.exists(dirname):
1349
- # raise Exception("Directory exists")
1350
- #
1351
- # os.mkdir(dirname)
1352
- #
1353
- # print("Opening socket...")
1354
- # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1355
- # server.bind(sock_path)
1356
- #
1357
- # print("Listening...")
1358
- # for i in range(iterations):
1359
- # a = array.array('i')
1360
- # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
1361
- # assert(len(ancdata) == 1)
1362
- # cmsg_level, cmsg_type, cmsg_data = ancdata[0]
1363
- # a.frombytes(cmsg_data)
1364
- # print("Received fd ", a[0], " (iteration #", i, ")")
1365
- #
1366
- # shutil.rmtree(dirname)
1367
- #
1368
- # Steps to reproduce:
1369
- #
1370
- # 1. Run two shells and set lower file descriptor limit in the receiving one:
1371
- # (shell1) ulimit -n 1020
1372
- # (shell2) ulimit -n 1022
1373
- #
1374
- # 2. Run the script above with the `recv` option in the first shell
1375
- # (shell1) ./test_socket.py sock_tmp 1017 recv
1376
- #
1377
- # 3. Run the script with the `send` option in the second shell:
1378
- # (shell2) ./test_socket.py sock_tmp 1017 send
1379
-
1380
- def _get_data(self):
1381
- # Fetches data from `self._data_queue`.
1382
- #
1383
- # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
1384
- # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
1385
- # in a loop. This is the only mechanism to detect worker failures for
1386
- # Windows. For other platforms, a SIGCHLD handler is also used for
1387
- # worker failure detection.
1388
- #
1389
- # If `pin_memory=True`, we also need check if `pin_memory_thread` had
1390
- # died at timeouts.
1391
- if self._timeout > 0:
1392
- success, data = self._try_get_data(self._timeout)
1393
- if success:
1394
- return data
1395
- else:
1396
- raise RuntimeError(
1397
- "DataLoader timed out after {} seconds".format(self._timeout)
1398
- )
1399
- elif self._pin_memory:
1400
- while self._pin_memory_thread.is_alive():
1401
- success, data = self._try_get_data()
1402
- if success:
1403
- return data
1404
- else:
1405
- # while condition is false, i.e., pin_memory_thread died.
1406
- raise RuntimeError("Pin memory thread exited unexpectedly")
1407
- # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1408
- # need to call `.task_done()` because we don't use `.join()`.
1409
- else:
1410
- while True:
1411
- success, data = self._try_get_data()
1412
- if success:
1413
- return data
1414
-
1415
- def _next_data(self):
1416
- while True:
1417
- # If the worker responsible for `self._rcvd_idx` has already ended
1418
- # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
1419
- # we try to advance `self._rcvd_idx` to find the next valid index.
1420
- #
1421
- # This part needs to run in the loop because both the `self._get_data()`
1422
- # call and `_IterableDatasetStopIteration` check below can mark
1423
- # extra worker(s) as dead.
1424
- while self._rcvd_idx < self._send_idx:
1425
- info = self._task_info[self._rcvd_idx]
1426
- worker_id = info[0]
1427
- if (
1428
- len(info) == 2 or self._workers_status[worker_id]
1429
- ): # has data or is still active
1430
- break
1431
- del self._task_info[self._rcvd_idx]
1432
- self._rcvd_idx += 1
1433
- else:
1434
- # no valid `self._rcvd_idx` is found (i.e., didn't break)
1435
- if not self._persistent_workers:
1436
- self._shutdown_workers()
1437
- raise StopIteration
1438
-
1439
- # Now `self._rcvd_idx` is the batch index we want to fetch
1440
-
1441
- # Check if the next sample has already been generated
1442
- if len(self._task_info[self._rcvd_idx]) == 2:
1443
- data = self._task_info.pop(self._rcvd_idx)[1]
1444
- return self._process_data(data)
1445
-
1446
- assert not self._shutdown and self._tasks_outstanding > 0
1447
- idx, data = self._get_data()
1448
- self._tasks_outstanding -= 1
1449
- if self._dataset_kind == _DatasetKind.Iterable:
1450
- # Check for _IterableDatasetStopIteration
1451
- if isinstance(data, _utils.worker._IterableDatasetStopIteration):
1452
- if self._persistent_workers:
1453
- self._workers_status[data.worker_id] = False
1454
- else:
1455
- self._mark_worker_as_unavailable(data.worker_id)
1456
- self._try_put_index()
1457
- continue
1458
-
1459
- if idx != self._rcvd_idx:
1460
- # store out-of-order samples
1461
- self._task_info[idx] += (data,)
1462
- else:
1463
- del self._task_info[idx]
1464
- return self._process_data(data)
1465
-
1466
- def _try_put_index(self):
1467
- assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
1468
-
1469
- try:
1470
- index = self._next_index()
1471
- except StopIteration:
1472
- return
1473
- for _ in range(self._num_workers): # find the next active worker, if any
1474
- worker_queue_idx = next(self._worker_queue_idx_cycle)
1475
- if self._workers_status[worker_queue_idx]:
1476
- break
1477
- else:
1478
- # not found (i.e., didn't break)
1479
- return
1480
-
1481
- self._index_queues[worker_queue_idx].put((self._send_idx, index))
1482
- self._task_info[self._send_idx] = (worker_queue_idx,)
1483
- self._tasks_outstanding += 1
1484
- self._send_idx += 1
1485
-
1486
- def _process_data(self, data):
1487
- self._rcvd_idx += 1
1488
- self._try_put_index()
1489
- if isinstance(data, ExceptionWrapper):
1490
- data.reraise()
1491
- return data
1492
-
1493
- def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
1494
- # Mark a worker as having finished its work e.g., due to
1495
- # exhausting an `IterableDataset`. This should be used only when this
1496
- # `_MultiProcessingDataLoaderIter` is going to continue running.
1497
-
1498
- assert self._workers_status[worker_id] or (
1499
- self._persistent_workers and shutdown
1500
- )
1501
-
1502
- # Signal termination to that specific worker.
1503
- q = self._index_queues[worker_id]
1504
- # Indicate that no more data will be put on this queue by the current
1505
- # process.
1506
- q.put(None)
1507
-
1508
- # Note that we don't actually join the worker here, nor do we remove the
1509
- # worker's pid from C side struct because (1) joining may be slow, and
1510
- # (2) since we don't join, the worker may still raise error, and we
1511
- # prefer capturing those, rather than ignoring them, even though they
1512
- # are raised after the worker has finished its job.
1513
- # Joinning is deferred to `_shutdown_workers`, which it is called when
1514
- # all workers finish their jobs (e.g., `IterableDataset` replicas) or
1515
- # when this iterator is garbage collected.
1516
-
1517
- self._workers_status[worker_id] = False
1518
-
1519
- assert self._workers_done_event.is_set() == shutdown
1520
-
1521
- def _shutdown_workers(self):
1522
- # Called when shutting down this `_MultiProcessingDataLoaderIter`.
1523
- # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
1524
- # the logic of this function.
1525
- if (
1526
- _utils is None
1527
- or _utils.python_exit_status is True
1528
- or _utils.python_exit_status is None
1529
- ):
1530
- # See (2) of the note. If Python is shutting down, do no-op.
1531
- return
1532
- # Normal exit when last reference is gone / iterator is depleted.
1533
- # See (1) and the second half of the note.
1534
- if not self._shutdown:
1535
- self._shutdown = True
1536
- try:
1537
- # Normal exit when last reference is gone / iterator is depleted.
1538
- # See (1) and the second half of the note.
1539
-
1540
- # Exit `pin_memory_thread` first because exiting workers may leave
1541
- # corrupted data in `worker_result_queue` which `pin_memory_thread`
1542
- # reads from.
1543
- if hasattr(self, "_pin_memory_thread"):
1544
- # Use hasattr in case error happens before we set the attribute.
1545
- self._pin_memory_thread_done_event.set()
1546
- # Send something to pin_memory_thread in case it is waiting
1547
- # so that it can wake up and check `pin_memory_thread_done_event`
1548
- self._worker_result_queue.put((None, None))
1549
- self._pin_memory_thread.join()
1550
- self._worker_result_queue.cancel_join_thread()
1551
- self._worker_result_queue.close()
1552
-
1553
- # Exit workers now.
1554
- self._workers_done_event.set()
1555
- for worker_id in range(len(self._workers)):
1556
- # Get number of workers from `len(self._workers)` instead of
1557
- # `self._num_workers` in case we error before starting all
1558
- # workers.
1559
- # If we are using workers_status with persistent_workers
1560
- # we have to shut it down because the worker is paused
1561
- if self._persistent_workers or self._workers_status[worker_id]:
1562
- self._mark_worker_as_unavailable(worker_id, shutdown=True)
1563
- for w in self._workers:
1564
- # We should be able to join here, but in case anything went
1565
- # wrong, we set a timeout and if the workers fail to join,
1566
- # they are killed in the `finally` block.
1567
- w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1568
- for q in self._index_queues:
1569
- q.cancel_join_thread()
1570
- q.close()
1571
- finally:
1572
- # Even though all this function does is putting into queues that
1573
- # we have called `cancel_join_thread` on, weird things can
1574
- # happen when a worker is killed by a signal, e.g., hanging in
1575
- # `Event.set()`. So we need to guard this with SIGCHLD handler,
1576
- # and remove pids from the C side data structure only at the
1577
- # end.
1578
- #
1579
- # FIXME: Unfortunately, for Windows, we are missing a worker
1580
- # error detection mechanism here in this function, as it
1581
- # doesn't provide a SIGCHLD handler.
1582
- if self._worker_pids_set:
1583
- _utils.signal_handling._remove_worker_pids(id(self))
1584
- self._worker_pids_set = False
1585
- for w in self._workers:
1586
- if w.is_alive():
1587
- # Existing mechanisms try to make the workers exit
1588
- # peacefully, but in case that we unfortunately reach
1589
- # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
1590
- # we kill the worker.
1591
- w.terminate()
1592
-
1593
- # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
1594
- @staticmethod
1595
- def _clean_up_worker(w):
1596
- try:
1597
- w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1598
- finally:
1599
- if w.is_alive():
1600
- w.terminate()
1601
-
1602
- def __del__(self):
1603
- self._shutdown_workers()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/data_provider/random_resolution/_data_worker.py DELETED
@@ -1,378 +0,0 @@
1
- r""""This file is based on torch/utils/data/_utils/worker.py
2
-
3
- Contains definitions of the methods used by the _BaseDataLoaderIter workers.
4
- These **needs** to be in global scope since Py2 doesn't support serializing
5
- static methods.
6
- """
7
-
8
- import os
9
- import queue
10
- import random
11
- from dataclasses import dataclass
12
- from typing import TYPE_CHECKING, Optional, Union
13
-
14
- import torch
15
- from torch._utils import ExceptionWrapper
16
- from torch.utils.data._utils import (HAS_NUMPY, IS_WINDOWS,
17
- MP_STATUS_CHECK_INTERVAL, signal_handling)
18
-
19
- if TYPE_CHECKING:
20
- from torch.utils.data import Dataset
21
-
22
- from .controller import RRSController
23
-
24
- if IS_WINDOWS:
25
- import ctypes
26
- from ctypes.wintypes import BOOL, DWORD, HANDLE
27
-
28
- # On Windows, the parent ID of the worker process remains unchanged when the manager process
29
- # is gone, and the only way to check it through OS is to let the worker have a process handle
30
- # of the manager and ask if the process status has changed.
31
- class ManagerWatchdog:
32
- def __init__(self):
33
- self.manager_pid = os.getppid()
34
-
35
- # mypy cannot detect this code is windows only
36
- self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
37
- self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
38
- self.kernel32.OpenProcess.restype = HANDLE
39
- self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
40
- self.kernel32.WaitForSingleObject.restype = DWORD
41
-
42
- # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
43
- SYNCHRONIZE = 0x00100000
44
- self.manager_handle = self.kernel32.OpenProcess(
45
- SYNCHRONIZE, 0, self.manager_pid
46
- )
47
-
48
- if not self.manager_handle:
49
- raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
50
-
51
- self.manager_dead = False
52
-
53
- def is_alive(self):
54
- if not self.manager_dead:
55
- # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
56
- self.manager_dead = (
57
- self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
58
- )
59
- return not self.manager_dead
60
-
61
- else:
62
-
63
- class ManagerWatchdog: # type: ignore[no-redef]
64
- def __init__(self):
65
- self.manager_pid = os.getppid()
66
- self.manager_dead = False
67
-
68
- def is_alive(self):
69
- if not self.manager_dead:
70
- self.manager_dead = os.getppid() != self.manager_pid
71
- return not self.manager_dead
72
-
73
-
74
- _worker_info = None
75
-
76
-
77
- class WorkerInfo:
78
- id: int
79
- num_workers: int
80
- seed: int
81
- dataset: "Dataset"
82
- __initialized = False
83
-
84
- def __init__(self, **kwargs):
85
- for k, v in kwargs.items():
86
- setattr(self, k, v)
87
- self.__keys = tuple(kwargs.keys())
88
- self.__initialized = True
89
-
90
- def __setattr__(self, key, val):
91
- if self.__initialized:
92
- raise RuntimeError(
93
- "Cannot assign attributes to {} objects".format(self.__class__.__name__)
94
- )
95
- return super().__setattr__(key, val)
96
-
97
- def __repr__(self):
98
- items = []
99
- for k in self.__keys:
100
- items.append("{}={}".format(k, getattr(self, k)))
101
- return "{}({})".format(self.__class__.__name__, ", ".join(items))
102
-
103
-
104
- def get_worker_info() -> Optional[WorkerInfo]:
105
- r"""Returns the information about the current
106
- :class:`~torch.utils.data.DataLoader` iterator worker process.
107
-
108
- When called in a worker, this returns an object guaranteed to have the
109
- following attributes:
110
-
111
- * :attr:`id`: the current worker id.
112
- * :attr:`num_workers`: the total number of workers.
113
- * :attr:`seed`: the random seed set for the current worker. This value is
114
- determined by main process RNG and the worker id. See
115
- :class:`~torch.utils.data.DataLoader`'s documentation for more details.
116
- * :attr:`dataset`: the copy of the dataset object in **this** process. Note
117
- that this will be a different object in a different process than the one
118
- in the main process.
119
-
120
- When called in the main process, this returns ``None``.
121
-
122
- .. note::
123
- When used in a :attr:`worker_init_fn` passed over to
124
- :class:`~torch.utils.data.DataLoader`, this method can be useful to
125
- set up each worker process differently, for instance, using ``worker_id``
126
- to configure the ``dataset`` object to only read a specific fraction of a
127
- sharded dataset, or use ``seed`` to seed other libraries used in dataset
128
- code.
129
- """
130
- return _worker_info
131
-
132
-
133
- r"""Dummy class used to signal the end of an IterableDataset"""
134
-
135
-
136
- @dataclass(frozen=True)
137
- class _IterableDatasetStopIteration:
138
- worker_id: int
139
-
140
-
141
- r"""Dummy class used to resume the fetching when worker reuse is enabled"""
142
-
143
-
144
- @dataclass(frozen=True)
145
- class _ResumeIteration:
146
- seed: Optional[int] = None
147
-
148
-
149
- # The function `_generate_state` is adapted from `numpy.random.SeedSequence`
150
- # from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
151
- # It's MIT licensed, here is the copyright:
152
-
153
- # Copyright (c) 2015 Melissa E. O'Neill
154
- # Copyright (c) 2019 NumPy Developers
155
- #
156
- # Permission is hereby granted, free of charge, to any person obtaining a copy
157
- # of this software and associated documentation files (the "Software"), to deal
158
- # in the Software without restriction, including without limitation the rights
159
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
160
- # copies of the Software, and to permit persons to whom the Software is
161
- # furnished to do so, subject to the following conditions:
162
- #
163
- # The above copyright notice and this permission notice shall be included in
164
- # all copies or substantial portions of the Software.
165
- #
166
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
167
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
168
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
169
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
170
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
171
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
172
- # SOFTWARE.
173
-
174
-
175
- # This function generates an array of int32 as the seed for
176
- # `numpy.random`, in order to prevent state collision due to same
177
- # seed and algorithm for `numpy.random` and `random` modules.
178
- # TODO: Implement `SeedSequence` like object for `torch.random`
179
- def _generate_state(base_seed, worker_id):
180
- INIT_A = 0x43B0D7E5
181
- MULT_A = 0x931E8875
182
- INIT_B = 0x8B51F9DD
183
- MULT_B = 0x58F38DED
184
- MIX_MULT_L = 0xCA01F9DD
185
- MIX_MULT_R = 0x4973F715
186
- XSHIFT = 4 * 8 // 2
187
- MASK32 = 0xFFFFFFFF
188
-
189
- entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
190
- pool = [0] * 4
191
-
192
- hash_const_A = INIT_A
193
-
194
- def hash(value):
195
- nonlocal hash_const_A
196
- value = (value ^ hash_const_A) & MASK32
197
- hash_const_A = (hash_const_A * MULT_A) & MASK32
198
- value = (value * hash_const_A) & MASK32
199
- value = (value ^ (value >> XSHIFT)) & MASK32
200
- return value
201
-
202
- def mix(x, y):
203
- result_x = (MIX_MULT_L * x) & MASK32
204
- result_y = (MIX_MULT_R * y) & MASK32
205
- result = (result_x - result_y) & MASK32
206
- result = (result ^ (result >> XSHIFT)) & MASK32
207
- return result
208
-
209
- # Add in the entropy to the pool.
210
- for i in range(len(pool)):
211
- pool[i] = hash(entropy[i])
212
-
213
- # Mix all bits together so late bits can affect earlier bits.
214
- for i_src in range(len(pool)):
215
- for i_dst in range(len(pool)):
216
- if i_src != i_dst:
217
- pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
218
-
219
- hash_const_B = INIT_B
220
- state = []
221
- for i_dst in range(4):
222
- data_val = pool[i_dst]
223
- data_val = (data_val ^ hash_const_B) & MASK32
224
- hash_const_B = (hash_const_B * MULT_B) & MASK32
225
- data_val = (data_val * hash_const_B) & MASK32
226
- data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
227
- state.append(data_val)
228
- return state
229
-
230
-
231
- def _worker_loop(
232
- dataset_kind,
233
- dataset,
234
- index_queue,
235
- data_queue,
236
- done_event,
237
- auto_collation,
238
- collate_fn,
239
- drop_last,
240
- base_seed,
241
- init_fn,
242
- worker_id,
243
- num_workers,
244
- persistent_workers,
245
- shared_seed,
246
- ):
247
- # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
248
- # logic of this function.
249
-
250
- try:
251
- # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
252
- # module's handlers are executed after Python returns from C low-level
253
- # handlers, likely when the same fatal signal had already happened
254
- # again.
255
- # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
256
- signal_handling._set_worker_signal_handlers()
257
-
258
- torch.set_num_threads(1)
259
- seed = base_seed + worker_id
260
- random.seed(seed)
261
- torch.manual_seed(seed)
262
- if HAS_NUMPY:
263
- np_seed = _generate_state(base_seed, worker_id)
264
- import numpy as np
265
-
266
- np.random.seed(np_seed)
267
-
268
- from torch.utils.data import IterDataPipe
269
- from torch.utils.data.graph_settings import apply_random_seed
270
-
271
- shared_rng = torch.Generator()
272
- if isinstance(dataset, IterDataPipe):
273
- assert shared_seed is not None
274
- shared_rng.manual_seed(shared_seed)
275
- dataset = apply_random_seed(dataset, shared_rng)
276
-
277
- global _worker_info
278
- _worker_info = WorkerInfo(
279
- id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
280
- )
281
-
282
- from torch.utils.data import _DatasetKind
283
-
284
- init_exception = None
285
-
286
- try:
287
- if init_fn is not None:
288
- init_fn(worker_id)
289
-
290
- fetcher = _DatasetKind.create_fetcher(
291
- dataset_kind, dataset, auto_collation, collate_fn, drop_last
292
- )
293
- except Exception:
294
- init_exception = ExceptionWrapper(
295
- where="in DataLoader worker process {}".format(worker_id)
296
- )
297
-
298
- # When using Iterable mode, some worker can exit earlier than others due
299
- # to the IterableDataset behaving differently for different workers.
300
- # When such things happen, an `_IterableDatasetStopIteration` object is
301
- # sent over to the main process with the ID of this worker, so that the
302
- # main process won't send more tasks to this worker, and will send
303
- # `None` to this worker to properly exit it.
304
- #
305
- # Note that we cannot set `done_event` from a worker as it is shared
306
- # among all processes. Instead, we set the `iteration_end` flag to
307
- # signify that the iterator is exhausted. When either `done_event` or
308
- # `iteration_end` is set, we skip all processing step and just wait for
309
- # `None`.
310
- iteration_end = False
311
-
312
- watchdog = ManagerWatchdog()
313
-
314
- while watchdog.is_alive():
315
- try:
316
- r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
317
- except queue.Empty:
318
- continue
319
- if isinstance(r, _ResumeIteration):
320
- # Acknowledge the main process
321
- data_queue.put((r, None))
322
- iteration_end = False
323
-
324
- if isinstance(dataset, IterDataPipe):
325
- assert r.seed is not None
326
- shared_rng.manual_seed(r.seed)
327
- dataset = apply_random_seed(dataset, shared_rng)
328
-
329
- # Recreate the fetcher for worker-reuse policy
330
- fetcher = _DatasetKind.create_fetcher(
331
- dataset_kind, dataset, auto_collation, collate_fn, drop_last
332
- )
333
- continue
334
- elif r is None:
335
- # Received the final signal
336
- assert done_event.is_set() or iteration_end
337
- break
338
- elif done_event.is_set() or iteration_end:
339
- # `done_event` is set. But I haven't received the final signal
340
- # (None) yet. I will keep continuing until get it, and skip the
341
- # processing steps.
342
- continue
343
- idx, index = r
344
- """ Added """
345
- RRSController.sample_resolution(batch_id=idx)
346
- """ Added """
347
- data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
348
- if init_exception is not None:
349
- data = init_exception
350
- init_exception = None
351
- else:
352
- try:
353
- data = fetcher.fetch(index)
354
- except Exception as e:
355
- if (
356
- isinstance(e, StopIteration)
357
- and dataset_kind == _DatasetKind.Iterable
358
- ):
359
- data = _IterableDatasetStopIteration(worker_id)
360
- # Set `iteration_end`
361
- # (1) to save future `next(...)` calls, and
362
- # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
363
- iteration_end = True
364
- else:
365
- # It is important that we don't store exc_info in a variable.
366
- # `ExceptionWrapper` does the correct thing.
367
- # See NOTE [ Python Traceback Reference Cycle Problem ]
368
- data = ExceptionWrapper(
369
- where="in DataLoader worker process {}".format(worker_id)
370
- )
371
- data_queue.put((idx, data))
372
- del data, idx, index, r # save memory
373
- except KeyboardInterrupt:
374
- # Main process will raise KeyboardInterrupt anyways.
375
- pass
376
- if done_event.is_set():
377
- data_queue.cancel_join_thread()
378
- data_queue.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/data_provider/random_resolution/controller.py DELETED
@@ -1,94 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import copy
6
-
7
- import torch
8
- import torchvision.transforms as transforms
9
- import torchvision.transforms.functional as F
10
-
11
- from efficientvit.models.utils import torch_random_choices
12
-
13
- __all__ = [
14
- "RRSController",
15
- "get_interpolate",
16
- "MyRandomResizedCrop",
17
- ]
18
-
19
-
20
- class RRSController:
21
- ACTIVE_SIZE = (224, 224)
22
- IMAGE_SIZE_LIST = [(224, 224)]
23
-
24
- CHOICE_LIST = None
25
-
26
- @staticmethod
27
- def get_candidates() -> list[tuple[int, int]]:
28
- return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
29
-
30
- @staticmethod
31
- def sample_resolution(batch_id: int) -> None:
32
- RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
33
-
34
- @staticmethod
35
- def set_epoch(epoch: int, batch_per_epoch: int) -> None:
36
- g = torch.Generator()
37
- g.manual_seed(epoch)
38
- RRSController.CHOICE_LIST = torch_random_choices(
39
- RRSController.get_candidates(),
40
- g,
41
- batch_per_epoch,
42
- )
43
-
44
-
45
- def get_interpolate(name: str) -> F.InterpolationMode:
46
- mapping = {
47
- "nearest": F.InterpolationMode.NEAREST,
48
- "bilinear": F.InterpolationMode.BILINEAR,
49
- "bicubic": F.InterpolationMode.BICUBIC,
50
- "box": F.InterpolationMode.BOX,
51
- "hamming": F.InterpolationMode.HAMMING,
52
- "lanczos": F.InterpolationMode.LANCZOS,
53
- }
54
- if name in mapping:
55
- return mapping[name]
56
- elif name == "random":
57
- return torch_random_choices(
58
- [
59
- F.InterpolationMode.NEAREST,
60
- F.InterpolationMode.BILINEAR,
61
- F.InterpolationMode.BICUBIC,
62
- F.InterpolationMode.BOX,
63
- F.InterpolationMode.HAMMING,
64
- F.InterpolationMode.LANCZOS,
65
- ],
66
- )
67
- else:
68
- raise NotImplementedError
69
-
70
-
71
- class MyRandomResizedCrop(transforms.RandomResizedCrop):
72
- def __init__(
73
- self,
74
- scale=(0.08, 1.0),
75
- ratio=(3.0 / 4.0, 4.0 / 3.0),
76
- interpolation: str = "random",
77
- ):
78
- super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
79
- self.interpolation = interpolation
80
-
81
- def forward(self, img: torch.Tensor) -> torch.Tensor:
82
- i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
83
- target_size = RRSController.ACTIVE_SIZE
84
- return F.resized_crop(
85
- img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
86
- )
87
-
88
- def __repr__(self) -> str:
89
- format_string = self.__class__.__name__
90
- format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
91
- format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
92
- format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
93
- format_string += f"\tinterpolation={self.interpolation})"
94
- return format_string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/setup.py DELETED
@@ -1,141 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import os
6
- import time
7
- from copy import deepcopy
8
-
9
- import torch.backends.cudnn
10
- import torch.distributed
11
- import torch.nn as nn
12
-
13
- from efficientvit.apps.data_provider import DataProvider
14
- from efficientvit.apps.trainer.run_config import RunConfig
15
- from efficientvit.apps.utils import (dist_init, dump_config,
16
- get_dist_local_rank, get_dist_rank,
17
- get_dist_size, init_modules, is_master,
18
- load_config, partial_update_config,
19
- zero_last_gamma)
20
- from efficientvit.models.utils import (build_kwargs_from_config,
21
- load_state_dict_from_file)
22
-
23
- __all__ = [
24
- "save_exp_config",
25
- "setup_dist_env",
26
- "setup_seed",
27
- "setup_exp_config",
28
- "setup_data_provider",
29
- "setup_run_config",
30
- "init_model",
31
- ]
32
-
33
-
34
- def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
35
- if not is_master():
36
- return
37
- dump_config(exp_config, os.path.join(path, name))
38
-
39
-
40
- def setup_dist_env(gpu: str or None = None) -> None:
41
- if gpu is not None:
42
- os.environ["CUDA_VISIBLE_DEVICES"] = gpu
43
- if not torch.distributed.is_initialized():
44
- dist_init()
45
- torch.backends.cudnn.benchmark = True
46
- torch.cuda.set_device(get_dist_local_rank())
47
-
48
-
49
- def setup_seed(manual_seed: int, resume: bool) -> None:
50
- if resume:
51
- manual_seed = int(time.time())
52
- manual_seed = get_dist_rank() + manual_seed
53
- torch.manual_seed(manual_seed)
54
- torch.cuda.manual_seed_all(manual_seed)
55
-
56
-
57
- def setup_exp_config(
58
- config_path: str, recursive=True, opt_args: dict or None = None
59
- ) -> dict:
60
- # load config
61
- if not os.path.isfile(config_path):
62
- raise ValueError(config_path)
63
-
64
- fpaths = [config_path]
65
- if recursive:
66
- extension = os.path.splitext(config_path)[1]
67
- while os.path.dirname(config_path) != config_path:
68
- config_path = os.path.dirname(config_path)
69
- fpath = os.path.join(config_path, "default" + extension)
70
- if os.path.isfile(fpath):
71
- fpaths.append(fpath)
72
- fpaths = fpaths[::-1]
73
-
74
- default_config = load_config(fpaths[0])
75
- exp_config = deepcopy(default_config)
76
- for fpath in fpaths[1:]:
77
- partial_update_config(exp_config, load_config(fpath))
78
- # update config via args
79
- if opt_args is not None:
80
- partial_update_config(exp_config, opt_args)
81
-
82
- return exp_config
83
-
84
-
85
- def setup_data_provider(
86
- exp_config: dict,
87
- data_provider_classes: list[type[DataProvider]],
88
- is_distributed: bool = True,
89
- ) -> DataProvider:
90
- dp_config = exp_config["data_provider"]
91
- dp_config["num_replicas"] = get_dist_size() if is_distributed else None
92
- dp_config["rank"] = get_dist_rank() if is_distributed else None
93
- dp_config["test_batch_size"] = (
94
- dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
95
- )
96
- dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
97
- "base_batch_size"
98
- ]
99
-
100
- data_provider_lookup = {
101
- provider.name: provider for provider in data_provider_classes
102
- }
103
- data_provider_class = data_provider_lookup[dp_config["dataset"]]
104
-
105
- data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
106
- data_provider = data_provider_class(**data_provider_kwargs)
107
- return data_provider
108
-
109
-
110
- def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
111
- exp_config["run_config"]["init_lr"] = (
112
- exp_config["run_config"]["base_lr"] * get_dist_size()
113
- )
114
-
115
- run_config = run_config_cls(**exp_config["run_config"])
116
-
117
- return run_config
118
-
119
-
120
- def init_model(
121
- network: nn.Module,
122
- init_from: str or None = None,
123
- backbone_init_from: str or None = None,
124
- rand_init="trunc_normal",
125
- last_gamma=None,
126
- ) -> None:
127
- # initialization
128
- init_modules(network, init_type=rand_init)
129
- # zero gamma of last bn in each block
130
- if last_gamma is not None:
131
- zero_last_gamma(network, last_gamma)
132
-
133
- # load weight
134
- if init_from is not None and os.path.isfile(init_from):
135
- network.load_state_dict(load_state_dict_from_file(init_from))
136
- print(f"Loaded init from {init_from}")
137
- elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
138
- network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
139
- print(f"Loaded backbone init from {backbone_init_from}")
140
- else:
141
- print(f"Random init ({rand_init}) with last gamma {last_gamma}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/trainer/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- from .base import *
6
- from .run_config import *
 
 
 
 
 
 
 
efficientvit/apps/trainer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (224 Bytes)
 
efficientvit/apps/trainer/__pycache__/base.cpython-310.pyc DELETED
Binary file (8.11 kB)
 
efficientvit/apps/trainer/__pycache__/run_config.cpython-310.pyc DELETED
Binary file (4.06 kB)
 
efficientvit/apps/trainer/base.py DELETED
@@ -1,297 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import os
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
- from efficientvit.apps.data_provider import DataProvider, parse_image_size
11
- from efficientvit.apps.trainer.run_config import RunConfig
12
- from efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank,
13
- is_master)
14
- from efficientvit.models.nn.norm import reset_bn
15
- from efficientvit.models.utils import is_parallel, load_state_dict_from_file
16
-
17
- __all__ = ["Trainer"]
18
-
19
-
20
- class Trainer:
21
- def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
22
- self.path = os.path.realpath(os.path.expanduser(path))
23
- self.model = model.cuda()
24
- self.data_provider = data_provider
25
-
26
- self.ema = None
27
-
28
- self.checkpoint_path = os.path.join(self.path, "checkpoint")
29
- self.logs_path = os.path.join(self.path, "logs")
30
- for path in [self.path, self.checkpoint_path, self.logs_path]:
31
- os.makedirs(path, exist_ok=True)
32
-
33
- self.best_val = 0.0
34
- self.start_epoch = 0
35
-
36
- @property
37
- def network(self) -> nn.Module:
38
- return self.model.module if is_parallel(self.model) else self.model
39
-
40
- @property
41
- def eval_network(self) -> nn.Module:
42
- if self.ema is None:
43
- model = self.model
44
- else:
45
- model = self.ema.shadows
46
- model = model.module if is_parallel(model) else model
47
- return model
48
-
49
- def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
50
- if is_master():
51
- fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
52
- fout.write(log_str + "\n")
53
- fout.flush()
54
- fout.close()
55
- if print_log:
56
- print(log_str)
57
-
58
- def save_model(
59
- self,
60
- checkpoint=None,
61
- only_state_dict=True,
62
- epoch=0,
63
- model_name=None,
64
- ) -> None:
65
- if is_master():
66
- if checkpoint is None:
67
- if only_state_dict:
68
- checkpoint = {"state_dict": self.network.state_dict()}
69
- else:
70
- checkpoint = {
71
- "state_dict": self.network.state_dict(),
72
- "epoch": epoch,
73
- "best_val": self.best_val,
74
- "optimizer": self.optimizer.state_dict(),
75
- "lr_scheduler": self.lr_scheduler.state_dict(),
76
- "ema": self.ema.state_dict() if self.ema is not None else None,
77
- "scaler": self.scaler.state_dict() if self.fp16 else None,
78
- }
79
-
80
- model_name = model_name or "checkpoint.pt"
81
-
82
- latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
83
- model_path = os.path.join(self.checkpoint_path, model_name)
84
- with open(latest_fname, "w") as _fout:
85
- _fout.write(model_path + "\n")
86
- torch.save(checkpoint, model_path)
87
-
88
- def load_model(self, model_fname=None) -> None:
89
- latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
90
- if model_fname is None and os.path.exists(latest_fname):
91
- with open(latest_fname, "r") as fin:
92
- model_fname = fin.readline()
93
- if len(model_fname) > 0 and model_fname[-1] == "\n":
94
- model_fname = model_fname[:-1]
95
- try:
96
- if model_fname is None:
97
- model_fname = f"{self.checkpoint_path}/checkpoint.pt"
98
- elif not os.path.exists(model_fname):
99
- model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
100
- if not os.path.exists(model_fname):
101
- model_fname = f"{self.checkpoint_path}/checkpoint.pt"
102
- print(f"=> loading checkpoint {model_fname}")
103
- checkpoint = load_state_dict_from_file(model_fname, False)
104
- except Exception:
105
- self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
106
- return
107
-
108
- # load checkpoint
109
- self.network.load_state_dict(checkpoint["state_dict"], strict=False)
110
- log = []
111
- if "epoch" in checkpoint:
112
- self.start_epoch = checkpoint["epoch"] + 1
113
- self.run_config.update_global_step(self.start_epoch)
114
- log.append(f"epoch={self.start_epoch - 1}")
115
- if "best_val" in checkpoint:
116
- self.best_val = checkpoint["best_val"]
117
- log.append(f"best_val={self.best_val:.2f}")
118
- if "optimizer" in checkpoint:
119
- self.optimizer.load_state_dict(checkpoint["optimizer"])
120
- log.append("optimizer")
121
- if "lr_scheduler" in checkpoint:
122
- self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
123
- log.append("lr_scheduler")
124
- if "ema" in checkpoint and self.ema is not None:
125
- self.ema.load_state_dict(checkpoint["ema"])
126
- log.append("ema")
127
- if "scaler" in checkpoint and self.fp16:
128
- self.scaler.load_state_dict(checkpoint["scaler"])
129
- log.append("scaler")
130
- self.write_log("Loaded: " + ", ".join(log))
131
-
132
- """ validate """
133
-
134
- def reset_bn(
135
- self,
136
- network: nn.Module or None = None,
137
- subset_size: int = 16000,
138
- subset_batch_size: int = 100,
139
- data_loader=None,
140
- progress_bar=False,
141
- ) -> None:
142
- network = network or self.network
143
- if data_loader is None:
144
- data_loader = []
145
- for data in self.data_provider.build_sub_train_loader(
146
- subset_size, subset_batch_size
147
- ):
148
- if isinstance(data, list):
149
- data_loader.append(data[0])
150
- elif isinstance(data, dict):
151
- data_loader.append(data["data"])
152
- elif isinstance(data, torch.Tensor):
153
- data_loader.append(data)
154
- else:
155
- raise NotImplementedError
156
-
157
- network.eval()
158
- reset_bn(
159
- network,
160
- data_loader,
161
- sync=True,
162
- progress_bar=progress_bar,
163
- )
164
-
165
- def _validate(self, model, data_loader, epoch) -> dict[str, any]:
166
- raise NotImplementedError
167
-
168
- def validate(
169
- self, model=None, data_loader=None, is_test=True, epoch=0
170
- ) -> dict[str, any]:
171
- model = model or self.eval_network
172
- if data_loader is None:
173
- if is_test:
174
- data_loader = self.data_provider.test
175
- else:
176
- data_loader = self.data_provider.valid
177
-
178
- model.eval()
179
- return self._validate(model, data_loader, epoch)
180
-
181
- def multires_validate(
182
- self,
183
- model=None,
184
- data_loader=None,
185
- is_test=True,
186
- epoch=0,
187
- eval_image_size=None,
188
- ) -> dict[str, dict[str, any]]:
189
- eval_image_size = eval_image_size or self.run_config.eval_image_size
190
- eval_image_size = eval_image_size or self.data_provider.image_size
191
- model = model or self.eval_network
192
-
193
- if not isinstance(eval_image_size, list):
194
- eval_image_size = [eval_image_size]
195
-
196
- output_dict = {}
197
- for r in eval_image_size:
198
- self.data_provider.assign_active_image_size(parse_image_size(r))
199
- if self.run_config.reset_bn:
200
- self.reset_bn(
201
- network=model,
202
- subset_size=self.run_config.reset_bn_size,
203
- subset_batch_size=self.run_config.reset_bn_batch_size,
204
- progress_bar=True,
205
- )
206
- output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
207
- return output_dict
208
-
209
- """ training """
210
-
211
- def prep_for_training(
212
- self, run_config: RunConfig, ema_decay: float or None = None, fp16=False
213
- ) -> None:
214
- self.run_config = run_config
215
- self.model = nn.parallel.DistributedDataParallel(
216
- self.model.cuda(),
217
- device_ids=[get_dist_local_rank()],
218
- static_graph=True,
219
- )
220
-
221
- self.run_config.global_step = 0
222
- self.run_config.batch_per_epoch = len(self.data_provider.train)
223
- assert self.run_config.batch_per_epoch > 0, "Training set is empty"
224
-
225
- # build optimizer
226
- self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
227
-
228
- if ema_decay is not None:
229
- self.ema = EMA(self.network, ema_decay)
230
-
231
- # fp16
232
- self.fp16 = fp16
233
- self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
234
-
235
- def sync_model(self):
236
- print("Sync model")
237
- self.save_model(model_name="sync.pt")
238
- dist_barrier()
239
- checkpoint = torch.load(
240
- os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu"
241
- )
242
- dist_barrier()
243
- if is_master():
244
- os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
245
- dist_barrier()
246
-
247
- # load checkpoint
248
- self.network.load_state_dict(checkpoint["state_dict"], strict=False)
249
- if "optimizer" in checkpoint:
250
- self.optimizer.load_state_dict(checkpoint["optimizer"])
251
- if "lr_scheduler" in checkpoint:
252
- self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
253
- if "ema" in checkpoint and self.ema is not None:
254
- self.ema.load_state_dict(checkpoint["ema"])
255
- if "scaler" in checkpoint and self.fp16:
256
- self.scaler.load_state_dict(checkpoint["scaler"])
257
-
258
- def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
259
- for key in feed_dict:
260
- if isinstance(feed_dict[key], torch.Tensor):
261
- feed_dict[key] = feed_dict[key].cuda()
262
- return feed_dict
263
-
264
- def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
265
- raise NotImplementedError
266
-
267
- def after_step(self) -> None:
268
- self.scaler.unscale_(self.optimizer)
269
- # gradient clip
270
- if self.run_config.grad_clip is not None:
271
- torch.nn.utils.clip_grad_value_(
272
- self.model.parameters(), self.run_config.grad_clip
273
- )
274
- # update
275
- self.scaler.step(self.optimizer)
276
- self.scaler.update()
277
-
278
- self.lr_scheduler.step()
279
- self.run_config.step()
280
- # update ema
281
- if self.ema is not None:
282
- self.ema.step(self.network, self.run_config.global_step)
283
-
284
- def _train_one_epoch(self, epoch: int) -> dict[str, any]:
285
- raise NotImplementedError
286
-
287
- def train_one_epoch(self, epoch: int) -> dict[str, any]:
288
- self.model.train()
289
-
290
- self.data_provider.set_epoch(epoch)
291
-
292
- train_info_dict = self._train_one_epoch(epoch)
293
-
294
- return train_info_dict
295
-
296
- def train(self) -> None:
297
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/trainer/run_config.py DELETED
@@ -1,121 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import json
6
-
7
- import numpy as np
8
- import torch.nn as nn
9
-
10
- from efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
11
-
12
- __all__ = ["Scheduler", "RunConfig"]
13
-
14
-
15
- class Scheduler:
16
- PROGRESS = 0
17
-
18
-
19
- class RunConfig:
20
- n_epochs: int
21
- init_lr: float
22
- warmup_epochs: int
23
- warmup_lr: float
24
- lr_schedule_name: str
25
- lr_schedule_param: dict
26
- optimizer_name: str
27
- optimizer_params: dict
28
- weight_decay: float
29
- no_wd_keys: list
30
- grad_clip: float # allow none to turn off grad clipping
31
- reset_bn: bool
32
- reset_bn_size: int
33
- reset_bn_batch_size: int
34
- eval_image_size: list # allow none to use image_size in data_provider
35
-
36
- @property
37
- def none_allowed(self):
38
- return ["grad_clip", "eval_image_size"]
39
-
40
- def __init__(self, **kwargs): # arguments must be passed as kwargs
41
- for k, val in kwargs.items():
42
- setattr(self, k, val)
43
-
44
- # check that all relevant configs are there
45
- annotations = {}
46
- for clas in type(self).mro():
47
- if hasattr(clas, "__annotations__"):
48
- annotations.update(clas.__annotations__)
49
- for k, k_type in annotations.items():
50
- assert hasattr(
51
- self, k
52
- ), f"Key {k} with type {k_type} required for initialization."
53
- attr = getattr(self, k)
54
- if k in self.none_allowed:
55
- k_type = (k_type, type(None))
56
- assert isinstance(
57
- attr, k_type
58
- ), f"Key {k} must be type {k_type}, provided={attr}."
59
-
60
- self.global_step = 0
61
- self.batch_per_epoch = 1
62
-
63
- def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
64
- r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
65
- param_dict = {}
66
- for name, param in network.named_parameters():
67
- if param.requires_grad:
68
- opt_config = [self.weight_decay, self.init_lr]
69
- if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
70
- if np.any([key in name for key in self.no_wd_keys]):
71
- opt_config[0] = 0
72
- opt_key = json.dumps(opt_config)
73
- param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
74
-
75
- net_params = []
76
- for opt_key, param_list in param_dict.items():
77
- wd, lr = json.loads(opt_key)
78
- net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
79
-
80
- optimizer = build_optimizer(
81
- net_params, self.optimizer_name, self.optimizer_params, self.init_lr
82
- )
83
- # build lr scheduler
84
- if self.lr_schedule_name == "cosine":
85
- decay_steps = []
86
- for epoch in self.lr_schedule_param.get("step", []):
87
- decay_steps.append(epoch * self.batch_per_epoch)
88
- decay_steps.append(self.n_epochs * self.batch_per_epoch)
89
- decay_steps.sort()
90
- lr_scheduler = CosineLRwithWarmup(
91
- optimizer,
92
- self.warmup_epochs * self.batch_per_epoch,
93
- self.warmup_lr,
94
- decay_steps,
95
- )
96
- else:
97
- raise NotImplementedError
98
- return optimizer, lr_scheduler
99
-
100
- def update_global_step(self, epoch, batch_id=0) -> None:
101
- self.global_step = epoch * self.batch_per_epoch + batch_id
102
- Scheduler.PROGRESS = self.progress
103
-
104
- @property
105
- def progress(self) -> float:
106
- warmup_steps = self.warmup_epochs * self.batch_per_epoch
107
- steps = max(0, self.global_step - warmup_steps)
108
- return steps / (self.n_epochs * self.batch_per_epoch)
109
-
110
- def step(self) -> None:
111
- self.global_step += 1
112
- Scheduler.PROGRESS = self.progress
113
-
114
- def get_remaining_epoch(self, epoch, post=True) -> int:
115
- return self.n_epochs + self.warmup_epochs - epoch - int(post)
116
-
117
- def epoch_format(self, epoch: int) -> str:
118
- epoch_format = f"%.{len(str(self.n_epochs))}d"
119
- epoch_format = f"[{epoch_format}/{epoch_format}]"
120
- epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
121
- return epoch_format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- from .dist import *
6
- from .ema import *
7
- from .export import *
8
- from .init import *
9
- from .lr import *
10
- from .metric import *
11
- from .misc import *
12
- from .opt import *
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (312 Bytes)
 
efficientvit/apps/utils/__pycache__/dist.cpython-310.pyc DELETED
Binary file (2.14 kB)
 
efficientvit/apps/utils/__pycache__/ema.cpython-310.pyc DELETED
Binary file (1.94 kB)
 
efficientvit/apps/utils/__pycache__/export.cpython-310.pyc DELETED
Binary file (1.37 kB)
 
efficientvit/apps/utils/__pycache__/init.cpython-310.pyc DELETED
Binary file (2.01 kB)
 
efficientvit/apps/utils/__pycache__/lr.cpython-310.pyc DELETED
Binary file (1.76 kB)
 
efficientvit/apps/utils/__pycache__/metric.cpython-310.pyc DELETED
Binary file (1.62 kB)
 
efficientvit/apps/utils/__pycache__/misc.cpython-310.pyc DELETED
Binary file (2.74 kB)
 
efficientvit/apps/utils/__pycache__/opt.cpython-310.pyc DELETED
Binary file (896 Bytes)
 
efficientvit/apps/utils/dist.py DELETED
@@ -1,73 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import os
6
-
7
- import torch
8
- import torch.distributed
9
-
10
- from efficientvit.models.utils.list import list_mean, list_sum
11
-
12
- __all__ = [
13
- "dist_init",
14
- "get_dist_rank",
15
- "get_dist_size",
16
- "is_master",
17
- "dist_barrier",
18
- "get_dist_local_rank",
19
- "sync_tensor",
20
- ]
21
-
22
-
23
- def dist_init() -> None:
24
- try:
25
- torch.distributed.init_process_group(backend="nccl")
26
- assert torch.distributed.is_initialized()
27
- except Exception:
28
- # use torchpack
29
- from torchpack import distributed as dist
30
-
31
- dist.init()
32
- os.environ["RANK"] = f"{dist.rank()}"
33
- os.environ["WORLD_SIZE"] = f"{dist.size()}"
34
- os.environ["LOCAL_RANK"] = f"{dist.local_rank()}"
35
-
36
-
37
- def get_dist_rank() -> int:
38
- return int(os.environ["RANK"])
39
-
40
-
41
- def get_dist_size() -> int:
42
- return int(os.environ["WORLD_SIZE"])
43
-
44
-
45
- def is_master() -> bool:
46
- return get_dist_rank() == 0
47
-
48
-
49
- def dist_barrier() -> None:
50
- torch.distributed.barrier()
51
-
52
-
53
- def get_dist_local_rank() -> int:
54
- return int(os.environ["LOCAL_RANK"])
55
-
56
-
57
- def sync_tensor(
58
- tensor: torch.Tensor or float, reduce="mean"
59
- ) -> torch.Tensor or list[torch.Tensor]:
60
- if not isinstance(tensor, torch.Tensor):
61
- tensor = torch.Tensor(1).fill_(tensor).cuda()
62
- tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
63
- torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
64
- if reduce == "mean":
65
- return list_mean(tensor_list)
66
- elif reduce == "sum":
67
- return list_sum(tensor_list)
68
- elif reduce == "cat":
69
- return torch.cat(tensor_list, dim=0)
70
- elif reduce == "root":
71
- return tensor_list[0]
72
- else:
73
- return tensor_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/ema.py DELETED
@@ -1,50 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import copy
6
- import math
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
- from efficientvit.models.utils import is_parallel
12
-
13
- __all__ = ["EMA"]
14
-
15
-
16
- def update_ema(
17
- ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float
18
- ) -> None:
19
- for k, v in ema.state_dict().items():
20
- if v.dtype.is_floating_point:
21
- v -= (1.0 - decay) * (v - new_state_dict[k].detach())
22
-
23
-
24
- class EMA:
25
- def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
26
- self.shadows = copy.deepcopy(
27
- model.module if is_parallel(model) else model
28
- ).eval()
29
- self.decay = decay
30
- self.warmup_steps = warmup_steps
31
-
32
- for p in self.shadows.parameters():
33
- p.requires_grad = False
34
-
35
- def step(self, model: nn.Module, global_step: int) -> None:
36
- with torch.no_grad():
37
- msd = (model.module if is_parallel(model) else model).state_dict()
38
- update_ema(
39
- self.shadows,
40
- msd,
41
- self.decay * (1 - math.exp(-global_step / self.warmup_steps)),
42
- )
43
-
44
- def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
45
- return {self.decay: self.shadows.state_dict()}
46
-
47
- def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
48
- for decay in state_dict:
49
- if decay == self.decay:
50
- self.shadows.load_state_dict(state_dict[decay])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/export.py DELETED
@@ -1,47 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import io
6
- import os
7
-
8
- import onnx
9
- import torch
10
- import torch.nn as nn
11
- from onnxsim import simplify as simplify_func
12
-
13
- __all__ = ["export_onnx"]
14
-
15
-
16
- def export_onnx(
17
- model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11
18
- ) -> None:
19
- """Export a model to a platform-specific onnx format.
20
-
21
- Args:
22
- model: a torch.nn.Module object.
23
- export_path: export location.
24
- sample_inputs: Any.
25
- simplify: a flag to turn on onnx-simplifier
26
- opset: int
27
- """
28
- model.eval()
29
-
30
- buffer = io.BytesIO()
31
- with torch.no_grad():
32
- torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
33
- buffer.seek(0, 0)
34
- if simplify:
35
- onnx_model = onnx.load_model(buffer)
36
- onnx_model, success = simplify_func(onnx_model)
37
- assert success
38
- new_buffer = io.BytesIO()
39
- onnx.save(onnx_model, new_buffer)
40
- buffer = new_buffer
41
- buffer.seek(0, 0)
42
-
43
- if buffer.getbuffer().nbytes > 0:
44
- save_dir = os.path.dirname(export_path)
45
- os.makedirs(save_dir, exist_ok=True)
46
- with open(export_path, "wb") as f:
47
- f.write(buffer.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/init.py DELETED
@@ -1,68 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import torch
6
- import torch.nn as nn
7
- from torch.nn.modules.batchnorm import _BatchNorm
8
-
9
- __all__ = ["init_modules", "zero_last_gamma"]
10
-
11
-
12
- def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None:
13
- _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
14
-
15
- if isinstance(model, list):
16
- for sub_module in model:
17
- init_modules(sub_module, init_type)
18
- else:
19
- init_params = init_type.split("@")
20
- init_params = float(init_params[1]) if len(init_params) > 1 else None
21
-
22
- if init_type.startswith("trunc_normal"):
23
- init_func = lambda param: nn.init.trunc_normal_(
24
- param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"])
25
- )
26
- else:
27
- raise NotImplementedError
28
-
29
- for m in model.modules():
30
- if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
31
- init_func(m.weight)
32
- if m.bias is not None:
33
- m.bias.data.zero_()
34
- elif isinstance(m, nn.Embedding):
35
- init_func(m.weight)
36
- elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
37
- m.weight.data.fill_(1)
38
- m.bias.data.zero_()
39
- else:
40
- weight = getattr(m, "weight", None)
41
- bias = getattr(m, "bias", None)
42
- if isinstance(weight, torch.nn.Parameter):
43
- init_func(weight)
44
- if isinstance(bias, torch.nn.Parameter):
45
- bias.data.zero_()
46
-
47
-
48
- def zero_last_gamma(model: nn.Module, init_val=0) -> None:
49
- import efficientvit.models.nn.ops as ops
50
-
51
- for m in model.modules():
52
- if isinstance(m, ops.ResidualBlock) and isinstance(
53
- m.shortcut, ops.IdentityLayer
54
- ):
55
- if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
56
- parent_module = m.main.point_conv
57
- elif isinstance(m.main, ops.ResBlock):
58
- parent_module = m.main.conv2
59
- elif isinstance(m.main, ops.ConvLayer):
60
- parent_module = m.main
61
- elif isinstance(m.main, (ops.LiteMLA)):
62
- parent_module = m.main.proj
63
- else:
64
- parent_module = None
65
- if parent_module is not None:
66
- norm = getattr(parent_module, "norm", None)
67
- if norm is not None:
68
- nn.init.constant_(norm.weight, init_val)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/lr.py DELETED
@@ -1,48 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import math
6
-
7
- import torch
8
-
9
- from efficientvit.models.utils.list import val2list
10
-
11
- __all__ = ["CosineLRwithWarmup"]
12
-
13
-
14
- class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
15
- def __init__(
16
- self,
17
- optimizer: torch.optim.Optimizer,
18
- warmup_steps: int,
19
- warmup_lr: float,
20
- decay_steps: int or list[int],
21
- last_epoch: int = -1,
22
- ) -> None:
23
- self.warmup_steps = warmup_steps
24
- self.warmup_lr = warmup_lr
25
- self.decay_steps = val2list(decay_steps)
26
- super().__init__(optimizer, last_epoch)
27
-
28
- def get_lr(self) -> list[float]:
29
- if self.last_epoch < self.warmup_steps:
30
- return [
31
- (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps
32
- + self.warmup_lr
33
- for base_lr in self.base_lrs
34
- ]
35
- else:
36
- current_steps = self.last_epoch - self.warmup_steps
37
- decay_steps = [0] + self.decay_steps
38
- idx = len(decay_steps) - 2
39
- for i, decay_step in enumerate(decay_steps[:-1]):
40
- if decay_step <= current_steps < decay_steps[i + 1]:
41
- idx = i
42
- break
43
- current_steps -= decay_steps[idx]
44
- decay_step = decay_steps[idx + 1] - decay_steps[idx]
45
- return [
46
- 0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step))
47
- for base_lr in self.base_lrs
48
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/metric.py DELETED
@@ -1,37 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import torch
6
-
7
- from efficientvit.apps.utils.dist import sync_tensor
8
-
9
- __all__ = ["AverageMeter"]
10
-
11
-
12
- class AverageMeter:
13
- """Computes and stores the average and current value."""
14
-
15
- def __init__(self, is_distributed=True):
16
- self.is_distributed = is_distributed
17
- self.sum = 0
18
- self.count = 0
19
-
20
- def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float:
21
- return sync_tensor(val, reduce="sum") if self.is_distributed else val
22
-
23
- def update(self, val: torch.Tensor or int or float, delta_n=1):
24
- self.count += self._sync(delta_n)
25
- self.sum += self._sync(val * delta_n)
26
-
27
- def get_count(self) -> torch.Tensor or int or float:
28
- return (
29
- self.count.item()
30
- if isinstance(self.count, torch.Tensor) and self.count.numel() == 1
31
- else self.count
32
- )
33
-
34
- @property
35
- def avg(self):
36
- avg = -1 if self.count == 0 else self.sum / self.count
37
- return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/misc.py DELETED
@@ -1,111 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import os
6
-
7
- import yaml
8
-
9
- __all__ = [
10
- "parse_with_yaml",
11
- "parse_unknown_args",
12
- "partial_update_config",
13
- "resolve_and_load_config",
14
- "load_config",
15
- "dump_config",
16
- ]
17
-
18
-
19
- def parse_with_yaml(config_str: str) -> str or dict:
20
- try:
21
- # add space manually for dict
22
- if "{" in config_str and "}" in config_str and ":" in config_str:
23
- out_str = config_str.replace(":", ": ")
24
- else:
25
- out_str = config_str
26
- return yaml.safe_load(out_str)
27
- except ValueError:
28
- # return raw string if parsing fails
29
- return config_str
30
-
31
-
32
- def parse_unknown_args(unknown: list) -> dict:
33
- """Parse unknown args."""
34
- index = 0
35
- parsed_dict = {}
36
- while index < len(unknown):
37
- key, val = unknown[index], unknown[index + 1]
38
- index += 2
39
- if not key.startswith("--"):
40
- continue
41
- key = key[2:]
42
-
43
- # try parsing with either dot notation or full yaml notation
44
- # Note that the vanilla case "--key value" will be parsed the same
45
- if "." in key:
46
- # key == a.b.c, val == val --> parsed_dict[a][b][c] = val
47
- keys = key.split(".")
48
- dict_to_update = parsed_dict
49
- for key in keys[:-1]:
50
- if not (
51
- key in dict_to_update and isinstance(dict_to_update[key], dict)
52
- ):
53
- dict_to_update[key] = {}
54
- dict_to_update = dict_to_update[key]
55
- dict_to_update[keys[-1]] = parse_with_yaml(
56
- val
57
- ) # so we can parse lists, bools, etc...
58
- else:
59
- parsed_dict[key] = parse_with_yaml(val)
60
- return parsed_dict
61
-
62
-
63
- def partial_update_config(config: dict, partial_config: dict) -> dict:
64
- for key in partial_config:
65
- if (
66
- key in config
67
- and isinstance(partial_config[key], dict)
68
- and isinstance(config[key], dict)
69
- ):
70
- partial_update_config(config[key], partial_config[key])
71
- else:
72
- config[key] = partial_config[key]
73
- return config
74
-
75
-
76
- def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
77
- path = os.path.realpath(os.path.expanduser(path))
78
- if os.path.isdir(path):
79
- config_path = os.path.join(path, config_name)
80
- else:
81
- config_path = path
82
- if os.path.isfile(config_path):
83
- pass
84
- else:
85
- raise Exception(f"Cannot find a valid config at {path}")
86
- config = load_config(config_path)
87
- return config
88
-
89
-
90
- class SafeLoaderWithTuple(yaml.SafeLoader):
91
- """A yaml safe loader with python tuple loading capabilities."""
92
-
93
- def construct_python_tuple(self, node):
94
- return tuple(self.construct_sequence(node))
95
-
96
-
97
- SafeLoaderWithTuple.add_constructor(
98
- "tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple
99
- )
100
-
101
-
102
- def load_config(filename: str) -> dict:
103
- """Load a yaml file."""
104
- filename = os.path.realpath(os.path.expanduser(filename))
105
- return yaml.load(open(filename), Loader=SafeLoaderWithTuple)
106
-
107
-
108
- def dump_config(config: dict, filename: str) -> None:
109
- """Dump a config file"""
110
- filename = os.path.realpath(os.path.expanduser(filename))
111
- yaml.dump(config, open(filename, "w"), sort_keys=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/apps/utils/opt.py DELETED
@@ -1,31 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- import torch
6
-
7
- __all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
8
-
9
- # register optimizer here
10
- # name: optimizer, kwargs with default values
11
- REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = {
12
- "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
13
- "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
14
- "adamw": (
15
- torch.optim.AdamW,
16
- {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False},
17
- ),
18
- }
19
-
20
-
21
- def build_optimizer(
22
- net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float
23
- ) -> torch.optim.Optimizer:
24
- optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
25
- optimizer_params = optimizer_params or {}
26
-
27
- for key in default_params:
28
- if key in optimizer_params:
29
- default_params[key] = optimizer_params[key]
30
- optimizer = optimizer_class(net_params, init_lr, **default_params)
31
- return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efficientvit/models/__init__.py DELETED
File without changes
efficientvit/models/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (175 Bytes)
 
efficientvit/models/efficientvit/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
- # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
- # International Conference on Computer Vision (ICCV), 2023
4
-
5
- from .backbone import *
6
- from .cls import *
7
- from .sam import *
8
- from .seg import *
 
 
 
 
 
 
 
 
 
efficientvit/models/efficientvit/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (258 Bytes)