pg56714 commited on
Commit
9043dc9
·
verified ·
1 Parent(s): d28c8e3

Upload 96 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. efficientvit/__init__.py +0 -0
  2. efficientvit/__pycache__/__init__.cpython-310.pyc +0 -0
  3. efficientvit/__pycache__/sam_model_zoo.cpython-310.pyc +0 -0
  4. efficientvit/apps/__init__.py +0 -0
  5. efficientvit/apps/__pycache__/__init__.cpython-310.pyc +0 -0
  6. efficientvit/apps/data_provider/__init__.py +7 -0
  7. efficientvit/apps/data_provider/__pycache__/__init__.cpython-310.pyc +0 -0
  8. efficientvit/apps/data_provider/__pycache__/base.cpython-310.pyc +0 -0
  9. efficientvit/apps/data_provider/augment/__init__.py +6 -0
  10. efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-310.pyc +0 -0
  11. efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-310.pyc +0 -0
  12. efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-310.pyc +0 -0
  13. efficientvit/apps/data_provider/augment/bbox.py +30 -0
  14. efficientvit/apps/data_provider/augment/color_aug.py +78 -0
  15. efficientvit/apps/data_provider/base.py +199 -0
  16. efficientvit/apps/data_provider/random_resolution/__init__.py +7 -0
  17. efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-310.pyc +0 -0
  18. efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-310.pyc +0 -0
  19. efficientvit/apps/data_provider/random_resolution/_data_loader.py +1538 -0
  20. efficientvit/apps/data_provider/random_resolution/_data_worker.py +358 -0
  21. efficientvit/apps/data_provider/random_resolution/controller.py +92 -0
  22. efficientvit/apps/setup.py +135 -0
  23. efficientvit/apps/trainer/__init__.py +6 -0
  24. efficientvit/apps/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  25. efficientvit/apps/trainer/__pycache__/base.cpython-310.pyc +0 -0
  26. efficientvit/apps/trainer/__pycache__/run_config.cpython-310.pyc +0 -0
  27. efficientvit/apps/trainer/base.py +299 -0
  28. efficientvit/apps/trainer/run_config.py +115 -0
  29. efficientvit/apps/utils/__init__.py +12 -0
  30. efficientvit/apps/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  31. efficientvit/apps/utils/__pycache__/dist.cpython-310.pyc +0 -0
  32. efficientvit/apps/utils/__pycache__/ema.cpython-310.pyc +0 -0
  33. efficientvit/apps/utils/__pycache__/export.cpython-310.pyc +0 -0
  34. efficientvit/apps/utils/__pycache__/init.cpython-310.pyc +0 -0
  35. efficientvit/apps/utils/__pycache__/lr.cpython-310.pyc +0 -0
  36. efficientvit/apps/utils/__pycache__/metric.cpython-310.pyc +0 -0
  37. efficientvit/apps/utils/__pycache__/misc.cpython-310.pyc +0 -0
  38. efficientvit/apps/utils/__pycache__/opt.cpython-310.pyc +0 -0
  39. efficientvit/apps/utils/dist.py +71 -0
  40. efficientvit/apps/utils/ema.py +42 -0
  41. efficientvit/apps/utils/export.py +45 -0
  42. efficientvit/apps/utils/init.py +66 -0
  43. efficientvit/apps/utils/lr.py +44 -0
  44. efficientvit/apps/utils/metric.py +33 -0
  45. efficientvit/apps/utils/misc.py +101 -0
  46. efficientvit/apps/utils/opt.py +28 -0
  47. efficientvit/cls_model_zoo.py +79 -0
  48. efficientvit/clscore/__init__.py +0 -0
  49. efficientvit/clscore/data_provider/__init__.py +5 -0
  50. efficientvit/clscore/data_provider/imagenet.py +123 -0
efficientvit/__init__.py ADDED
File without changes
efficientvit/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (175 Bytes). View file
 
efficientvit/__pycache__/sam_model_zoo.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
efficientvit/apps/__init__.py ADDED
File without changes
efficientvit/apps/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (180 Bytes). View file
 
efficientvit/apps/data_provider/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .augment import *
6
+ from .base import *
7
+ from .random_resolution import *
efficientvit/apps/data_provider/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (263 Bytes). View file
 
efficientvit/apps/data_provider/__pycache__/base.cpython-310.pyc ADDED
Binary file (6.35 kB). View file
 
efficientvit/apps/data_provider/augment/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .bbox import *
6
+ from .color_aug import *
efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (244 Bytes). View file
 
efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-310.pyc ADDED
Binary file (807 Bytes). View file
 
efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-310.pyc ADDED
Binary file (3.13 kB). View file
 
efficientvit/apps/data_provider/augment/bbox.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+
7
+ __all__ = ["rand_bbox"]
8
+
9
+
10
+ def rand_bbox(
11
+ h: int,
12
+ w: int,
13
+ lam: float,
14
+ rand_func: callable = np.random.uniform,
15
+ ) -> tuple[int, int, int, int]:
16
+ """randomly sample bbox, used in cutmix"""
17
+ cut_rat = np.sqrt(1.0 - lam)
18
+ cut_w = w * cut_rat
19
+ cut_h = h * cut_rat
20
+
21
+ # uniform
22
+ cx = rand_func(0, w)
23
+ cy = rand_func(0, h)
24
+
25
+ bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
26
+ bby1 = int(np.clip(cy - cut_h / 2, 0, h))
27
+ bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
28
+ bby2 = int(np.clip(cy + cut_h / 2, 0, h))
29
+
30
+ return bbx1, bby1, bbx2, bby2
efficientvit/apps/data_provider/augment/color_aug.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ from timm.data.auto_augment import rand_augment_transform
9
+
10
+ __all__ = ["ColorAug", "RandAug"]
11
+
12
+
13
+ class ImageAug:
14
+ def aug_image(self, image: Image.Image) -> Image.Image:
15
+ raise NotImplementedError
16
+
17
+ def __call__(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image:
18
+ if isinstance(feed_dict, dict):
19
+ output_dict = feed_dict
20
+ image = feed_dict[self.key]
21
+ else:
22
+ output_dict = None
23
+ image = feed_dict
24
+ is_ndarray = isinstance(image, np.ndarray)
25
+ if is_ndarray:
26
+ image = Image.fromarray(image)
27
+
28
+ image = self.aug_image(image)
29
+
30
+ if is_ndarray:
31
+ image = np.array(image)
32
+
33
+ if output_dict is None:
34
+ return image
35
+ else:
36
+ output_dict[self.key] = image
37
+ return output_dict
38
+
39
+
40
+ class ColorAug(transforms.ColorJitter, ImageAug):
41
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
42
+ super().__init__(
43
+ brightness=brightness,
44
+ contrast=contrast,
45
+ saturation=saturation,
46
+ hue=hue,
47
+ )
48
+ self.key = key
49
+
50
+ def aug_image(self, image: Image.Image) -> Image.Image:
51
+ return transforms.ColorJitter.forward(self, image)
52
+
53
+ def forward(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image:
54
+ return ImageAug.__call__(self, feed_dict)
55
+
56
+
57
+ class RandAug(ImageAug):
58
+ def __init__(self, config: dict[str, any], mean: tuple[float, float, float], key="data"):
59
+ n = config.get("n", 2)
60
+ m = config.get("m", 9)
61
+ mstd = config.get("mstd", 1.0)
62
+ inc = config.get("inc", 1)
63
+ tpct = config.get("tpct", 0.45)
64
+ config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
65
+
66
+ aa_params = dict(
67
+ translate_pct=tpct,
68
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
69
+ interpolation=Image.BICUBIC,
70
+ )
71
+ self.aug_op = rand_augment_transform(config_str, aa_params)
72
+ self.key = key
73
+
74
+ def aug_image(self, image: Image.Image) -> Image.Image:
75
+ return self.aug_op(image)
76
+
77
+ def __repr__(self):
78
+ return self.aug_op.__repr__()
efficientvit/apps/data_provider/base.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+ import warnings
7
+
8
+ import torch.utils.data
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from efficientvit.apps.data_provider.random_resolution import RRSController
12
+ from efficientvit.models.utils import val2tuple
13
+
14
+ __all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
15
+
16
+
17
+ def parse_image_size(size: int or str) -> tuple[int, int]:
18
+ if isinstance(size, str):
19
+ size = [int(val) for val in size.split("-")]
20
+ return size[0], size[1]
21
+ else:
22
+ return val2tuple(size, 2)
23
+
24
+
25
+ def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
26
+ g = torch.Generator()
27
+ g.manual_seed(seed) # set random seed before sampling validation set
28
+ rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
29
+
30
+ dropped_indexes = rand_indexes[:drop_size]
31
+ remaining_indexes = rand_indexes[drop_size:]
32
+
33
+ dropped_dataset = copy.deepcopy(dataset)
34
+ for key in keys:
35
+ setattr(dropped_dataset, key, [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes])
36
+ setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
37
+ return dataset, dropped_dataset
38
+
39
+
40
+ class DataProvider:
41
+ data_keys = ("samples",)
42
+ mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
43
+ SUB_SEED = 937162211 # random seed for sampling subset
44
+ VALID_SEED = 2147483647 # random seed for the validation set
45
+
46
+ name: str
47
+
48
+ def __init__(
49
+ self,
50
+ train_batch_size: int,
51
+ test_batch_size: int or None,
52
+ valid_size: int or float or None,
53
+ n_worker: int,
54
+ image_size: int or list[int] or str or list[str],
55
+ num_replicas: int or None = None,
56
+ rank: int or None = None,
57
+ train_ratio: float or None = None,
58
+ drop_last: bool = False,
59
+ ):
60
+ warnings.filterwarnings("ignore")
61
+ super().__init__()
62
+
63
+ # batch_size & valid_size
64
+ self.train_batch_size = train_batch_size
65
+ self.test_batch_size = test_batch_size or self.train_batch_size
66
+ self.valid_size = valid_size
67
+
68
+ # image size
69
+ if isinstance(image_size, list):
70
+ self.image_size = [parse_image_size(size) for size in image_size]
71
+ self.image_size.sort() # e.g., 160 -> 224
72
+ RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
73
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
74
+ else:
75
+ self.image_size = parse_image_size(image_size)
76
+ RRSController.IMAGE_SIZE_LIST = [self.image_size]
77
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
78
+
79
+ # distributed configs
80
+ self.num_replicas = num_replicas
81
+ self.rank = rank
82
+
83
+ # build datasets
84
+ train_dataset, val_dataset, test_dataset = self.build_datasets()
85
+
86
+ if train_ratio is not None and train_ratio < 1.0:
87
+ assert 0 < train_ratio < 1
88
+ _, train_dataset = random_drop_data(
89
+ train_dataset,
90
+ int(train_ratio * len(train_dataset)),
91
+ self.SUB_SEED,
92
+ self.data_keys,
93
+ )
94
+
95
+ # build data loader
96
+ self.train = self.build_dataloader(train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True)
97
+ self.valid = self.build_dataloader(val_dataset, test_batch_size, n_worker, drop_last=False, train=False)
98
+ self.test = self.build_dataloader(test_dataset, test_batch_size, n_worker, drop_last=False, train=False)
99
+ if self.valid is None:
100
+ self.valid = self.test
101
+ self.sub_train = None
102
+
103
+ @property
104
+ def data_shape(self) -> tuple[int, ...]:
105
+ return 3, self.active_image_size[0], self.active_image_size[1]
106
+
107
+ def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
108
+ raise NotImplementedError
109
+
110
+ def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
111
+ raise NotImplementedError
112
+
113
+ def build_datasets(self) -> tuple[any, any, any]:
114
+ raise NotImplementedError
115
+
116
+ def build_dataloader(self, dataset: any or None, batch_size: int, n_worker: int, drop_last: bool, train: bool):
117
+ if dataset is None:
118
+ return None
119
+ if isinstance(self.image_size, list) and train:
120
+ from efficientvit.apps.data_provider.random_resolution._data_loader import RRSDataLoader
121
+
122
+ dataloader_class = RRSDataLoader
123
+ else:
124
+ dataloader_class = torch.utils.data.DataLoader
125
+ if self.num_replicas is None:
126
+ return dataloader_class(
127
+ dataset=dataset,
128
+ batch_size=batch_size,
129
+ shuffle=True,
130
+ num_workers=n_worker,
131
+ pin_memory=True,
132
+ drop_last=drop_last,
133
+ )
134
+ else:
135
+ sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
136
+ return dataloader_class(
137
+ dataset=dataset,
138
+ batch_size=batch_size,
139
+ sampler=sampler,
140
+ num_workers=n_worker,
141
+ pin_memory=True,
142
+ drop_last=drop_last,
143
+ )
144
+
145
+ def set_epoch(self, epoch: int) -> None:
146
+ RRSController.set_epoch(epoch, len(self.train))
147
+ if isinstance(self.train.sampler, DistributedSampler):
148
+ self.train.sampler.set_epoch(epoch)
149
+
150
+ def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
151
+ self.active_image_size = val2tuple(new_size, 2)
152
+ new_transform = self.build_valid_transform(self.active_image_size)
153
+ # change the transform of the valid and test set
154
+ self.valid.dataset.transform = self.test.dataset.transform = new_transform
155
+
156
+ def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
157
+ if self.valid_size is not None:
158
+ if 0 < self.valid_size < 1:
159
+ valid_size = int(self.valid_size * len(train_dataset))
160
+ else:
161
+ assert self.valid_size >= 1
162
+ valid_size = int(self.valid_size)
163
+ train_dataset, val_dataset = random_drop_data(
164
+ train_dataset,
165
+ valid_size,
166
+ self.VALID_SEED,
167
+ self.data_keys,
168
+ )
169
+ val_dataset.transform = valid_transform
170
+ else:
171
+ val_dataset = None
172
+ return train_dataset, val_dataset
173
+
174
+ def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
175
+ # used for resetting BN running statistics
176
+ if self.sub_train is None:
177
+ self.sub_train = {}
178
+ if self.active_image_size in self.sub_train:
179
+ return self.sub_train[self.active_image_size]
180
+
181
+ # construct dataset and dataloader
182
+ train_dataset = copy.deepcopy(self.train.dataset)
183
+ if n_samples < len(train_dataset):
184
+ _, train_dataset = random_drop_data(
185
+ train_dataset,
186
+ n_samples,
187
+ self.SUB_SEED,
188
+ self.data_keys,
189
+ )
190
+ RRSController.ACTIVE_SIZE = self.active_image_size
191
+ train_dataset.transform = self.build_train_transform(image_size=self.active_image_size)
192
+ data_loader = self.build_dataloader(train_dataset, batch_size, self.train.num_workers, True, False)
193
+
194
+ # pre-fetch data
195
+ self.sub_train[self.active_image_size] = [
196
+ data for data in data_loader for _ in range(max(1, n_samples // len(train_dataset)))
197
+ ]
198
+
199
+ return self.sub_train[self.active_image_size]
efficientvit/apps/data_provider/random_resolution/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Random resolution data loader compatible with multi-processing and distributed training.
2
+
3
+ Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
4
+ at the training time, resolution sampling is controlled by RRSController
5
+ """
6
+
7
+ from .controller import *
efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (496 Bytes). View file
 
efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-310.pyc ADDED
Binary file (3.31 kB). View file
 
efficientvit/apps/data_provider/random_resolution/_data_loader.py ADDED
@@ -0,0 +1,1538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""This file is based on torch/utils/data/data_loader.py
2
+
3
+ Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
4
+
5
+ To support these two classes, in `./_utils` we define many utility methods and
6
+ functions to be run in multiprocessing. E.g., the data loading worker loop is
7
+ in `./_utils/worker.py`.
8
+ """
9
+
10
+ import functools
11
+ import itertools
12
+ import logging
13
+ import multiprocessing as python_multiprocessing
14
+ import os
15
+ import queue
16
+ import threading
17
+ import warnings
18
+ from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, TypeVar, Union
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ import torch.multiprocessing as multiprocessing
23
+ import torch.utils.data.graph_settings
24
+ from torch._utils import ExceptionWrapper
25
+ from torch.utils.data import (
26
+ BatchSampler,
27
+ Dataset,
28
+ IterableDataset,
29
+ IterDataPipe,
30
+ MapDataPipe,
31
+ RandomSampler,
32
+ Sampler,
33
+ SequentialSampler,
34
+ _utils,
35
+ )
36
+ from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper
37
+
38
+ from ._data_worker import _worker_loop
39
+
40
+ __all__ = ["RRSDataLoader"]
41
+
42
+ T_co = TypeVar("T_co", covariant=True)
43
+ T = TypeVar("T")
44
+ _worker_init_fn_t = Callable[[int], None]
45
+
46
+ # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
47
+ # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
48
+ # See https://github.com/python/mypy/issues/3737.
49
+ _collate_fn_t = Callable[[List[T]], Any]
50
+
51
+
52
+ # These functions used to be defined in this file. However, it was moved to
53
+ # _utils/collate.py. Although it is rather hard to access this from user land
54
+ # (one has to explicitly directly `import torch.utils.data.dataloader`), there
55
+ # probably is user code out there using it. This aliasing maintains BC in this
56
+ # aspect.
57
+ default_collate: _collate_fn_t = _utils.collate.default_collate
58
+ default_convert = _utils.collate.default_convert
59
+
60
+ get_worker_info = _utils.worker.get_worker_info
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+
65
+ class _DatasetKind:
66
+ Map = 0
67
+ Iterable = 1
68
+
69
+ @staticmethod
70
+ def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
71
+ if kind == _DatasetKind.Map:
72
+ return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
73
+ else:
74
+ return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
75
+
76
+
77
+ class _InfiniteConstantSampler(Sampler):
78
+ r"""Analogous to ``itertools.repeat(None, None)``.
79
+ Used as sampler for :class:`~torch.utils.data.IterableDataset`.
80
+
81
+ Args:
82
+ data_source (Dataset): dataset to sample from
83
+ """
84
+
85
+ def __init__(self):
86
+ super().__init__(None)
87
+
88
+ def __iter__(self):
89
+ while True:
90
+ yield None
91
+
92
+
93
+ def _get_distributed_settings():
94
+ if dist.is_available() and dist.is_initialized():
95
+ return dist.get_world_size(), dist.get_rank()
96
+ else:
97
+ return 1, 0
98
+
99
+
100
+ def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
101
+ global_worker_id = worker_id
102
+ info = torch.utils.data.get_worker_info()
103
+ assert info is not None
104
+ total_workers = info.num_workers
105
+ datapipe = info.dataset
106
+ assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
107
+ # To distribute elements across distributed process evenly, we should shard data on distributed
108
+ # processes first then shard on worker processes
109
+ total_workers *= world_size
110
+ global_worker_id = global_worker_id * world_size + rank_id
111
+ # For BC, use default SHARDING_PRIORITIES
112
+ torch.utils.data.graph_settings.apply_sharding(datapipe, total_workers, global_worker_id)
113
+ if worker_init_fn is not None:
114
+ worker_init_fn(worker_id)
115
+
116
+
117
+ def _share_dist_seed(generator, pg):
118
+ _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
119
+ if isinstance(pg, dist.ProcessGroup):
120
+ dist.broadcast(_shared_seed, src=0, group=pg)
121
+ return _shared_seed.item()
122
+
123
+
124
+ class RRSDataLoader(Generic[T_co]):
125
+ r"""
126
+ Data loader. Combines a dataset and a sampler, and provides an iterable over
127
+ the given dataset.
128
+
129
+ The :class:`~torch.utils.data.DataLoader` supports both map-style and
130
+ iterable-style datasets with single- or multi-process loading, customizing
131
+ loading order and optional automatic batching (collation) and memory pinning.
132
+
133
+ See :py:mod:`torch.utils.data` documentation page for more details.
134
+
135
+ Args:
136
+ dataset (Dataset): dataset from which to load the data.
137
+ batch_size (int, optional): how many samples per batch to load
138
+ (default: ``1``).
139
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
140
+ at every epoch (default: ``False``).
141
+ sampler (Sampler or Iterable, optional): defines the strategy to draw
142
+ samples from the dataset. Can be any ``Iterable`` with ``__len__``
143
+ implemented. If specified, :attr:`shuffle` must not be specified.
144
+ batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
145
+ returns a batch of indices at a time. Mutually exclusive with
146
+ :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
147
+ and :attr:`drop_last`.
148
+ num_workers (int, optional): how many subprocesses to use for data
149
+ loading. ``0`` means that the data will be loaded in the main process.
150
+ (default: ``0``)
151
+ collate_fn (Callable, optional): merges a list of samples to form a
152
+ mini-batch of Tensor(s). Used when using batched loading from a
153
+ map-style dataset.
154
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
155
+ into device/CUDA pinned memory before returning them. If your data elements
156
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
157
+ see the example below.
158
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
159
+ if the dataset size is not divisible by the batch size. If ``False`` and
160
+ the size of dataset is not divisible by the batch size, then the last batch
161
+ will be smaller. (default: ``False``)
162
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
163
+ from workers. Should always be non-negative. (default: ``0``)
164
+ worker_init_fn (Callable, optional): If not ``None``, this will be called on each
165
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
166
+ input, after seeding and before data loading. (default: ``None``)
167
+ generator (torch.Generator, optional): If not ``None``, this RNG will be used
168
+ by RandomSampler to generate random indexes and multiprocessing to generate
169
+ `base_seed` for workers. (default: ``None``)
170
+ prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
171
+ in advance by each worker. ``2`` means there will be a total of
172
+ 2 * num_workers batches prefetched across all workers. (default value depends
173
+ on the set value for num_workers. If value of num_workers=0 default is ``None``.
174
+ Otherwise if value of num_workers>0 default is ``2``).
175
+ persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
176
+ the worker processes after a dataset has been consumed once. This allows to
177
+ maintain the workers `Dataset` instances alive. (default: ``False``)
178
+ pin_memory_device (str, optional): the data loader will copy Tensors
179
+ into device pinned memory before returning them if pin_memory is set to true.
180
+
181
+
182
+ .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
183
+ cannot be an unpicklable object, e.g., a lambda function. See
184
+ :ref:`multiprocessing-best-practices` on more details related
185
+ to multiprocessing in PyTorch.
186
+
187
+ .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
188
+ When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
189
+ it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
190
+ rounding depending on :attr:`drop_last`, regardless of multi-process loading
191
+ configurations. This represents the best guess PyTorch can make because PyTorch
192
+ trusts user :attr:`dataset` code in correctly handling multi-process
193
+ loading to avoid duplicate data.
194
+
195
+ However, if sharding results in multiple workers having incomplete last batches,
196
+ this estimate can still be inaccurate, because (1) an otherwise complete batch can
197
+ be broken into multiple ones and (2) more than one batch worth of samples can be
198
+ dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
199
+ cases in general.
200
+
201
+ See `Dataset Types`_ for more details on these two types of datasets and how
202
+ :class:`~torch.utils.data.IterableDataset` interacts with
203
+ `Multi-process data loading`_.
204
+
205
+ .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
206
+ :ref:`data-loading-randomness` notes for random seed related questions.
207
+ """
208
+
209
+ dataset: Dataset[T_co]
210
+ batch_size: Optional[int]
211
+ num_workers: int
212
+ pin_memory: bool
213
+ drop_last: bool
214
+ timeout: float
215
+ sampler: Union[Sampler, Iterable]
216
+ pin_memory_device: str
217
+ prefetch_factor: Optional[int]
218
+ _iterator: Optional["_BaseDataLoaderIter"]
219
+ __initialized = False
220
+
221
+ def __init__(
222
+ self,
223
+ dataset: Dataset[T_co],
224
+ batch_size: Optional[int] = 1,
225
+ shuffle: Optional[bool] = None,
226
+ sampler: Union[Sampler, Iterable, None] = None,
227
+ batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
228
+ num_workers: int = 0,
229
+ collate_fn: Optional[_collate_fn_t] = None,
230
+ pin_memory: bool = False,
231
+ drop_last: bool = False,
232
+ timeout: float = 0,
233
+ worker_init_fn: Optional[_worker_init_fn_t] = None,
234
+ multiprocessing_context=None,
235
+ generator=None,
236
+ *,
237
+ prefetch_factor: Optional[int] = None,
238
+ persistent_workers: bool = False,
239
+ pin_memory_device: str = ""
240
+ ):
241
+ torch._C._log_api_usage_once("python.data_loader")
242
+
243
+ if num_workers < 0:
244
+ raise ValueError(
245
+ "num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing."
246
+ )
247
+
248
+ if timeout < 0:
249
+ raise ValueError("timeout option should be non-negative")
250
+
251
+ if num_workers == 0 and prefetch_factor is not None:
252
+ raise ValueError(
253
+ "prefetch_factor option could only be specified in multiprocessing."
254
+ "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
255
+ )
256
+ elif num_workers > 0 and prefetch_factor is None:
257
+ prefetch_factor = 2
258
+ elif prefetch_factor is not None and prefetch_factor < 0:
259
+ raise ValueError("prefetch_factor option should be non-negative")
260
+
261
+ if persistent_workers and num_workers == 0:
262
+ raise ValueError("persistent_workers option needs num_workers > 0")
263
+
264
+ self.dataset = dataset
265
+ self.num_workers = num_workers
266
+ self.prefetch_factor = prefetch_factor
267
+ self.pin_memory = pin_memory
268
+ self.pin_memory_device = pin_memory_device
269
+ self.timeout = timeout
270
+ self.worker_init_fn = worker_init_fn
271
+ self.multiprocessing_context = multiprocessing_context
272
+
273
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
274
+ # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
275
+ if isinstance(self.dataset, IterDataPipe):
276
+ self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
277
+ elif isinstance(self.dataset, MapDataPipe):
278
+ self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
279
+
280
+ # Arg-check dataset related before checking samplers because we want to
281
+ # tell users that iterable-style datasets are incompatible with custom
282
+ # samplers first, so that they don't learn that this combo doesn't work
283
+ # after spending time fixing the custom sampler errors.
284
+ if isinstance(dataset, IterableDataset):
285
+ self._dataset_kind = _DatasetKind.Iterable
286
+ # NOTE [ Custom Samplers and IterableDataset ]
287
+ #
288
+ # `IterableDataset` does not support custom `batch_sampler` or
289
+ # `sampler` since the key is irrelevant (unless we support
290
+ # generator-style dataset one day...).
291
+ #
292
+ # For `sampler`, we always create a dummy sampler. This is an
293
+ # infinite sampler even when the dataset may have an implemented
294
+ # finite `__len__` because in multi-process data loading, naive
295
+ # settings will return duplicated data (which may be desired), and
296
+ # thus using a sampler with length matching that of dataset will
297
+ # cause data lost (you may have duplicates of the first couple
298
+ # batches, but never see anything afterwards). Therefore,
299
+ # `Iterabledataset` always uses an infinite sampler, an instance of
300
+ # `_InfiniteConstantSampler` defined above.
301
+ #
302
+ # A custom `batch_sampler` essentially only controls the batch size.
303
+ # However, it is unclear how useful it would be since an iterable-style
304
+ # dataset can handle that within itself. Moreover, it is pointless
305
+ # in multi-process data loading as the assignment order of batches
306
+ # to workers is an implementation detail so users can not control
307
+ # how to batchify each worker's iterable. Thus, we disable this
308
+ # option. If this turns out to be useful in future, we can re-enable
309
+ # this, and support custom samplers that specify the assignments to
310
+ # specific workers.
311
+ if isinstance(dataset, IterDataPipe):
312
+ if shuffle is not None:
313
+ dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
314
+ # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
315
+ elif shuffle not in {False, None}:
316
+ raise ValueError(
317
+ "DataLoader with IterableDataset: expected unspecified "
318
+ "shuffle option, but got shuffle={}".format(shuffle)
319
+ )
320
+
321
+ if sampler is not None:
322
+ # See NOTE [ Custom Samplers and IterableDataset ]
323
+ raise ValueError(
324
+ "DataLoader with IterableDataset: expected unspecified "
325
+ "sampler option, but got sampler={}".format(sampler)
326
+ )
327
+ elif batch_sampler is not None:
328
+ # See NOTE [ Custom Samplers and IterableDataset ]
329
+ raise ValueError(
330
+ "DataLoader with IterableDataset: expected unspecified "
331
+ "batch_sampler option, but got batch_sampler={}".format(batch_sampler)
332
+ )
333
+ else:
334
+ shuffle = bool(shuffle)
335
+ self._dataset_kind = _DatasetKind.Map
336
+
337
+ if sampler is not None and shuffle:
338
+ raise ValueError("sampler option is mutually exclusive with " "shuffle")
339
+
340
+ if batch_sampler is not None:
341
+ # auto_collation with custom batch_sampler
342
+ if batch_size != 1 or shuffle or sampler is not None or drop_last:
343
+ raise ValueError(
344
+ "batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last"
345
+ )
346
+ batch_size = None
347
+ drop_last = False
348
+ elif batch_size is None:
349
+ # no auto_collation
350
+ if drop_last:
351
+ raise ValueError(
352
+ "batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last"
353
+ )
354
+
355
+ if sampler is None: # give default samplers
356
+ if self._dataset_kind == _DatasetKind.Iterable:
357
+ # See NOTE [ Custom Samplers and IterableDataset ]
358
+ sampler = _InfiniteConstantSampler()
359
+ else: # map-style
360
+ if shuffle:
361
+ sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
362
+ else:
363
+ sampler = SequentialSampler(dataset) # type: ignore[arg-type]
364
+
365
+ if batch_size is not None and batch_sampler is None:
366
+ # auto_collation without custom batch_sampler
367
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
368
+
369
+ self.batch_size = batch_size
370
+ self.drop_last = drop_last
371
+ self.sampler = sampler
372
+ self.batch_sampler = batch_sampler
373
+ self.generator = generator
374
+
375
+ if collate_fn is None:
376
+ if self._auto_collation:
377
+ collate_fn = _utils.collate.default_collate
378
+ else:
379
+ collate_fn = _utils.collate.default_convert
380
+
381
+ self.collate_fn = collate_fn
382
+ self.persistent_workers = persistent_workers
383
+
384
+ self.__initialized = True
385
+ self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
386
+
387
+ self._iterator = None
388
+
389
+ self.check_worker_number_rationality()
390
+
391
+ torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
392
+
393
+ def _get_iterator(self) -> "_BaseDataLoaderIter":
394
+ if self.num_workers == 0:
395
+ return _SingleProcessDataLoaderIter(self)
396
+ else:
397
+ self.check_worker_number_rationality()
398
+ return _MultiProcessingDataLoaderIter(self)
399
+
400
+ @property
401
+ def multiprocessing_context(self):
402
+ return self.__multiprocessing_context
403
+
404
+ @multiprocessing_context.setter
405
+ def multiprocessing_context(self, multiprocessing_context):
406
+ if multiprocessing_context is not None:
407
+ if self.num_workers > 0:
408
+ if isinstance(multiprocessing_context, str):
409
+ valid_start_methods = multiprocessing.get_all_start_methods()
410
+ if multiprocessing_context not in valid_start_methods:
411
+ raise ValueError(
412
+ (
413
+ "multiprocessing_context option "
414
+ "should specify a valid start method in {!r}, but got "
415
+ "multiprocessing_context={!r}"
416
+ ).format(valid_start_methods, multiprocessing_context)
417
+ )
418
+ multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
419
+
420
+ if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
421
+ raise TypeError(
422
+ (
423
+ "multiprocessing_context option should be a valid context "
424
+ "object or a string specifying the start method, but got "
425
+ "multiprocessing_context={}"
426
+ ).format(multiprocessing_context)
427
+ )
428
+ else:
429
+ raise ValueError(
430
+ (
431
+ "multiprocessing_context can only be used with "
432
+ "multi-process loading (num_workers > 0), but got "
433
+ "num_workers={}"
434
+ ).format(self.num_workers)
435
+ )
436
+
437
+ self.__multiprocessing_context = multiprocessing_context
438
+
439
+ def __setattr__(self, attr, val):
440
+ if self.__initialized and attr in (
441
+ "batch_size",
442
+ "batch_sampler",
443
+ "sampler",
444
+ "drop_last",
445
+ "dataset",
446
+ "persistent_workers",
447
+ ):
448
+ raise ValueError(
449
+ "{} attribute should not be set after {} is " "initialized".format(attr, self.__class__.__name__)
450
+ )
451
+
452
+ super().__setattr__(attr, val)
453
+
454
+ # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
455
+ # since '_BaseDataLoaderIter' references 'DataLoader'.
456
+ def __iter__(self) -> "_BaseDataLoaderIter":
457
+ # When using a single worker the returned iterator should be
458
+ # created everytime to avoid reseting its state
459
+ # However, in the case of a multiple workers iterator
460
+ # the iterator is only created once in the lifetime of the
461
+ # DataLoader object so that workers can be reused
462
+ if self.persistent_workers and self.num_workers > 0:
463
+ if self._iterator is None:
464
+ self._iterator = self._get_iterator()
465
+ else:
466
+ self._iterator._reset(self)
467
+ return self._iterator
468
+ else:
469
+ return self._get_iterator()
470
+
471
+ @property
472
+ def _auto_collation(self):
473
+ return self.batch_sampler is not None
474
+
475
+ @property
476
+ def _index_sampler(self):
477
+ # The actual sampler used for generating indices for `_DatasetFetcher`
478
+ # (see _utils/fetch.py) to read data at each time. This would be
479
+ # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
480
+ # We can't change `.sampler` and `.batch_sampler` attributes for BC
481
+ # reasons.
482
+ if self._auto_collation:
483
+ return self.batch_sampler
484
+ else:
485
+ return self.sampler
486
+
487
+ def __len__(self) -> int:
488
+ if self._dataset_kind == _DatasetKind.Iterable:
489
+ # NOTE [ IterableDataset and __len__ ]
490
+ #
491
+ # For `IterableDataset`, `__len__` could be inaccurate when one naively
492
+ # does multi-processing data loading, since the samples will be duplicated.
493
+ # However, no real use case should be actually using that behavior, so
494
+ # it should count as a user error. We should generally trust user
495
+ # code to do the proper thing (e.g., configure each replica differently
496
+ # in `__iter__`), and give us the correct `__len__` if they choose to
497
+ # implement it (this will still throw if the dataset does not implement
498
+ # a `__len__`).
499
+ #
500
+ # To provide a further warning, we track if `__len__` was called on the
501
+ # `DataLoader`, save the returned value in `self._len_called`, and warn
502
+ # if the iterator ends up yielding more than this number of samples.
503
+
504
+ # Cannot statically verify that dataset is Sized
505
+ length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
506
+ if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler
507
+ from math import ceil
508
+
509
+ if self.drop_last:
510
+ length = length // self.batch_size
511
+ else:
512
+ length = ceil(length / self.batch_size)
513
+ return length
514
+ else:
515
+ return len(self._index_sampler)
516
+
517
+ def check_worker_number_rationality(self):
518
+ # This function check whether the dataloader's worker number is rational based on
519
+ # current system's resource. Current rule is that if the number of workers this
520
+ # Dataloader will create is bigger than the number of logical cpus that is allowed to
521
+ # use, than we will pop up a warning to let user pay attention.
522
+ #
523
+ # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
524
+ # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
525
+ # DataLoader process can use half of them which is 32, then the rational max number of
526
+ # worker that initiated from this process is 32.
527
+ # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
528
+ # So the warning message is triggered to notify the user to lower the worker number if
529
+ # necessary.
530
+ #
531
+ #
532
+ # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
533
+ # available (available in most of Linux system, but not OSX and Windows).
534
+ # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
535
+ # it doesn't repect cpuset.
536
+ # We don't take threading into account since each worker process is single threaded
537
+ # at this time.
538
+ #
539
+ # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
540
+ # other than `torch.set_num_threads` to 1 in the worker process, if the passing
541
+ # in functions use 3rd party modules that rely on those threading flags to determine
542
+ # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
543
+ # set those flags correctly.
544
+ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
545
+ suggested_max_worker_msg = (
546
+ (
547
+ (
548
+ "Our suggested max number of worker in current system is {}{}, which is smaller "
549
+ "than what this DataLoader is going to create."
550
+ ).format(
551
+ num_worker_suggest,
552
+ ("" if cpuset_checked else " (`cpuset` is not taken into account)"),
553
+ )
554
+ )
555
+ if num_worker_suggest is not None
556
+ else ("DataLoader is not able to compute a suggested max number of worker in current system.")
557
+ )
558
+
559
+ warn_msg = (
560
+ "This DataLoader will create {} worker processes in total. {} "
561
+ "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
562
+ "lower the worker number to avoid potential slowness/freeze if necessary."
563
+ ).format(num_worker_created, suggested_max_worker_msg)
564
+ return warn_msg
565
+
566
+ if not self.num_workers or self.num_workers == 0:
567
+ return
568
+
569
+ # try to compute a suggested max number of worker based on system's resource
570
+ max_num_worker_suggest = None
571
+ cpuset_checked = False
572
+ if hasattr(os, "sched_getaffinity"):
573
+ try:
574
+ max_num_worker_suggest = len(os.sched_getaffinity(0))
575
+ cpuset_checked = True
576
+ except Exception:
577
+ pass
578
+ if max_num_worker_suggest is None:
579
+ # os.cpu_count() could return Optional[int]
580
+ # get cpu count first and check None in order to satify mypy check
581
+ cpu_count = os.cpu_count()
582
+ if cpu_count is not None:
583
+ max_num_worker_suggest = cpu_count
584
+
585
+ if max_num_worker_suggest is None:
586
+ warnings.warn(_create_warning_msg(max_num_worker_suggest, self.num_workers, cpuset_checked))
587
+ return
588
+
589
+ if self.num_workers > max_num_worker_suggest:
590
+ warnings.warn(_create_warning_msg(max_num_worker_suggest, self.num_workers, cpuset_checked))
591
+
592
+
593
+ class _BaseDataLoaderIter:
594
+ def __init__(self, loader: RRSDataLoader) -> None:
595
+ self._dataset = loader.dataset
596
+ self._shared_seed = None
597
+ self._pg = None
598
+ if isinstance(self._dataset, IterDataPipe):
599
+ if dist.is_available() and dist.is_initialized():
600
+ self._pg = dist.new_group(backend="gloo")
601
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
602
+ shared_rng = torch.Generator()
603
+ shared_rng.manual_seed(self._shared_seed)
604
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
605
+ self._dataset_kind = loader._dataset_kind
606
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
607
+ self._auto_collation = loader._auto_collation
608
+ self._drop_last = loader.drop_last
609
+ self._index_sampler = loader._index_sampler
610
+ self._num_workers = loader.num_workers
611
+ ws, rank = _get_distributed_settings()
612
+ self._world_size = ws
613
+ self._rank = rank
614
+ # for other backends, pin_memory_device need to set. if not set
615
+ # default behaviour is CUDA device. if pin_memory_device is selected
616
+ # and pin_memory is not set, the default behaviour false.
617
+ if len(loader.pin_memory_device) == 0:
618
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
619
+ self._pin_memory_device = None
620
+ else:
621
+ if not loader.pin_memory:
622
+ warn_msg = (
623
+ "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
624
+ "please set pin_memory to true, if you need to use the device pin memory"
625
+ )
626
+ warnings.warn(warn_msg)
627
+
628
+ self._pin_memory = loader.pin_memory
629
+ self._pin_memory_device = loader.pin_memory_device
630
+ self._timeout = loader.timeout
631
+ self._collate_fn = loader.collate_fn
632
+ self._sampler_iter = iter(self._index_sampler)
633
+ self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
634
+ self._persistent_workers = loader.persistent_workers
635
+ self._num_yielded = 0
636
+ self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
637
+
638
+ def __iter__(self) -> "_BaseDataLoaderIter":
639
+ return self
640
+
641
+ def _reset(self, loader, first_iter=False):
642
+ self._sampler_iter = iter(self._index_sampler)
643
+ self._num_yielded = 0
644
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
645
+ if isinstance(self._dataset, IterDataPipe):
646
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
647
+ shared_rng = torch.Generator()
648
+ shared_rng.manual_seed(self._shared_seed)
649
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
650
+
651
+ def _next_index(self):
652
+ return next(self._sampler_iter) # may raise StopIteration
653
+
654
+ def _next_data(self):
655
+ raise NotImplementedError
656
+
657
+ def __next__(self) -> Any:
658
+ with torch.autograd.profiler.record_function(self._profile_name):
659
+ if self._sampler_iter is None:
660
+ # TODO(https://github.com/pytorch/pytorch/issues/76750)
661
+ self._reset() # type: ignore[call-arg]
662
+ data = self._next_data()
663
+ self._num_yielded += 1
664
+ if (
665
+ self._dataset_kind == _DatasetKind.Iterable
666
+ and self._IterableDataset_len_called is not None
667
+ and self._num_yielded > self._IterableDataset_len_called
668
+ ):
669
+ warn_msg = (
670
+ "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
671
+ "samples have been fetched. "
672
+ ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded)
673
+ if self._num_workers > 0:
674
+ warn_msg += (
675
+ "For multiprocessing data-loading, this could be caused by not properly configuring the "
676
+ "IterableDataset replica at each worker. Please see "
677
+ "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
678
+ )
679
+ warnings.warn(warn_msg)
680
+ return data
681
+
682
+ def __len__(self) -> int:
683
+ return len(self._index_sampler)
684
+
685
+ def __getstate__(self):
686
+ # TODO: add limited pickling support for sharing an iterator
687
+ # across multiple threads for HOGWILD.
688
+ # Probably the best way to do this is by moving the sample pushing
689
+ # to a separate thread and then just sharing the data queue
690
+ # but signalling the end is tricky without a non-blocking API
691
+ raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
692
+
693
+
694
+ class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
695
+ def __init__(self, loader):
696
+ super().__init__(loader)
697
+ assert self._timeout == 0
698
+ assert self._num_workers == 0
699
+
700
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
701
+ # Taking care of distributed sharding
702
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
703
+ # For BC, use default SHARDING_PRIORITIES
704
+ torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank)
705
+
706
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
707
+ self._dataset_kind,
708
+ self._dataset,
709
+ self._auto_collation,
710
+ self._collate_fn,
711
+ self._drop_last,
712
+ )
713
+
714
+ def _next_data(self):
715
+ index = self._next_index() # may raise StopIteration
716
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
717
+ if self._pin_memory:
718
+ data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
719
+ return data
720
+
721
+
722
+ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
723
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
724
+
725
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
726
+ #
727
+ # Preliminary:
728
+ #
729
+ # Our data model looks like this (queues are indicated with curly brackets):
730
+ #
731
+ # main process ||
732
+ # | ||
733
+ # {index_queue} ||
734
+ # | ||
735
+ # worker processes || DATA
736
+ # | ||
737
+ # {worker_result_queue} || FLOW
738
+ # | ||
739
+ # pin_memory_thread of main process || DIRECTION
740
+ # | ||
741
+ # {data_queue} ||
742
+ # | ||
743
+ # data output \/
744
+ #
745
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
746
+ # `pin_memory=False`.
747
+ #
748
+ #
749
+ # Terminating multiprocessing logic requires very careful design. In
750
+ # particular, we need to make sure that
751
+ #
752
+ # 1. The iterator gracefully exits the workers when its last reference is
753
+ # gone or it is depleted.
754
+ #
755
+ # In this case, the workers should be gracefully exited because the
756
+ # main process may still need to continue to run, and we want cleaning
757
+ # up code in the workers to be executed (e.g., releasing GPU memory).
758
+ # Naturally, we implement the shutdown logic in `__del__` of
759
+ # DataLoaderIterator.
760
+ #
761
+ # We delay the discussion on the logic in this case until later.
762
+ #
763
+ # 2. The iterator exits the workers when the loader process and/or worker
764
+ # processes exits normally or with error.
765
+ #
766
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
767
+ #
768
+ # You may ask, why can't we make the workers non-daemonic, and
769
+ # gracefully exit using the same logic as we have in `__del__` when the
770
+ # iterator gets deleted (see 1 above)?
771
+ #
772
+ # First of all, `__del__` is **not** guaranteed to be called when
773
+ # interpreter exits. Even if it is called, by the time it executes,
774
+ # many Python core library resources may alreay be freed, and even
775
+ # simple things like acquiring an internal lock of a queue may hang.
776
+ # Therefore, in this case, we actually need to prevent `__del__` from
777
+ # being executed, and rely on the automatic termination of daemonic
778
+ # children.
779
+ #
780
+ # Thus, we register an `atexit` hook that sets a global flag
781
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
782
+ # reverse order of registration, we are guaranteed that this flag is
783
+ # set before library resources we use are freed (which, at least in
784
+ # CPython, is done via an `atexit` handler defined in
785
+ # `multiprocessing/util.py`
786
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
787
+ # registered when an object requiring this mechanism is first
788
+ # created, e.g., `mp.Queue`
789
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
790
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
791
+ # )
792
+ #
793
+ # So in `__del__`, we check if `_utils.python_exit_status` is set or
794
+ # `None` (freed), and perform no-op if so.
795
+ #
796
+ # However, simply letting library clean-up codes run can also be bad,
797
+ # because such codes (i.e., `multiprocessing.util._exit_function()`)
798
+ # include join putting threads for `mp.Queue`, which can be blocking.
799
+ # Hence, the main process putting threads are called with
800
+ # `cancel_join_thread` at creation. See later section
801
+ # [ 3b. A process won't hang when putting into a queue; ]
802
+ # for more details.
803
+ #
804
+ # Here are two example cases where library clean-up codes can run
805
+ # before `__del__` is called:
806
+ #
807
+ # 1. If we hold onto a reference to the iterator, it more often
808
+ # than not tries to do `multiprocessing` library cleaning before
809
+ # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
810
+ # and thus prevents our cleaning-up code to run first.
811
+ #
812
+ # 2. A similar issue araises when a `DataLoader` is used in a subprocess.
813
+ # When a process ends, it shuts the all its daemonic children
814
+ # down with a SIGTERM (instead of joining them without a timeout).
815
+ # Simiarly for threads, but by a different mechanism. This fact,
816
+ # together with a few implementation details of multiprocessing, forces
817
+ # us to make workers daemonic. All of our problems arise when a
818
+ # DataLoader is used in a subprocess, and are caused by multiprocessing
819
+ # code which looks more or less like this:
820
+ #
821
+ # try:
822
+ # your_function_using_a_dataloader()
823
+ # finally:
824
+ # multiprocessing.util._exit_function()
825
+ #
826
+ # The joining/termination mentioned above happens inside
827
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
828
+ # throws, the stack trace stored in the exception will prevent the
829
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
830
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
831
+ # its `__del__`, which starts the shutdown procedure, will not be
832
+ # called. That, in turn, means that workers aren't notified. Attempting
833
+ # to join in `_exit_function` will then result in a hang.
834
+ #
835
+ # For context, `_exit_function` is also registered as an `atexit` call.
836
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
837
+ # The code dates back to 2008 and there is no comment on the original
838
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
839
+ # the finally block and the `atexit` registration) that explains this.
840
+ #
841
+ #
842
+ # Finally, another choice is to just shutdown workers with logic in 1
843
+ # above whenever we see an error in `next`. This isn't ideal because
844
+ # a. It prevents users from using try-catch to resume data loading.
845
+ # b. It doesn't prevent hanging if users have references to the
846
+ # iterator.
847
+ #
848
+ # 3. All processes exit if any of them die unexpectedly by fatal signals.
849
+ #
850
+ # As shown above, the workers are set as daemonic children of the main
851
+ # process. However, automatic cleaning-up of such child processes only
852
+ # happens if the parent process exits gracefully (e.g., not via fatal
853
+ # signals like SIGKILL). So we must ensure that each process will exit
854
+ # even the process that should send/receive data to/from it were
855
+ # killed, i.e.,
856
+ #
857
+ # a. A process won't hang when getting from a queue.
858
+ #
859
+ # Even with carefully designed data dependencies (i.e., a `put()`
860
+ # always corresponding to a `get()`), hanging on `get()` can still
861
+ # happen when data in queue is corrupted (e.g., due to
862
+ # `cancel_join_thread` or unexpected exit).
863
+ #
864
+ # For child exit, we set a timeout whenever we try to get data
865
+ # from `data_queue`, and check the workers' status on each timeout
866
+ # and error.
867
+ # See `_DataLoaderiter._get_batch()` and
868
+ # `_DataLoaderiter._try_get_data()` for details.
869
+ #
870
+ # Additionally, for child exit on non-Windows platforms, we also
871
+ # register a SIGCHLD handler (which is supported on Windows) on
872
+ # the main process, which checks if any of the workers fail in the
873
+ # (Python) handler. This is more efficient and faster in detecting
874
+ # worker failures, compared to only using the above mechanism.
875
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
876
+ #
877
+ # For `.get()` calls where the sender(s) is not the workers, we
878
+ # guard them with timeouts, and check the status of the sender
879
+ # when timeout happens:
880
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
881
+ # checks the status of the main process.
882
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
883
+ # check `pin_memory_thread` status periodically until `.get()`
884
+ # returns or see that `pin_memory_thread` died.
885
+ #
886
+ # b. A process won't hang when putting into a queue;
887
+ #
888
+ # We use `mp.Queue` which has a separate background thread to put
889
+ # objects from an unbounded buffer array. The background thread is
890
+ # daemonic and usually automatically joined when the process
891
+ # *exits*.
892
+ #
893
+ # In case that the receiver has ended abruptly while
894
+ # reading from the pipe, the join will hang forever. The usual
895
+ # solution for this in Python is calling `q.cancel_join_thread`,
896
+ # which prevents automatically joining it when finalizing
897
+ # (exiting).
898
+ #
899
+ # Nonetheless, `cancel_join_thread` must only be called when the
900
+ # queue is **not** going to be read from or write into by another
901
+ # process, because it may hold onto a lock or leave corrupted data
902
+ # in the queue, leading other readers/writers to hang.
903
+ #
904
+ # Hence,
905
+ # + For worker processes, we only do so (for their output
906
+ # queues, i.e., `worker_result_queue`) before exiting.
907
+ # + For `pin_memory_thread`, its output queue `data_queue` is a
908
+ # `queue.Queue` that does blocking `put` if the queue is full.
909
+ # So there is no above problem, but as a result, in
910
+ # `_pin_memory_loop`, we do need to wrap the `put` in a loop
911
+ # that breaks not only upon success, but also when the main
912
+ # process stops reading, i.e., is shutting down.
913
+ # + For loader process, we `cancel_join_thread()` for all
914
+ # `_index_queues` because the whole purpose of workers and
915
+ # `pin_memory_thread` is to serve the loader process. If
916
+ # loader process is already exiting, we don't really care if
917
+ # the queues are corrupted.
918
+ #
919
+ #
920
+ # Now let's get back to 1:
921
+ # how we gracefully exit the workers when the last reference to the
922
+ # iterator is gone.
923
+ #
924
+ # To achieve this, we implement the following logic along with the design
925
+ # choices mentioned above:
926
+ #
927
+ # `workers_done_event`:
928
+ # A `multiprocessing.Event` shared among the main process and all worker
929
+ # processes. This is used to signal the workers that the iterator is
930
+ # shutting down. After it is set, they will not send processed data to
931
+ # queues anymore, and only wait for the final `None` before exiting.
932
+ # `done_event` isn't strictly needed. I.e., we can just check for `None`
933
+ # from the input queue, but it allows us to skip wasting resources
934
+ # processing data if we are already shutting down.
935
+ #
936
+ # `pin_memory_thread_done_event`:
937
+ # A `threading.Event` for a similar purpose to that of
938
+ # `workers_done_event`, but is for the `pin_memory_thread`. The reason
939
+ # that separate events are needed is that `pin_memory_thread` reads from
940
+ # the output queue of the workers. But the workers, upon seeing that
941
+ # `workers_done_event` is set, only wants to see the final `None`, and is
942
+ # not required to flush all data in the output queue (e.g., it may call
943
+ # `cancel_join_thread` on that queue if its `IterableDataset` iterator
944
+ # happens to exhaust coincidentally, which is out of the control of the
945
+ # main process). Thus, since we will exit `pin_memory_thread` before the
946
+ # workers (see below), two separete events are used.
947
+ #
948
+ # NOTE: In short, the protocol is that the main process will set these
949
+ # `done_event`s and then the corresponding processes/threads a `None`,
950
+ # and that they may exit at any time after receiving the `None`.
951
+ #
952
+ # NOTE: Using `None` as the final signal is valid, since normal data will
953
+ # always be a 2-tuple with the 1st element being the index of the data
954
+ # transferred (different from dataset index/key), and the 2nd being
955
+ # either the dataset key or the data sample (depending on which part
956
+ # of the data model the queue is at).
957
+ #
958
+ # [ worker processes ]
959
+ # While loader process is alive:
960
+ # Get from `index_queue`.
961
+ # If get anything else,
962
+ # Check `workers_done_event`.
963
+ # If set, continue to next iteration
964
+ # i.e., keep getting until see the `None`, then exit.
965
+ # Otherwise, process data:
966
+ # If is fetching from an `IterableDataset` and the iterator
967
+ # is exhausted, send an `_IterableDatasetStopIteration`
968
+ # object to signal iteration end. The main process, upon
969
+ # receiving such an object, will send `None` to this
970
+ # worker and not use the corresponding `index_queue`
971
+ # anymore.
972
+ # If timed out,
973
+ # No matter `workers_done_event` is set (still need to see `None`)
974
+ # or not, must continue to next iteration.
975
+ # (outside loop)
976
+ # If `workers_done_event` is set, (this can be False with `IterableDataset`)
977
+ # `data_queue.cancel_join_thread()`. (Everything is ending here:
978
+ # main process won't read from it;
979
+ # other workers will also call
980
+ # `cancel_join_thread`.)
981
+ #
982
+ # [ pin_memory_thread ]
983
+ # # No need to check main thread. If this thread is alive, the main loader
984
+ # # thread must be alive, because this thread is set as daemonic.
985
+ # While `pin_memory_thread_done_event` is not set:
986
+ # Get from `index_queue`.
987
+ # If timed out, continue to get in the next iteration.
988
+ # Otherwise, process data.
989
+ # While `pin_memory_thread_done_event` is not set:
990
+ # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
991
+ # If timed out, continue to put in the next iteration.
992
+ # Otherwise, break, i.e., continuing to the out loop.
993
+ #
994
+ # NOTE: we don't check the status of the main thread because
995
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
996
+ # ends.
997
+ # 2. in other cases, either the cleaning-up in __del__ or the
998
+ # automatic exit of daemonic thread will take care of it.
999
+ # This won't busy-wait either because `.get(timeout)` does not
1000
+ # busy-wait.
1001
+ #
1002
+ # [ main process ]
1003
+ # In the DataLoader Iter's `__del__`
1004
+ # b. Exit `pin_memory_thread`
1005
+ # i. Set `pin_memory_thread_done_event`.
1006
+ # ii Put `None` in `worker_result_queue`.
1007
+ # iii. Join the `pin_memory_thread`.
1008
+ # iv. `worker_result_queue.cancel_join_thread()`.
1009
+ #
1010
+ # c. Exit the workers.
1011
+ # i. Set `workers_done_event`.
1012
+ # ii. Put `None` in each worker's `index_queue`.
1013
+ # iii. Join the workers.
1014
+ # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
1015
+ #
1016
+ # NOTE: (c) is better placed after (b) because it may leave corrupted
1017
+ # data in `worker_result_queue`, which `pin_memory_thread`
1018
+ # reads from, in which case the `pin_memory_thread` can only
1019
+ # happen at timeing out, which is slow. Nonetheless, same thing
1020
+ # happens if a worker is killed by signal at unfortunate times,
1021
+ # but in other cases, we are better off having a non-corrupted
1022
+ # `worker_result_queue` for `pin_memory_thread`.
1023
+ #
1024
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
1025
+ # can be omitted
1026
+ #
1027
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
1028
+ # `None` from `index_queue`, but it allows us to skip wasting resources
1029
+ # processing indices already in `index_queue` if we are already shutting
1030
+ # down.
1031
+
1032
+ def __init__(self, loader):
1033
+ super().__init__(loader)
1034
+
1035
+ self._prefetch_factor = loader.prefetch_factor
1036
+
1037
+ assert self._num_workers > 0
1038
+ assert self._prefetch_factor > 0
1039
+
1040
+ if loader.multiprocessing_context is None:
1041
+ multiprocessing_context = multiprocessing
1042
+ else:
1043
+ multiprocessing_context = loader.multiprocessing_context
1044
+
1045
+ self._worker_init_fn = loader.worker_init_fn
1046
+
1047
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
1048
+ # Additional worker init function will take care of sharding in MP and Distributed
1049
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
1050
+ self._worker_init_fn = functools.partial(
1051
+ _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
1052
+ )
1053
+
1054
+ # No certainty which module multiprocessing_context is
1055
+ self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1056
+ self._worker_pids_set = False
1057
+ self._shutdown = False
1058
+ self._workers_done_event = multiprocessing_context.Event()
1059
+
1060
+ self._index_queues = []
1061
+ self._workers = []
1062
+ for i in range(self._num_workers):
1063
+ # No certainty which module multiprocessing_context is
1064
+ index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1065
+ # Need to `cancel_join_thread` here!
1066
+ # See sections (2) and (3b) above.
1067
+ index_queue.cancel_join_thread()
1068
+ w = multiprocessing_context.Process(
1069
+ target=_worker_loop,
1070
+ args=(
1071
+ self._dataset_kind,
1072
+ self._dataset,
1073
+ index_queue,
1074
+ self._worker_result_queue,
1075
+ self._workers_done_event,
1076
+ self._auto_collation,
1077
+ self._collate_fn,
1078
+ self._drop_last,
1079
+ self._base_seed,
1080
+ self._worker_init_fn,
1081
+ i,
1082
+ self._num_workers,
1083
+ self._persistent_workers,
1084
+ self._shared_seed,
1085
+ ),
1086
+ )
1087
+ w.daemon = True
1088
+ # NB: Process.start() actually take some time as it needs to
1089
+ # start a process and pass the arguments over via a pipe.
1090
+ # Therefore, we only add a worker to self._workers list after
1091
+ # it started, so that we do not call .join() if program dies
1092
+ # before it starts, and __del__ tries to join but will get:
1093
+ # AssertionError: can only join a started process.
1094
+ w.start()
1095
+ self._index_queues.append(index_queue)
1096
+ self._workers.append(w)
1097
+
1098
+ if self._pin_memory:
1099
+ self._pin_memory_thread_done_event = threading.Event()
1100
+
1101
+ # Queue is not type-annotated
1102
+ self._data_queue = queue.Queue() # type: ignore[var-annotated]
1103
+ if self._pin_memory_device == "xpu":
1104
+ current_device = torch.xpu.current_device() # type: ignore[attr-defined]
1105
+ else:
1106
+ current_device = torch.cuda.current_device() # choose cuda for default
1107
+ pin_memory_thread = threading.Thread(
1108
+ target=_utils.pin_memory._pin_memory_loop,
1109
+ args=(
1110
+ self._worker_result_queue,
1111
+ self._data_queue,
1112
+ current_device,
1113
+ self._pin_memory_thread_done_event,
1114
+ self._pin_memory_device,
1115
+ ),
1116
+ )
1117
+ pin_memory_thread.daemon = True
1118
+ pin_memory_thread.start()
1119
+ # Similar to workers (see comment above), we only register
1120
+ # pin_memory_thread once it is started.
1121
+ self._pin_memory_thread = pin_memory_thread
1122
+ else:
1123
+ self._data_queue = self._worker_result_queue
1124
+
1125
+ # In some rare cases, persistent workers (daemonic processes)
1126
+ # would be terminated before `__del__` of iterator is invoked
1127
+ # when main process exits
1128
+ # It would cause failure when pin_memory_thread tries to read
1129
+ # corrupted data from worker_result_queue
1130
+ # atexit is used to shutdown thread and child processes in the
1131
+ # right sequence before main process exits
1132
+ if self._persistent_workers and self._pin_memory:
1133
+ import atexit
1134
+
1135
+ for w in self._workers:
1136
+ atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
1137
+
1138
+ # .pid can be None only before process is spawned (not the case, so ignore)
1139
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
1140
+ _utils.signal_handling._set_SIGCHLD_handler()
1141
+ self._worker_pids_set = True
1142
+ self._reset(loader, first_iter=True)
1143
+
1144
+ def _reset(self, loader, first_iter=False):
1145
+ super()._reset(loader, first_iter)
1146
+ self._send_idx = 0 # idx of the next task to be sent to workers
1147
+ self._rcvd_idx = 0 # idx of the next task to be returned in __next__
1148
+ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
1149
+ # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
1150
+ # \ (worker_id, data) if data is already fetched (out-of-order)
1151
+ self._task_info = {}
1152
+ self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
1153
+ # A list of booleans representing whether each worker still has work to
1154
+ # do, i.e., not having exhausted its iterable dataset object. It always
1155
+ # contains all `True`s if not using an iterable-style dataset
1156
+ # (i.e., if kind != Iterable).
1157
+ # Not that this indicates that a worker still has work to do *for this epoch*.
1158
+ # It does not mean that a worker is dead. In case of `_persistent_workers`,
1159
+ # the worker will be reset to available in the next epoch.
1160
+ self._workers_status = [True for i in range(self._num_workers)]
1161
+ # Reset the worker queue cycle so it resumes next epoch at worker 0
1162
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
1163
+ # We resume the prefetching in case it was enabled
1164
+ if not first_iter:
1165
+ for idx in range(self._num_workers):
1166
+ self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed))
1167
+ resume_iteration_cnt = self._num_workers
1168
+ while resume_iteration_cnt > 0:
1169
+ return_idx, return_data = self._get_data()
1170
+ if isinstance(return_idx, _utils.worker._ResumeIteration):
1171
+ assert return_data is None
1172
+ resume_iteration_cnt -= 1
1173
+ # prime the prefetch loop
1174
+ for _ in range(self._prefetch_factor * self._num_workers):
1175
+ self._try_put_index()
1176
+
1177
+ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
1178
+ # Tries to fetch data from `self._data_queue` once for a given timeout.
1179
+ # This can also be used as inner loop of fetching without timeout, with
1180
+ # the sender status as the loop condition.
1181
+ #
1182
+ # This raises a `RuntimeError` if any worker died expectedly. This error
1183
+ # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
1184
+ # (only for non-Windows platforms), or the manual check below on errors
1185
+ # and timeouts.
1186
+ #
1187
+ # Returns a 2-tuple:
1188
+ # (bool: whether successfully get data, any: data if successful else None)
1189
+ try:
1190
+ data = self._data_queue.get(timeout=timeout)
1191
+ return (True, data)
1192
+ except Exception as e:
1193
+ # At timeout and error, we manually check whether any worker has
1194
+ # failed. Note that this is the only mechanism for Windows to detect
1195
+ # worker failures.
1196
+ failed_workers = []
1197
+ for worker_id, w in enumerate(self._workers):
1198
+ if self._workers_status[worker_id] and not w.is_alive():
1199
+ failed_workers.append(w)
1200
+ self._mark_worker_as_unavailable(worker_id)
1201
+ if len(failed_workers) > 0:
1202
+ pids_str = ", ".join(str(w.pid) for w in failed_workers)
1203
+ raise RuntimeError("DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)) from e
1204
+ if isinstance(e, queue.Empty):
1205
+ return (False, None)
1206
+ import errno
1207
+ import tempfile
1208
+
1209
+ try:
1210
+ # Raise an exception if we are this close to the FDs limit.
1211
+ # Apparently, trying to open only one file is not a sufficient
1212
+ # test.
1213
+ # See NOTE [ DataLoader on Linux and open files limit ]
1214
+ fds_limit_margin = 10
1215
+ fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
1216
+ except OSError as e:
1217
+ if e.errno == errno.EMFILE:
1218
+ raise RuntimeError(
1219
+ "Too many open files. Communication with the"
1220
+ " workers is no longer possible. Please increase the"
1221
+ " limit using `ulimit -n` in the shell or change the"
1222
+ " sharing strategy by calling"
1223
+ " `torch.multiprocessing.set_sharing_strategy('file_system')`"
1224
+ " at the beginning of your code"
1225
+ ) from None
1226
+ raise
1227
+
1228
+ # NOTE [ DataLoader on Linux and open files limit ]
1229
+ #
1230
+ # On Linux when DataLoader is used with multiprocessing we pass the data between
1231
+ # the root process and the workers through SHM files. We remove those files from
1232
+ # the filesystem as soon as they are created and keep them alive by
1233
+ # passing around their file descriptors through AF_UNIX sockets. (See
1234
+ # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
1235
+ # the wiki (https://github.com/pytorch/pytorch/wiki).)
1236
+ #
1237
+ # This sometimes leads us to exceeding the open files limit. When that happens,
1238
+ # and the offending file descriptor is coming over a socket, the `socket` Python
1239
+ # package silently strips the file descriptor from the message, setting only the
1240
+ # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
1241
+ # it _indicates that some control data were discarded due to lack of space in
1242
+ # the buffer for ancillary data_). This might reflect the C implementation of
1243
+ # AF_UNIX sockets.
1244
+ #
1245
+ # This behaviour can be reproduced with the script and instructions at the
1246
+ # bottom of this note.
1247
+ #
1248
+ # When that happens, the standard Python `multiprocessing` (and not
1249
+ # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
1250
+ #
1251
+ # Sometimes, instead of the FD being stripped, you may get an `OSError:
1252
+ # Too many open files`, both in the script below and in DataLoader. However,
1253
+ # this is rare and seems to be nondeterministic.
1254
+ #
1255
+ #
1256
+ # #!/usr/bin/env python3
1257
+ # import sys
1258
+ # import socket
1259
+ # import os
1260
+ # import array
1261
+ # import shutil
1262
+ # import socket
1263
+ #
1264
+ #
1265
+ # if len(sys.argv) != 4:
1266
+ # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
1267
+ # sys.exit(1)
1268
+ #
1269
+ # if __name__ == '__main__':
1270
+ # dirname = sys.argv[1]
1271
+ # sock_path = dirname + "/sock"
1272
+ # iterations = int(sys.argv[2])
1273
+ # def dummy_path(i):
1274
+ # return dirname + "/" + str(i) + ".dummy"
1275
+ #
1276
+ #
1277
+ # if sys.argv[3] == 'send':
1278
+ # while not os.path.exists(sock_path):
1279
+ # pass
1280
+ # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1281
+ # client.connect(sock_path)
1282
+ # for i in range(iterations):
1283
+ # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
1284
+ # ancdata = array.array('i', [fd])
1285
+ # msg = bytes([i % 256])
1286
+ # print("Sending fd ", fd, " (iteration #", i, ")")
1287
+ # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
1288
+ #
1289
+ #
1290
+ # else:
1291
+ # assert sys.argv[3] == 'recv'
1292
+ #
1293
+ # if os.path.exists(dirname):
1294
+ # raise Exception("Directory exists")
1295
+ #
1296
+ # os.mkdir(dirname)
1297
+ #
1298
+ # print("Opening socket...")
1299
+ # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1300
+ # server.bind(sock_path)
1301
+ #
1302
+ # print("Listening...")
1303
+ # for i in range(iterations):
1304
+ # a = array.array('i')
1305
+ # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
1306
+ # assert(len(ancdata) == 1)
1307
+ # cmsg_level, cmsg_type, cmsg_data = ancdata[0]
1308
+ # a.frombytes(cmsg_data)
1309
+ # print("Received fd ", a[0], " (iteration #", i, ")")
1310
+ #
1311
+ # shutil.rmtree(dirname)
1312
+ #
1313
+ # Steps to reproduce:
1314
+ #
1315
+ # 1. Run two shells and set lower file descriptor limit in the receiving one:
1316
+ # (shell1) ulimit -n 1020
1317
+ # (shell2) ulimit -n 1022
1318
+ #
1319
+ # 2. Run the script above with the `recv` option in the first shell
1320
+ # (shell1) ./test_socket.py sock_tmp 1017 recv
1321
+ #
1322
+ # 3. Run the script with the `send` option in the second shell:
1323
+ # (shell2) ./test_socket.py sock_tmp 1017 send
1324
+
1325
+ def _get_data(self):
1326
+ # Fetches data from `self._data_queue`.
1327
+ #
1328
+ # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
1329
+ # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
1330
+ # in a loop. This is the only mechanism to detect worker failures for
1331
+ # Windows. For other platforms, a SIGCHLD handler is also used for
1332
+ # worker failure detection.
1333
+ #
1334
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
1335
+ # died at timeouts.
1336
+ if self._timeout > 0:
1337
+ success, data = self._try_get_data(self._timeout)
1338
+ if success:
1339
+ return data
1340
+ else:
1341
+ raise RuntimeError("DataLoader timed out after {} seconds".format(self._timeout))
1342
+ elif self._pin_memory:
1343
+ while self._pin_memory_thread.is_alive():
1344
+ success, data = self._try_get_data()
1345
+ if success:
1346
+ return data
1347
+ else:
1348
+ # while condition is false, i.e., pin_memory_thread died.
1349
+ raise RuntimeError("Pin memory thread exited unexpectedly")
1350
+ # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1351
+ # need to call `.task_done()` because we don't use `.join()`.
1352
+ else:
1353
+ while True:
1354
+ success, data = self._try_get_data()
1355
+ if success:
1356
+ return data
1357
+
1358
+ def _next_data(self):
1359
+ while True:
1360
+ # If the worker responsible for `self._rcvd_idx` has already ended
1361
+ # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
1362
+ # we try to advance `self._rcvd_idx` to find the next valid index.
1363
+ #
1364
+ # This part needs to run in the loop because both the `self._get_data()`
1365
+ # call and `_IterableDatasetStopIteration` check below can mark
1366
+ # extra worker(s) as dead.
1367
+ while self._rcvd_idx < self._send_idx:
1368
+ info = self._task_info[self._rcvd_idx]
1369
+ worker_id = info[0]
1370
+ if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
1371
+ break
1372
+ del self._task_info[self._rcvd_idx]
1373
+ self._rcvd_idx += 1
1374
+ else:
1375
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
1376
+ if not self._persistent_workers:
1377
+ self._shutdown_workers()
1378
+ raise StopIteration
1379
+
1380
+ # Now `self._rcvd_idx` is the batch index we want to fetch
1381
+
1382
+ # Check if the next sample has already been generated
1383
+ if len(self._task_info[self._rcvd_idx]) == 2:
1384
+ data = self._task_info.pop(self._rcvd_idx)[1]
1385
+ return self._process_data(data)
1386
+
1387
+ assert not self._shutdown and self._tasks_outstanding > 0
1388
+ idx, data = self._get_data()
1389
+ self._tasks_outstanding -= 1
1390
+ if self._dataset_kind == _DatasetKind.Iterable:
1391
+ # Check for _IterableDatasetStopIteration
1392
+ if isinstance(data, _utils.worker._IterableDatasetStopIteration):
1393
+ if self._persistent_workers:
1394
+ self._workers_status[data.worker_id] = False
1395
+ else:
1396
+ self._mark_worker_as_unavailable(data.worker_id)
1397
+ self._try_put_index()
1398
+ continue
1399
+
1400
+ if idx != self._rcvd_idx:
1401
+ # store out-of-order samples
1402
+ self._task_info[idx] += (data,)
1403
+ else:
1404
+ del self._task_info[idx]
1405
+ return self._process_data(data)
1406
+
1407
+ def _try_put_index(self):
1408
+ assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
1409
+
1410
+ try:
1411
+ index = self._next_index()
1412
+ except StopIteration:
1413
+ return
1414
+ for _ in range(self._num_workers): # find the next active worker, if any
1415
+ worker_queue_idx = next(self._worker_queue_idx_cycle)
1416
+ if self._workers_status[worker_queue_idx]:
1417
+ break
1418
+ else:
1419
+ # not found (i.e., didn't break)
1420
+ return
1421
+
1422
+ self._index_queues[worker_queue_idx].put((self._send_idx, index))
1423
+ self._task_info[self._send_idx] = (worker_queue_idx,)
1424
+ self._tasks_outstanding += 1
1425
+ self._send_idx += 1
1426
+
1427
+ def _process_data(self, data):
1428
+ self._rcvd_idx += 1
1429
+ self._try_put_index()
1430
+ if isinstance(data, ExceptionWrapper):
1431
+ data.reraise()
1432
+ return data
1433
+
1434
+ def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
1435
+ # Mark a worker as having finished its work e.g., due to
1436
+ # exhausting an `IterableDataset`. This should be used only when this
1437
+ # `_MultiProcessingDataLoaderIter` is going to continue running.
1438
+
1439
+ assert self._workers_status[worker_id] or (self._persistent_workers and shutdown)
1440
+
1441
+ # Signal termination to that specific worker.
1442
+ q = self._index_queues[worker_id]
1443
+ # Indicate that no more data will be put on this queue by the current
1444
+ # process.
1445
+ q.put(None)
1446
+
1447
+ # Note that we don't actually join the worker here, nor do we remove the
1448
+ # worker's pid from C side struct because (1) joining may be slow, and
1449
+ # (2) since we don't join, the worker may still raise error, and we
1450
+ # prefer capturing those, rather than ignoring them, even though they
1451
+ # are raised after the worker has finished its job.
1452
+ # Joinning is deferred to `_shutdown_workers`, which it is called when
1453
+ # all workers finish their jobs (e.g., `IterableDataset` replicas) or
1454
+ # when this iterator is garbage collected.
1455
+
1456
+ self._workers_status[worker_id] = False
1457
+
1458
+ assert self._workers_done_event.is_set() == shutdown
1459
+
1460
+ def _shutdown_workers(self):
1461
+ # Called when shutting down this `_MultiProcessingDataLoaderIter`.
1462
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
1463
+ # the logic of this function.
1464
+ if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None:
1465
+ # See (2) of the note. If Python is shutting down, do no-op.
1466
+ return
1467
+ # Normal exit when last reference is gone / iterator is depleted.
1468
+ # See (1) and the second half of the note.
1469
+ if not self._shutdown:
1470
+ self._shutdown = True
1471
+ try:
1472
+ # Normal exit when last reference is gone / iterator is depleted.
1473
+ # See (1) and the second half of the note.
1474
+
1475
+ # Exit `pin_memory_thread` first because exiting workers may leave
1476
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
1477
+ # reads from.
1478
+ if hasattr(self, "_pin_memory_thread"):
1479
+ # Use hasattr in case error happens before we set the attribute.
1480
+ self._pin_memory_thread_done_event.set()
1481
+ # Send something to pin_memory_thread in case it is waiting
1482
+ # so that it can wake up and check `pin_memory_thread_done_event`
1483
+ self._worker_result_queue.put((None, None))
1484
+ self._pin_memory_thread.join()
1485
+ self._worker_result_queue.cancel_join_thread()
1486
+ self._worker_result_queue.close()
1487
+
1488
+ # Exit workers now.
1489
+ self._workers_done_event.set()
1490
+ for worker_id in range(len(self._workers)):
1491
+ # Get number of workers from `len(self._workers)` instead of
1492
+ # `self._num_workers` in case we error before starting all
1493
+ # workers.
1494
+ # If we are using workers_status with persistent_workers
1495
+ # we have to shut it down because the worker is paused
1496
+ if self._persistent_workers or self._workers_status[worker_id]:
1497
+ self._mark_worker_as_unavailable(worker_id, shutdown=True)
1498
+ for w in self._workers:
1499
+ # We should be able to join here, but in case anything went
1500
+ # wrong, we set a timeout and if the workers fail to join,
1501
+ # they are killed in the `finally` block.
1502
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1503
+ for q in self._index_queues:
1504
+ q.cancel_join_thread()
1505
+ q.close()
1506
+ finally:
1507
+ # Even though all this function does is putting into queues that
1508
+ # we have called `cancel_join_thread` on, weird things can
1509
+ # happen when a worker is killed by a signal, e.g., hanging in
1510
+ # `Event.set()`. So we need to guard this with SIGCHLD handler,
1511
+ # and remove pids from the C side data structure only at the
1512
+ # end.
1513
+ #
1514
+ # FIXME: Unfortunately, for Windows, we are missing a worker
1515
+ # error detection mechanism here in this function, as it
1516
+ # doesn't provide a SIGCHLD handler.
1517
+ if self._worker_pids_set:
1518
+ _utils.signal_handling._remove_worker_pids(id(self))
1519
+ self._worker_pids_set = False
1520
+ for w in self._workers:
1521
+ if w.is_alive():
1522
+ # Existing mechanisms try to make the workers exit
1523
+ # peacefully, but in case that we unfortunately reach
1524
+ # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
1525
+ # we kill the worker.
1526
+ w.terminate()
1527
+
1528
+ # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
1529
+ @staticmethod
1530
+ def _clean_up_worker(w):
1531
+ try:
1532
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1533
+ finally:
1534
+ if w.is_alive():
1535
+ w.terminate()
1536
+
1537
+ def __del__(self):
1538
+ self._shutdown_workers()
efficientvit/apps/data_provider/random_resolution/_data_worker.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""""This file is based on torch/utils/data/_utils/worker.py
2
+
3
+ Contains definitions of the methods used by the _BaseDataLoaderIter workers.
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+
8
+ import os
9
+ import queue
10
+ import random
11
+ from dataclasses import dataclass
12
+ from typing import TYPE_CHECKING, Optional, Union
13
+
14
+ import torch
15
+ from torch._utils import ExceptionWrapper
16
+ from torch.utils.data._utils import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling
17
+
18
+ if TYPE_CHECKING:
19
+ from torch.utils.data import Dataset
20
+
21
+ from .controller import RRSController
22
+
23
+ if IS_WINDOWS:
24
+ import ctypes
25
+ from ctypes.wintypes import BOOL, DWORD, HANDLE
26
+
27
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
28
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
29
+ # of the manager and ask if the process status has changed.
30
+ class ManagerWatchdog:
31
+ def __init__(self):
32
+ self.manager_pid = os.getppid()
33
+
34
+ # mypy cannot detect this code is windows only
35
+ self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
36
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
37
+ self.kernel32.OpenProcess.restype = HANDLE
38
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
39
+ self.kernel32.WaitForSingleObject.restype = DWORD
40
+
41
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
42
+ SYNCHRONIZE = 0x00100000
43
+ self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
44
+
45
+ if not self.manager_handle:
46
+ raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
47
+
48
+ self.manager_dead = False
49
+
50
+ def is_alive(self):
51
+ if not self.manager_dead:
52
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
53
+ self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
54
+ return not self.manager_dead
55
+
56
+ else:
57
+
58
+ class ManagerWatchdog: # type: ignore[no-redef]
59
+ def __init__(self):
60
+ self.manager_pid = os.getppid()
61
+ self.manager_dead = False
62
+
63
+ def is_alive(self):
64
+ if not self.manager_dead:
65
+ self.manager_dead = os.getppid() != self.manager_pid
66
+ return not self.manager_dead
67
+
68
+
69
+ _worker_info = None
70
+
71
+
72
+ class WorkerInfo:
73
+ id: int
74
+ num_workers: int
75
+ seed: int
76
+ dataset: "Dataset"
77
+ __initialized = False
78
+
79
+ def __init__(self, **kwargs):
80
+ for k, v in kwargs.items():
81
+ setattr(self, k, v)
82
+ self.__keys = tuple(kwargs.keys())
83
+ self.__initialized = True
84
+
85
+ def __setattr__(self, key, val):
86
+ if self.__initialized:
87
+ raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
88
+ return super().__setattr__(key, val)
89
+
90
+ def __repr__(self):
91
+ items = []
92
+ for k in self.__keys:
93
+ items.append("{}={}".format(k, getattr(self, k)))
94
+ return "{}({})".format(self.__class__.__name__, ", ".join(items))
95
+
96
+
97
+ def get_worker_info() -> Optional[WorkerInfo]:
98
+ r"""Returns the information about the current
99
+ :class:`~torch.utils.data.DataLoader` iterator worker process.
100
+
101
+ When called in a worker, this returns an object guaranteed to have the
102
+ following attributes:
103
+
104
+ * :attr:`id`: the current worker id.
105
+ * :attr:`num_workers`: the total number of workers.
106
+ * :attr:`seed`: the random seed set for the current worker. This value is
107
+ determined by main process RNG and the worker id. See
108
+ :class:`~torch.utils.data.DataLoader`'s documentation for more details.
109
+ * :attr:`dataset`: the copy of the dataset object in **this** process. Note
110
+ that this will be a different object in a different process than the one
111
+ in the main process.
112
+
113
+ When called in the main process, this returns ``None``.
114
+
115
+ .. note::
116
+ When used in a :attr:`worker_init_fn` passed over to
117
+ :class:`~torch.utils.data.DataLoader`, this method can be useful to
118
+ set up each worker process differently, for instance, using ``worker_id``
119
+ to configure the ``dataset`` object to only read a specific fraction of a
120
+ sharded dataset, or use ``seed`` to seed other libraries used in dataset
121
+ code.
122
+ """
123
+ return _worker_info
124
+
125
+
126
+ r"""Dummy class used to signal the end of an IterableDataset"""
127
+
128
+
129
+ @dataclass(frozen=True)
130
+ class _IterableDatasetStopIteration:
131
+ worker_id: int
132
+
133
+
134
+ r"""Dummy class used to resume the fetching when worker reuse is enabled"""
135
+
136
+
137
+ @dataclass(frozen=True)
138
+ class _ResumeIteration:
139
+ seed: Optional[int] = None
140
+
141
+
142
+ # The function `_generate_state` is adapted from `numpy.random.SeedSequence`
143
+ # from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
144
+ # It's MIT licensed, here is the copyright:
145
+
146
+ # Copyright (c) 2015 Melissa E. O'Neill
147
+ # Copyright (c) 2019 NumPy Developers
148
+ #
149
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
150
+ # of this software and associated documentation files (the "Software"), to deal
151
+ # in the Software without restriction, including without limitation the rights
152
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
153
+ # copies of the Software, and to permit persons to whom the Software is
154
+ # furnished to do so, subject to the following conditions:
155
+ #
156
+ # The above copyright notice and this permission notice shall be included in
157
+ # all copies or substantial portions of the Software.
158
+ #
159
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
160
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
161
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
162
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
163
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
164
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
165
+ # SOFTWARE.
166
+
167
+
168
+ # This function generates an array of int32 as the seed for
169
+ # `numpy.random`, in order to prevent state collision due to same
170
+ # seed and algorithm for `numpy.random` and `random` modules.
171
+ # TODO: Implement `SeedSequence` like object for `torch.random`
172
+ def _generate_state(base_seed, worker_id):
173
+ INIT_A = 0x43B0D7E5
174
+ MULT_A = 0x931E8875
175
+ INIT_B = 0x8B51F9DD
176
+ MULT_B = 0x58F38DED
177
+ MIX_MULT_L = 0xCA01F9DD
178
+ MIX_MULT_R = 0x4973F715
179
+ XSHIFT = 4 * 8 // 2
180
+ MASK32 = 0xFFFFFFFF
181
+
182
+ entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
183
+ pool = [0] * 4
184
+
185
+ hash_const_A = INIT_A
186
+
187
+ def hash(value):
188
+ nonlocal hash_const_A
189
+ value = (value ^ hash_const_A) & MASK32
190
+ hash_const_A = (hash_const_A * MULT_A) & MASK32
191
+ value = (value * hash_const_A) & MASK32
192
+ value = (value ^ (value >> XSHIFT)) & MASK32
193
+ return value
194
+
195
+ def mix(x, y):
196
+ result_x = (MIX_MULT_L * x) & MASK32
197
+ result_y = (MIX_MULT_R * y) & MASK32
198
+ result = (result_x - result_y) & MASK32
199
+ result = (result ^ (result >> XSHIFT)) & MASK32
200
+ return result
201
+
202
+ # Add in the entropy to the pool.
203
+ for i in range(len(pool)):
204
+ pool[i] = hash(entropy[i])
205
+
206
+ # Mix all bits together so late bits can affect earlier bits.
207
+ for i_src in range(len(pool)):
208
+ for i_dst in range(len(pool)):
209
+ if i_src != i_dst:
210
+ pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
211
+
212
+ hash_const_B = INIT_B
213
+ state = []
214
+ for i_dst in range(4):
215
+ data_val = pool[i_dst]
216
+ data_val = (data_val ^ hash_const_B) & MASK32
217
+ hash_const_B = (hash_const_B * MULT_B) & MASK32
218
+ data_val = (data_val * hash_const_B) & MASK32
219
+ data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
220
+ state.append(data_val)
221
+ return state
222
+
223
+
224
+ def _worker_loop(
225
+ dataset_kind,
226
+ dataset,
227
+ index_queue,
228
+ data_queue,
229
+ done_event,
230
+ auto_collation,
231
+ collate_fn,
232
+ drop_last,
233
+ base_seed,
234
+ init_fn,
235
+ worker_id,
236
+ num_workers,
237
+ persistent_workers,
238
+ shared_seed,
239
+ ):
240
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
241
+ # logic of this function.
242
+
243
+ try:
244
+ # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
245
+ # module's handlers are executed after Python returns from C low-level
246
+ # handlers, likely when the same fatal signal had already happened
247
+ # again.
248
+ # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
249
+ signal_handling._set_worker_signal_handlers()
250
+
251
+ torch.set_num_threads(1)
252
+ seed = base_seed + worker_id
253
+ random.seed(seed)
254
+ torch.manual_seed(seed)
255
+ if HAS_NUMPY:
256
+ np_seed = _generate_state(base_seed, worker_id)
257
+ import numpy as np
258
+
259
+ np.random.seed(np_seed)
260
+
261
+ from torch.utils.data import IterDataPipe
262
+ from torch.utils.data.graph_settings import apply_random_seed
263
+
264
+ shared_rng = torch.Generator()
265
+ if isinstance(dataset, IterDataPipe):
266
+ assert shared_seed is not None
267
+ shared_rng.manual_seed(shared_seed)
268
+ dataset = apply_random_seed(dataset, shared_rng)
269
+
270
+ global _worker_info
271
+ _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset)
272
+
273
+ from torch.utils.data import _DatasetKind
274
+
275
+ init_exception = None
276
+
277
+ try:
278
+ if init_fn is not None:
279
+ init_fn(worker_id)
280
+
281
+ fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
282
+ except Exception:
283
+ init_exception = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id))
284
+
285
+ # When using Iterable mode, some worker can exit earlier than others due
286
+ # to the IterableDataset behaving differently for different workers.
287
+ # When such things happen, an `_IterableDatasetStopIteration` object is
288
+ # sent over to the main process with the ID of this worker, so that the
289
+ # main process won't send more tasks to this worker, and will send
290
+ # `None` to this worker to properly exit it.
291
+ #
292
+ # Note that we cannot set `done_event` from a worker as it is shared
293
+ # among all processes. Instead, we set the `iteration_end` flag to
294
+ # signify that the iterator is exhausted. When either `done_event` or
295
+ # `iteration_end` is set, we skip all processing step and just wait for
296
+ # `None`.
297
+ iteration_end = False
298
+
299
+ watchdog = ManagerWatchdog()
300
+
301
+ while watchdog.is_alive():
302
+ try:
303
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
304
+ except queue.Empty:
305
+ continue
306
+ if isinstance(r, _ResumeIteration):
307
+ # Acknowledge the main process
308
+ data_queue.put((r, None))
309
+ iteration_end = False
310
+
311
+ if isinstance(dataset, IterDataPipe):
312
+ assert r.seed is not None
313
+ shared_rng.manual_seed(r.seed)
314
+ dataset = apply_random_seed(dataset, shared_rng)
315
+
316
+ # Recreate the fetcher for worker-reuse policy
317
+ fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
318
+ continue
319
+ elif r is None:
320
+ # Received the final signal
321
+ assert done_event.is_set() or iteration_end
322
+ break
323
+ elif done_event.is_set() or iteration_end:
324
+ # `done_event` is set. But I haven't received the final signal
325
+ # (None) yet. I will keep continuing until get it, and skip the
326
+ # processing steps.
327
+ continue
328
+ idx, index = r
329
+ """ Added """
330
+ RRSController.sample_resolution(batch_id=idx)
331
+ """ Added """
332
+ data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
333
+ if init_exception is not None:
334
+ data = init_exception
335
+ init_exception = None
336
+ else:
337
+ try:
338
+ data = fetcher.fetch(index)
339
+ except Exception as e:
340
+ if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
341
+ data = _IterableDatasetStopIteration(worker_id)
342
+ # Set `iteration_end`
343
+ # (1) to save future `next(...)` calls, and
344
+ # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
345
+ iteration_end = True
346
+ else:
347
+ # It is important that we don't store exc_info in a variable.
348
+ # `ExceptionWrapper` does the correct thing.
349
+ # See NOTE [ Python Traceback Reference Cycle Problem ]
350
+ data = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id))
351
+ data_queue.put((idx, data))
352
+ del data, idx, index, r # save memory
353
+ except KeyboardInterrupt:
354
+ # Main process will raise KeyboardInterrupt anyways.
355
+ pass
356
+ if done_event.is_set():
357
+ data_queue.cancel_join_thread()
358
+ data_queue.close()
efficientvit/apps/data_provider/random_resolution/controller.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ import torchvision.transforms.functional as F
10
+
11
+ from efficientvit.models.utils import torch_random_choices
12
+
13
+ __all__ = [
14
+ "RRSController",
15
+ "get_interpolate",
16
+ "MyRandomResizedCrop",
17
+ ]
18
+
19
+
20
+ class RRSController:
21
+ ACTIVE_SIZE = (224, 224)
22
+ IMAGE_SIZE_LIST = [(224, 224)]
23
+
24
+ CHOICE_LIST = None
25
+
26
+ @staticmethod
27
+ def get_candidates() -> list[tuple[int, int]]:
28
+ return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
29
+
30
+ @staticmethod
31
+ def sample_resolution(batch_id: int) -> None:
32
+ RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
33
+
34
+ @staticmethod
35
+ def set_epoch(epoch: int, batch_per_epoch: int) -> None:
36
+ g = torch.Generator()
37
+ g.manual_seed(epoch)
38
+ RRSController.CHOICE_LIST = torch_random_choices(
39
+ RRSController.get_candidates(),
40
+ g,
41
+ batch_per_epoch,
42
+ )
43
+
44
+
45
+ def get_interpolate(name: str) -> F.InterpolationMode:
46
+ mapping = {
47
+ "nearest": F.InterpolationMode.NEAREST,
48
+ "bilinear": F.InterpolationMode.BILINEAR,
49
+ "bicubic": F.InterpolationMode.BICUBIC,
50
+ "box": F.InterpolationMode.BOX,
51
+ "hamming": F.InterpolationMode.HAMMING,
52
+ "lanczos": F.InterpolationMode.LANCZOS,
53
+ }
54
+ if name in mapping:
55
+ return mapping[name]
56
+ elif name == "random":
57
+ return torch_random_choices(
58
+ [
59
+ F.InterpolationMode.NEAREST,
60
+ F.InterpolationMode.BILINEAR,
61
+ F.InterpolationMode.BICUBIC,
62
+ F.InterpolationMode.BOX,
63
+ F.InterpolationMode.HAMMING,
64
+ F.InterpolationMode.LANCZOS,
65
+ ],
66
+ )
67
+ else:
68
+ raise NotImplementedError
69
+
70
+
71
+ class MyRandomResizedCrop(transforms.RandomResizedCrop):
72
+ def __init__(
73
+ self,
74
+ scale=(0.08, 1.0),
75
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
76
+ interpolation: str = "random",
77
+ ):
78
+ super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
79
+ self.interpolation = interpolation
80
+
81
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
82
+ i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
83
+ target_size = RRSController.ACTIVE_SIZE
84
+ return F.resized_crop(img, i, j, h, w, list(target_size), get_interpolate(self.interpolation))
85
+
86
+ def __repr__(self) -> str:
87
+ format_string = self.__class__.__name__
88
+ format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
89
+ format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
90
+ format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
91
+ format_string += f"\tinterpolation={self.interpolation})"
92
+ return format_string
efficientvit/apps/setup.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+ import time
7
+ from copy import deepcopy
8
+
9
+ import torch.backends.cudnn
10
+ import torch.distributed
11
+ import torch.nn as nn
12
+
13
+ from efficientvit.apps.data_provider import DataProvider
14
+ from efficientvit.apps.trainer.run_config import RunConfig
15
+ from efficientvit.apps.utils import (
16
+ dist_init,
17
+ dump_config,
18
+ get_dist_local_rank,
19
+ get_dist_rank,
20
+ get_dist_size,
21
+ init_modules,
22
+ is_master,
23
+ load_config,
24
+ partial_update_config,
25
+ zero_last_gamma,
26
+ )
27
+ from efficientvit.models.utils import build_kwargs_from_config, load_state_dict_from_file
28
+
29
+ __all__ = [
30
+ "save_exp_config",
31
+ "setup_dist_env",
32
+ "setup_seed",
33
+ "setup_exp_config",
34
+ "setup_data_provider",
35
+ "setup_run_config",
36
+ "init_model",
37
+ ]
38
+
39
+
40
+ def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
41
+ if not is_master():
42
+ return
43
+ dump_config(exp_config, os.path.join(path, name))
44
+
45
+
46
+ def setup_dist_env(gpu: str or None = None) -> None:
47
+ if gpu is not None:
48
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpu
49
+ if not torch.distributed.is_initialized():
50
+ dist_init()
51
+ torch.backends.cudnn.benchmark = True
52
+ torch.cuda.set_device(get_dist_local_rank())
53
+
54
+
55
+ def setup_seed(manual_seed: int, resume: bool) -> None:
56
+ if resume:
57
+ manual_seed = int(time.time())
58
+ manual_seed = get_dist_rank() + manual_seed
59
+ torch.manual_seed(manual_seed)
60
+ torch.cuda.manual_seed_all(manual_seed)
61
+
62
+
63
+ def setup_exp_config(config_path: str, recursive=True, opt_args: dict or None = None) -> dict:
64
+ # load config
65
+ if not os.path.isfile(config_path):
66
+ raise ValueError(config_path)
67
+
68
+ fpaths = [config_path]
69
+ if recursive:
70
+ extension = os.path.splitext(config_path)[1]
71
+ while os.path.dirname(config_path) != config_path:
72
+ config_path = os.path.dirname(config_path)
73
+ fpath = os.path.join(config_path, "default" + extension)
74
+ if os.path.isfile(fpath):
75
+ fpaths.append(fpath)
76
+ fpaths = fpaths[::-1]
77
+
78
+ default_config = load_config(fpaths[0])
79
+ exp_config = deepcopy(default_config)
80
+ for fpath in fpaths[1:]:
81
+ partial_update_config(exp_config, load_config(fpath))
82
+ # update config via args
83
+ if opt_args is not None:
84
+ partial_update_config(exp_config, opt_args)
85
+
86
+ return exp_config
87
+
88
+
89
+ def setup_data_provider(
90
+ exp_config: dict, data_provider_classes: list[type[DataProvider]], is_distributed: bool = True
91
+ ) -> DataProvider:
92
+ dp_config = exp_config["data_provider"]
93
+ dp_config["num_replicas"] = get_dist_size() if is_distributed else None
94
+ dp_config["rank"] = get_dist_rank() if is_distributed else None
95
+ dp_config["test_batch_size"] = dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
96
+ dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config["base_batch_size"]
97
+
98
+ data_provider_lookup = {provider.name: provider for provider in data_provider_classes}
99
+ data_provider_class = data_provider_lookup[dp_config["dataset"]]
100
+
101
+ data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
102
+ data_provider = data_provider_class(**data_provider_kwargs)
103
+ return data_provider
104
+
105
+
106
+ def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
107
+ exp_config["run_config"]["init_lr"] = exp_config["run_config"]["base_lr"] * get_dist_size()
108
+
109
+ run_config = run_config_cls(**exp_config["run_config"])
110
+
111
+ return run_config
112
+
113
+
114
+ def init_model(
115
+ network: nn.Module,
116
+ init_from: str or None = None,
117
+ backbone_init_from: str or None = None,
118
+ rand_init="trunc_normal",
119
+ last_gamma=None,
120
+ ) -> None:
121
+ # initialization
122
+ init_modules(network, init_type=rand_init)
123
+ # zero gamma of last bn in each block
124
+ if last_gamma is not None:
125
+ zero_last_gamma(network, last_gamma)
126
+
127
+ # load weight
128
+ if init_from is not None and os.path.isfile(init_from):
129
+ network.load_state_dict(load_state_dict_from_file(init_from))
130
+ print(f"Loaded init from {init_from}")
131
+ elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
132
+ network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
133
+ print(f"Loaded backbone init from {backbone_init_from}")
134
+ else:
135
+ print(f"Random init ({rand_init}) with last gamma {last_gamma}")
efficientvit/apps/trainer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .base import *
6
+ from .run_config import *
efficientvit/apps/trainer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (231 Bytes). View file
 
efficientvit/apps/trainer/__pycache__/base.cpython-310.pyc ADDED
Binary file (8.47 kB). View file
 
efficientvit/apps/trainer/__pycache__/run_config.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
efficientvit/apps/trainer/base.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from efficientvit.apps.data_provider import DataProvider, parse_image_size
11
+ from efficientvit.apps.trainer.run_config import RunConfig
12
+ from efficientvit.apps.utils import EMA, dist_barrier, get_dist_local_rank, is_master
13
+ from efficientvit.models.nn.norm import reset_bn
14
+ from efficientvit.models.utils import is_parallel, load_state_dict_from_file
15
+
16
+ __all__ = ["Trainer"]
17
+
18
+
19
+ class Trainer:
20
+ def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
21
+ self.path = os.path.realpath(os.path.expanduser(path))
22
+ self.model = model.cuda()
23
+ self.data_provider = data_provider
24
+
25
+ self.ema = None
26
+
27
+ self.checkpoint_path = os.path.join(self.path, "checkpoint")
28
+ self.logs_path = os.path.join(self.path, "logs")
29
+ for path in [self.path, self.checkpoint_path, self.logs_path]:
30
+ os.makedirs(path, exist_ok=True)
31
+
32
+ self.best_val = 0.0
33
+ self.start_epoch = 0
34
+
35
+ @property
36
+ def network(self) -> nn.Module:
37
+ return self.model.module if is_parallel(self.model) else self.model
38
+
39
+ @property
40
+ def eval_network(self) -> nn.Module:
41
+ if self.ema is None:
42
+ model = self.model
43
+ else:
44
+ model = self.ema.shadows
45
+ model = model.module if is_parallel(model) else model
46
+ return model
47
+
48
+ def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
49
+ if is_master():
50
+ fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
51
+ fout.write(log_str + "\n")
52
+ fout.flush()
53
+ fout.close()
54
+ if print_log:
55
+ print(log_str)
56
+
57
+ def save_model(
58
+ self,
59
+ checkpoint=None,
60
+ only_state_dict=True,
61
+ epoch=0,
62
+ model_name=None,
63
+ ) -> None:
64
+ if is_master():
65
+ if checkpoint is None:
66
+ if only_state_dict:
67
+ checkpoint = {"state_dict": self.network.state_dict()}
68
+ else:
69
+ checkpoint = {
70
+ "state_dict": self.network.state_dict(),
71
+ "epoch": epoch,
72
+ "best_val": self.best_val,
73
+ "optimizer": self.optimizer.state_dict(),
74
+ "lr_scheduler": self.lr_scheduler.state_dict(),
75
+ "ema": self.ema.state_dict() if self.ema is not None else None,
76
+ "scaler": self.scaler.state_dict() if self.enable_amp else None,
77
+ }
78
+
79
+ model_name = model_name or "checkpoint.pt"
80
+
81
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
82
+ model_path = os.path.join(self.checkpoint_path, model_name)
83
+ with open(latest_fname, "w") as _fout:
84
+ _fout.write(model_path + "\n")
85
+ torch.save(checkpoint, model_path)
86
+
87
+ def load_model(self, model_fname=None) -> None:
88
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
89
+ if model_fname is None and os.path.exists(latest_fname):
90
+ with open(latest_fname, "r") as fin:
91
+ model_fname = fin.readline()
92
+ if len(model_fname) > 0 and model_fname[-1] == "\n":
93
+ model_fname = model_fname[:-1]
94
+ try:
95
+ if model_fname is None:
96
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
97
+ elif not os.path.exists(model_fname):
98
+ model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
99
+ if not os.path.exists(model_fname):
100
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
101
+ print(f"=> loading checkpoint {model_fname}")
102
+ checkpoint = load_state_dict_from_file(model_fname, False)
103
+ except Exception:
104
+ self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
105
+ return
106
+
107
+ # load checkpoint
108
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
109
+ log = []
110
+ if "epoch" in checkpoint:
111
+ self.start_epoch = checkpoint["epoch"] + 1
112
+ self.run_config.update_global_step(self.start_epoch)
113
+ log.append(f"epoch={self.start_epoch - 1}")
114
+ if "best_val" in checkpoint:
115
+ self.best_val = checkpoint["best_val"]
116
+ log.append(f"best_val={self.best_val:.2f}")
117
+ if "optimizer" in checkpoint:
118
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
119
+ log.append("optimizer")
120
+ if "lr_scheduler" in checkpoint:
121
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
122
+ log.append("lr_scheduler")
123
+ if "ema" in checkpoint and self.ema is not None:
124
+ self.ema.load_state_dict(checkpoint["ema"])
125
+ log.append("ema")
126
+ if "scaler" in checkpoint and self.enable_amp:
127
+ self.scaler.load_state_dict(checkpoint["scaler"])
128
+ log.append("scaler")
129
+ self.write_log("Loaded: " + ", ".join(log))
130
+
131
+ """ validate """
132
+
133
+ def reset_bn(
134
+ self,
135
+ network: nn.Module or None = None,
136
+ subset_size: int = 16000,
137
+ subset_batch_size: int = 100,
138
+ data_loader=None,
139
+ progress_bar=False,
140
+ ) -> None:
141
+ network = network or self.network
142
+ if data_loader is None:
143
+ data_loader = []
144
+ for data in self.data_provider.build_sub_train_loader(subset_size, subset_batch_size):
145
+ if isinstance(data, list):
146
+ data_loader.append(data[0])
147
+ elif isinstance(data, dict):
148
+ data_loader.append(data["data"])
149
+ elif isinstance(data, torch.Tensor):
150
+ data_loader.append(data)
151
+ else:
152
+ raise NotImplementedError
153
+
154
+ network.eval()
155
+ reset_bn(
156
+ network,
157
+ data_loader,
158
+ sync=True,
159
+ progress_bar=progress_bar,
160
+ )
161
+
162
+ def _validate(self, model, data_loader, epoch) -> dict[str, any]:
163
+ raise NotImplementedError
164
+
165
+ def validate(self, model=None, data_loader=None, is_test=True, epoch=0) -> dict[str, any]:
166
+ model = model or self.eval_network
167
+ if data_loader is None:
168
+ if is_test:
169
+ data_loader = self.data_provider.test
170
+ else:
171
+ data_loader = self.data_provider.valid
172
+
173
+ model.eval()
174
+ return self._validate(model, data_loader, epoch)
175
+
176
+ def multires_validate(
177
+ self,
178
+ model=None,
179
+ data_loader=None,
180
+ is_test=True,
181
+ epoch=0,
182
+ eval_image_size=None,
183
+ ) -> dict[str, dict[str, any]]:
184
+ eval_image_size = eval_image_size or self.run_config.eval_image_size
185
+ eval_image_size = eval_image_size or self.data_provider.image_size
186
+ model = model or self.eval_network
187
+
188
+ if not isinstance(eval_image_size, list):
189
+ eval_image_size = [eval_image_size]
190
+
191
+ output_dict = {}
192
+ for r in eval_image_size:
193
+ self.data_provider.assign_active_image_size(parse_image_size(r))
194
+ if self.run_config.reset_bn:
195
+ self.reset_bn(
196
+ network=model,
197
+ subset_size=self.run_config.reset_bn_size,
198
+ subset_batch_size=self.run_config.reset_bn_batch_size,
199
+ progress_bar=True,
200
+ )
201
+ output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
202
+ return output_dict
203
+
204
+ """ training """
205
+
206
+ def prep_for_training(self, run_config: RunConfig, ema_decay: float or None = None, amp="fp32") -> None:
207
+ self.run_config = run_config
208
+ self.model = nn.parallel.DistributedDataParallel(
209
+ self.model.cuda(),
210
+ device_ids=[get_dist_local_rank()],
211
+ static_graph=True,
212
+ )
213
+
214
+ self.run_config.global_step = 0
215
+ self.run_config.batch_per_epoch = len(self.data_provider.train)
216
+ assert self.run_config.batch_per_epoch > 0, "Training set is empty"
217
+
218
+ # build optimizer
219
+ self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
220
+
221
+ if ema_decay is not None:
222
+ self.ema = EMA(self.network, ema_decay)
223
+
224
+ # amp
225
+ self.amp = amp
226
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp)
227
+
228
+ @property
229
+ def enable_amp(self) -> bool:
230
+ return self.amp != "fp32"
231
+
232
+ @property
233
+ def amp_dtype(self) -> torch.dtype:
234
+ if self.amp == "fp16":
235
+ return torch.float16
236
+ elif self.amp == "bf16":
237
+ return torch.bfloat16
238
+ else:
239
+ return torch.float32
240
+
241
+ def sync_model(self):
242
+ print("Sync model")
243
+ self.save_model(model_name="sync.pt")
244
+ dist_barrier()
245
+ checkpoint = torch.load(os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu")
246
+ dist_barrier()
247
+ if is_master():
248
+ os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
249
+ dist_barrier()
250
+
251
+ # load checkpoint
252
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
253
+ if "optimizer" in checkpoint:
254
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
255
+ if "lr_scheduler" in checkpoint:
256
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
257
+ if "ema" in checkpoint and self.ema is not None:
258
+ self.ema.load_state_dict(checkpoint["ema"])
259
+ if "scaler" in checkpoint and self.enable_amp:
260
+ self.scaler.load_state_dict(checkpoint["scaler"])
261
+
262
+ def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
263
+ for key in feed_dict:
264
+ if isinstance(feed_dict[key], torch.Tensor):
265
+ feed_dict[key] = feed_dict[key].cuda()
266
+ return feed_dict
267
+
268
+ def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
269
+ raise NotImplementedError
270
+
271
+ def after_step(self) -> None:
272
+ self.scaler.unscale_(self.optimizer)
273
+ # gradient clip
274
+ if self.run_config.grad_clip is not None:
275
+ torch.nn.utils.clip_grad_value_(self.model.parameters(), self.run_config.grad_clip)
276
+ # update
277
+ self.scaler.step(self.optimizer)
278
+ self.scaler.update()
279
+
280
+ self.lr_scheduler.step()
281
+ self.run_config.step()
282
+ # update ema
283
+ if self.ema is not None:
284
+ self.ema.step(self.network, self.run_config.global_step)
285
+
286
+ def _train_one_epoch(self, epoch: int) -> dict[str, any]:
287
+ raise NotImplementedError
288
+
289
+ def train_one_epoch(self, epoch: int) -> dict[str, any]:
290
+ self.model.train()
291
+
292
+ self.data_provider.set_epoch(epoch)
293
+
294
+ train_info_dict = self._train_one_epoch(epoch)
295
+
296
+ return train_info_dict
297
+
298
+ def train(self) -> None:
299
+ raise NotImplementedError
efficientvit/apps/trainer/run_config.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import json
6
+
7
+ import numpy as np
8
+ import torch.nn as nn
9
+
10
+ from efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
11
+
12
+ __all__ = ["Scheduler", "RunConfig"]
13
+
14
+
15
+ class Scheduler:
16
+ PROGRESS = 0
17
+
18
+
19
+ class RunConfig:
20
+ n_epochs: int
21
+ init_lr: float
22
+ warmup_epochs: int
23
+ warmup_lr: float
24
+ lr_schedule_name: str
25
+ lr_schedule_param: dict
26
+ optimizer_name: str
27
+ optimizer_params: dict
28
+ weight_decay: float
29
+ no_wd_keys: list
30
+ grad_clip: float # allow none to turn off grad clipping
31
+ reset_bn: bool
32
+ reset_bn_size: int
33
+ reset_bn_batch_size: int
34
+ eval_image_size: list # allow none to use image_size in data_provider
35
+
36
+ @property
37
+ def none_allowed(self):
38
+ return ["grad_clip", "eval_image_size"]
39
+
40
+ def __init__(self, **kwargs): # arguments must be passed as kwargs
41
+ for k, val in kwargs.items():
42
+ setattr(self, k, val)
43
+
44
+ # check that all relevant configs are there
45
+ annotations = {}
46
+ for clas in type(self).mro():
47
+ if hasattr(clas, "__annotations__"):
48
+ annotations.update(clas.__annotations__)
49
+ for k, k_type in annotations.items():
50
+ assert hasattr(self, k), f"Key {k} with type {k_type} required for initialization."
51
+ attr = getattr(self, k)
52
+ if k in self.none_allowed:
53
+ k_type = (k_type, type(None))
54
+ assert isinstance(attr, k_type), f"Key {k} must be type {k_type}, provided={attr}."
55
+
56
+ self.global_step = 0
57
+ self.batch_per_epoch = 1
58
+
59
+ def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
60
+ r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
61
+ param_dict = {}
62
+ for name, param in network.named_parameters():
63
+ if param.requires_grad:
64
+ opt_config = [self.weight_decay, self.init_lr]
65
+ if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
66
+ if np.any([key in name for key in self.no_wd_keys]):
67
+ opt_config[0] = 0
68
+ opt_key = json.dumps(opt_config)
69
+ param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
70
+
71
+ net_params = []
72
+ for opt_key, param_list in param_dict.items():
73
+ wd, lr = json.loads(opt_key)
74
+ net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
75
+
76
+ optimizer = build_optimizer(net_params, self.optimizer_name, self.optimizer_params, self.init_lr)
77
+ # build lr scheduler
78
+ if self.lr_schedule_name == "cosine":
79
+ decay_steps = []
80
+ for epoch in self.lr_schedule_param.get("step", []):
81
+ decay_steps.append(epoch * self.batch_per_epoch)
82
+ decay_steps.append(self.n_epochs * self.batch_per_epoch)
83
+ decay_steps.sort()
84
+ lr_scheduler = CosineLRwithWarmup(
85
+ optimizer,
86
+ self.warmup_epochs * self.batch_per_epoch,
87
+ self.warmup_lr,
88
+ decay_steps,
89
+ )
90
+ else:
91
+ raise NotImplementedError
92
+ return optimizer, lr_scheduler
93
+
94
+ def update_global_step(self, epoch, batch_id=0) -> None:
95
+ self.global_step = epoch * self.batch_per_epoch + batch_id
96
+ Scheduler.PROGRESS = self.progress
97
+
98
+ @property
99
+ def progress(self) -> float:
100
+ warmup_steps = self.warmup_epochs * self.batch_per_epoch
101
+ steps = max(0, self.global_step - warmup_steps)
102
+ return steps / (self.n_epochs * self.batch_per_epoch)
103
+
104
+ def step(self) -> None:
105
+ self.global_step += 1
106
+ Scheduler.PROGRESS = self.progress
107
+
108
+ def get_remaining_epoch(self, epoch, post=True) -> int:
109
+ return self.n_epochs + self.warmup_epochs - epoch - int(post)
110
+
111
+ def epoch_format(self, epoch: int) -> str:
112
+ epoch_format = f"%.{len(str(self.n_epochs))}d"
113
+ epoch_format = f"[{epoch_format}/{epoch_format}]"
114
+ epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
115
+ return epoch_format
efficientvit/apps/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .dist import *
6
+ from .ema import *
7
+ from .export import *
8
+ from .init import *
9
+ from .lr import *
10
+ from .metric import *
11
+ from .misc import *
12
+ from .opt import *
efficientvit/apps/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (319 Bytes). View file
 
efficientvit/apps/utils/__pycache__/dist.cpython-310.pyc ADDED
Binary file (2.13 kB). View file
 
efficientvit/apps/utils/__pycache__/ema.cpython-310.pyc ADDED
Binary file (1.92 kB). View file
 
efficientvit/apps/utils/__pycache__/export.cpython-310.pyc ADDED
Binary file (1.35 kB). View file
 
efficientvit/apps/utils/__pycache__/init.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
efficientvit/apps/utils/__pycache__/lr.cpython-310.pyc ADDED
Binary file (1.75 kB). View file
 
efficientvit/apps/utils/__pycache__/metric.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
efficientvit/apps/utils/__pycache__/misc.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
efficientvit/apps/utils/__pycache__/opt.cpython-310.pyc ADDED
Binary file (899 Bytes). View file
 
efficientvit/apps/utils/dist.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.distributed
9
+
10
+ from efficientvit.models.utils.list import list_mean, list_sum
11
+
12
+ __all__ = [
13
+ "dist_init",
14
+ "get_dist_rank",
15
+ "get_dist_size",
16
+ "is_master",
17
+ "dist_barrier",
18
+ "get_dist_local_rank",
19
+ "sync_tensor",
20
+ ]
21
+
22
+
23
+ def dist_init() -> None:
24
+ try:
25
+ torch.distributed.init_process_group(backend="nccl")
26
+ assert torch.distributed.is_initialized()
27
+ except Exception:
28
+ # use torchpack
29
+ from torchpack import distributed as dist
30
+
31
+ dist.init()
32
+ os.environ["RANK"] = f"{dist.rank()}"
33
+ os.environ["WORLD_SIZE"] = f"{dist.size()}"
34
+ os.environ["LOCAL_RANK"] = f"{dist.local_rank()}"
35
+
36
+
37
+ def get_dist_rank() -> int:
38
+ return int(os.environ["RANK"])
39
+
40
+
41
+ def get_dist_size() -> int:
42
+ return int(os.environ["WORLD_SIZE"])
43
+
44
+
45
+ def is_master() -> bool:
46
+ return get_dist_rank() == 0
47
+
48
+
49
+ def dist_barrier() -> None:
50
+ torch.distributed.barrier()
51
+
52
+
53
+ def get_dist_local_rank() -> int:
54
+ return int(os.environ["LOCAL_RANK"])
55
+
56
+
57
+ def sync_tensor(tensor: torch.Tensor or float, reduce="mean") -> torch.Tensor or list[torch.Tensor]:
58
+ if not isinstance(tensor, torch.Tensor):
59
+ tensor = torch.Tensor(1).fill_(tensor).cuda()
60
+ tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
61
+ torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
62
+ if reduce == "mean":
63
+ return list_mean(tensor_list)
64
+ elif reduce == "sum":
65
+ return list_sum(tensor_list)
66
+ elif reduce == "cat":
67
+ return torch.cat(tensor_list, dim=0)
68
+ elif reduce == "root":
69
+ return tensor_list[0]
70
+ else:
71
+ return tensor_list
efficientvit/apps/utils/ema.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from efficientvit.models.utils import is_parallel
12
+
13
+ __all__ = ["EMA"]
14
+
15
+
16
+ def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None:
17
+ for k, v in ema.state_dict().items():
18
+ if v.dtype.is_floating_point:
19
+ v -= (1.0 - decay) * (v - new_state_dict[k].detach())
20
+
21
+
22
+ class EMA:
23
+ def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
24
+ self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval()
25
+ self.decay = decay
26
+ self.warmup_steps = warmup_steps
27
+
28
+ for p in self.shadows.parameters():
29
+ p.requires_grad = False
30
+
31
+ def step(self, model: nn.Module, global_step: int) -> None:
32
+ with torch.no_grad():
33
+ msd = (model.module if is_parallel(model) else model).state_dict()
34
+ update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps)))
35
+
36
+ def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
37
+ return {self.decay: self.shadows.state_dict()}
38
+
39
+ def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
40
+ for decay in state_dict:
41
+ if decay == self.decay:
42
+ self.shadows.load_state_dict(state_dict[decay])
efficientvit/apps/utils/export.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import io
6
+ import os
7
+
8
+ import onnx
9
+ import torch
10
+ import torch.nn as nn
11
+ from onnxsim import simplify as simplify_func
12
+
13
+ __all__ = ["export_onnx"]
14
+
15
+
16
+ def export_onnx(model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11) -> None:
17
+ """Export a model to a platform-specific onnx format.
18
+
19
+ Args:
20
+ model: a torch.nn.Module object.
21
+ export_path: export location.
22
+ sample_inputs: Any.
23
+ simplify: a flag to turn on onnx-simplifier
24
+ opset: int
25
+ """
26
+ model.eval()
27
+
28
+ buffer = io.BytesIO()
29
+ with torch.no_grad():
30
+ torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
31
+ buffer.seek(0, 0)
32
+ if simplify:
33
+ onnx_model = onnx.load_model(buffer)
34
+ onnx_model, success = simplify_func(onnx_model)
35
+ assert success
36
+ new_buffer = io.BytesIO()
37
+ onnx.save(onnx_model, new_buffer)
38
+ buffer = new_buffer
39
+ buffer.seek(0, 0)
40
+
41
+ if buffer.getbuffer().nbytes > 0:
42
+ save_dir = os.path.dirname(export_path)
43
+ os.makedirs(save_dir, exist_ok=True)
44
+ with open(export_path, "wb") as f:
45
+ f.write(buffer.read())
efficientvit/apps/utils/init.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.modules.batchnorm import _BatchNorm
8
+
9
+ __all__ = ["init_modules", "zero_last_gamma"]
10
+
11
+
12
+ def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None:
13
+ _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
14
+
15
+ if isinstance(model, list):
16
+ for sub_module in model:
17
+ init_modules(sub_module, init_type)
18
+ else:
19
+ init_params = init_type.split("@")
20
+ init_params = float(init_params[1]) if len(init_params) > 1 else None
21
+
22
+ if init_type.startswith("trunc_normal"):
23
+ init_func = lambda param: nn.init.trunc_normal_(
24
+ param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"])
25
+ )
26
+ else:
27
+ raise NotImplementedError
28
+
29
+ for m in model.modules():
30
+ if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
31
+ init_func(m.weight)
32
+ if m.bias is not None:
33
+ m.bias.data.zero_()
34
+ elif isinstance(m, nn.Embedding):
35
+ init_func(m.weight)
36
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
37
+ m.weight.data.fill_(1)
38
+ m.bias.data.zero_()
39
+ else:
40
+ weight = getattr(m, "weight", None)
41
+ bias = getattr(m, "bias", None)
42
+ if isinstance(weight, torch.nn.Parameter):
43
+ init_func(weight)
44
+ if isinstance(bias, torch.nn.Parameter):
45
+ bias.data.zero_()
46
+
47
+
48
+ def zero_last_gamma(model: nn.Module, init_val=0) -> None:
49
+ import efficientvit.models.nn.ops as ops
50
+
51
+ for m in model.modules():
52
+ if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer):
53
+ if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
54
+ parent_module = m.main.point_conv
55
+ elif isinstance(m.main, ops.ResBlock):
56
+ parent_module = m.main.conv2
57
+ elif isinstance(m.main, ops.ConvLayer):
58
+ parent_module = m.main
59
+ elif isinstance(m.main, (ops.LiteMLA)):
60
+ parent_module = m.main.proj
61
+ else:
62
+ parent_module = None
63
+ if parent_module is not None:
64
+ norm = getattr(parent_module, "norm", None)
65
+ if norm is not None:
66
+ nn.init.constant_(norm.weight, init_val)
efficientvit/apps/utils/lr.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+ from efficientvit.models.utils.list import val2list
10
+
11
+ __all__ = ["CosineLRwithWarmup"]
12
+
13
+
14
+ class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
15
+ def __init__(
16
+ self,
17
+ optimizer: torch.optim.Optimizer,
18
+ warmup_steps: int,
19
+ warmup_lr: float,
20
+ decay_steps: int or list[int],
21
+ last_epoch: int = -1,
22
+ ) -> None:
23
+ self.warmup_steps = warmup_steps
24
+ self.warmup_lr = warmup_lr
25
+ self.decay_steps = val2list(decay_steps)
26
+ super().__init__(optimizer, last_epoch)
27
+
28
+ def get_lr(self) -> list[float]:
29
+ if self.last_epoch < self.warmup_steps:
30
+ return [
31
+ (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
32
+ for base_lr in self.base_lrs
33
+ ]
34
+ else:
35
+ current_steps = self.last_epoch - self.warmup_steps
36
+ decay_steps = [0] + self.decay_steps
37
+ idx = len(decay_steps) - 2
38
+ for i, decay_step in enumerate(decay_steps[:-1]):
39
+ if decay_step <= current_steps < decay_steps[i + 1]:
40
+ idx = i
41
+ break
42
+ current_steps -= decay_steps[idx]
43
+ decay_step = decay_steps[idx + 1] - decay_steps[idx]
44
+ return [0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) for base_lr in self.base_lrs]
efficientvit/apps/utils/metric.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+
7
+ from efficientvit.apps.utils.dist import sync_tensor
8
+
9
+ __all__ = ["AverageMeter"]
10
+
11
+
12
+ class AverageMeter:
13
+ """Computes and stores the average and current value."""
14
+
15
+ def __init__(self, is_distributed=True):
16
+ self.is_distributed = is_distributed
17
+ self.sum = 0
18
+ self.count = 0
19
+
20
+ def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float:
21
+ return sync_tensor(val, reduce="sum") if self.is_distributed else val
22
+
23
+ def update(self, val: torch.Tensor or int or float, delta_n=1):
24
+ self.count += self._sync(delta_n)
25
+ self.sum += self._sync(val * delta_n)
26
+
27
+ def get_count(self) -> torch.Tensor or int or float:
28
+ return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count
29
+
30
+ @property
31
+ def avg(self):
32
+ avg = -1 if self.count == 0 else self.sum / self.count
33
+ return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
efficientvit/apps/utils/misc.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import yaml
8
+
9
+ __all__ = [
10
+ "parse_with_yaml",
11
+ "parse_unknown_args",
12
+ "partial_update_config",
13
+ "resolve_and_load_config",
14
+ "load_config",
15
+ "dump_config",
16
+ ]
17
+
18
+
19
+ def parse_with_yaml(config_str: str) -> str or dict:
20
+ try:
21
+ # add space manually for dict
22
+ if "{" in config_str and "}" in config_str and ":" in config_str:
23
+ out_str = config_str.replace(":", ": ")
24
+ else:
25
+ out_str = config_str
26
+ return yaml.safe_load(out_str)
27
+ except ValueError:
28
+ # return raw string if parsing fails
29
+ return config_str
30
+
31
+
32
+ def parse_unknown_args(unknown: list) -> dict:
33
+ """Parse unknown args."""
34
+ index = 0
35
+ parsed_dict = {}
36
+ while index < len(unknown):
37
+ key, val = unknown[index], unknown[index + 1]
38
+ index += 2
39
+ if not key.startswith("--"):
40
+ continue
41
+ key = key[2:]
42
+
43
+ # try parsing with either dot notation or full yaml notation
44
+ # Note that the vanilla case "--key value" will be parsed the same
45
+ if "." in key:
46
+ # key == a.b.c, val == val --> parsed_dict[a][b][c] = val
47
+ keys = key.split(".")
48
+ dict_to_update = parsed_dict
49
+ for key in keys[:-1]:
50
+ if not (key in dict_to_update and isinstance(dict_to_update[key], dict)):
51
+ dict_to_update[key] = {}
52
+ dict_to_update = dict_to_update[key]
53
+ dict_to_update[keys[-1]] = parse_with_yaml(val) # so we can parse lists, bools, etc...
54
+ else:
55
+ parsed_dict[key] = parse_with_yaml(val)
56
+ return parsed_dict
57
+
58
+
59
+ def partial_update_config(config: dict, partial_config: dict) -> dict:
60
+ for key in partial_config:
61
+ if key in config and isinstance(partial_config[key], dict) and isinstance(config[key], dict):
62
+ partial_update_config(config[key], partial_config[key])
63
+ else:
64
+ config[key] = partial_config[key]
65
+ return config
66
+
67
+
68
+ def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
69
+ path = os.path.realpath(os.path.expanduser(path))
70
+ if os.path.isdir(path):
71
+ config_path = os.path.join(path, config_name)
72
+ else:
73
+ config_path = path
74
+ if os.path.isfile(config_path):
75
+ pass
76
+ else:
77
+ raise Exception(f"Cannot find a valid config at {path}")
78
+ config = load_config(config_path)
79
+ return config
80
+
81
+
82
+ class SafeLoaderWithTuple(yaml.SafeLoader):
83
+ """A yaml safe loader with python tuple loading capabilities."""
84
+
85
+ def construct_python_tuple(self, node):
86
+ return tuple(self.construct_sequence(node))
87
+
88
+
89
+ SafeLoaderWithTuple.add_constructor("tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple)
90
+
91
+
92
+ def load_config(filename: str) -> dict:
93
+ """Load a yaml file."""
94
+ filename = os.path.realpath(os.path.expanduser(filename))
95
+ return yaml.load(open(filename), Loader=SafeLoaderWithTuple)
96
+
97
+
98
+ def dump_config(config: dict, filename: str) -> None:
99
+ """Dump a config file"""
100
+ filename = os.path.realpath(os.path.expanduser(filename))
101
+ yaml.dump(config, open(filename, "w"), sort_keys=False)
efficientvit/apps/utils/opt.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+
7
+ __all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
8
+
9
+ # register optimizer here
10
+ # name: optimizer, kwargs with default values
11
+ REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = {
12
+ "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
13
+ "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
14
+ "adamw": (torch.optim.AdamW, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
15
+ }
16
+
17
+
18
+ def build_optimizer(
19
+ net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float
20
+ ) -> torch.optim.Optimizer:
21
+ optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
22
+ optimizer_params = optimizer_params or {}
23
+
24
+ for key in default_params:
25
+ if key in optimizer_params:
26
+ default_params[key] = optimizer_params[key]
27
+ optimizer = optimizer_class(net_params, init_lr, **default_params)
28
+ return optimizer
efficientvit/cls_model_zoo.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from efficientvit.models.efficientvit import (
6
+ EfficientViTCls,
7
+ efficientvit_cls_b0,
8
+ efficientvit_cls_b1,
9
+ efficientvit_cls_b2,
10
+ efficientvit_cls_b3,
11
+ efficientvit_cls_l1,
12
+ efficientvit_cls_l2,
13
+ efficientvit_cls_l3,
14
+ )
15
+ from efficientvit.models.nn.norm import set_norm_eps
16
+ from efficientvit.models.utils import load_state_dict_from_file
17
+
18
+ __all__ = ["create_cls_model"]
19
+
20
+
21
+ REGISTERED_CLS_MODEL: dict[str, str] = {
22
+ "b0-r224": "assets/checkpoints/cls/b0-r224.pt",
23
+ ###############################################
24
+ "b1-r224": "assets/checkpoints/cls/b1-r224.pt",
25
+ "b1-r256": "assets/checkpoints/cls/b1-r256.pt",
26
+ "b1-r288": "assets/checkpoints/cls/b1-r288.pt",
27
+ ###############################################
28
+ "b2-r224": "assets/checkpoints/cls/b2-r224.pt",
29
+ "b2-r256": "assets/checkpoints/cls/b2-r256.pt",
30
+ "b2-r288": "assets/checkpoints/cls/b2-r288.pt",
31
+ ###############################################
32
+ "b3-r224": "assets/checkpoints/cls/b3-r224.pt",
33
+ "b3-r256": "assets/checkpoints/cls/b3-r256.pt",
34
+ "b3-r288": "assets/checkpoints/cls/b3-r288.pt",
35
+ ###############################################
36
+ "l1-r224": "assets/checkpoints/cls/l1-r224.pt",
37
+ ###############################################
38
+ "l2-r224": "assets/checkpoints/cls/l2-r224.pt",
39
+ "l2-r256": "assets/checkpoints/cls/l2-r256.pt",
40
+ "l2-r288": "assets/checkpoints/cls/l2-r288.pt",
41
+ "l2-r320": "assets/checkpoints/cls/l2-r320.pt",
42
+ "l2-r384": "assets/checkpoints/cls/l2-r384.pt",
43
+ ###############################################
44
+ "l3-r224": "assets/checkpoints/cls/l3-r224.pt",
45
+ "l3-r256": "assets/checkpoints/cls/l3-r256.pt",
46
+ "l3-r288": "assets/checkpoints/cls/l3-r288.pt",
47
+ "l3-r320": "assets/checkpoints/cls/l3-r320.pt",
48
+ "l3-r384": "assets/checkpoints/cls/l3-r384.pt",
49
+ }
50
+
51
+
52
+ def create_cls_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTCls:
53
+ model_dict = {
54
+ "b0": efficientvit_cls_b0,
55
+ "b1": efficientvit_cls_b1,
56
+ "b2": efficientvit_cls_b2,
57
+ "b3": efficientvit_cls_b3,
58
+ #########################
59
+ "l1": efficientvit_cls_l1,
60
+ "l2": efficientvit_cls_l2,
61
+ "l3": efficientvit_cls_l3,
62
+ }
63
+
64
+ model_id = name.split("-")[0]
65
+ if model_id not in model_dict:
66
+ raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}")
67
+ else:
68
+ model = model_dict[model_id](**kwargs)
69
+ if model_id in ["l1", "l2", "l3"]:
70
+ set_norm_eps(model, 1e-7)
71
+
72
+ if pretrained:
73
+ weight_url = weight_url or REGISTERED_CLS_MODEL.get(name, None)
74
+ if weight_url is None:
75
+ raise ValueError(f"Do not find the pretrained weight of {name}.")
76
+ else:
77
+ weight = load_state_dict_from_file(weight_url)
78
+ model.load_state_dict(weight)
79
+ return model
efficientvit/clscore/__init__.py ADDED
File without changes
efficientvit/clscore/data_provider/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .imagenet import *
efficientvit/clscore/data_provider/imagenet.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+ import math
7
+ import os
8
+
9
+ import torchvision.transforms as transforms
10
+ from torchvision.datasets import ImageFolder
11
+
12
+ from efficientvit.apps.data_provider import DataProvider
13
+ from efficientvit.apps.data_provider.augment import RandAug
14
+ from efficientvit.apps.data_provider.random_resolution import MyRandomResizedCrop, get_interpolate
15
+ from efficientvit.apps.utils import partial_update_config
16
+ from efficientvit.models.utils import val2list
17
+
18
+ __all__ = ["ImageNetDataProvider"]
19
+
20
+
21
+ class ImageNetDataProvider(DataProvider):
22
+ name = "imagenet"
23
+
24
+ data_dir = "/dataset/imagenet"
25
+ n_classes = 1000
26
+ _DEFAULT_RRC_CONFIG = {
27
+ "train_interpolate": "random",
28
+ "test_interpolate": "bicubic",
29
+ "test_crop_ratio": 1.0,
30
+ }
31
+
32
+ def __init__(
33
+ self,
34
+ data_dir: str or None = None,
35
+ rrc_config: dict or None = None,
36
+ data_aug: dict or list[dict] or None = None,
37
+ ###########################################
38
+ train_batch_size=128,
39
+ test_batch_size=128,
40
+ valid_size: int or float or None = None,
41
+ n_worker=8,
42
+ image_size: int or list[int] = 224,
43
+ num_replicas: int or None = None,
44
+ rank: int or None = None,
45
+ train_ratio: float or None = None,
46
+ drop_last: bool = False,
47
+ ):
48
+ self.data_dir = data_dir or self.data_dir
49
+ self.rrc_config = partial_update_config(
50
+ copy.deepcopy(self._DEFAULT_RRC_CONFIG),
51
+ rrc_config or {},
52
+ )
53
+ self.data_aug = data_aug
54
+
55
+ super().__init__(
56
+ train_batch_size,
57
+ test_batch_size,
58
+ valid_size,
59
+ n_worker,
60
+ image_size,
61
+ num_replicas,
62
+ rank,
63
+ train_ratio,
64
+ drop_last,
65
+ )
66
+
67
+ def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
68
+ image_size = (image_size or self.active_image_size)[0]
69
+ crop_size = int(math.ceil(image_size / self.rrc_config["test_crop_ratio"]))
70
+ return transforms.Compose(
71
+ [
72
+ transforms.Resize(
73
+ crop_size,
74
+ interpolation=get_interpolate(self.rrc_config["test_interpolate"]),
75
+ ),
76
+ transforms.CenterCrop(image_size),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize(**self.mean_std),
79
+ ]
80
+ )
81
+
82
+ def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
83
+ image_size = image_size or self.image_size
84
+
85
+ # random_resize_crop -> random_horizontal_flip
86
+ train_transforms = [
87
+ MyRandomResizedCrop(interpolation=self.rrc_config["train_interpolate"]),
88
+ transforms.RandomHorizontalFlip(),
89
+ ]
90
+
91
+ # data augmentation
92
+ post_aug = []
93
+ if self.data_aug is not None:
94
+ for aug_op in val2list(self.data_aug):
95
+ if aug_op["name"] == "randaug":
96
+ data_aug = RandAug(aug_op, mean=self.mean_std["mean"])
97
+ elif aug_op["name"] == "erase":
98
+ from timm.data.random_erasing import RandomErasing
99
+
100
+ random_erase = RandomErasing(aug_op["p"], device="cpu")
101
+ post_aug.append(random_erase)
102
+ data_aug = None
103
+ else:
104
+ raise NotImplementedError
105
+ if data_aug is not None:
106
+ train_transforms.append(data_aug)
107
+ train_transforms = [
108
+ *train_transforms,
109
+ transforms.ToTensor(),
110
+ transforms.Normalize(**self.mean_std),
111
+ *post_aug,
112
+ ]
113
+ return transforms.Compose(train_transforms)
114
+
115
+ def build_datasets(self) -> tuple[any, any, any]:
116
+ train_transform = self.build_train_transform()
117
+ valid_transform = self.build_valid_transform()
118
+
119
+ train_dataset = ImageFolder(os.path.join(self.data_dir, "train"), train_transform)
120
+ test_dataset = ImageFolder(os.path.join(self.data_dir, "val"), valid_transform)
121
+
122
+ train_dataset, val_dataset = self.sample_val_dataset(train_dataset, valid_transform)
123
+ return train_dataset, val_dataset, test_dataset