andreysher commited on
Commit
d2d52b7
·
1 Parent(s): f7f2696

Initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ deeplabv3_mobilenet_v3_large filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [submodule "vision"]
2
+ path = vision
3
+ url = https://github.com/pytorch/vision
4
+ shallow = true
common.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torchvision
5
+ from fvcore.nn import FlopCountAnalysis
6
+ from torch import nn
7
+ from transforms import Compose
8
+
9
+ sys.path.append("vision/references/segmentation")
10
+ from coco_utils import ConvertCocoPolysToMask
11
+ from coco_utils import FilterAndRemapCocoCategories
12
+ from coco_utils import _coco_remove_images_without_annotations
13
+ from utils import ConfusionMatrix
14
+
15
+
16
+ class NanSafeConfusionMatrix(ConfusionMatrix):
17
+ """Confusion matrix with replacement nans to zeros."""
18
+
19
+ def __init__(self, num_classes):
20
+ super().__init__(num_classes=num_classes)
21
+
22
+ def compute(self):
23
+ """Compute metrics based on confusion matrix."""
24
+ confusion_matrix = self.mat.float()
25
+ acc_global = torch.nan_to_num(torch.diag(confusion_matrix).sum() / confusion_matrix.sum())
26
+ acc = torch.nan_to_num(torch.diag(confusion_matrix) / confusion_matrix.sum(1))
27
+ intersection_over_unions = torch.nan_to_num(
28
+ torch.diag(confusion_matrix)
29
+ / (confusion_matrix.sum(1) + confusion_matrix.sum(0) - torch.diag(confusion_matrix))
30
+ )
31
+ return acc_global, acc, intersection_over_unions
32
+
33
+
34
+ def flops_calculation_function(model: nn.Module, input_sample: torch.Tensor) -> float:
35
+ """Calculate number of flops in millions."""
36
+ counter = FlopCountAnalysis(
37
+ model=model.eval(),
38
+ inputs=input_sample,
39
+ )
40
+ counter.unsupported_ops_warnings(False)
41
+ counter.uncalled_modules_warnings(False)
42
+
43
+ flops = counter.total() / input_sample.shape[0]
44
+
45
+ return flops / 1e6
46
+
47
+
48
+ def get_coco(root, image_set, transforms, use_v2=False, use_orig=False):
49
+ """Get COCO dataset with VOC or COCO classes."""
50
+ paths = {
51
+ "train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
52
+ "val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
53
+ # "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
54
+ }
55
+ if use_orig:
56
+ classes_list = list(range(81))
57
+ else:
58
+ classes_list = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
59
+
60
+ img_folder, ann_file = paths[image_set]
61
+ img_folder = os.path.join(root, img_folder)
62
+ ann_file = os.path.join(root, ann_file)
63
+
64
+ # The 2 "Compose" below achieve the same thing: converting coco detection
65
+ # samples into segmentation-compatible samples. They just do it with
66
+ # slightly different implementations. We could refactor and unify, but
67
+ # keeping them separate helps keeping the v2 version clean
68
+ if use_v2:
69
+ import v2_extras # pylint: disable=import-outside-toplevel
70
+ from torchvision.datasets import wrap_dataset_for_transforms_v2 # pylint: disable=import-outside-toplevel
71
+
72
+ transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
73
+ dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
74
+ dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
75
+ else:
76
+ transforms = Compose(
77
+ [FilterAndRemapCocoCategories(classes_list, remap=True), ConvertCocoPolysToMask(), transforms]
78
+ )
79
+ dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
80
+
81
+ if image_set == "train":
82
+ dataset = _coco_remove_images_without_annotations(dataset, classes_list)
83
+
84
+ return dataset
deeplabv3_mobilenet_v3_large/deeplabv3_mobilenet_v3_large.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:366cfdd55f38a53aefe374c7f529cd05af2e4ba2c90848c202976376ff5e8c09
3
+ size 88767468
deeplabv3_mobilenet_v3_large/deeplabv3_mobilenet_v3_large_x2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb61548f8b66ead5a95b55ff41fa7db201fbf8340fab916f91fdac151f61d30e
3
+ size 48772992
deeplabv3_mobilenet_v3_large/deeplabv3_mobilenet_v3_large_x4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:173a1084b1ac46643d5fb1a0f8c91a73b4ba790c25d9e2130e7b050cd23c9b22
3
+ size 27865280
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
test.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from functools import partial
3
+
4
+ from typing import Callable
5
+ from typing import Dict
6
+ from typing import Tuple
7
+ from typing import Union
8
+ from argparse import Namespace
9
+
10
+ sys.path.append("vision/references/segmentation")
11
+
12
+ import presets
13
+ import torch
14
+ import torch.utils.data
15
+ import torchvision
16
+ import utils
17
+ from torch import nn
18
+ from common import flops_calculation_function
19
+ from common import NanSafeConfusionMatrix as ConfusionMatrix
20
+ from common import get_coco
21
+
22
+
23
+ def get_dataset(args: Namespace, is_train: bool, transform: Callable = None) -> Tuple[torch.utils.data.Dataset, int]:
24
+ def sbd(*args, **kwargs):
25
+ kwargs.pop("use_v2")
26
+ return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
27
+
28
+ def voc(*args, **kwargs):
29
+ kwargs.pop("use_v2")
30
+ return torchvision.datasets.VOCSegmentation(*args, **kwargs)
31
+
32
+ paths = {
33
+ "voc": (args.data_path, voc, 21),
34
+ "voc_aug": (args.data_path, sbd, 21),
35
+ "coco": (args.data_path, get_coco, 21),
36
+ "coco_orig": (args.data_path, partial(get_coco, use_orig=True), 81)
37
+ }
38
+ p, ds_fn, num_classes = paths["coco_orig"]
39
+
40
+ if transform is None:
41
+ transform = get_transform(is_train, args)
42
+ image_set = "train" if is_train else "val"
43
+ ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=args.use_v2)
44
+ return ds, num_classes
45
+
46
+
47
+ def get_transform(is_train: bool, args: Namespace) -> Callable:
48
+ return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)
49
+
50
+
51
+ def criterion(inputs: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor:
52
+ losses = {}
53
+ for name, x in inputs.items():
54
+ losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
55
+
56
+ if len(losses) == 1:
57
+ return losses["out"]
58
+
59
+ return losses["out"] + 0.5 * losses["aux"]
60
+
61
+
62
+ def evaluate(
63
+ model: torch.nn.Module,
64
+ data_loader: torch.utils.data.DataLoader,
65
+ device: Union[str, torch.device],
66
+ num_classes: int,
67
+ criterion: Callable,
68
+ ) -> Tuple[ConfusionMatrix, float]:
69
+ model.eval()
70
+ confmat = ConfusionMatrix(num_classes)
71
+ metric_logger = utils.MetricLogger(delimiter=" ")
72
+ header = "Test:"
73
+ num_processed_samples = 0
74
+ with torch.inference_mode():
75
+ for batch_n, (image, target) in enumerate(metric_logger.log_every(data_loader, 100, header)):
76
+ image, target = image.to(device), target.to(device)
77
+ output = model(image)
78
+ loss = criterion(output, target)
79
+ output = output["out"]
80
+
81
+ confmat.update(target.flatten(), output.argmax(1).flatten())
82
+ # FIXME need to take into account that the datasets
83
+ # could have been padded in distributed setup
84
+ num_processed_samples += image.shape[0]
85
+
86
+ metric_logger.update(loss=loss.item())
87
+
88
+ confmat.reduce_from_all_processes()
89
+
90
+ return confmat, metric_logger.loss.global_avg
91
+
92
+
93
+ def main(args):
94
+ if args.backend.lower() != "pil" and not args.use_v2:
95
+ # TODO: Support tensor backend in V1?
96
+ raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.")
97
+ if args.use_v2:
98
+ raise ValueError("v2 is only supported for coco dataset for now.")
99
+
100
+ print(args)
101
+
102
+ device = torch.device(args.device)
103
+
104
+ if args.use_deterministic_algorithms:
105
+ torch.backends.cudnn.benchmark = False
106
+ torch.use_deterministic_algorithms(True)
107
+ else:
108
+ torch.backends.cudnn.benchmark = True
109
+
110
+ dataset_test, num_classes = get_dataset(args, is_train=False)
111
+
112
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
113
+
114
+ data_loader_test = torch.utils.data.DataLoader(
115
+ dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
116
+ )
117
+
118
+ checkpoint = torch.load(args.model_path)
119
+ model = checkpoint["model"]
120
+ model.to(device)
121
+ model_flops = flops_calculation_function(model=model, input_sample=next(iter(data_loader_test))[0].to(device))
122
+ print(f"Model Flops: {model_flops}M")
123
+
124
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
125
+ torch.backends.cudnn.benchmark = False
126
+ torch.backends.cudnn.deterministic = True
127
+ confmat, loss = evaluate(
128
+ model=model,
129
+ data_loader=data_loader_test,
130
+ device=device,
131
+ num_classes=num_classes,
132
+ criterion=criterion,
133
+ )
134
+ print(confmat)
135
+ return
136
+
137
+ def get_args_parser(add_help=True):
138
+ import argparse
139
+
140
+ parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)
141
+
142
+ parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
143
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
144
+ parser.add_argument(
145
+ "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
146
+ )
147
+ parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")
148
+
149
+ parser.add_argument(
150
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
151
+ )
152
+ parser.add_argument(
153
+ "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
154
+ )
155
+ # distributed training parameters
156
+
157
+ parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
158
+ parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
159
+ parser.add_argument("--model-path", default=None, help="Path to model checkpoint.")
160
+ return parser
161
+
162
+
163
+ if __name__ == "__main__":
164
+ args = get_args_parser().parse_args()
165
+ main(args)
vision ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 126fc22ce33e6c2426edcf9ed540810c178fe9ce