Haobo Yuan commited on
Commit
b34d1d6
1 Parent(s): 3a3cc44

add omg code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -4
  2. .gitignore +125 -0
  3. app/configs/m2_convl.py +152 -0
  4. assets/000000000139.jpg +3 -0
  5. assets/000000000285.jpg +3 -0
  6. assets/000000000632.jpg +3 -0
  7. assets/000000000724.jpg +3 -0
  8. ext/cityscapes_scripts/createPanopticImgs.py +194 -0
  9. ext/cityscapes_scripts/helpers/__init__.py +1 -0
  10. ext/cityscapes_scripts/helpers/annotation.py +441 -0
  11. ext/cityscapes_scripts/helpers/csHelpers.py +129 -0
  12. ext/cityscapes_scripts/helpers/labels.py +182 -0
  13. ext/cityscapes_scripts/helpers/labels_cityPersons.py +61 -0
  14. ext/cityscapes_scripts/helpers/version.py +9 -0
  15. ext/class_names/VIPSeg.py +261 -0
  16. ext/davis2017/__init__.py +3 -0
  17. ext/davis2017/davis.py +122 -0
  18. ext/davis2017/evaluation.py +110 -0
  19. ext/davis2017/metrics.py +197 -0
  20. ext/davis2017/results.py +52 -0
  21. ext/davis2017/utils.py +174 -0
  22. ext/meta/sam_meta.py +41 -0
  23. ext/open_clip/__init__.py +15 -0
  24. ext/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  25. ext/open_clip/coca_model.py +458 -0
  26. ext/open_clip/constants.py +2 -0
  27. ext/open_clip/factory.py +387 -0
  28. ext/open_clip/generation_utils.py +0 -0
  29. ext/open_clip/hf_configs.py +56 -0
  30. ext/open_clip/hf_model.py +193 -0
  31. ext/open_clip/loss.py +216 -0
  32. ext/open_clip/model.py +473 -0
  33. ext/open_clip/model_configs/EVA01-g-14-plus.json +18 -0
  34. ext/open_clip/model_configs/EVA01-g-14.json +18 -0
  35. ext/open_clip/model_configs/EVA02-B-16.json +18 -0
  36. ext/open_clip/model_configs/EVA02-E-14-plus.json +18 -0
  37. ext/open_clip/model_configs/EVA02-E-14.json +18 -0
  38. ext/open_clip/model_configs/EVA02-L-14-336.json +18 -0
  39. ext/open_clip/model_configs/EVA02-L-14.json +18 -0
  40. ext/open_clip/model_configs/RN101-quickgelu.json +22 -0
  41. ext/open_clip/model_configs/RN101.json +21 -0
  42. ext/open_clip/model_configs/RN50-quickgelu.json +22 -0
  43. ext/open_clip/model_configs/RN50.json +21 -0
  44. ext/open_clip/model_configs/RN50x16.json +21 -0
  45. ext/open_clip/model_configs/RN50x4.json +21 -0
  46. ext/open_clip/model_configs/RN50x64.json +21 -0
  47. ext/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
  48. ext/open_clip/model_configs/ViT-B-16-plus.json +16 -0
  49. ext/open_clip/model_configs/ViT-B-16.json +16 -0
  50. ext/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
.gitattributes CHANGED
@@ -17,10 +17,6 @@
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +29,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
 
 
20
  *.rar filter=lfs diff=lfs merge=lfs -text
21
  *.safetensors filter=lfs diff=lfs merge=lfs -text
22
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.pickle filter=lfs diff=lfs merge=lfs -text
33
+ *.pkl filter=lfs diff=lfs merge=lfs -text
34
+ *.pt filter=lfs diff=lfs merge=lfs -text
35
+ *.pth filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/en/_build/
68
+ docs/zh_cn/_build/
69
+
70
+ # PyBuilder
71
+ target/
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # pyenv
77
+ .python-version
78
+
79
+ # celery beat schedule file
80
+ celerybeat-schedule
81
+
82
+ # SageMath parsed files
83
+ *.sage.py
84
+
85
+ # Environments
86
+ .env
87
+ .venv
88
+ env/
89
+ venv/
90
+ ENV/
91
+ env.bak/
92
+ venv.bak/
93
+
94
+ # Spyder project settings
95
+ .spyderproject
96
+ .spyproject
97
+
98
+ # Rope project settings
99
+ .ropeproject
100
+
101
+ # mkdocs documentation
102
+ /site
103
+
104
+ # mypy
105
+ .mypy_cache/
106
+ data/
107
+ data
108
+ .vscode
109
+ .idea/
110
+ .DS_Store
111
+
112
+ # custom
113
+ *.pkl
114
+ *.pkl.json
115
+ *.log.json
116
+ docs/modelzoo_statistics.md
117
+ mmdet/.mim
118
+ work_dirs/
119
+
120
+ # Pytorch
121
+ *.py~
122
+ *.sh~
123
+
124
+ # remove tmp folder
125
+ tmp/
app/configs/m2_convl.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import GroupNorm, ReLU
2
+
3
+ from mmdet.models import MSDeformAttnPixelDecoder, CrossEntropyLoss, DiceLoss, FocalLoss
4
+ from mmdet.models.task_modules.assigners import HungarianAssigner, ClassificationCost, CrossEntropyLossCost, DiceCost
5
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
6
+
7
+ from seg.models.detectors import Mask2formerVideo
8
+ from seg.models.fusion_head import OMGFusionHead
9
+ from seg.models.heads import Mask2FormerVideoHead
10
+ from seg.models.backbones import OpenCLIPBackbone
11
+
12
+ num_things_classes = 80
13
+ num_stuff_classes = 53
14
+
15
+ ov_model_name = 'convnext_large_d_320'
16
+ ov_datasets_name = 'CocoPanopticOVDataset'
17
+ model = dict(
18
+ type=Mask2formerVideo,
19
+ data_preprocessor=None, # to fill
20
+ backbone=dict(
21
+ type=OpenCLIPBackbone,
22
+ model_name='convnext_large_d_320',
23
+ fix=True,
24
+ init_cfg=dict(
25
+ type='clip_pretrain',
26
+ checkpoint='laion2b_s29b_b131k_ft_soup'
27
+ )
28
+ ),
29
+ panoptic_head=dict(
30
+ init_cfg=dict(
31
+ type='Pretrained',
32
+ checkpoint='./models/omg_seg_convl.pth',
33
+ prefix='panoptic_head.'
34
+ ),
35
+ type=Mask2FormerVideoHead,
36
+ sphere_cls=True,
37
+ ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}',
38
+ logit=None,
39
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
40
+ strides=[4, 8, 16, 32],
41
+ feat_channels=256,
42
+ out_channels=256,
43
+ num_things_classes=num_things_classes,
44
+ num_stuff_classes=num_stuff_classes,
45
+ num_queries=300,
46
+ num_transformer_feat_level=3,
47
+ pixel_decoder=dict(
48
+ type=MSDeformAttnPixelDecoder,
49
+ num_outs=3,
50
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
51
+ act_cfg=dict(type=ReLU),
52
+ encoder=dict( # DeformableDetrTransformerEncoder
53
+ num_layers=6,
54
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
55
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
56
+ embed_dims=256,
57
+ num_heads=8,
58
+ num_levels=3,
59
+ num_points=4,
60
+ dropout=0.0,
61
+ batch_first=True),
62
+ ffn_cfg=dict(
63
+ embed_dims=256,
64
+ feedforward_channels=1024,
65
+ num_fcs=2,
66
+ ffn_drop=0.0,
67
+ act_cfg=dict(type=ReLU, inplace=True)))),
68
+ positional_encoding=dict(num_feats=128, normalize=True)),
69
+ enforce_decoder_input_project=False,
70
+ positional_encoding=dict(num_feats=128, normalize=True),
71
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
72
+ return_intermediate=True,
73
+ num_layers=9,
74
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
75
+ self_attn_cfg=dict( # MultiheadAttention
76
+ embed_dims=256,
77
+ num_heads=8,
78
+ dropout=0.0,
79
+ batch_first=True),
80
+ cross_attn_cfg=dict( # MultiheadAttention
81
+ embed_dims=256,
82
+ num_heads=8,
83
+ dropout=0.0,
84
+ batch_first=True),
85
+ ffn_cfg=dict(
86
+ embed_dims=256,
87
+ feedforward_channels=2048,
88
+ num_fcs=2,
89
+ ffn_drop=0.0,
90
+ act_cfg=dict(type='ReLU', inplace=True))),
91
+ init_cfg=None),
92
+ loss_cls=dict(
93
+ type=CrossEntropyLoss,
94
+ use_sigmoid=False,
95
+ loss_weight=2.0,
96
+ reduction='mean',
97
+ class_weight=None # [1.0] * num_classes + [0.1]
98
+ ),
99
+ loss_mask=dict(
100
+ type=CrossEntropyLoss,
101
+ use_sigmoid=True,
102
+ reduction='mean',
103
+ loss_weight=5.0),
104
+ loss_dice=dict(
105
+ type=DiceLoss,
106
+ use_sigmoid=True,
107
+ activate=True,
108
+ reduction='mean',
109
+ naive_dice=True,
110
+ eps=1.0,
111
+ loss_weight=5.0),
112
+ loss_iou=dict(
113
+ type=FocalLoss,
114
+ use_sigmoid=True,
115
+ loss_weight=2.0,
116
+ reduction='mean'
117
+ )
118
+ ),
119
+ panoptic_fusion_head=dict(
120
+ type=OMGFusionHead,
121
+ num_things_classes=num_things_classes,
122
+ num_stuff_classes=num_stuff_classes,
123
+ loss_panoptic=None,
124
+ init_cfg=None
125
+ ),
126
+ train_cfg=dict(
127
+ num_points=12544,
128
+ oversample_ratio=3.0,
129
+ importance_sample_ratio=0.75,
130
+ assigner=dict(
131
+ type=HungarianAssigner,
132
+ match_costs=[
133
+ dict(type=ClassificationCost, weight=2.0),
134
+ dict(
135
+ type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
136
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
137
+ ]),
138
+ sampler=dict(type=MaskPseudoSampler)),
139
+ test_cfg=dict(
140
+ panoptic_on=True,
141
+ semantic_on=False,
142
+ instance_on=True,
143
+ # max_per_image is for instance segmentation.
144
+ max_per_image=100,
145
+ iou_thr=0.8,
146
+ # In Mask2Former's panoptic postprocessing,
147
+ # it will filter mask area where score is less than 0.5 .
148
+ filter_low_score=True,
149
+ object_mask_thr=0.,
150
+ ),
151
+ init_cfg=None
152
+ )
assets/000000000139.jpg ADDED

Git LFS Details

  • SHA256: ffe0f0cec3b2e27aab1967229cdf0a0d7751dcdd5800322f0b8ac0dffb3b8a8d
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
assets/000000000285.jpg ADDED

Git LFS Details

  • SHA256: f3a2974ce3686332609124c70e3e6a2e3aca43fccf1cd1bd7c5c03820977f57d
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
assets/000000000632.jpg ADDED

Git LFS Details

  • SHA256: a4cd7f45ac1ce27eaafb254b23af7c0b18a064be08870ceaaf03b2147f2ce550
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB
assets/000000000724.jpg ADDED

Git LFS Details

  • SHA256: 5c0e559c75d3969c8e3e297b61f61063f78045c9d4802b526ba616361f3823fd
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
ext/cityscapes_scripts/createPanopticImgs.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #
3
+ # Converts the *instanceIds.png annotations of the Cityscapes dataset
4
+ # to COCO-style panoptic segmentation format (http://cocodataset.org/#format-data).
5
+ # The convertion is working for 'fine' set of the annotations.
6
+ #
7
+ # By default with this tool uses IDs specified in labels.py. You can use flag
8
+ # --use-train-id to get train ids for categories. 'ignoreInEval' categories are
9
+ # removed during the conversion.
10
+ #
11
+ # In panoptic segmentation format image_id is used to match predictions and ground truth.
12
+ # For cityscapes image_id has form <city>_123456_123456 and corresponds to the prefix
13
+ # of cityscapes image files.
14
+ #
15
+
16
+ # python imports
17
+ from __future__ import print_function, absolute_import, division, unicode_literals
18
+ import os
19
+ import glob
20
+ import sys
21
+ import argparse
22
+ import json
23
+ import numpy as np
24
+
25
+ # Image processing
26
+ from PIL import Image
27
+
28
+ # cityscapes imports
29
+ from ext.cityscapes_scripts.helpers.csHelpers import printError
30
+ from ext.cityscapes_scripts.helpers.labels import id2label, labels
31
+
32
+
33
+ import mmengine
34
+
35
+
36
+ # The main method
37
+ def convert2panoptic(cityscapesPath=None, outputFolder=None, useTrainId=False, setNames=["val", "train", "test"]):
38
+ # Where to look for Cityscapes
39
+ if cityscapesPath is None:
40
+ if 'CITYSCAPES_DATASET' in os.environ:
41
+ cityscapesPath = os.environ['CITYSCAPES_DATASET']
42
+ else:
43
+ cityscapesPath = 'data/cityscapes'
44
+ cityscapesPath = os.path.join(cityscapesPath, "gtFine")
45
+
46
+ if outputFolder is None:
47
+ outputFolder = cityscapesPath.replace('gtFine', "annotations")
48
+
49
+ mmengine.mkdir_or_exist(outputFolder)
50
+
51
+ categories = []
52
+ for label in labels:
53
+ if label.ignoreInEval:
54
+ continue
55
+ categories.append({'id': int(label.trainId) if useTrainId else int(label.id),
56
+ 'name': label.name,
57
+ 'color': label.color,
58
+ 'supercategory': label.category,
59
+ 'isthing': 1 if label.hasInstances else 0})
60
+
61
+ categories = sorted(categories, key=lambda x:x['id'])
62
+
63
+ for setName in setNames:
64
+ # how to search for all ground truth
65
+ searchFine = os.path.join(cityscapesPath, setName, "*", "*_instanceIds.png")
66
+ # search files
67
+ filesFine = glob.glob(searchFine)
68
+ filesFine.sort()
69
+
70
+ files = filesFine
71
+ # quit if we did not find anything
72
+ if not files:
73
+ printError(
74
+ "Did not find any files for {} set using matching pattern {}. Please consult the README.".format(setName, searchFine)
75
+ )
76
+ # a bit verbose
77
+ print("Converting {} annotation files for {} set.".format(len(files), setName))
78
+
79
+ trainIfSuffix = "_trainId" if useTrainId else ""
80
+ outputBaseFile = "cityscapes_panoptic_{}{}".format(setName, trainIfSuffix)
81
+ outFile = os.path.join(outputFolder, "{}.json".format(outputBaseFile))
82
+ print("Json file with the annotations in panoptic format will be saved in {}".format(outFile))
83
+ panopticFolder = os.path.join(outputFolder, outputBaseFile)
84
+ if not os.path.isdir(panopticFolder):
85
+ print("Creating folder {} for panoptic segmentation PNGs".format(panopticFolder))
86
+ os.mkdir(panopticFolder)
87
+ print("Corresponding segmentations in .png format will be saved in {}".format(panopticFolder))
88
+
89
+ images = []
90
+ annotations = []
91
+ for progress, f in enumerate(files):
92
+
93
+ originalFormat = np.array(Image.open(f))
94
+
95
+ fileName = os.path.basename(f)
96
+ location = fileName.split('_')[0]
97
+ imageId = fileName.replace("_gtFine_instanceIds.png", "")
98
+ fileName = os.path.join(location, fileName)
99
+ inputFileName = fileName.replace("_gtFine_instanceIds.png", "_leftImg8bit.png")
100
+ outputFileName = fileName.replace("_gtFine_instanceIds.png", "_panoptic.png")
101
+ # image entry, id for image is its filename without extension
102
+ images.append({"id": imageId,
103
+ "width": int(originalFormat.shape[1]),
104
+ "height": int(originalFormat.shape[0]),
105
+ "file_name": inputFileName})
106
+
107
+ pan_format = np.zeros(
108
+ (originalFormat.shape[0], originalFormat.shape[1], 3), dtype=np.uint8
109
+ )
110
+
111
+ segmentIds = np.unique(originalFormat)
112
+ segmInfo = []
113
+ for segmentId in segmentIds:
114
+ if segmentId < 1000:
115
+ semanticId = segmentId
116
+ isCrowd = 1
117
+ else:
118
+ semanticId = segmentId // 1000
119
+ isCrowd = 0
120
+ labelInfo = id2label[semanticId]
121
+ categoryId = labelInfo.trainId if useTrainId else labelInfo.id
122
+ if labelInfo.ignoreInEval:
123
+ continue
124
+ if not labelInfo.hasInstances:
125
+ isCrowd = 0
126
+
127
+ mask = originalFormat == segmentId
128
+ color = [segmentId % 256, segmentId // 256, segmentId // 256 // 256]
129
+ pan_format[mask] = color
130
+
131
+ area = np.sum(mask) # segment area computation
132
+
133
+ # bbox computation for a segment
134
+ hor = np.sum(mask, axis=0)
135
+ hor_idx = np.nonzero(hor)[0]
136
+ x = hor_idx[0]
137
+ width = hor_idx[-1] - x + 1
138
+ vert = np.sum(mask, axis=1)
139
+ vert_idx = np.nonzero(vert)[0]
140
+ y = vert_idx[0]
141
+ height = vert_idx[-1] - y + 1
142
+ bbox = [int(x), int(y), int(width), int(height)]
143
+
144
+ segmInfo.append({"id": int(segmentId),
145
+ "category_id": int(categoryId),
146
+ "area": int(area),
147
+ "bbox": bbox,
148
+ "iscrowd": isCrowd})
149
+
150
+ annotations.append({'image_id': imageId,
151
+ 'file_name': outputFileName,
152
+ "segments_info": segmInfo})
153
+
154
+ mmengine.mkdir_or_exist(os.path.dirname(os.path.join(panopticFolder, outputFileName)))
155
+ Image.fromarray(pan_format).save(os.path.join(panopticFolder, outputFileName))
156
+
157
+ print("\rProgress: {:>3.2f} %".format((progress + 1) * 100 / len(files)), end=' ')
158
+ sys.stdout.flush()
159
+
160
+ print("\nSaving the json file {}".format(outFile))
161
+ d = {'images': images,
162
+ 'annotations': annotations,
163
+ 'categories': categories}
164
+ with open(outFile, 'w') as f:
165
+ json.dump(d, f, sort_keys=True, indent=4)
166
+
167
+
168
+ def main():
169
+ parser = argparse.ArgumentParser()
170
+ parser.add_argument("--dataset-folder",
171
+ dest="cityscapesPath",
172
+ help="path to the Cityscapes dataset 'gtFine' folder",
173
+ default=None,
174
+ type=str)
175
+ parser.add_argument("--output-folder",
176
+ dest="outputFolder",
177
+ help="path to the output folder.",
178
+ default=None,
179
+ type=str)
180
+ parser.add_argument("--use-train-id", default=True,action="store_true", dest="useTrainId")
181
+ parser.add_argument("--set-names",
182
+ dest="setNames",
183
+ help="set names to which apply the function to",
184
+ nargs='+',
185
+ default=["val", "train"],
186
+ type=str)
187
+ args = parser.parse_args()
188
+
189
+ convert2panoptic(args.cityscapesPath, args.outputFolder, args.useTrainId, args.setNames)
190
+
191
+
192
+ # call the main
193
+ if __name__ == "__main__":
194
+ main()
ext/cityscapes_scripts/helpers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # empty
ext/cityscapes_scripts/helpers/annotation.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #
3
+ # Classes to store, read, and write annotations
4
+ #
5
+
6
+ from __future__ import print_function, absolute_import, division
7
+ import os
8
+ import json
9
+ import numpy as np
10
+ from collections import namedtuple
11
+
12
+ # get current date and time
13
+ import datetime
14
+ import locale
15
+
16
+ from abc import ABCMeta, abstractmethod
17
+ from .box3dImageTransform import Camera
18
+
19
+ # A point in a polygon
20
+ Point = namedtuple('Point', ['x', 'y'])
21
+
22
+
23
+ class CsObjectType():
24
+ """Type of an object"""
25
+ POLY = 1 # polygon
26
+ BBOX2D = 2 # bounding box
27
+ BBOX3D = 3 # 3d bounding box
28
+ IGNORE2D = 4 # 2d ignore region
29
+
30
+
31
+ class CsObject:
32
+ """Abstract base class for annotation objects"""
33
+ __metaclass__ = ABCMeta
34
+
35
+ def __init__(self, objType):
36
+ self.objectType = objType
37
+ # the label
38
+ self.label = ""
39
+
40
+ # If deleted or not
41
+ self.deleted = 0
42
+ # If verified or not
43
+ self.verified = 0
44
+ # The date string
45
+ self.date = ""
46
+ # The username
47
+ self.user = ""
48
+ # Draw the object
49
+ # Not read from or written to JSON
50
+ # Set to False if deleted object
51
+ # Might be set to False by the application for other reasons
52
+ self.draw = True
53
+
54
+ @abstractmethod
55
+ def __str__(self): pass
56
+
57
+ @abstractmethod
58
+ def fromJsonText(self, jsonText, objId=-1): pass
59
+
60
+ @abstractmethod
61
+ def toJsonText(self): pass
62
+
63
+ def updateDate(self):
64
+ try:
65
+ locale.setlocale(locale.LC_ALL, 'en_US.utf8')
66
+ except locale.Error:
67
+ locale.setlocale(locale.LC_ALL, 'en_US')
68
+ except locale.Error:
69
+ locale.setlocale(locale.LC_ALL, 'us_us.utf8')
70
+ except locale.Error:
71
+ locale.setlocale(locale.LC_ALL, 'us_us')
72
+ except Exception:
73
+ pass
74
+ self.date = datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S")
75
+
76
+ # Mark the object as deleted
77
+ def delete(self):
78
+ self.deleted = 1
79
+ self.draw = False
80
+
81
+
82
+ class CsPoly(CsObject):
83
+ """Class that contains the information of a single annotated object as polygon"""
84
+
85
+ # Constructor
86
+ def __init__(self):
87
+ CsObject.__init__(self, CsObjectType.POLY)
88
+ # the polygon as list of points
89
+ self.polygon = []
90
+ # the object ID
91
+ self.id = -1
92
+
93
+ def __str__(self):
94
+ polyText = ""
95
+ if self.polygon:
96
+ if len(self.polygon) <= 4:
97
+ for p in self.polygon:
98
+ polyText += '({},{}) '.format(p.x, p.y)
99
+ else:
100
+ polyText += '({},{}) ({},{}) ... ({},{}) ({},{})'.format(
101
+ self.polygon[0].x, self.polygon[0].y,
102
+ self.polygon[1].x, self.polygon[1].y,
103
+ self.polygon[-2].x, self.polygon[-2].y,
104
+ self.polygon[-1].x, self.polygon[-1].y)
105
+ else:
106
+ polyText = "none"
107
+ text = "Object: {} - {}".format(self.label, polyText)
108
+ return text
109
+
110
+ def fromJsonText(self, jsonText, objId=-1):
111
+ self.id = objId
112
+ self.label = str(jsonText['label'])
113
+ self.polygon = [Point(p[0], p[1]) for p in jsonText['polygon']]
114
+ if 'deleted' in jsonText.keys():
115
+ self.deleted = jsonText['deleted']
116
+ else:
117
+ self.deleted = 0
118
+ if 'verified' in jsonText.keys():
119
+ self.verified = jsonText['verified']
120
+ else:
121
+ self.verified = 1
122
+ if 'user' in jsonText.keys():
123
+ self.user = jsonText['user']
124
+ else:
125
+ self.user = ''
126
+ if 'date' in jsonText.keys():
127
+ self.date = jsonText['date']
128
+ else:
129
+ self.date = ''
130
+ if self.deleted == 1:
131
+ self.draw = False
132
+ else:
133
+ self.draw = True
134
+
135
+ def toJsonText(self):
136
+ objDict = {}
137
+ objDict['label'] = self.label
138
+ objDict['id'] = self.id
139
+ objDict['deleted'] = self.deleted
140
+ objDict['verified'] = self.verified
141
+ objDict['user'] = self.user
142
+ objDict['date'] = self.date
143
+ objDict['polygon'] = []
144
+ for pt in self.polygon:
145
+ objDict['polygon'].append([pt.x, pt.y])
146
+
147
+ return objDict
148
+
149
+
150
+ class CsBbox2d(CsObject):
151
+ """Class that contains the information of a single annotated object as bounding box"""
152
+
153
+ # Constructor
154
+ def __init__(self):
155
+ CsObject.__init__(self, CsObjectType.BBOX2D)
156
+ # the polygon as list of points
157
+ self.bbox_amodal_xywh = []
158
+ self.bbox_modal_xywh = []
159
+
160
+ # the ID of the corresponding object
161
+ self.instanceId = -1
162
+ # the label of the corresponding object
163
+ self.label = ""
164
+
165
+ def __str__(self):
166
+ bboxAmodalText = ""
167
+ bboxAmodalText += '[(x1: {}, y1: {}), (w: {}, h: {})]'.format(
168
+ self.bbox_amodal_xywh[0], self.bbox_amodal_xywh[1], self.bbox_amodal_xywh[2], self.bbox_amodal_xywh[3])
169
+
170
+ bboxModalText = ""
171
+ bboxModalText += '[(x1: {}, y1: {}), (w: {}, h: {})]'.format(
172
+ self.bbox_modal_xywh[0], self.bbox_modal_xywh[1], self.bbox_modal_xywh[2], self.bbox_modal_xywh[3])
173
+
174
+ text = "Object: {}\n - Amodal {}\n - Modal {}".format(
175
+ self.label, bboxAmodalText, bboxModalText)
176
+ return text
177
+
178
+ def setAmodalBox(self, bbox_amodal):
179
+ # sets the amodal box if required
180
+ self.bbox_amodal_xywh = [
181
+ bbox_amodal[0],
182
+ bbox_amodal[1],
183
+ bbox_amodal[2] - bbox_amodal[0],
184
+ bbox_amodal[3] - bbox_amodal[1]
185
+ ]
186
+
187
+ # access 2d boxes in [xmin, ymin, xmax, ymax] format
188
+ @property
189
+ def bbox_amodal(self):
190
+ """Returns the 2d box as [xmin, ymin, xmax, ymax]"""
191
+ return [
192
+ self.bbox_amodal_xywh[0],
193
+ self.bbox_amodal_xywh[1],
194
+ self.bbox_amodal_xywh[0] + self.bbox_amodal_xywh[2],
195
+ self.bbox_amodal_xywh[1] + self.bbox_amodal_xywh[3]
196
+ ]
197
+
198
+ @property
199
+ def bbox_modal(self):
200
+ """Returns the 2d box as [xmin, ymin, xmax, ymax]"""
201
+ return [
202
+ self.bbox_modal_xywh[0],
203
+ self.bbox_modal_xywh[1],
204
+ self.bbox_modal_xywh[0] + self.bbox_modal_xywh[2],
205
+ self.bbox_modal_xywh[1] + self.bbox_modal_xywh[3]
206
+ ]
207
+
208
+ def fromJsonText(self, jsonText, objId=-1):
209
+ # try to load from cityperson format
210
+ if 'bbox' in jsonText.keys() and 'bboxVis' in jsonText.keys():
211
+ self.bbox_amodal_xywh = jsonText['bbox']
212
+ self.bbox_modal_xywh = jsonText['bboxVis']
213
+ # both modal and amodal boxes are provided
214
+ elif "modal" in jsonText.keys() and "amodal" in jsonText.keys():
215
+ self.bbox_amodal_xywh = jsonText['amodal']
216
+ self.bbox_modal_xywh = jsonText['modal']
217
+ # only amodal boxes are provided
218
+ else:
219
+ self.bbox_modal_xywh = jsonText['amodal']
220
+ self.bbox_amodal_xywh = jsonText['amodal']
221
+
222
+ # load label and instanceId if available
223
+ if 'label' in jsonText.keys() and 'instanceId' in jsonText.keys():
224
+ self.label = str(jsonText['label'])
225
+ self.instanceId = jsonText['instanceId']
226
+
227
+ def toJsonText(self):
228
+ objDict = {}
229
+ objDict['label'] = self.label
230
+ objDict['instanceId'] = self.instanceId
231
+ objDict['modal'] = self.bbox_modal_xywh
232
+ objDict['amodal'] = self.bbox_amodal_xywh
233
+
234
+ return objDict
235
+
236
+
237
+ class CsBbox3d(CsObject):
238
+ """Class that contains the information of a single annotated object as 3D bounding box"""
239
+
240
+ # Constructor
241
+ def __init__(self):
242
+ CsObject.__init__(self, CsObjectType.BBOX3D)
243
+
244
+ self.bbox_2d = None
245
+
246
+ self.center = []
247
+ self.dims = []
248
+ self.rotation = []
249
+ self.instanceId = -1
250
+ self.label = ""
251
+ self.score = -1.
252
+
253
+ def __str__(self):
254
+ bbox2dText = str(self.bbox_2d)
255
+
256
+ bbox3dText = ""
257
+ bbox3dText += '\n - Center (x/y/z) [m]: {}/{}/{}'.format(
258
+ self.center[0], self.center[1], self.center[2])
259
+ bbox3dText += '\n - Dimensions (l/w/h) [m]: {}/{}/{}'.format(
260
+ self.dims[0], self.dims[1], self.dims[2])
261
+ bbox3dText += '\n - Rotation: {}/{}/{}/{}'.format(
262
+ self.rotation[0], self.rotation[1], self.rotation[2], self.rotation[3])
263
+
264
+ text = "Object: {}\n2D {}\n - 3D {}".format(
265
+ self.label, bbox2dText, bbox3dText)
266
+ return text
267
+
268
+ def fromJsonText(self, jsonText, objId=-1):
269
+ # load 2D box
270
+ self.bbox_2d = CsBbox2d()
271
+ self.bbox_2d.fromJsonText(jsonText['2d'])
272
+
273
+ self.center = jsonText['3d']['center']
274
+ self.dims = jsonText['3d']['dimensions']
275
+ self.rotation = jsonText['3d']['rotation']
276
+ self.label = jsonText['label']
277
+ self.score = jsonText['score']
278
+
279
+ if 'instanceId' in jsonText.keys():
280
+ self.instanceId = jsonText['instanceId']
281
+
282
+ def toJsonText(self):
283
+ objDict = {}
284
+ objDict['label'] = self.label
285
+ objDict['instanceId'] = self.instanceId
286
+ objDict['2d']['amodal'] = self.bbox_2d.bbox_amodal_xywh
287
+ objDict['2d']['modal'] = self.bbox_2d.bbox_modal_xywh
288
+ objDict['3d']['center'] = self.center
289
+ objDict['3d']['dimensions'] = self.dims
290
+ objDict['3d']['rotation'] = self.rotation
291
+
292
+ return objDict
293
+
294
+ @property
295
+ def depth(self):
296
+ # returns the BEV depth
297
+ return np.sqrt(self.center[0]**2 + self.center[1]**2).astype(int)
298
+
299
+
300
+ class CsIgnore2d(CsObject):
301
+ """Class that contains the information of a single annotated 2d ignore region"""
302
+
303
+ # Constructor
304
+ def __init__(self):
305
+ CsObject.__init__(self, CsObjectType.IGNORE2D)
306
+
307
+ self.bbox_xywh = []
308
+ self.label = ""
309
+ self.instanceId = -1
310
+
311
+ def __str__(self):
312
+ bbox2dText = ""
313
+ bbox2dText += 'Ignore Region: (x1: {}, y1: {}), (w: {}, h: {})'.format(
314
+ self.bbox_xywh[0], self.bbox_xywh[1], self.bbox_xywh[2], self.bbox_xywh[3])
315
+
316
+ return bbox2dText
317
+
318
+ def fromJsonText(self, jsonText, objId=-1):
319
+ self.bbox_xywh = jsonText['2d']
320
+
321
+ if 'label' in jsonText.keys():
322
+ self.label = jsonText['label']
323
+
324
+ if 'instanceId' in jsonText.keys():
325
+ self.instanceId = jsonText['instanceId']
326
+
327
+ def toJsonText(self):
328
+ objDict = {}
329
+ objDict['label'] = self.label
330
+ objDict['instanceId'] = self.instanceId
331
+ objDict['2d'] = self.bbox_xywh
332
+
333
+ return objDict
334
+
335
+ @property
336
+ def bbox(self):
337
+ """Returns the 2d box as [xmin, ymin, xmax, ymax]"""
338
+ return [
339
+ self.bbox_xywh[0],
340
+ self.bbox_xywh[1],
341
+ self.bbox_xywh[0] + self.bbox_xywh[2],
342
+ self.bbox_xywh[1] + self.bbox_xywh[3]
343
+ ]
344
+
345
+ # Extend api to be compatible to bbox2d
346
+ @property
347
+ def bbox_amodal_xywh(self):
348
+ return self.bbox_xywh
349
+
350
+ @property
351
+ def bbox_modal_xywh(self):
352
+ return self.bbox_xywh
353
+
354
+
355
+ class Annotation:
356
+ """The annotation of a whole image (doesn't support mixed annotations, i.e. combining CsPoly and CsBbox2d)"""
357
+
358
+ # Constructor
359
+ def __init__(self, objType=CsObjectType.POLY):
360
+ # the width of that image and thus of the label image
361
+ self.imgWidth = 0
362
+ # the height of that image and thus of the label image
363
+ self.imgHeight = 0
364
+ # the list of objects
365
+ self.objects = []
366
+ # the camera calibration
367
+ self.camera = None
368
+ assert objType in CsObjectType.__dict__.values()
369
+ self.objectType = objType
370
+
371
+ def toJson(self):
372
+ return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
373
+
374
+ def fromJsonText(self, jsonText):
375
+ jsonDict = json.loads(jsonText)
376
+ self.imgWidth = int(jsonDict['imgWidth'])
377
+ self.imgHeight = int(jsonDict['imgHeight'])
378
+ self.objects = []
379
+ # load objects
380
+ if self.objectType != CsObjectType.IGNORE2D:
381
+ for objId, objIn in enumerate(jsonDict['objects']):
382
+ if self.objectType == CsObjectType.POLY:
383
+ obj = CsPoly()
384
+ elif self.objectType == CsObjectType.BBOX2D:
385
+ obj = CsBbox2d()
386
+ elif self.objectType == CsObjectType.BBOX3D:
387
+ obj = CsBbox3d()
388
+ obj.fromJsonText(objIn, objId)
389
+ self.objects.append(obj)
390
+
391
+ # load ignores
392
+ if 'ignore' in jsonDict.keys():
393
+ for ignoreId, ignoreIn in enumerate(jsonDict['ignore']):
394
+ obj = CsIgnore2d()
395
+ obj.fromJsonText(ignoreIn, ignoreId)
396
+ self.objects.append(obj)
397
+
398
+ # load camera calibration
399
+ if 'sensor' in jsonDict.keys():
400
+ self.camera = Camera(fx=jsonDict['sensor']['fx'],
401
+ fy=jsonDict['sensor']['fy'],
402
+ u0=jsonDict['sensor']['u0'],
403
+ v0=jsonDict['sensor']['v0'],
404
+ sensor_T_ISO_8855=jsonDict['sensor']['sensor_T_ISO_8855'])
405
+
406
+ def toJsonText(self):
407
+ jsonDict = {}
408
+ jsonDict['imgWidth'] = self.imgWidth
409
+ jsonDict['imgHeight'] = self.imgHeight
410
+ jsonDict['objects'] = []
411
+ for obj in self.objects:
412
+ objDict = obj.toJsonText()
413
+ jsonDict['objects'].append(objDict)
414
+
415
+ return jsonDict
416
+
417
+ # Read a json formatted polygon file and return the annotation
418
+ def fromJsonFile(self, jsonFile):
419
+ if not os.path.isfile(jsonFile):
420
+ print('Given json file not found: {}'.format(jsonFile))
421
+ return
422
+ with open(jsonFile, 'r') as f:
423
+ jsonText = f.read()
424
+ self.fromJsonText(jsonText)
425
+
426
+ def toJsonFile(self, jsonFile):
427
+ with open(jsonFile, 'w') as f:
428
+ f.write(self.toJson())
429
+
430
+
431
+ # a dummy example
432
+ if __name__ == "__main__":
433
+ obj = CsPoly()
434
+ obj.label = 'car'
435
+ obj.polygon.append(Point(0, 0))
436
+ obj.polygon.append(Point(1, 0))
437
+ obj.polygon.append(Point(1, 1))
438
+ obj.polygon.append(Point(0, 1))
439
+
440
+ print(type(obj).__name__)
441
+ print(obj)
ext/cityscapes_scripts/helpers/csHelpers.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #
3
+ # Various helper methods and includes for Cityscapes
4
+ #
5
+
6
+ # Python imports
7
+ from __future__ import print_function, absolute_import, division
8
+ import os
9
+ import sys
10
+ import getopt
11
+ import glob
12
+ import math
13
+ import json
14
+ from collections import namedtuple
15
+ import logging
16
+ import traceback
17
+
18
+ # Image processing
19
+ from PIL import Image
20
+ from PIL import ImageDraw
21
+
22
+ # Numpy for datastructures
23
+ import numpy as np
24
+
25
+ # Cityscapes modules
26
+ # from .annotation import Annotation
27
+ from .labels import labels, name2label, id2label, trainId2label, category2labels
28
+
29
+
30
+ def printError(message):
31
+ """Print an error message and quit"""
32
+ print('ERROR: ' + str(message))
33
+ sys.exit(-1)
34
+
35
+
36
+ class colors:
37
+ """Class for colors"""
38
+ RED = '\033[31;1m'
39
+ GREEN = '\033[32;1m'
40
+ YELLOW = '\033[33;1m'
41
+ BLUE = '\033[34;1m'
42
+ MAGENTA = '\033[35;1m'
43
+ CYAN = '\033[36;1m'
44
+ BOLD = '\033[1m'
45
+ UNDERLINE = '\033[4m'
46
+ ENDC = '\033[0m'
47
+
48
+
49
+ def getColorEntry(val, args):
50
+ """Colored value output if colorized flag is activated."""
51
+
52
+ if not args.colorized:
53
+ return ""
54
+ if not isinstance(val, float) or math.isnan(val):
55
+ return colors.ENDC
56
+ if (val < .20):
57
+ return colors.RED
58
+ elif (val < .40):
59
+ return colors.YELLOW
60
+ elif (val < .60):
61
+ return colors.BLUE
62
+ elif (val < .80):
63
+ return colors.CYAN
64
+ else:
65
+ return colors.GREEN
66
+
67
+
68
+ # Cityscapes files have a typical filename structure
69
+ # <city>_<sequenceNb>_<frameNb>_<type>[_<type2>].<ext>
70
+ # This class contains the individual elements as members
71
+ # For the sequence and frame number, the strings are returned, including leading zeros
72
+ CsFile = namedtuple('csFile', ['city', 'sequenceNb', 'frameNb', 'type', 'type2', 'ext'])
73
+
74
+
75
+ def getCsFileInfo(fileName):
76
+ """Returns a CsFile object filled from the info in the given filename"""
77
+ baseName = os.path.basename(fileName)
78
+ parts = baseName.split('_')
79
+ parts = parts[:-1] + parts[-1].split('.')
80
+ if not parts:
81
+ printError('Cannot parse given filename ({}). Does not seem to be a valid Cityscapes file.'.format(fileName))
82
+ if len(parts) == 5:
83
+ csFile = CsFile(*parts[:-1], type2="", ext=parts[-1])
84
+ elif len(parts) == 6:
85
+ csFile = CsFile(*parts)
86
+ else:
87
+ printError('Found {} part(s) in given filename ({}). Expected 5 or 6.'.format(len(parts), fileName))
88
+
89
+ return csFile
90
+
91
+
92
+ def getCoreImageFileName(filename):
93
+ """Returns the part of Cityscapes filenames that is common to all data types
94
+
95
+ e.g. for city_123456_123456_gtFine_polygons.json returns city_123456_123456
96
+ """
97
+ csFile = getCsFileInfo(filename)
98
+ return "{}_{}_{}".format(csFile.city, csFile.sequenceNb, csFile.frameNb)
99
+
100
+
101
+ def getDirectory(fileName):
102
+ """Returns the directory name for the given filename
103
+
104
+ e.g.
105
+ fileName = "/foo/bar/foobar.txt"
106
+ return value is "bar"
107
+ Not much error checking though
108
+ """
109
+ dirName = os.path.dirname(fileName)
110
+ return os.path.basename(dirName)
111
+
112
+
113
+ def ensurePath(path):
114
+ """Make sure that the given path exists"""
115
+ if not path:
116
+ return
117
+ if not os.path.isdir(path):
118
+ os.makedirs(path)
119
+
120
+
121
+ def writeDict2JSON(dictName, fileName):
122
+ """Write a dictionary as json file"""
123
+ with open(fileName, 'w') as f:
124
+ f.write(json.dumps(dictName, default=lambda o: o.__dict__, sort_keys=True, indent=4))
125
+
126
+
127
+ # dummy main
128
+ if __name__ == "__main__":
129
+ printError("Only for include, not executable on its own.")
ext/cityscapes_scripts/helpers/labels.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #
3
+ # Cityscapes labels
4
+ #
5
+
6
+ from __future__ import print_function, absolute_import, division
7
+ from collections import namedtuple
8
+
9
+
10
+ #--------------------------------------------------------------------------------
11
+ # Definitions
12
+ #--------------------------------------------------------------------------------
13
+
14
+ # a label and all meta information
15
+ Label = namedtuple( 'Label' , [
16
+
17
+ 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
18
+ # We use them to uniquely name a class
19
+
20
+ 'id' , # An integer ID that is associated with this label.
21
+ # The IDs are used to represent the label in ground truth images
22
+ # An ID of -1 means that this label does not have an ID and thus
23
+ # is ignored when creating ground truth images (e.g. license plate).
24
+ # Do not modify these IDs, since exactly these IDs are expected by the
25
+ # evaluation server.
26
+
27
+ 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
28
+ # ground truth images with train IDs, using the tools provided in the
29
+ # 'preparation' folder. However, make sure to validate or submit results
30
+ # to our evaluation server using the regular IDs above!
31
+ # For trainIds, multiple labels might have the same ID. Then, these labels
32
+ # are mapped to the same class in the ground truth images. For the inverse
33
+ # mapping, we use the label that is defined first in the list below.
34
+ # For example, mapping all void-type classes to the same ID in training,
35
+ # might make sense for some approaches.
36
+ # Max value is 255!
37
+
38
+ 'category' , # The name of the category that this label belongs to
39
+
40
+ 'categoryId' , # The ID of this category. Used to create ground truth images
41
+ # on category level.
42
+
43
+ 'hasInstances', # Whether this label distinguishes between single instances or not
44
+
45
+ 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
46
+ # during evaluations or not
47
+
48
+ 'color' , # The color of this label
49
+ ] )
50
+
51
+
52
+ #--------------------------------------------------------------------------------
53
+ # A list of all labels
54
+ #--------------------------------------------------------------------------------
55
+
56
+ # Please adapt the train IDs as appropriate for your approach.
57
+ # Note that you might want to ignore labels with ID 255 during training.
58
+ # Further note that the current train IDs are only a suggestion. You can use whatever you like.
59
+ # Make sure to provide your results using the original IDs and not the training IDs.
60
+ # Note that many IDs are ignored in evaluation and thus you never need to predict these!
61
+
62
+ labels = [
63
+ # name id trainId category catId hasInstances ignoreInEval color
64
+ Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
65
+ Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
66
+ Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
67
+ Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
68
+ Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
69
+ Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
70
+ Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
71
+ Label( 'road' , 7 , 0 + 8, 'flat' , 1 , False , False , (128, 64,128) ),
72
+ Label( 'sidewalk' , 8 , 1 + 8, 'flat' , 1 , False , False , (244, 35,232) ),
73
+ Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
74
+ Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
75
+ Label( 'building' , 11 , 2 + 8, 'construction' , 2 , False , False , ( 70, 70, 70) ),
76
+ Label( 'wall' , 12 , 3 + 8, 'construction' , 2 , False , False , (102,102,156) ),
77
+ Label( 'fence' , 13 , 4 + 8, 'construction' , 2 , False , False , (190,153,153) ),
78
+ Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
79
+ Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
80
+ Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
81
+ Label( 'pole' , 17 , 5 + 8, 'object' , 3 , False , False , (153,153,153) ),
82
+ Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
83
+ Label( 'traffic light' , 19 , 6 + 8, 'object' , 3 , False , False , (250,170, 30) ),
84
+ Label( 'traffic sign' , 20 , 7 + 8, 'object' , 3 , False , False , (220,220, 0) ),
85
+ Label( 'vegetation' , 21 , 8 + 8, 'nature' , 4 , False , False , (107,142, 35) ),
86
+ Label( 'terrain' , 22 , 9 + 8, 'nature' , 4 , False , False , (152,251,152) ),
87
+ Label( 'sky' , 23 , 10 + 8, 'sky' , 5 , False , False , ( 70,130,180) ),
88
+ Label( 'person' , 24 , 11 - 11 , 'human' , 6 , True , False , (220, 20, 60) ),
89
+ Label( 'rider' , 25 , 12 - 11 , 'human' , 6 , True , False , (255, 0, 0) ),
90
+ Label( 'car' , 26 , 13 - 11, 'vehicle' , 7 , True , False , ( 0, 0,142) ),
91
+ Label( 'truck' , 27 , 14 - 11, 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
92
+ Label( 'bus' , 28 , 15 - 11, 'vehicle' , 7 , True , False , ( 0, 60,100) ),
93
+ Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
94
+ Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
95
+ Label( 'train' , 31 , 16 - 11, 'vehicle' , 7 , True , False , ( 0, 80,100) ),
96
+ Label( 'motorcycle' , 32 , 17 - 11, 'vehicle' , 7 , True , False , ( 0, 0,230) ),
97
+ Label( 'bicycle' , 33 , 18 - 11, 'vehicle' , 7 , True , False , (119, 11, 32) ),
98
+ Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
99
+ ]
100
+
101
+
102
+ #--------------------------------------------------------------------------------
103
+ # Create dictionaries for a fast lookup
104
+ #--------------------------------------------------------------------------------
105
+
106
+ # Please refer to the main method below for example usages!
107
+
108
+ # name to label object
109
+ name2label = { label.name : label for label in labels }
110
+ # id to label object
111
+ id2label = { label.id : label for label in labels }
112
+ # trainId to label object
113
+ trainId2label = { label.trainId : label for label in reversed(labels) }
114
+ # category to list of label objects
115
+ category2labels = {}
116
+ for label in labels:
117
+ category = label.category
118
+ if category in category2labels:
119
+ category2labels[category].append(label)
120
+ else:
121
+ category2labels[category] = [label]
122
+
123
+ #--------------------------------------------------------------------------------
124
+ # Assure single instance name
125
+ #--------------------------------------------------------------------------------
126
+
127
+ # returns the label name that describes a single instance (if possible)
128
+ # e.g. input | output
129
+ # ----------------------
130
+ # car | car
131
+ # cargroup | car
132
+ # foo | None
133
+ # foogroup | None
134
+ # skygroup | None
135
+ def assureSingleInstanceName( name ):
136
+ # if the name is known, it is not a group
137
+ if name in name2label:
138
+ return name
139
+ # test if the name actually denotes a group
140
+ if not name.endswith("group"):
141
+ return None
142
+ # remove group
143
+ name = name[:-len("group")]
144
+ # test if the new name exists
145
+ if not name in name2label:
146
+ return None
147
+ # test if the new name denotes a label that actually has instances
148
+ if not name2label[name].hasInstances:
149
+ return None
150
+ # all good then
151
+ return name
152
+
153
+ #--------------------------------------------------------------------------------
154
+ # Main for testing
155
+ #--------------------------------------------------------------------------------
156
+
157
+ # just a dummy main
158
+ if __name__ == "__main__":
159
+ # Print all the labels
160
+ print("List of cityscapes labels:")
161
+ print("")
162
+ print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))
163
+ print(" " + ('-' * 98))
164
+ for label in labels:
165
+ print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))
166
+ print("")
167
+
168
+ print("Example usages:")
169
+
170
+ # Map from name to label
171
+ name = 'car'
172
+ id = name2label[name].id
173
+ print("ID of label '{name}': {id}".format( name=name, id=id ))
174
+
175
+ # Map from ID to label
176
+ category = id2label[id].category
177
+ print("Category of label with ID '{id}': {category}".format( id=id, category=category ))
178
+
179
+ # Map from trainID to label
180
+ trainId = 0
181
+ name = trainId2label[trainId].name
182
+ print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))
ext/cityscapes_scripts/helpers/labels_cityPersons.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #
3
+ # CityPersons (cp) labels
4
+ #
5
+
6
+ from __future__ import print_function, absolute_import, division
7
+ from collections import namedtuple
8
+
9
+
10
+ #--------------------------------------------------------------------------------
11
+ # Definitions
12
+ #--------------------------------------------------------------------------------
13
+
14
+ # a label and all meta information
15
+ LabelCp = namedtuple( 'LabelCp' , [
16
+
17
+ 'name' , # The identifier of this label, e.g. 'pedestrian', 'rider', ... .
18
+ # We use them to uniquely name a class
19
+
20
+ 'id' , # An integer ID that is associated with this label.
21
+ # The IDs are used to represent the label in ground truth
22
+
23
+ 'hasInstances', # Whether this label distinguishes between single instances or not
24
+
25
+ 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
26
+ # during evaluations or not
27
+
28
+ 'color' , # The color of this label
29
+ ] )
30
+
31
+
32
+ #--------------------------------------------------------------------------------
33
+ # A list of all labels
34
+ #--------------------------------------------------------------------------------
35
+
36
+ # The 'ignore' label covers representations of humans, e.g. people on posters, reflections etc.
37
+ # Each annotation includes both the full bounding box (bbox) as well as a bounding box covering the visible area (bboxVis).
38
+ # The latter is obtained automatically from the segmentation masks.
39
+
40
+ labelsCp = [
41
+ # name id hasInstances ignoreInEval color
42
+ LabelCp( 'ignore' , 0 , False , True , (250,170, 30) ),
43
+ LabelCp( 'pedestrian' , 1 , True , False , (220, 20, 60) ),
44
+ LabelCp( 'rider' , 2 , True , False , ( 0, 0,142) ),
45
+ LabelCp( 'sitting person' , 3 , True , False , (107,142, 35) ),
46
+ LabelCp( 'person (other)' , 4 , True , False , (190,153,153) ),
47
+ LabelCp( 'person group' , 5 , False , True , (255, 0, 0) ),
48
+ ]
49
+
50
+
51
+ #--------------------------------------------------------------------------------
52
+ # Create dictionaries for a fast lookup
53
+ #--------------------------------------------------------------------------------
54
+
55
+ # Please refer to the main method below for example usages!
56
+
57
+ # name to label object
58
+ name2labelCp = { label.name : label for label in labelsCp }
59
+ # id to label object
60
+ id2labelCp = { label.id : label for label in labelsCp }
61
+
ext/cityscapes_scripts/helpers/version.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+
5
+ with open(os.path.join(os.path.dirname(__file__), '..', 'VERSION')) as f:
6
+ version = f.read().strip()
7
+
8
+ if __name__ == "__main__":
9
+ print(version)
ext/class_names/VIPSeg.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CLASSES = [
2
+ {"id": 0, "name": "wall", "isthing": 0, "color": [120, 120, 120]},
3
+ {"id": 1, "name": "ceiling", "isthing": 0, "color": [180, 120, 120]},
4
+ {"id": 2, "name": "door", "isthing": 1, "color": [6, 230, 230]},
5
+ {"id": 3, "name": "stair", "isthing": 0, "color": [80, 50, 50]},
6
+ {"id": 4, "name": "ladder", "isthing": 1, "color": [4, 200, 3]},
7
+ {"id": 5, "name": "escalator", "isthing": 0, "color": [120, 120, 80]},
8
+ {"id": 6, "name": "Playground_slide", "isthing": 0, "color": [140, 140, 140]},
9
+ {"id": 7, "name": "handrail_or_fence", "isthing": 0, "color": [204, 5, 255]},
10
+ {"id": 8, "name": "window", "isthing": 1, "color": [230, 230, 230]},
11
+ {"id": 9, "name": "rail", "isthing": 0, "color": [4, 250, 7]},
12
+ {"id": 10, "name": "goal", "isthing": 1, "color": [224, 5, 255]},
13
+ {"id": 11, "name": "pillar", "isthing": 0, "color": [235, 255, 7]},
14
+ {"id": 12, "name": "pole", "isthing": 0, "color": [150, 5, 61]},
15
+ {"id": 13, "name": "floor", "isthing": 0, "color": [120, 120, 70]},
16
+ {"id": 14, "name": "ground", "isthing": 0, "color": [8, 255, 51]},
17
+ {"id": 15, "name": "grass", "isthing": 0, "color": [255, 6, 82]},
18
+ {"id": 16, "name": "sand", "isthing": 0, "color": [143, 255, 140]},
19
+ {"id": 17, "name": "athletic_field", "isthing": 0, "color": [204, 255, 4]},
20
+ {"id": 18, "name": "road", "isthing": 0, "color": [255, 51, 7]},
21
+ {"id": 19, "name": "path", "isthing": 0, "color": [204, 70, 3]},
22
+ {"id": 20, "name": "crosswalk", "isthing": 0, "color": [0, 102, 200]},
23
+ {"id": 21, "name": "building", "isthing": 0, "color": [61, 230, 250]},
24
+ {"id": 22, "name": "house", "isthing": 0, "color": [255, 6, 51]},
25
+ {"id": 23, "name": "bridge", "isthing": 0, "color": [11, 102, 255]},
26
+ {"id": 24, "name": "tower", "isthing": 0, "color": [255, 7, 71]},
27
+ {"id": 25, "name": "windmill", "isthing": 0, "color": [255, 9, 224]},
28
+ {"id": 26, "name": "well_or_well_lid", "isthing": 0, "color": [9, 7, 230]},
29
+ {"id": 27, "name": "other_construction", "isthing": 0, "color": [220, 220, 220]},
30
+ {"id": 28, "name": "sky", "isthing": 0, "color": [255, 9, 92]},
31
+ {"id": 29, "name": "mountain", "isthing": 0, "color": [112, 9, 255]},
32
+ {"id": 30, "name": "stone", "isthing": 0, "color": [8, 255, 214]},
33
+ {"id": 31, "name": "wood", "isthing": 0, "color": [7, 255, 224]},
34
+ {"id": 32, "name": "ice", "isthing": 0, "color": [255, 184, 6]},
35
+ {"id": 33, "name": "snowfield", "isthing": 0, "color": [10, 255, 71]},
36
+ {"id": 34, "name": "grandstand", "isthing": 0, "color": [255, 41, 10]},
37
+ {"id": 35, "name": "sea", "isthing": 0, "color": [7, 255, 255]},
38
+ {"id": 36, "name": "river", "isthing": 0, "color": [224, 255, 8]},
39
+ {"id": 37, "name": "lake", "isthing": 0, "color": [102, 8, 255]},
40
+ {"id": 38, "name": "waterfall", "isthing": 0, "color": [255, 61, 6]},
41
+ {"id": 39, "name": "water", "isthing": 0, "color": [255, 194, 7]},
42
+ {"id": 40, "name": "billboard_or_Bulletin_Board", "isthing": 0, "color": [255, 122, 8]},
43
+ {"id": 41, "name": "sculpture", "isthing": 1, "color": [0, 255, 20]},
44
+ {"id": 42, "name": "pipeline", "isthing": 0, "color": [255, 8, 41]},
45
+ {"id": 43, "name": "flag", "isthing": 1, "color": [255, 5, 153]},
46
+ {"id": 44, "name": "parasol_or_umbrella", "isthing": 1, "color": [6, 51, 255]},
47
+ {"id": 45, "name": "cushion_or_carpet", "isthing": 0, "color": [235, 12, 255]},
48
+ {"id": 46, "name": "tent", "isthing": 1, "color": [160, 150, 20]},
49
+ {"id": 47, "name": "roadblock", "isthing": 1, "color": [0, 163, 255]},
50
+ {"id": 48, "name": "car", "isthing": 1, "color": [140, 140, 140]},
51
+ {"id": 49, "name": "bus", "isthing": 1, "color": [250, 10, 15]},
52
+ {"id": 50, "name": "truck", "isthing": 1, "color": [20, 255, 0]},
53
+ {"id": 51, "name": "bicycle", "isthing": 1, "color": [31, 255, 0]},
54
+ {"id": 52, "name": "motorcycle", "isthing": 1, "color": [255, 31, 0]},
55
+ {"id": 53, "name": "wheeled_machine", "isthing": 0, "color": [255, 224, 0]},
56
+ {"id": 54, "name": "ship_or_boat", "isthing": 1, "color": [153, 255, 0]},
57
+ {"id": 55, "name": "raft", "isthing": 1, "color": [0, 0, 255]},
58
+ {"id": 56, "name": "airplane", "isthing": 1, "color": [255, 71, 0]},
59
+ {"id": 57, "name": "tyre", "isthing": 0, "color": [0, 235, 255]},
60
+ {"id": 58, "name": "traffic_light", "isthing": 0, "color": [0, 173, 255]},
61
+ {"id": 59, "name": "lamp", "isthing": 0, "color": [31, 0, 255]},
62
+ {"id": 60, "name": "person", "isthing": 1, "color": [11, 200, 200]},
63
+ {"id": 61, "name": "cat", "isthing": 1, "color": [255, 82, 0]},
64
+ {"id": 62, "name": "dog", "isthing": 1, "color": [0, 255, 245]},
65
+ {"id": 63, "name": "horse", "isthing": 1, "color": [0, 61, 255]},
66
+ {"id": 64, "name": "cattle", "isthing": 1, "color": [0, 255, 112]},
67
+ {"id": 65, "name": "other_animal", "isthing": 1, "color": [0, 255, 133]},
68
+ {"id": 66, "name": "tree", "isthing": 0, "color": [255, 0, 0]},
69
+ {"id": 67, "name": "flower", "isthing": 0, "color": [255, 163, 0]},
70
+ {"id": 68, "name": "other_plant", "isthing": 0, "color": [255, 102, 0]},
71
+ {"id": 69, "name": "toy", "isthing": 0, "color": [194, 255, 0]},
72
+ {"id": 70, "name": "ball_net", "isthing": 0, "color": [0, 143, 255]},
73
+ {"id": 71, "name": "backboard", "isthing": 0, "color": [51, 255, 0]},
74
+ {"id": 72, "name": "skateboard", "isthing": 1, "color": [0, 82, 255]},
75
+ {"id": 73, "name": "bat", "isthing": 0, "color": [0, 255, 41]},
76
+ {"id": 74, "name": "ball", "isthing": 1, "color": [0, 255, 173]},
77
+ {"id": 75, "name": "cupboard_or_showcase_or_storage_rack", "isthing": 0, "color": [10, 0, 255]},
78
+ {"id": 76, "name": "box", "isthing": 1, "color": [173, 255, 0]},
79
+ {"id": 77, "name": "traveling_case_or_trolley_case", "isthing": 1, "color": [0, 255, 153]},
80
+ {"id": 78, "name": "basket", "isthing": 1, "color": [255, 92, 0]},
81
+ {"id": 79, "name": "bag_or_package", "isthing": 1, "color": [255, 0, 255]},
82
+ {"id": 80, "name": "trash_can", "isthing": 0, "color": [255, 0, 245]},
83
+ {"id": 81, "name": "cage", "isthing": 0, "color": [255, 0, 102]},
84
+ {"id": 82, "name": "plate", "isthing": 1, "color": [255, 173, 0]},
85
+ {"id": 83, "name": "tub_or_bowl_or_pot", "isthing": 1, "color": [255, 0, 20]},
86
+ {"id": 84, "name": "bottle_or_cup", "isthing": 1, "color": [255, 184, 184]},
87
+ {"id": 85, "name": "barrel", "isthing": 1, "color": [0, 31, 255]},
88
+ {"id": 86, "name": "fishbowl", "isthing": 1, "color": [0, 255, 61]},
89
+ {"id": 87, "name": "bed", "isthing": 1, "color": [0, 71, 255]},
90
+ {"id": 88, "name": "pillow", "isthing": 1, "color": [255, 0, 204]},
91
+ {"id": 89, "name": "table_or_desk", "isthing": 1, "color": [0, 255, 194]},
92
+ {"id": 90, "name": "chair_or_seat", "isthing": 1, "color": [0, 255, 82]},
93
+ {"id": 91, "name": "bench", "isthing": 1, "color": [0, 10, 255]},
94
+ {"id": 92, "name": "sofa", "isthing": 1, "color": [0, 112, 255]},
95
+ {"id": 93, "name": "shelf", "isthing": 0, "color": [51, 0, 255]},
96
+ {"id": 94, "name": "bathtub", "isthing": 0, "color": [0, 194, 255]},
97
+ {"id": 95, "name": "gun", "isthing": 1, "color": [0, 122, 255]},
98
+ {"id": 96, "name": "commode", "isthing": 1, "color": [0, 255, 163]},
99
+ {"id": 97, "name": "roaster", "isthing": 1, "color": [255, 153, 0]},
100
+ {"id": 98, "name": "other_machine", "isthing": 0, "color": [0, 255, 10]},
101
+ {"id": 99, "name": "refrigerator", "isthing": 1, "color": [255, 112, 0]},
102
+ {"id": 100, "name": "washing_machine", "isthing": 1, "color": [143, 255, 0]},
103
+ {"id": 101, "name": "Microwave_oven", "isthing": 1, "color": [82, 0, 255]},
104
+ {"id": 102, "name": "fan", "isthing": 1, "color": [163, 255, 0]},
105
+ {"id": 103, "name": "curtain", "isthing": 0, "color": [255, 235, 0]},
106
+ {"id": 104, "name": "textiles", "isthing": 0, "color": [8, 184, 170]},
107
+ {"id": 105, "name": "clothes", "isthing": 0, "color": [133, 0, 255]},
108
+ {"id": 106, "name": "painting_or_poster", "isthing": 1, "color": [0, 255, 92]},
109
+ {"id": 107, "name": "mirror", "isthing": 1, "color": [184, 0, 255]},
110
+ {"id": 108, "name": "flower_pot_or_vase", "isthing": 1, "color": [255, 0, 31]},
111
+ {"id": 109, "name": "clock", "isthing": 1, "color": [0, 184, 255]},
112
+ {"id": 110, "name": "book", "isthing": 0, "color": [0, 214, 255]},
113
+ {"id": 111, "name": "tool", "isthing": 0, "color": [255, 0, 112]},
114
+ {"id": 112, "name": "blackboard", "isthing": 0, "color": [92, 255, 0]},
115
+ {"id": 113, "name": "tissue", "isthing": 0, "color": [0, 224, 255]},
116
+ {"id": 114, "name": "screen_or_television", "isthing": 1, "color": [112, 224, 255]},
117
+ {"id": 115, "name": "computer", "isthing": 1, "color": [70, 184, 160]},
118
+ {"id": 116, "name": "printer", "isthing": 1, "color": [163, 0, 255]},
119
+ {"id": 117, "name": "Mobile_phone", "isthing": 1, "color": [153, 0, 255]},
120
+ {"id": 118, "name": "keyboard", "isthing": 1, "color": [71, 255, 0]},
121
+ {"id": 119, "name": "other_electronic_product", "isthing": 0, "color": [255, 0, 163]},
122
+ {"id": 120, "name": "fruit", "isthing": 0, "color": [255, 204, 0]},
123
+ {"id": 121, "name": "food", "isthing": 0, "color": [255, 0, 143]},
124
+ {"id": 122, "name": "instrument", "isthing": 1, "color": [0, 255, 235]},
125
+ {"id": 123, "name": "train", "isthing": 1, "color": [133, 255, 0]}
126
+ ]
127
+
128
+ CLASSES_THING = [
129
+ {'id': 2, 'name': 'door', 'isthing': 1, 'color': [6, 230, 230]},
130
+ {'id': 4, 'name': 'ladder', 'isthing': 1, 'color': [4, 200, 3]},
131
+ {'id': 8, 'name': 'window', 'isthing': 1, 'color': [230, 230, 230]},
132
+ {'id': 10, 'name': 'goal', 'isthing': 1, 'color': [224, 5, 255]},
133
+ {'id': 41, 'name': 'sculpture', 'isthing': 1, 'color': [0, 255, 20]},
134
+ {'id': 43, 'name': 'flag', 'isthing': 1, 'color': [255, 5, 153]},
135
+ {'id': 44, 'name': 'parasol_or_umbrella', 'isthing': 1, 'color': [6, 51, 255]},
136
+ {'id': 46, 'name': 'tent', 'isthing': 1, 'color': [160, 150, 20]},
137
+ {'id': 47, 'name': 'roadblock', 'isthing': 1, 'color': [0, 163, 255]},
138
+ {'id': 48, 'name': 'car', 'isthing': 1, 'color': [140, 140, 140]},
139
+ {'id': 49, 'name': 'bus', 'isthing': 1, 'color': [250, 10, 15]},
140
+ {'id': 50, 'name': 'truck', 'isthing': 1, 'color': [20, 255, 0]},
141
+ {'id': 51, 'name': 'bicycle', 'isthing': 1, 'color': [31, 255, 0]},
142
+ {'id': 52, 'name': 'motorcycle', 'isthing': 1, 'color': [255, 31, 0]},
143
+ {'id': 54, 'name': 'ship_or_boat', 'isthing': 1, 'color': [153, 255, 0]},
144
+ {'id': 55, 'name': 'raft', 'isthing': 1, 'color': [0, 0, 255]},
145
+ {'id': 56, 'name': 'airplane', 'isthing': 1, 'color': [255, 71, 0]},
146
+ {'id': 60, 'name': 'person', 'isthing': 1, 'color': [11, 200, 200]},
147
+ {'id': 61, 'name': 'cat', 'isthing': 1, 'color': [255, 82, 0]},
148
+ {'id': 62, 'name': 'dog', 'isthing': 1, 'color': [0, 255, 245]},
149
+ {'id': 63, 'name': 'horse', 'isthing': 1, 'color': [0, 61, 255]},
150
+ {'id': 64, 'name': 'cattle', 'isthing': 1, 'color': [0, 255, 112]},
151
+ {'id': 65, 'name': 'other_animal', 'isthing': 1, 'color': [0, 255, 133]},
152
+ {'id': 72, 'name': 'skateboard', 'isthing': 1, 'color': [0, 82, 255]},
153
+ {'id': 74, 'name': 'ball', 'isthing': 1, 'color': [0, 255, 173]},
154
+ {'id': 76, 'name': 'box', 'isthing': 1, 'color': [173, 255, 0]},
155
+ {'id': 77, 'name': 'traveling_case_or_trolley_case', 'isthing': 1, 'color': [0, 255, 153]},
156
+ {'id': 78, 'name': 'basket', 'isthing': 1, 'color': [255, 92, 0]},
157
+ {'id': 79, 'name': 'bag_or_package', 'isthing': 1, 'color': [255, 0, 255]},
158
+ {'id': 82, 'name': 'plate', 'isthing': 1, 'color': [255, 173, 0]},
159
+ {'id': 83, 'name': 'tub_or_bowl_or_pot', 'isthing': 1, 'color': [255, 0, 20]},
160
+ {'id': 84, 'name': 'bottle_or_cup', 'isthing': 1, 'color': [255, 184, 184]},
161
+ {'id': 85, 'name': 'barrel', 'isthing': 1, 'color': [0, 31, 255]},
162
+ {'id': 86, 'name': 'fishbowl', 'isthing': 1, 'color': [0, 255, 61]},
163
+ {'id': 87, 'name': 'bed', 'isthing': 1, 'color': [0, 71, 255]},
164
+ {'id': 88, 'name': 'pillow', 'isthing': 1, 'color': [255, 0, 204]},
165
+ {'id': 89, 'name': 'table_or_desk', 'isthing': 1, 'color': [0, 255, 194]},
166
+ {'id': 90, 'name': 'chair_or_seat', 'isthing': 1, 'color': [0, 255, 82]},
167
+ {'id': 91, 'name': 'bench', 'isthing': 1, 'color': [0, 10, 255]},
168
+ {'id': 92, 'name': 'sofa', 'isthing': 1, 'color': [0, 112, 255]},
169
+ {'id': 95, 'name': 'gun', 'isthing': 1, 'color': [0, 122, 255]},
170
+ {'id': 96, 'name': 'commode', 'isthing': 1, 'color': [0, 255, 163]},
171
+ {'id': 97, 'name': 'roaster', 'isthing': 1, 'color': [255, 153, 0]},
172
+ {'id': 99, 'name': 'refrigerator', 'isthing': 1, 'color': [255, 112, 0]},
173
+ {'id': 100, 'name': 'washing_machine', 'isthing': 1, 'color': [143, 255, 0]},
174
+ {'id': 101, 'name': 'Microwave_oven', 'isthing': 1, 'color': [82, 0, 255]},
175
+ {'id': 102, 'name': 'fan', 'isthing': 1, 'color': [163, 255, 0]},
176
+ {'id': 106, 'name': 'painting_or_poster', 'isthing': 1, 'color': [0, 255, 92]},
177
+ {'id': 107, 'name': 'mirror', 'isthing': 1, 'color': [184, 0, 255]},
178
+ {'id': 108, 'name': 'flower_pot_or_vase', 'isthing': 1, 'color': [255, 0, 31]},
179
+ {'id': 109, 'name': 'clock', 'isthing': 1, 'color': [0, 184, 255]},
180
+ {'id': 114, 'name': 'screen_or_television', 'isthing': 1, 'color': [112, 224, 255]},
181
+ {'id': 115, 'name': 'computer', 'isthing': 1, 'color': [70, 184, 160]},
182
+ {'id': 116, 'name': 'printer', 'isthing': 1, 'color': [163, 0, 255]},
183
+ {'id': 117, 'name': 'Mobile_phone', 'isthing': 1, 'color': [153, 0, 255]},
184
+ {'id': 118, 'name': 'keyboard', 'isthing': 1, 'color': [71, 255, 0]},
185
+ {'id': 122, 'name': 'instrument', 'isthing': 1, 'color': [0, 255, 235]},
186
+ {'id': 123, 'name': 'train', 'isthing': 1, 'color': [133, 255, 0]}
187
+ ]
188
+
189
+ CLASSES_STUFF = [
190
+ {'id': 0, 'name': 'wall', 'isthing': 0, 'color': [120, 120, 120]},
191
+ {'id': 1, 'name': 'ceiling', 'isthing': 0, 'color': [180, 120, 120]},
192
+ {'id': 3, 'name': 'stair', 'isthing': 0, 'color': [80, 50, 50]},
193
+ {'id': 5, 'name': 'escalator', 'isthing': 0, 'color': [120, 120, 80]},
194
+ {'id': 6, 'name': 'Playground_slide', 'isthing': 0, 'color': [140, 140, 140]},
195
+ {'id': 7, 'name': 'handrail_or_fence', 'isthing': 0, 'color': [204, 5, 255]},
196
+ {'id': 9, 'name': 'rail', 'isthing': 0, 'color': [4, 250, 7]},
197
+ {'id': 11, 'name': 'pillar', 'isthing': 0, 'color': [235, 255, 7]},
198
+ {'id': 12, 'name': 'pole', 'isthing': 0, 'color': [150, 5, 61]},
199
+ {'id': 13, 'name': 'floor', 'isthing': 0, 'color': [120, 120, 70]},
200
+ {'id': 14, 'name': 'ground', 'isthing': 0, 'color': [8, 255, 51]},
201
+ {'id': 15, 'name': 'grass', 'isthing': 0, 'color': [255, 6, 82]},
202
+ {'id': 16, 'name': 'sand', 'isthing': 0, 'color': [143, 255, 140]},
203
+ {'id': 17, 'name': 'athletic_field', 'isthing': 0, 'color': [204, 255, 4]},
204
+ {'id': 18, 'name': 'road', 'isthing': 0, 'color': [255, 51, 7]},
205
+ {'id': 19, 'name': 'path', 'isthing': 0, 'color': [204, 70, 3]},
206
+ {'id': 20, 'name': 'crosswalk', 'isthing': 0, 'color': [0, 102, 200]},
207
+ {'id': 21, 'name': 'building', 'isthing': 0, 'color': [61, 230, 250]},
208
+ {'id': 22, 'name': 'house', 'isthing': 0, 'color': [255, 6, 51]},
209
+ {'id': 23, 'name': 'bridge', 'isthing': 0, 'color': [11, 102, 255]},
210
+ {'id': 24, 'name': 'tower', 'isthing': 0, 'color': [255, 7, 71]},
211
+ {'id': 25, 'name': 'windmill', 'isthing': 0, 'color': [255, 9, 224]},
212
+ {'id': 26, 'name': 'well_or_well_lid', 'isthing': 0, 'color': [9, 7, 230]},
213
+ {'id': 27, 'name': 'other_construction', 'isthing': 0, 'color': [220, 220, 220]},
214
+ {'id': 28, 'name': 'sky', 'isthing': 0, 'color': [255, 9, 92]},
215
+ {'id': 29, 'name': 'mountain', 'isthing': 0, 'color': [112, 9, 255]},
216
+ {'id': 30, 'name': 'stone', 'isthing': 0, 'color': [8, 255, 214]},
217
+ {'id': 31, 'name': 'wood', 'isthing': 0, 'color': [7, 255, 224]},
218
+ {'id': 32, 'name': 'ice', 'isthing': 0, 'color': [255, 184, 6]},
219
+ {'id': 33, 'name': 'snowfield', 'isthing': 0, 'color': [10, 255, 71]},
220
+ {'id': 34, 'name': 'grandstand', 'isthing': 0, 'color': [255, 41, 10]},
221
+ {'id': 35, 'name': 'sea', 'isthing': 0, 'color': [7, 255, 255]},
222
+ {'id': 36, 'name': 'river', 'isthing': 0, 'color': [224, 255, 8]},
223
+ {'id': 37, 'name': 'lake', 'isthing': 0, 'color': [102, 8, 255]},
224
+ {'id': 38, 'name': 'waterfall', 'isthing': 0, 'color': [255, 61, 6]},
225
+ {'id': 39, 'name': 'water', 'isthing': 0, 'color': [255, 194, 7]},
226
+ {'id': 40, 'name': 'billboard_or_Bulletin_Board', 'isthing': 0, 'color': [255, 122, 8]},
227
+ {'id': 42, 'name': 'pipeline', 'isthing': 0, 'color': [255, 8, 41]},
228
+ {'id': 45, 'name': 'cushion_or_carpet', 'isthing': 0, 'color': [235, 12, 255]},
229
+ {'id': 53, 'name': 'wheeled_machine', 'isthing': 0, 'color': [255, 224, 0]},
230
+ {'id': 57, 'name': 'tyre', 'isthing': 0, 'color': [0, 235, 255]},
231
+ {'id': 58, 'name': 'traffic_light', 'isthing': 0, 'color': [0, 173, 255]},
232
+ {'id': 59, 'name': 'lamp', 'isthing': 0, 'color': [31, 0, 255]},
233
+ {'id': 66, 'name': 'tree', 'isthing': 0, 'color': [255, 0, 0]},
234
+ {'id': 67, 'name': 'flower', 'isthing': 0, 'color': [255, 163, 0]},
235
+ {'id': 68, 'name': 'other_plant', 'isthing': 0, 'color': [255, 102, 0]},
236
+ {'id': 69, 'name': 'toy', 'isthing': 0, 'color': [194, 255, 0]},
237
+ {'id': 70, 'name': 'ball_net', 'isthing': 0, 'color': [0, 143, 255]},
238
+ {'id': 71, 'name': 'backboard', 'isthing': 0, 'color': [51, 255, 0]},
239
+ {'id': 73, 'name': 'bat', 'isthing': 0, 'color': [0, 255, 41]},
240
+ {'id': 75, 'name': 'cupboard_or_showcase_or_storage_rack', 'isthing': 0, 'color': [10, 0, 255]},
241
+ {'id': 80, 'name': 'trash_can', 'isthing': 0, 'color': [255, 0, 245]},
242
+ {'id': 81, 'name': 'cage', 'isthing': 0, 'color': [255, 0, 102]},
243
+ {'id': 93, 'name': 'shelf', 'isthing': 0, 'color': [51, 0, 255]},
244
+ {'id': 94, 'name': 'bathtub', 'isthing': 0, 'color': [0, 194, 255]},
245
+ {'id': 98, 'name': 'other_machine', 'isthing': 0, 'color': [0, 255, 10]},
246
+ {'id': 103, 'name': 'curtain', 'isthing': 0, 'color': [255, 235, 0]},
247
+ {'id': 104, 'name': 'textiles', 'isthing': 0, 'color': [8, 184, 170]},
248
+ {'id': 105, 'name': 'clothes', 'isthing': 0, 'color': [133, 0, 255]},
249
+ {'id': 110, 'name': 'book', 'isthing': 0, 'color': [0, 214, 255]},
250
+ {'id': 111, 'name': 'tool', 'isthing': 0, 'color': [255, 0, 112]},
251
+ {'id': 112, 'name': 'blackboard', 'isthing': 0, 'color': [92, 255, 0]},
252
+ {'id': 113, 'name': 'tissue', 'isthing': 0, 'color': [0, 224, 255]},
253
+ {'id': 119, 'name': 'other_electronic_product', 'isthing': 0, 'color': [255, 0, 163]},
254
+ {'id': 120, 'name': 'fruit', 'isthing': 0, 'color': [255, 204, 0]},
255
+ {'id': 121, 'name': 'food', 'isthing': 0, 'color': [255, 0, 143]}
256
+ ]
257
+
258
+ COCO_THINGS = [itm['name'] for itm in CLASSES_THING]
259
+ COCO_STUFF = [itm['name'] for itm in CLASSES_STUFF]
260
+ COCO_CLASSES = [*COCO_THINGS, *COCO_STUFF]
261
+ PALETTE = [*[itm['color'] for itm in CLASSES_THING], *[itm['color'] for itm in CLASSES_STUFF]]
ext/davis2017/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ __version__ = '0.1.0'
ext/davis2017/davis.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from collections import defaultdict
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ class DAVIS(object):
9
+ SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge']
10
+ TASKS = ['semi-supervised', 'unsupervised']
11
+ DATASET_WEB = 'https://davischallenge.org/davis2017/code.html'
12
+ VOID_LABEL = 255
13
+
14
+ def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False):
15
+ """
16
+ Class to read the DAVIS dataset
17
+ :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders.
18
+ :param task: Task to load the annotations, choose between semi-supervised or unsupervised.
19
+ :param subset: Set to load the annotations
20
+ :param sequences: Sequences to consider, 'all' to use all the sequences in a set.
21
+ :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution'
22
+ """
23
+ if subset not in self.SUBSET_OPTIONS:
24
+ raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}')
25
+ if task not in self.TASKS:
26
+ raise ValueError(f'The only tasks that are supported are {self.TASKS}')
27
+
28
+ self.task = task
29
+ self.subset = subset
30
+ self.root = root
31
+ self.img_path = os.path.join(self.root, 'JPEGImages', resolution)
32
+ annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised'
33
+ self.mask_path = os.path.join(self.root, annotations_folder, resolution)
34
+ year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017'
35
+ self.imagesets_path = os.path.join(self.root, 'ImageSets', year)
36
+
37
+ self._check_directories()
38
+
39
+ if sequences == 'all':
40
+ with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f:
41
+ tmp = f.readlines()
42
+ sequences_names = [x.strip() for x in tmp]
43
+ else:
44
+ sequences_names = sequences if isinstance(sequences, list) else [sequences]
45
+ self.sequences = defaultdict(dict)
46
+
47
+ for seq in sequences_names:
48
+ images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist()
49
+ if len(images) == 0 and not codalab:
50
+ raise FileNotFoundError(f'Images for sequence {seq} not found.')
51
+ self.sequences[seq]['images'] = images
52
+ masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist()
53
+ masks.extend([-1] * (len(images) - len(masks)))
54
+ self.sequences[seq]['masks'] = masks
55
+
56
+ def _check_directories(self):
57
+ if not os.path.exists(self.root):
58
+ raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}')
59
+ if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')):
60
+ raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset '
61
+ f'for the {self.task} task from {self.DATASET_WEB}')
62
+ if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path):
63
+ raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}')
64
+
65
+ def get_frames(self, sequence):
66
+ for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']):
67
+ image = np.array(Image.open(img))
68
+ mask = None if msk is None else np.array(Image.open(msk))
69
+ yield image, mask
70
+
71
+ def _get_all_elements(self, sequence, obj_type):
72
+ obj = np.array(Image.open(self.sequences[sequence][obj_type][0]))
73
+ all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape))
74
+ obj_id = []
75
+ for i, obj in enumerate(self.sequences[sequence][obj_type]):
76
+ all_objs[i, ...] = np.array(Image.open(obj))
77
+ obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1]))
78
+ return all_objs, obj_id
79
+
80
+ def get_all_images(self, sequence):
81
+ return self._get_all_elements(sequence, 'images')
82
+
83
+ def get_all_masks(self, sequence, separate_objects_masks=False):
84
+ masks, masks_id = self._get_all_elements(sequence, 'masks')
85
+ masks_void = np.zeros_like(masks)
86
+
87
+ # Separate void and object masks
88
+ for i in range(masks.shape[0]):
89
+ masks_void[i, ...] = masks[i, ...] == 255
90
+ masks[i, masks[i, ...] == 255] = 0
91
+
92
+ if separate_objects_masks:
93
+ num_objects = int(np.max(masks[0, ...]))
94
+ tmp = np.ones((num_objects, *masks.shape))
95
+ tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None]
96
+ masks = (tmp == masks[None, ...])
97
+ masks = masks > 0
98
+ return masks, masks_void, masks_id
99
+
100
+ def get_sequences(self):
101
+ for seq in self.sequences:
102
+ yield seq
103
+
104
+
105
+ if __name__ == '__main__':
106
+ from matplotlib import pyplot as plt
107
+
108
+ only_first_frame = True
109
+ subsets = ['train', 'val']
110
+
111
+ for s in subsets:
112
+ dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s)
113
+ for seq in dataset.get_sequences():
114
+ g = dataset.get_frames(seq)
115
+ img, mask = next(g)
116
+ plt.subplot(2, 1, 1)
117
+ plt.title(seq)
118
+ plt.imshow(img)
119
+ plt.subplot(2, 1, 2)
120
+ plt.imshow(mask)
121
+ plt.show(block=True)
122
+
ext/davis2017/evaluation.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from tqdm import tqdm
3
+ import warnings
4
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
5
+
6
+ import numpy as np
7
+ from ext.davis2017.davis import DAVIS
8
+ from ext.davis2017.metrics import db_eval_boundary, db_eval_iou
9
+ from ext.davis2017 import utils
10
+ from ext.davis2017.results import Results
11
+ from scipy.optimize import linear_sum_assignment
12
+
13
+
14
+ class DAVISEvaluation(object):
15
+ def __init__(self, davis_root, task, gt_set, sequences='all', codalab=False):
16
+ """
17
+ Class to evaluate DAVIS sequences from a certain set and for a certain task
18
+ :param davis_root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders.
19
+ :param task: Task to compute the evaluation, chose between semi-supervised or unsupervised.
20
+ :param gt_set: Set to compute the evaluation
21
+ :param sequences: Sequences to consider for the evaluation, 'all' to use all the sequences in a set.
22
+ """
23
+ self.davis_root = davis_root
24
+ self.task = task
25
+ self.dataset = DAVIS(root=davis_root, task=task, subset=gt_set, sequences=sequences, codalab=codalab)
26
+
27
+ @staticmethod
28
+ def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric):
29
+ if all_res_masks.shape[0] > all_gt_masks.shape[0]:
30
+ sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!")
31
+ sys.exit()
32
+ elif all_res_masks.shape[0] < all_gt_masks.shape[0]:
33
+ zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:]))
34
+ all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0)
35
+ j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2])
36
+ for ii in range(all_gt_masks.shape[0]):
37
+ if 'J' in metric:
38
+ j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
39
+ if 'F' in metric:
40
+ f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
41
+ return j_metrics_res, f_metrics_res
42
+
43
+ @staticmethod
44
+ def _evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric, max_n_proposals=20):
45
+ if all_res_masks.shape[0] > max_n_proposals:
46
+ sys.stdout.write(f"\nIn your PNG files there is an index higher than the maximum number ({max_n_proposals}) of proposals allowed!")
47
+ sys.exit()
48
+ elif all_res_masks.shape[0] < all_gt_masks.shape[0]:
49
+ zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:]))
50
+ all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0)
51
+ j_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1]))
52
+ f_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1]))
53
+ for ii in range(all_gt_masks.shape[0]):
54
+ for jj in range(all_res_masks.shape[0]):
55
+ if 'J' in metric:
56
+ j_metrics_res[jj, ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks)
57
+ if 'F' in metric:
58
+ f_metrics_res[jj, ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks)
59
+ if 'J' in metric and 'F' in metric:
60
+ all_metrics = (np.mean(j_metrics_res, axis=2) + np.mean(f_metrics_res, axis=2)) / 2
61
+ else:
62
+ all_metrics = np.mean(j_metrics_res, axis=2) if 'J' in metric else np.mean(f_metrics_res, axis=2)
63
+ row_ind, col_ind = linear_sum_assignment(-all_metrics)
64
+ return j_metrics_res[row_ind, col_ind, :], f_metrics_res[row_ind, col_ind, :]
65
+
66
+ def evaluate(self, res_path, metric=('J', 'F'), debug=False):
67
+ metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric]
68
+ if 'T' in metric:
69
+ raise ValueError('Temporal metric not supported!')
70
+ if 'J' not in metric and 'F' not in metric:
71
+ raise ValueError('Metric possible values are J for IoU or F for Boundary')
72
+
73
+ # Containers
74
+ metrics_res = {}
75
+ if 'J' in metric:
76
+ metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
77
+ if 'F' in metric:
78
+ metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
79
+
80
+ # Sweep all sequences
81
+ results = Results(root_dir=res_path)
82
+ for seq in tqdm(list(self.dataset.get_sequences())):
83
+ all_gt_masks, all_void_masks, all_masks_id = self.dataset.get_all_masks(seq, True)
84
+ if self.task == 'semi-supervised':
85
+ all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1]
86
+ all_res_masks = results.read_masks(seq, all_masks_id)
87
+ if self.task == 'unsupervised':
88
+ j_metrics_res, f_metrics_res = self._evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric)
89
+ elif self.task == 'semi-supervised':
90
+ j_metrics_res, f_metrics_res = self._evaluate_semisupervised(all_gt_masks, all_res_masks, None, metric)
91
+ for ii in range(all_gt_masks.shape[0]):
92
+ seq_name = f'{seq}_{ii+1}'
93
+ if 'J' in metric:
94
+ [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii])
95
+ metrics_res['J']["M"].append(JM)
96
+ metrics_res['J']["R"].append(JR)
97
+ metrics_res['J']["D"].append(JD)
98
+ metrics_res['J']["M_per_object"][seq_name] = JM
99
+ if 'F' in metric:
100
+ [FM, FR, FD] = utils.db_statistics(f_metrics_res[ii])
101
+ metrics_res['F']["M"].append(FM)
102
+ metrics_res['F']["R"].append(FR)
103
+ metrics_res['F']["D"].append(FD)
104
+ metrics_res['F']["M_per_object"][seq_name] = FM
105
+
106
+ # Show progress
107
+ if debug:
108
+ sys.stdout.write(seq + '\n')
109
+ sys.stdout.flush()
110
+ return metrics_res
ext/davis2017/metrics.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ def db_eval_iou(annotation, segmentation, void_pixels=None):
7
+ """ Compute region similarity as the Jaccard Index.
8
+ Arguments:
9
+ annotation (ndarray): binary annotation map.
10
+ segmentation (ndarray): binary segmentation map.
11
+ void_pixels (ndarray): optional mask with void pixels
12
+
13
+ Return:
14
+ jaccard (float): region similarity
15
+ """
16
+ assert annotation.shape == segmentation.shape, \
17
+ f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.'
18
+ annotation = annotation.astype(bool)
19
+ segmentation = segmentation.astype(bool)
20
+
21
+ if void_pixels is not None:
22
+ assert annotation.shape == void_pixels.shape, \
23
+ f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.'
24
+ void_pixels = void_pixels.astype(bool)
25
+ else:
26
+ void_pixels = np.zeros_like(segmentation)
27
+
28
+ # Intersection between all sets
29
+ inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1))
30
+ union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1))
31
+
32
+ j = inters / union
33
+ if j.ndim == 0:
34
+ j = 1 if np.isclose(union, 0) else j
35
+ else:
36
+ j[np.isclose(union, 0)] = 1
37
+ return j
38
+
39
+
40
+ def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008):
41
+ assert annotation.shape == segmentation.shape
42
+ if void_pixels is not None:
43
+ assert annotation.shape == void_pixels.shape
44
+ if annotation.ndim == 3:
45
+ n_frames = annotation.shape[0]
46
+ f_res = np.zeros(n_frames)
47
+ for frame_id in range(n_frames):
48
+ void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ]
49
+ f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th)
50
+ elif annotation.ndim == 2:
51
+ f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th)
52
+ else:
53
+ raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions')
54
+ return f_res
55
+
56
+
57
+ def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
58
+ """
59
+ Compute mean,recall and decay from per-frame evaluation.
60
+ Calculates precision/recall for boundaries between foreground_mask and
61
+ gt_mask using morphological operators to speed it up.
62
+
63
+ Arguments:
64
+ foreground_mask (ndarray): binary segmentation image.
65
+ gt_mask (ndarray): binary annotated image.
66
+ void_pixels (ndarray): optional mask with void pixels
67
+
68
+ Returns:
69
+ F (float): boundaries F-measure
70
+ """
71
+ assert np.atleast_3d(foreground_mask).shape[2] == 1
72
+ if void_pixels is not None:
73
+ void_pixels = void_pixels.astype(bool)
74
+ else:
75
+ void_pixels = np.zeros_like(foreground_mask).astype(bool)
76
+
77
+ bound_pix = bound_th if bound_th >= 1 else \
78
+ np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))
79
+
80
+ # Get the pixel boundaries of both masks
81
+ fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels))
82
+ gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels))
83
+
84
+ from skimage.morphology import disk
85
+
86
+ # fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
87
+ fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
88
+ # gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
89
+ gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
90
+
91
+ # Get the intersection
92
+ gt_match = gt_boundary * fg_dil
93
+ fg_match = fg_boundary * gt_dil
94
+
95
+ # Area of the intersection
96
+ n_fg = np.sum(fg_boundary)
97
+ n_gt = np.sum(gt_boundary)
98
+
99
+ # % Compute precision and recall
100
+ if n_fg == 0 and n_gt > 0:
101
+ precision = 1
102
+ recall = 0
103
+ elif n_fg > 0 and n_gt == 0:
104
+ precision = 0
105
+ recall = 1
106
+ elif n_fg == 0 and n_gt == 0:
107
+ precision = 1
108
+ recall = 1
109
+ else:
110
+ precision = np.sum(fg_match) / float(n_fg)
111
+ recall = np.sum(gt_match) / float(n_gt)
112
+
113
+ # Compute F measure
114
+ if precision + recall == 0:
115
+ F = 0
116
+ else:
117
+ F = 2 * precision * recall / (precision + recall)
118
+
119
+ return F
120
+
121
+
122
+ def _seg2bmap(seg, width=None, height=None):
123
+ """
124
+ From a segmentation, compute a binary boundary map with 1 pixel wide
125
+ boundaries. The boundary pixels are offset by 1/2 pixel towards the
126
+ origin from the actual segment boundary.
127
+ Arguments:
128
+ seg : Segments labeled from 1..k.
129
+ width : Width of desired bmap <= seg.shape[1]
130
+ height : Height of desired bmap <= seg.shape[0]
131
+ Returns:
132
+ bmap (ndarray): Binary boundary map.
133
+ David Martin <[email protected]>
134
+ January 2003
135
+ """
136
+
137
+ seg = seg.astype(bool)
138
+ seg[seg > 0] = 1
139
+
140
+ assert np.atleast_3d(seg).shape[2] == 1
141
+
142
+ width = seg.shape[1] if width is None else width
143
+ height = seg.shape[0] if height is None else height
144
+
145
+ h, w = seg.shape[:2]
146
+
147
+ ar1 = float(width) / float(height)
148
+ ar2 = float(w) / float(h)
149
+
150
+ assert not (
151
+ width > w | height > h | abs(ar1 - ar2) > 0.01
152
+ ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
153
+
154
+ e = np.zeros_like(seg)
155
+ s = np.zeros_like(seg)
156
+ se = np.zeros_like(seg)
157
+
158
+ e[:, :-1] = seg[:, 1:]
159
+ s[:-1, :] = seg[1:, :]
160
+ se[:-1, :-1] = seg[1:, 1:]
161
+
162
+ b = seg ^ e | seg ^ s | seg ^ se
163
+ b[-1, :] = seg[-1, :] ^ e[-1, :]
164
+ b[:, -1] = seg[:, -1] ^ s[:, -1]
165
+ b[-1, -1] = 0
166
+
167
+ if w == width and h == height:
168
+ bmap = b
169
+ else:
170
+ bmap = np.zeros((height, width))
171
+ for x in range(w):
172
+ for y in range(h):
173
+ if b[y, x]:
174
+ j = 1 + math.floor((y - 1) + height / h)
175
+ i = 1 + math.floor((x - 1) + width / h)
176
+ bmap[j, i] = 1
177
+
178
+ return bmap
179
+
180
+
181
+ if __name__ == '__main__':
182
+ from davis2017.davis import DAVIS
183
+ from davis2017.results import Results
184
+
185
+ dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics')
186
+ results = Results(root_dir='examples/osvos')
187
+ # Test timing F measure
188
+ for seq in dataset.get_sequences():
189
+ all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True)
190
+ all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1]
191
+ all_res_masks = results.read_masks(seq, all_masks_id)
192
+ f_metrics_res = np.zeros(all_gt_masks.shape[:2])
193
+ for ii in range(all_gt_masks.shape[0]):
194
+ f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...])
195
+
196
+ # Run using to profile code: python -m cProfile -o f_measure.prof metrics.py
197
+ # snakeviz f_measure.prof
ext/davis2017/results.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image, ImagePalette
4
+ import sys
5
+
6
+
7
+ davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
8
+ mose_palette = b'\x00\x00\x00\xe4\x1a\x1c7~\xb8M\xafJ\x98N\xa3\xff\x7f\x00\xff\xff3\xa6V(\xf7\x81\xbf\x99\x99\x99f\xc2\xa5\xfc\x8db\x8d\xa0\xcb\xe7\x8a\xc3\xa6\xd8T\xff\xd9/\xe5\xc4\x94\xb3\xb3\xb3\x8d\xd3\xc7\xff\xff\xb3\xbe\xba\xda\xfb\x80r\x80\xb1\xd3\xfd\xb4b\xb3\xdei\xfc\xcd\xe5\xd9\xd9\xd9\xbc\x80\xbd\xcc\xeb\xc5\xff\xedo'
9
+
10
+ class Results(object):
11
+ def __init__(self, root_dir):
12
+ self.root_dir = root_dir
13
+
14
+ def _read_mask(self, sequence, frame_id):
15
+ try:
16
+ mask_path = os.path.join(self.root_dir, sequence, f'{frame_id}.png')
17
+ # BUGFIX
18
+ # There is a bug in the codebase
19
+ # Here is a compensation.
20
+ img = Image.open(mask_path)
21
+ if img.mode != 'P':
22
+ img_color = np.array(img)
23
+ h, w, three = img_color.shape
24
+ assert three == 3
25
+
26
+ img_new = np.ones((h, w), dtype=np.uint8) * 255
27
+ color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
28
+ for i in range(10):
29
+ cur_color = color_map_np[i]
30
+ mask = np.all(img_color == cur_color, axis=-1)
31
+ img_new[mask] = i
32
+ assert not np.all(img_new == 255).any()
33
+ img = img_new
34
+ # BUGFIX
35
+ return np.array(img)
36
+ except IOError as err:
37
+ sys.stdout.write(sequence + " frame %s not found!\n" % frame_id)
38
+ sys.stdout.write("The frames have to be indexed PNG files placed inside the corespondent sequence "
39
+ "folder.\nThe indexes have to match with the initial frame.\n")
40
+ sys.stderr.write("IOError: " + err.strerror + "\n")
41
+ sys.exit()
42
+
43
+ def read_masks(self, sequence, masks_id):
44
+ mask_0 = self._read_mask(sequence, masks_id[0])
45
+ masks = np.zeros((len(masks_id), *mask_0.shape))
46
+ for ii, m in enumerate(masks_id):
47
+ masks[ii, ...] = self._read_mask(sequence, m)
48
+ num_objects = int(np.max(masks))
49
+ tmp = np.ones((num_objects, *masks.shape))
50
+ tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None]
51
+ masks = (tmp == masks[None, ...]) > 0
52
+ return masks
ext/davis2017/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import errno
3
+ import numpy as np
4
+ from PIL import Image
5
+ import warnings
6
+ from ext.davis2017.davis import DAVIS
7
+
8
+
9
+ def _pascal_color_map(N=256, normalized=False):
10
+ """
11
+ Python implementation of the color map function for the PASCAL VOC data set.
12
+ Official Matlab version can be found in the PASCAL VOC devkit
13
+ http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
14
+ """
15
+
16
+ def bitget(byteval, idx):
17
+ return (byteval & (1 << idx)) != 0
18
+
19
+ dtype = 'float32' if normalized else 'uint8'
20
+ cmap = np.zeros((N, 3), dtype=dtype)
21
+ for i in range(N):
22
+ r = g = b = 0
23
+ c = i
24
+ for j in range(8):
25
+ r = r | (bitget(c, 0) << 7 - j)
26
+ g = g | (bitget(c, 1) << 7 - j)
27
+ b = b | (bitget(c, 2) << 7 - j)
28
+ c = c >> 3
29
+
30
+ cmap[i] = np.array([r, g, b])
31
+
32
+ cmap = cmap / 255 if normalized else cmap
33
+ return cmap
34
+
35
+
36
+ def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None):
37
+ im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int)
38
+ if im.shape[:-1] != ann.shape:
39
+ raise ValueError('First two dimensions of `im` and `ann` must match')
40
+ if im.shape[-1] != 3:
41
+ raise ValueError('im must have three channels at the 3 dimension')
42
+
43
+ colors = colors or _pascal_color_map()
44
+ colors = np.asarray(colors, dtype=np.uint8)
45
+
46
+ mask = colors[ann]
47
+ fg = im * alpha + (1 - alpha) * mask
48
+
49
+ img = im.copy()
50
+ img[ann > 0] = fg[ann > 0]
51
+
52
+ if contour_thickness: # pragma: no cover
53
+ import cv2
54
+ for obj_id in np.unique(ann[ann > 0]):
55
+ contours = cv2.findContours((ann == obj_id).astype(
56
+ np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
57
+ cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(),
58
+ contour_thickness)
59
+ return img
60
+
61
+
62
+ def generate_obj_proposals(davis_root, subset, num_proposals, save_path):
63
+ dataset = DAVIS(davis_root, subset=subset, codalab=True)
64
+ for seq in dataset.get_sequences():
65
+ save_dir = os.path.join(save_path, seq)
66
+ if os.path.exists(save_dir):
67
+ continue
68
+ all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True)
69
+ img_size = all_gt_masks.shape[2:]
70
+ num_rows = int(np.ceil(np.sqrt(num_proposals)))
71
+ proposals = np.zeros((num_proposals, len(all_masks_id), *img_size))
72
+ height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist()
73
+ width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist()
74
+ ii = 0
75
+ prev_h, prev_w = 0, 0
76
+ for h in height_slices[1:]:
77
+ for w in width_slices[1:]:
78
+ proposals[ii, :, prev_h:h, prev_w:w] = 1
79
+ prev_w = w
80
+ ii += 1
81
+ if ii == num_proposals:
82
+ break
83
+ prev_h, prev_w = h, 0
84
+ if ii == num_proposals:
85
+ break
86
+
87
+ os.makedirs(save_dir, exist_ok=True)
88
+ for i, mask_id in enumerate(all_masks_id):
89
+ mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0)
90
+ save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))
91
+
92
+
93
+ def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path):
94
+ dataset = DAVIS(davis_root, subset=subset, codalab=True)
95
+ for seq in dataset.get_sequences():
96
+ gt_masks, all_masks_id = dataset.get_all_masks(seq, True)
97
+ obj_swap = np.random.permutation(np.arange(gt_masks.shape[0]))
98
+ gt_masks = gt_masks[obj_swap, ...]
99
+ save_dir = os.path.join(save_path, seq)
100
+ os.makedirs(save_dir, exist_ok=True)
101
+ for i, mask_id in enumerate(all_masks_id):
102
+ mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0)
103
+ save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))
104
+
105
+
106
+ def color_map(N=256, normalized=False):
107
+ def bitget(byteval, idx):
108
+ return ((byteval & (1 << idx)) != 0)
109
+
110
+ dtype = 'float32' if normalized else 'uint8'
111
+ cmap = np.zeros((N, 3), dtype=dtype)
112
+ for i in range(N):
113
+ r = g = b = 0
114
+ c = i
115
+ for j in range(8):
116
+ r = r | (bitget(c, 0) << 7-j)
117
+ g = g | (bitget(c, 1) << 7-j)
118
+ b = b | (bitget(c, 2) << 7-j)
119
+ c = c >> 3
120
+
121
+ cmap[i] = np.array([r, g, b])
122
+
123
+ cmap = cmap/255 if normalized else cmap
124
+ return cmap
125
+
126
+
127
+ def save_mask(mask, img_path):
128
+ if np.max(mask) > 255:
129
+ raise ValueError('Maximum id pixel value is 255')
130
+ mask_img = Image.fromarray(mask.astype(np.uint8))
131
+ mask_img.putpalette(color_map().flatten().tolist())
132
+ mask_img.save(img_path)
133
+
134
+
135
+ def db_statistics(per_frame_values):
136
+ """ Compute mean,recall and decay from per-frame evaluation.
137
+ Arguments:
138
+ per_frame_values (ndarray): per-frame evaluation
139
+
140
+ Returns:
141
+ M,O,D (float,float,float):
142
+ return evaluation statistics: mean,recall,decay.
143
+ """
144
+
145
+ # strip off nan values
146
+ with warnings.catch_warnings():
147
+ warnings.simplefilter("ignore", category=RuntimeWarning)
148
+ M = np.nanmean(per_frame_values)
149
+ O = np.nanmean(per_frame_values > 0.5)
150
+
151
+ N_bins = 4
152
+ ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1
153
+ ids = ids.astype(np.uint8)
154
+
155
+ D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)]
156
+
157
+ with warnings.catch_warnings():
158
+ warnings.simplefilter("ignore", category=RuntimeWarning)
159
+ D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3])
160
+
161
+ return M, O, D
162
+
163
+
164
+ def list_files(dir, extension=".png"):
165
+ return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)]
166
+
167
+
168
+ def force_symlink(file1, file2):
169
+ try:
170
+ os.symlink(file1, file2)
171
+ except OSError as e:
172
+ if e.errno == errno.EEXIST:
173
+ os.remove(file2)
174
+ os.symlink(file1, file2)
ext/meta/sam_meta.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ meta_dict = {
2
+ 'vit_h': dict(
3
+ encoder_embed_dim=1280,
4
+ encoder_depth=32,
5
+ encoder_num_heads=16,
6
+ encoder_global_attn_indexes=[7, 15, 23, 31],
7
+ # common
8
+ prompt_embed_dim=256,
9
+ image_size=1024,
10
+ vit_patch_size=16,
11
+ image_embedding_size=64
12
+ ),
13
+ 'vit_l': dict(
14
+ encoder_embed_dim=1024,
15
+ encoder_depth=24,
16
+ encoder_num_heads=16,
17
+ encoder_global_attn_indexes=[5, 11, 17, 23],
18
+ # common
19
+ prompt_embed_dim=256,
20
+ image_size=1024,
21
+ vit_patch_size=16,
22
+ image_embedding_size=64
23
+ ),
24
+ 'vit_b': dict(
25
+ encoder_embed_dim=768,
26
+ encoder_depth=12,
27
+ encoder_num_heads=12,
28
+ encoder_global_attn_indexes=[2, 5, 8, 11],
29
+ # common
30
+ prompt_embed_dim=256,
31
+ image_size=1024,
32
+ vit_patch_size=16,
33
+ image_embedding_size=64
34
+ )
35
+ }
36
+
37
+ checkpoint_dict = {
38
+ 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
39
+ 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
40
+ 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
41
+ }
ext/open_clip/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .coca_model import CoCa
2
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype
8
+ from .openai import load_openai_model, list_openai_models
9
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
10
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
11
+ from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
12
+ from .tokenizer import SimpleTokenizer, tokenize, decode
13
+ from .transform import image_transform, AugmentationCfg
14
+ from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
15
+ from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
ext/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
ext/open_clip/coca_model.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StoppingCriteriaList
27
+ )
28
+
29
+ GENERATION_TYPES = {
30
+ "top_k": TopKLogitsWarper,
31
+ "top_p": TopPLogitsWarper,
32
+ "beam_search": "beam_search"
33
+ }
34
+ _has_transformers = True
35
+ except ImportError as e:
36
+ GENERATION_TYPES = {
37
+ "top_k": None,
38
+ "top_p": None,
39
+ "beam_search": "beam_search"
40
+ }
41
+ _has_transformers = False
42
+
43
+
44
+ @dataclass
45
+ class MultimodalCfg(CLIPTextCfg):
46
+ mlp_ratio: int = 4
47
+ dim_head: int = 64
48
+ heads: int = 8
49
+ n_queries: int = 256
50
+ attn_pooler_heads: int = 8
51
+
52
+
53
+ def _build_text_decoder_tower(
54
+ embed_dim,
55
+ multimodal_cfg,
56
+ quick_gelu: bool = False,
57
+ cast_dtype: Optional[torch.dtype] = None,
58
+ ):
59
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
60
+ act_layer = QuickGELU if quick_gelu else nn.GELU
61
+ norm_layer = (
62
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
63
+ )
64
+
65
+ decoder = MultimodalTransformer(
66
+ context_length=multimodal_cfg.context_length,
67
+ width=multimodal_cfg.width,
68
+ heads=multimodal_cfg.heads,
69
+ layers=multimodal_cfg.layers,
70
+ ls_init_value=multimodal_cfg.ls_init_value,
71
+ output_dim=embed_dim,
72
+ act_layer=act_layer,
73
+ norm_layer=norm_layer,
74
+ )
75
+
76
+ return decoder
77
+
78
+
79
+ class CoCa(nn.Module):
80
+ def __init__(
81
+ self,
82
+ embed_dim,
83
+ multimodal_cfg: MultimodalCfg,
84
+ text_cfg: CLIPTextCfg,
85
+ vision_cfg: CLIPVisionCfg,
86
+ quick_gelu: bool = False,
87
+ cast_dtype: Optional[torch.dtype] = None,
88
+ pad_id: int = 0,
89
+ ):
90
+ super().__init__()
91
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
92
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
94
+
95
+ self.text = _build_text_tower(
96
+ embed_dim=embed_dim,
97
+ text_cfg=text_cfg,
98
+ quick_gelu=quick_gelu,
99
+ cast_dtype=cast_dtype,
100
+ )
101
+
102
+ vocab_size = (
103
+ text_cfg.vocab_size # for hf models
104
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
105
+ else text_cfg.vocab_size
106
+ )
107
+
108
+ self.visual = _build_vision_tower(
109
+ embed_dim=embed_dim,
110
+ vision_cfg=vision_cfg,
111
+ quick_gelu=quick_gelu,
112
+ cast_dtype=cast_dtype,
113
+ )
114
+
115
+ self.text_decoder = _build_text_decoder_tower(
116
+ vocab_size,
117
+ multimodal_cfg=multimodal_cfg,
118
+ quick_gelu=quick_gelu,
119
+ cast_dtype=cast_dtype,
120
+ )
121
+
122
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
123
+ self.pad_id = pad_id
124
+
125
+ @torch.jit.ignore
126
+ def set_grad_checkpointing(self, enable=True):
127
+ self.visual.set_grad_checkpointing(enable)
128
+ self.text.set_grad_checkpointing(enable)
129
+ self.text_decoder.set_grad_checkpointing(enable)
130
+
131
+ def _encode_image(self, images, normalize=True):
132
+ image_latent, tokens_embs = self.visual(images)
133
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
134
+ return image_latent, tokens_embs
135
+
136
+ def _encode_text(self, text, normalize=True, embed_cls=True):
137
+ text = text[:, :-1] if embed_cls else text # make space for CLS token
138
+ text_latent, token_emb = self.text(text)
139
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
140
+ return text_latent, token_emb
141
+
142
+ def encode_image(self, images, normalize=True):
143
+ image_latent, _ = self._encode_image(images, normalize=normalize)
144
+ return image_latent
145
+
146
+ def encode_text(self, text, normalize=True, embed_cls=True):
147
+ text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
148
+ return text_latent
149
+
150
+ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
151
+ text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
152
+ if image_latent is None or image_embs is None:
153
+ image_latent, image_embs = self._encode_image(image)
154
+
155
+ # TODO: add assertion to avoid bugs?
156
+ labels = text[:, -token_embs.shape[1]:]
157
+
158
+ logits = self.text_decoder(image_embs, token_embs)
159
+ return {
160
+ "image_features": image_latent,
161
+ "text_features": text_latent,
162
+ "logits": logits,
163
+ "labels": labels,
164
+ "logit_scale": self.logit_scale.exp()
165
+ }
166
+
167
+ def generate(
168
+ self,
169
+ image,
170
+ text=None,
171
+ seq_len=30,
172
+ max_seq_len=77,
173
+ temperature=1.,
174
+ generation_type="beam_search",
175
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
176
+ top_k=1, # keeps the top_k most probable tokens
177
+ pad_token_id=None,
178
+ eos_token_id=None,
179
+ sot_token_id=None,
180
+ num_beams=6,
181
+ num_beam_groups=3,
182
+ min_seq_len=5,
183
+ stopping_criteria=None,
184
+ repetition_penalty=1.0,
185
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
186
+ ):
187
+ # taking many ideas and components from HuggingFace GenerationMixin
188
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
189
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
190
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
191
+
192
+ with torch.no_grad():
193
+ sot_token_id = 49406 if sot_token_id is None else sot_token_id
194
+ eos_token_id = 49407 if eos_token_id is None else eos_token_id
195
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
196
+ logit_processor = LogitsProcessorList(
197
+ [
198
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
199
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
200
+ ]
201
+ )
202
+
203
+ if stopping_criteria is None:
204
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
205
+
206
+ stopping_criteria = StoppingCriteriaList(
207
+ stopping_criteria
208
+ )
209
+
210
+ device = image.device
211
+
212
+ if generation_type == "beam_search":
213
+ output = self._generate_beamsearch(
214
+ image_inputs = image,
215
+ pad_token_id=pad_token_id,
216
+ eos_token_id=eos_token_id,
217
+ sot_token_id=sot_token_id,
218
+ num_beams=num_beams,
219
+ num_beam_groups=num_beam_groups,
220
+ min_seq_len=min_seq_len,
221
+ stopping_criteria=stopping_criteria,
222
+ logit_processor=logit_processor,
223
+ )
224
+ if fixed_output_length and output.shape[1] < seq_len:
225
+ return torch.cat(
226
+ (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
227
+ dim=1
228
+ )
229
+ return output
230
+
231
+ elif generation_type == "top_p":
232
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
233
+ elif generation_type == "top_k":
234
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
235
+ else:
236
+ raise ValueError(
237
+ f"generation_type has to be one of "
238
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
239
+ )
240
+
241
+ image_latent, image_embs = self._encode_image(image)
242
+
243
+ if text is None:
244
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
245
+
246
+ was_training = self.training
247
+ num_dims = len(text.shape)
248
+
249
+ if num_dims == 1:
250
+ text = text[None, :]
251
+
252
+ cur_len = text.shape[1]
253
+ self.eval()
254
+ out = text
255
+
256
+ while True:
257
+ x = out[:, -max_seq_len:]
258
+ cur_len = x.shape[1]
259
+ logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
260
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
261
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
262
+
263
+ if mask.all():
264
+ if not fixed_output_length:
265
+ break
266
+ else:
267
+ logits = logits[~mask, :]
268
+ filtered_logits = logit_processor(x[~mask, :], logits)
269
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
270
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
271
+
272
+ if (cur_len + 1 == seq_len):
273
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
274
+ else:
275
+ sample[~mask, :] = torch.multinomial(probs, 1)
276
+
277
+ out = torch.cat((out, sample), dim=-1)
278
+
279
+ cur_len += 1
280
+
281
+ if stopping_criteria(out, None):
282
+ break
283
+
284
+ if num_dims == 1:
285
+ out = out.squeeze(0)
286
+
287
+ self.train(was_training)
288
+ return out
289
+
290
+ def _generate_beamsearch(
291
+ self,
292
+ image_inputs,
293
+ pad_token_id=None,
294
+ eos_token_id=None,
295
+ sot_token_id=None,
296
+ num_beams=6,
297
+ num_beam_groups=3,
298
+ min_seq_len=5,
299
+ stopping_criteria=None,
300
+ logit_processor=None,
301
+ logit_warper=None,
302
+ ):
303
+ device = image_inputs.device
304
+ batch_size = image_inputs.shape[0]
305
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
306
+ image_latent, image_embs = self._encode_image(image_inputs)
307
+
308
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
309
+ input_ids = input_ids * sot_token_id
310
+ beam_scorer = BeamSearchScorer(
311
+ batch_size=batch_size,
312
+ num_beams=num_beams,
313
+ device=device,
314
+ num_beam_groups=num_beam_groups,
315
+ )
316
+ # instantiate logits processors
317
+ logits_processor = (
318
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
319
+ if logit_processor is None
320
+ else logit_processor
321
+ )
322
+
323
+ batch_size = len(beam_scorer._beam_hyps)
324
+ num_beams = beam_scorer.num_beams
325
+ num_beam_groups = beam_scorer.num_beam_groups
326
+ num_sub_beams = num_beams // num_beam_groups
327
+ batch_beam_size, cur_len = input_ids.shape
328
+ beam_indices = None
329
+
330
+ if num_beams * batch_size != batch_beam_size:
331
+ raise ValueError(
332
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
333
+ )
334
+
335
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
336
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
337
+ # the same group don't produce same tokens everytime.
338
+ beam_scores[:, ::num_sub_beams] = 0
339
+ beam_scores = beam_scores.view((batch_size * num_beams,))
340
+
341
+ while True:
342
+
343
+ # predicted tokens in cur_len step
344
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
345
+
346
+ # indices which will form the beams in the next time step
347
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
348
+
349
+ # do one decoder step on all beams of all sentences in batch
350
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
351
+ outputs = self(
352
+ model_inputs['images'],
353
+ model_inputs['text'],
354
+ embed_cls=False,
355
+ image_latent=image_latent,
356
+ image_embs=image_embs
357
+ )
358
+
359
+ for beam_group_idx in range(num_beam_groups):
360
+ group_start_idx = beam_group_idx * num_sub_beams
361
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
362
+ group_size = group_end_idx - group_start_idx
363
+
364
+ # indices of beams of current group among all sentences in batch
365
+ batch_group_indices = []
366
+
367
+ for batch_idx in range(batch_size):
368
+ batch_group_indices.extend(
369
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
370
+ )
371
+ group_input_ids = input_ids[batch_group_indices]
372
+
373
+ # select outputs of beams of currentg group only
374
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
375
+ vocab_size = next_token_logits.shape[-1]
376
+
377
+ next_token_scores_processed = logits_processor(
378
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
379
+ )
380
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
381
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
382
+
383
+ # reshape for beam search
384
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
385
+
386
+ next_token_scores, next_tokens = torch.topk(
387
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
388
+ )
389
+
390
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
391
+ next_tokens = next_tokens % vocab_size
392
+
393
+ # stateless
394
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
395
+ beam_outputs = beam_scorer.process(
396
+ group_input_ids,
397
+ next_token_scores,
398
+ next_tokens,
399
+ next_indices,
400
+ pad_token_id=pad_token_id,
401
+ eos_token_id=eos_token_id,
402
+ beam_indices=process_beam_indices,
403
+ )
404
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
405
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
406
+ beam_idx = beam_outputs["next_beam_indices"]
407
+
408
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
409
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
410
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
411
+
412
+ # (beam_idx // group_size) -> batch_idx
413
+ # (beam_idx % group_size) -> offset of idx inside the group
414
+ reordering_indices[batch_group_indices] = (
415
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
416
+ )
417
+
418
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
419
+
420
+ # increase cur_len
421
+ cur_len = cur_len + 1
422
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
423
+ break
424
+
425
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
426
+ sequence_outputs = beam_scorer.finalize(
427
+ input_ids,
428
+ beam_scores,
429
+ next_tokens,
430
+ next_indices,
431
+ pad_token_id=pad_token_id,
432
+ eos_token_id=eos_token_id,
433
+ max_length=stopping_criteria.max_length,
434
+ beam_indices=final_beam_indices,
435
+ )
436
+ return sequence_outputs['sequences']
437
+
438
+
439
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
440
+ if past:
441
+ input_ids = input_ids[:, -1].unsqueeze(-1)
442
+
443
+ attention_mask = kwargs.get("attention_mask", None)
444
+ position_ids = kwargs.get("position_ids", None)
445
+
446
+ if attention_mask is not None and position_ids is None:
447
+ # create position_ids on the fly for batch generation
448
+ position_ids = attention_mask.long().cumsum(-1) - 1
449
+ position_ids.masked_fill_(attention_mask == 0, 1)
450
+ else:
451
+ position_ids = None
452
+ return {
453
+ "text": input_ids,
454
+ "images": image_inputs,
455
+ "past_key_values": past,
456
+ "position_ids": position_ids,
457
+ "attention_mask": attention_mask,
458
+ }
ext/open_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
ext/open_clip/factory.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
14
+ resize_pos_embed, get_cast_dtype
15
+ from .coca_model import CoCa
16
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
17
+ from .openai import load_openai_model
18
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
19
+ list_pretrained_tags_by_model, download_pretrained_from_hf
20
+ from .transform import image_transform, AugmentationCfg
21
+ from .tokenizer import HFTokenizer, tokenize
22
+
23
+
24
+ HF_HUB_PREFIX = 'hf-hub:'
25
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
26
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
27
+
28
+
29
+ def _natural_key(string_):
30
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
31
+
32
+
33
+ def _rescan_model_configs():
34
+ global _MODEL_CONFIGS
35
+
36
+ config_ext = ('.json',)
37
+ config_files = []
38
+ for config_path in _MODEL_CONFIG_PATHS:
39
+ if config_path.is_file() and config_path.suffix in config_ext:
40
+ config_files.append(config_path)
41
+ elif config_path.is_dir():
42
+ for ext in config_ext:
43
+ config_files.extend(config_path.glob(f'*{ext}'))
44
+
45
+ for cf in config_files:
46
+ with open(cf, 'r') as f:
47
+ model_cfg = json.load(f)
48
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
49
+ _MODEL_CONFIGS[cf.stem] = model_cfg
50
+
51
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
52
+
53
+
54
+ _rescan_model_configs() # initial populate of model config registry
55
+
56
+
57
+ def list_models():
58
+ """ enumerate available model architectures based on config files """
59
+ return list(_MODEL_CONFIGS.keys())
60
+
61
+
62
+ def add_model_config(path):
63
+ """ add model config path or file and update registry """
64
+ if not isinstance(path, Path):
65
+ path = Path(path)
66
+ _MODEL_CONFIG_PATHS.append(path)
67
+ _rescan_model_configs()
68
+
69
+
70
+ def get_model_config(model_name):
71
+ if model_name in _MODEL_CONFIGS:
72
+ return deepcopy(_MODEL_CONFIGS[model_name])
73
+ else:
74
+ return None
75
+
76
+
77
+ def get_tokenizer(model_name):
78
+ if model_name.startswith(HF_HUB_PREFIX):
79
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
80
+ else:
81
+ config = get_model_config(model_name)
82
+ tokenizer = HFTokenizer(
83
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
84
+ return tokenizer
85
+
86
+
87
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
88
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
89
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
90
+ state_dict = checkpoint['state_dict']
91
+ else:
92
+ state_dict = checkpoint
93
+ if next(iter(state_dict.items()))[0].startswith('module'):
94
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
95
+ return state_dict
96
+
97
+
98
+ def load_checkpoint(model, checkpoint_path, strict=True):
99
+ state_dict = load_state_dict(checkpoint_path)
100
+ # detect old format and make compatible with new format
101
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
102
+ state_dict = convert_to_custom_text_state_dict(state_dict)
103
+ resize_pos_embed(state_dict, model)
104
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
105
+ return incompatible_keys
106
+
107
+
108
+ def create_model(
109
+ model_name: str,
110
+ pretrained: Optional[str] = None,
111
+ precision: str = 'fp32',
112
+ device: Union[str, torch.device] = 'cpu',
113
+ jit: bool = False,
114
+ force_quick_gelu: bool = False,
115
+ force_custom_text: bool = False,
116
+ force_patch_dropout: Optional[float] = None,
117
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
118
+ pretrained_image: bool = False,
119
+ pretrained_hf: bool = True,
120
+ cache_dir: Optional[str] = None,
121
+ output_dict: Optional[bool] = None,
122
+ require_pretrained: bool = False,
123
+ logger: logging.Logger = logging,
124
+ ):
125
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
126
+ if has_hf_hub_prefix:
127
+ model_id = model_name[len(HF_HUB_PREFIX):]
128
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
129
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
130
+
131
+ with open(config_path, 'r', encoding='utf-8') as f:
132
+ config = json.load(f)
133
+ pretrained_cfg = config['preprocess_cfg']
134
+ model_cfg = config['model_cfg']
135
+ else:
136
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
137
+ checkpoint_path = None
138
+ pretrained_cfg = {}
139
+ model_cfg = None
140
+
141
+ if isinstance(device, str):
142
+ device = torch.device(device)
143
+
144
+ if pretrained and pretrained.lower() == 'openai':
145
+ logger.info(f'Loading pretrained {model_name} from OpenAI.')
146
+ model = load_openai_model(
147
+ model_name,
148
+ precision=precision,
149
+ device=device,
150
+ cache_dir=cache_dir,
151
+ )
152
+ else:
153
+ model_cfg = model_cfg or get_model_config(model_name)
154
+ if model_cfg is not None:
155
+ logger.info(f'Loaded {model_name} model config.')
156
+ else:
157
+ logger.error(f'Model config for {model_name} not found; available models {list_models()}.')
158
+ raise RuntimeError(f'Model config for {model_name} not found.')
159
+
160
+ if force_quick_gelu:
161
+ # override for use of QuickGELU on non-OpenAI transformer models
162
+ model_cfg["quick_gelu"] = True
163
+
164
+ if force_patch_dropout is not None:
165
+ # override the default patch dropout value
166
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
167
+
168
+ if force_image_size is not None:
169
+ # override model config's image size
170
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
171
+
172
+ is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
173
+ if pretrained_image:
174
+ if is_timm_model:
175
+ # pretrained weight loading for timm models set via vision_cfg
176
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
177
+ else:
178
+ assert False, 'pretrained image towers currently only supported for timm models'
179
+
180
+ # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
181
+ cast_dtype = get_cast_dtype(precision)
182
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
183
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
184
+
185
+ if custom_text:
186
+ if is_hf_model:
187
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
188
+ if "coca" in model_name:
189
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
190
+ else:
191
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
192
+ else:
193
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
194
+
195
+ if precision in ("fp16", "bf16"):
196
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
197
+ # manual mixed precision that matches original OpenAI behaviour
198
+ if is_timm_model:
199
+ # FIXME this is a bit janky, create timm based model in low-precision and
200
+ # then cast only LayerNormFp32 instances back to float32 so they don't break.
201
+ # Why? The convert_weights_to_lp fn only works with native models.
202
+ model.to(device=device, dtype=dtype)
203
+ from .transformer import LayerNormFp32
204
+ def _convert_ln(m):
205
+ if isinstance(m, LayerNormFp32):
206
+ m.weight.data = m.weight.data.to(torch.float32)
207
+ m.bias.data = m.bias.data.to(torch.float32)
208
+ model.apply(_convert_ln)
209
+ else:
210
+ model.to(device=device)
211
+ convert_weights_to_lp(model, dtype=dtype)
212
+ elif precision in ("pure_fp16", "pure_bf16"):
213
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
214
+ model.to(device=device, dtype=dtype)
215
+ else:
216
+ model.to(device=device)
217
+
218
+ pretrained_loaded = False
219
+ if pretrained:
220
+ checkpoint_path = ''
221
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
222
+ if pretrained_cfg:
223
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
224
+ elif os.path.exists(pretrained):
225
+ checkpoint_path = pretrained
226
+
227
+ if checkpoint_path:
228
+ logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
229
+ load_checkpoint(model, checkpoint_path)
230
+ else:
231
+ error_str = (
232
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
233
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
234
+ logger.warning(error_str)
235
+ raise RuntimeError(error_str)
236
+ pretrained_loaded = True
237
+ elif has_hf_hub_prefix:
238
+ logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
239
+ load_checkpoint(model, checkpoint_path)
240
+ pretrained_loaded = True
241
+
242
+ if require_pretrained and not pretrained_loaded:
243
+ # callers of create_model_from_pretrained always expect pretrained weights
244
+ raise RuntimeError(
245
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
246
+
247
+ # set image / mean metadata from pretrained_cfg if available, or use default
248
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
249
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
250
+
251
+ if output_dict and hasattr(model, "output_dict"):
252
+ model.output_dict = True
253
+
254
+ if jit:
255
+ model = torch.jit.script(model)
256
+
257
+ return model
258
+
259
+
260
+ def create_loss(args):
261
+ if args.distill:
262
+ return DistillClipLoss(
263
+ local_loss=args.local_loss,
264
+ gather_with_grad=args.gather_with_grad,
265
+ cache_labels=True,
266
+ rank=args.rank,
267
+ world_size=args.world_size,
268
+ use_horovod=args.horovod,
269
+ )
270
+ elif "coca" in args.model.lower():
271
+ return CoCaLoss(
272
+ caption_loss_weight=args.coca_caption_loss_weight,
273
+ clip_loss_weight=args.coca_contrastive_loss_weight,
274
+ local_loss=args.local_loss,
275
+ gather_with_grad=args.gather_with_grad,
276
+ cache_labels=True,
277
+ rank=args.rank,
278
+ world_size=args.world_size,
279
+ use_horovod=args.horovod,
280
+ )
281
+ return ClipLoss(
282
+ local_loss=args.local_loss,
283
+ gather_with_grad=args.gather_with_grad,
284
+ cache_labels=True,
285
+ rank=args.rank,
286
+ world_size=args.world_size,
287
+ use_horovod=args.horovod,
288
+ )
289
+
290
+
291
+ def create_model_and_transforms(
292
+ model_name: str,
293
+ pretrained: Optional[str] = None,
294
+ precision: str = 'fp32',
295
+ device: Union[str, torch.device] = 'cpu',
296
+ jit: bool = False,
297
+ force_quick_gelu: bool = False,
298
+ force_custom_text: bool = False,
299
+ force_patch_dropout: Optional[float] = None,
300
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
301
+ pretrained_image: bool = False,
302
+ pretrained_hf: bool = True,
303
+ image_mean: Optional[Tuple[float, ...]] = None,
304
+ image_std: Optional[Tuple[float, ...]] = None,
305
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
306
+ cache_dir: Optional[str] = None,
307
+ output_dict: Optional[bool] = None,
308
+ logger: logging.Logger = logging,
309
+ ):
310
+ model = create_model(
311
+ model_name,
312
+ pretrained,
313
+ precision=precision,
314
+ device=device,
315
+ jit=jit,
316
+ force_quick_gelu=force_quick_gelu,
317
+ force_custom_text=force_custom_text,
318
+ force_patch_dropout=force_patch_dropout,
319
+ force_image_size=force_image_size,
320
+ pretrained_image=pretrained_image,
321
+ pretrained_hf=pretrained_hf,
322
+ cache_dir=cache_dir,
323
+ output_dict=output_dict,
324
+ logger=logger,
325
+ )
326
+
327
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
328
+ image_std = image_std or getattr(model.visual, 'image_std', None)
329
+ preprocess_train = image_transform(
330
+ model.visual.image_size,
331
+ is_train=True,
332
+ mean=image_mean,
333
+ std=image_std,
334
+ aug_cfg=aug_cfg,
335
+ )
336
+ preprocess_val = image_transform(
337
+ model.visual.image_size,
338
+ is_train=False,
339
+ mean=image_mean,
340
+ std=image_std,
341
+ )
342
+
343
+ return model, preprocess_train, preprocess_val
344
+
345
+
346
+ def create_model_from_pretrained(
347
+ model_name: str,
348
+ pretrained: Optional[str] = None,
349
+ precision: str = 'fp32',
350
+ device: Union[str, torch.device] = 'cpu',
351
+ jit: bool = False,
352
+ force_quick_gelu: bool = False,
353
+ force_custom_text: bool = False,
354
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
355
+ return_transform: bool = True,
356
+ image_mean: Optional[Tuple[float, ...]] = None,
357
+ image_std: Optional[Tuple[float, ...]] = None,
358
+ cache_dir: Optional[str] = None,
359
+ logger: logging.Logger = logging,
360
+ ):
361
+ model = create_model(
362
+ model_name,
363
+ pretrained,
364
+ precision=precision,
365
+ device=device,
366
+ jit=jit,
367
+ force_quick_gelu=force_quick_gelu,
368
+ force_custom_text=force_custom_text,
369
+ force_image_size=force_image_size,
370
+ cache_dir=cache_dir,
371
+ require_pretrained=True,
372
+ logger=logger,
373
+ )
374
+
375
+ if not return_transform:
376
+ return model
377
+
378
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
379
+ image_std = image_std or getattr(model.visual, 'image_std', None)
380
+ preprocess = image_transform(
381
+ model.visual.image_size,
382
+ is_train=False,
383
+ mean=image_mean,
384
+ std=image_std,
385
+ )
386
+
387
+ return model, preprocess
ext/open_clip/generation_utils.py ADDED
File without changes
ext/open_clip/hf_configs.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ # https://huggingface.co/docs/transformers/model_doc/bert
46
+ "bert": {
47
+ "config_names": {
48
+ "context_length": "max_position_embeddings",
49
+ "vocab_size": "vocab_size",
50
+ "width": "hidden_size",
51
+ "heads": "num_attention_heads",
52
+ "layers": "num_hidden_layers",
53
+ },
54
+ "pooler": "cls_pooler",
55
+ },
56
+ }
ext/open_clip/hf_model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+ import re
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import TensorType
10
+
11
+ try:
12
+ import transformers
13
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
14
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15
+ BaseModelOutputWithPoolingAndCrossAttentions
16
+ except ImportError as e:
17
+ transformers = None
18
+
19
+
20
+ class BaseModelOutput:
21
+ pass
22
+
23
+
24
+ class PretrainedConfig:
25
+ pass
26
+
27
+ from .hf_configs import arch_dict
28
+
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+
35
+ # TODO: ?last - for gpt-like models
36
+ _POOLERS = {}
37
+
38
+
39
+ def register_pooler(cls):
40
+ """Decorator registering pooler class"""
41
+ _POOLERS[_camel2snake(cls.__name__)] = cls
42
+ return cls
43
+
44
+
45
+ @register_pooler
46
+ class MeanPooler(nn.Module):
47
+ """Mean pooling"""
48
+
49
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
50
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
51
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
52
+
53
+
54
+ @register_pooler
55
+ class MaxPooler(nn.Module):
56
+ """Max pooling"""
57
+
58
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
59
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
60
+ return masked_output.max(1).values
61
+
62
+
63
+ @register_pooler
64
+ class ClsPooler(nn.Module):
65
+ """CLS token pooling"""
66
+
67
+ def __init__(self, use_pooler_output=True):
68
+ super().__init__()
69
+ self.cls_token_position = 0
70
+ self.use_pooler_output = use_pooler_output
71
+
72
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
73
+ if (self.use_pooler_output and
74
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
75
+ (x.pooler_output is not None)
76
+ ):
77
+ return x.pooler_output
78
+
79
+ return x.last_hidden_state[:, self.cls_token_position, :]
80
+
81
+
82
+ @register_pooler
83
+ class ClsLastHiddenStatePooler(nn.Module):
84
+ """CLS token pooling
85
+ NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
86
+ """
87
+
88
+ def __init__(self):
89
+ super().__init__()
90
+ self.cls_token_position = 0
91
+
92
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
93
+ return x.last_hidden_state[:, self.cls_token_position, :]
94
+
95
+
96
+ class HFTextEncoder(nn.Module):
97
+ """HuggingFace model adapter"""
98
+ output_tokens: torch.jit.Final[bool]
99
+
100
+ def __init__(
101
+ self,
102
+ model_name_or_path: str,
103
+ output_dim: int,
104
+ config: PretrainedConfig = None,
105
+ pooler_type: str = None,
106
+ proj: str = None,
107
+ pretrained: bool = True,
108
+ output_tokens: bool = False,
109
+ ):
110
+ super().__init__()
111
+ self.output_tokens = output_tokens
112
+ self.output_dim = output_dim
113
+
114
+ # TODO: find better way to get this information
115
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
116
+
117
+ if transformers is None:
118
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
119
+ if config is None:
120
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
121
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
122
+ AutoModel.from_config, self.config)
123
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
124
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
125
+ self.transformer = create_func(model_args)
126
+ self.transformer = self.transformer.encoder
127
+ else:
128
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
129
+ else:
130
+ self.config = config
131
+ self.transformer = AutoModel.from_config(config)
132
+ if pooler_type is None: # get default arch pooler
133
+ pooler_type = (arch_dict[self.config.model_type]["pooler"])
134
+
135
+ # FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
136
+ self.vocab_size = getattr(self.config, 'vocab_size', 0)
137
+ self.context_length = getattr(self.config, 'max_position_embeddings', 0)
138
+
139
+ self.pooler = _POOLERS[pooler_type]()
140
+
141
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
142
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
143
+ self.proj = nn.Identity()
144
+ elif proj == 'linear':
145
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
146
+ elif proj == 'mlp':
147
+ hidden_size = (d_model + output_dim) // 2
148
+ self.proj = nn.Sequential(
149
+ nn.Linear(d_model, hidden_size, bias=False),
150
+ nn.GELU(),
151
+ nn.Linear(hidden_size, output_dim, bias=False),
152
+ )
153
+
154
+ def forward(self, x: TensorType):
155
+ attn_mask = (x != self.config.pad_token_id).long()
156
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
157
+ pooled_out = self.pooler(out, attn_mask)
158
+ projected = self.proj(pooled_out)
159
+
160
+ seq_len = out.last_hidden_state.shape[1]
161
+ tokens = (
162
+ out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
163
+ if type(self.pooler) == ClsPooler
164
+ else out.last_hidden_state
165
+ )
166
+
167
+ if self.output_tokens:
168
+ return projected, tokens
169
+ return projected
170
+
171
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
172
+ if not unlocked_layers: # full freezing
173
+ for n, p in self.transformer.named_parameters():
174
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
175
+ return
176
+
177
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
178
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
179
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
180
+ embeddings = getattr(
181
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
182
+ modules = [embeddings, *layer_list][:-unlocked_layers]
183
+ # freeze layers
184
+ for module in modules:
185
+ for n, p in module.named_parameters():
186
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
187
+
188
+ @torch.jit.ignore
189
+ def set_grad_checkpointing(self, enable=True):
190
+ self.transformer.gradient_checkpointing_enable()
191
+
192
+ def init_parameters(self):
193
+ pass
ext/open_clip/loss.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ try:
6
+ import torch.distributed.nn
7
+ from torch import distributed as dist
8
+
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+
19
+ def gather_features(
20
+ image_features,
21
+ text_features,
22
+ local_loss=False,
23
+ gather_with_grad=False,
24
+ rank=0,
25
+ world_size=1,
26
+ use_horovod=False
27
+ ):
28
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
29
+ if use_horovod:
30
+ assert hvd is not None, 'Please install horovod'
31
+ if gather_with_grad:
32
+ all_image_features = hvd.allgather(image_features)
33
+ all_text_features = hvd.allgather(text_features)
34
+ else:
35
+ with torch.no_grad():
36
+ all_image_features = hvd.allgather(image_features)
37
+ all_text_features = hvd.allgather(text_features)
38
+ if not local_loss:
39
+ # ensure grads for local rank when all_* features don't have a gradient
40
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
41
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
42
+ gathered_image_features[rank] = image_features
43
+ gathered_text_features[rank] = text_features
44
+ all_image_features = torch.cat(gathered_image_features, dim=0)
45
+ all_text_features = torch.cat(gathered_text_features, dim=0)
46
+ else:
47
+ # We gather tensors from all gpus
48
+ if gather_with_grad:
49
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
50
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
51
+ else:
52
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
53
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
54
+ dist.all_gather(gathered_image_features, image_features)
55
+ dist.all_gather(gathered_text_features, text_features)
56
+ if not local_loss:
57
+ # ensure grads for local rank when all_* features don't have a gradient
58
+ gathered_image_features[rank] = image_features
59
+ gathered_text_features[rank] = text_features
60
+ all_image_features = torch.cat(gathered_image_features, dim=0)
61
+ all_text_features = torch.cat(gathered_text_features, dim=0)
62
+
63
+ return all_image_features, all_text_features
64
+
65
+
66
+ class ClipLoss(nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ local_loss=False,
71
+ gather_with_grad=False,
72
+ cache_labels=False,
73
+ rank=0,
74
+ world_size=1,
75
+ use_horovod=False,
76
+ ):
77
+ super().__init__()
78
+ self.local_loss = local_loss
79
+ self.gather_with_grad = gather_with_grad
80
+ self.cache_labels = cache_labels
81
+ self.rank = rank
82
+ self.world_size = world_size
83
+ self.use_horovod = use_horovod
84
+
85
+ # cache state
86
+ self.prev_num_logits = 0
87
+ self.labels = {}
88
+
89
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
90
+ # calculated ground-truth and cache if enabled
91
+ if self.prev_num_logits != num_logits or device not in self.labels:
92
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
93
+ if self.world_size > 1 and self.local_loss:
94
+ labels = labels + num_logits * self.rank
95
+ if self.cache_labels:
96
+ self.labels[device] = labels
97
+ self.prev_num_logits = num_logits
98
+ else:
99
+ labels = self.labels[device]
100
+ return labels
101
+
102
+ def get_logits(self, image_features, text_features, logit_scale):
103
+ if self.world_size > 1:
104
+ all_image_features, all_text_features = gather_features(
105
+ image_features, text_features,
106
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
107
+
108
+ if self.local_loss:
109
+ logits_per_image = logit_scale * image_features @ all_text_features.T
110
+ logits_per_text = logit_scale * text_features @ all_image_features.T
111
+ else:
112
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
113
+ logits_per_text = logits_per_image.T
114
+ else:
115
+ logits_per_image = logit_scale * image_features @ text_features.T
116
+ logits_per_text = logit_scale * text_features @ image_features.T
117
+
118
+ return logits_per_image, logits_per_text
119
+
120
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
121
+ device = image_features.device
122
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
123
+
124
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
125
+
126
+ total_loss = (
127
+ F.cross_entropy(logits_per_image, labels) +
128
+ F.cross_entropy(logits_per_text, labels)
129
+ ) / 2
130
+
131
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
132
+
133
+
134
+ class CoCaLoss(ClipLoss):
135
+ def __init__(
136
+ self,
137
+ caption_loss_weight,
138
+ clip_loss_weight,
139
+ pad_id=0, # pad_token for open_clip custom tokenizer
140
+ local_loss=False,
141
+ gather_with_grad=False,
142
+ cache_labels=False,
143
+ rank=0,
144
+ world_size=1,
145
+ use_horovod=False,
146
+ ):
147
+ super().__init__(
148
+ local_loss=local_loss,
149
+ gather_with_grad=gather_with_grad,
150
+ cache_labels=cache_labels,
151
+ rank=rank,
152
+ world_size=world_size,
153
+ use_horovod=use_horovod
154
+ )
155
+
156
+ self.clip_loss_weight = clip_loss_weight
157
+ self.caption_loss_weight = caption_loss_weight
158
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
159
+
160
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
161
+
162
+ clip_loss = torch.tensor(0)
163
+
164
+ if self.clip_loss_weight:
165
+ clip_loss = super().forward(image_features, text_features, logit_scale)
166
+ clip_loss = self.clip_loss_weight * clip_loss
167
+
168
+ caption_loss = self.caption_loss(
169
+ logits.permute(0, 2, 1),
170
+ labels,
171
+ )
172
+ caption_loss = caption_loss * self.caption_loss_weight
173
+
174
+ if output_dict:
175
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
176
+
177
+ return clip_loss, caption_loss
178
+
179
+
180
+ class DistillClipLoss(ClipLoss):
181
+
182
+ def dist_loss(self, teacher_logits, student_logits):
183
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
184
+
185
+ def forward(
186
+ self,
187
+ image_features,
188
+ text_features,
189
+ logit_scale,
190
+ dist_image_features,
191
+ dist_text_features,
192
+ dist_logit_scale,
193
+ output_dict=False,
194
+ ):
195
+ logits_per_image, logits_per_text = \
196
+ self.get_logits(image_features, text_features, logit_scale)
197
+
198
+ dist_logits_per_image, dist_logits_per_text = \
199
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
200
+
201
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
202
+
203
+ contrastive_loss = (
204
+ F.cross_entropy(logits_per_image, labels) +
205
+ F.cross_entropy(logits_per_text, labels)
206
+ ) / 2
207
+
208
+ distill_loss = (
209
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
210
+ self.dist_loss(dist_logits_per_text, logits_per_text)
211
+ ) / 2
212
+
213
+ if output_dict:
214
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
215
+
216
+ return contrastive_loss, distill_loss
ext/open_clip/model.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from .hf_model import HFTextEncoder
17
+ from .modified_resnet import ModifiedResNet
18
+ from .timm_model import TimmModel
19
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
20
+ from .utils import to_2tuple
21
+
22
+
23
+ @dataclass
24
+ class CLIPVisionCfg:
25
+ layers: Union[Tuple[int, int, int, int], int] = 12
26
+ width: int = 768
27
+ head_width: int = 64
28
+ mlp_ratio: float = 4.0
29
+ patch_size: int = 16
30
+ image_size: Union[Tuple[int, int], int] = 224
31
+
32
+ ls_init_value: Optional[float] = None # layer scale initial value
33
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
34
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
35
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
36
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
37
+ n_queries: int = 256 # n_queries for attentional pooler
38
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
39
+ output_tokens: bool = False
40
+
41
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
42
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
43
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
44
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
45
+ timm_proj_bias: bool = False # enable bias final projection
46
+ timm_drop: float = 0. # head dropout
47
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
48
+
49
+
50
+ @dataclass
51
+ class CLIPTextCfg:
52
+ context_length: int = 77
53
+ vocab_size: int = 49408
54
+ width: int = 512
55
+ heads: int = 8
56
+ layers: int = 12
57
+ ls_init_value: Optional[float] = None # layer scale initial value
58
+ hf_model_name: str = None
59
+ hf_tokenizer_name: str = None
60
+ hf_model_pretrained: bool = True
61
+ proj: str = 'mlp'
62
+ pooler_type: str = 'mean_pooler'
63
+ embed_cls: bool = False
64
+ pad_id: int = 0
65
+ output_tokens: bool = False
66
+
67
+
68
+ def get_cast_dtype(precision: str):
69
+ cast_dtype = None
70
+ if precision == 'bf16':
71
+ cast_dtype = torch.bfloat16
72
+ elif precision == 'fp16':
73
+ cast_dtype = torch.float16
74
+ return cast_dtype
75
+
76
+
77
+ def get_input_dtype(precision: str):
78
+ input_dtype = None
79
+ if precision in ('bf16', 'pure_bf16'):
80
+ input_dtype = torch.bfloat16
81
+ elif precision in ('fp16', 'pure_fp16'):
82
+ input_dtype = torch.float16
83
+ return input_dtype
84
+
85
+
86
+ def _build_vision_tower(
87
+ embed_dim: int,
88
+ vision_cfg: CLIPVisionCfg,
89
+ quick_gelu: bool = False,
90
+ cast_dtype: Optional[torch.dtype] = None
91
+ ):
92
+ if isinstance(vision_cfg, dict):
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
94
+
95
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
96
+ # memory efficient in recent PyTorch releases (>= 1.10).
97
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
98
+ act_layer = QuickGELU if quick_gelu else nn.GELU
99
+
100
+ if vision_cfg.timm_model_name:
101
+ visual = TimmModel(
102
+ vision_cfg.timm_model_name,
103
+ pretrained=vision_cfg.timm_model_pretrained,
104
+ pool=vision_cfg.timm_pool,
105
+ proj=vision_cfg.timm_proj,
106
+ proj_bias=vision_cfg.timm_proj_bias,
107
+ drop=vision_cfg.timm_drop,
108
+ drop_path=vision_cfg.timm_drop_path,
109
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
110
+ embed_dim=embed_dim,
111
+ image_size=vision_cfg.image_size,
112
+ )
113
+ elif isinstance(vision_cfg.layers, (tuple, list)):
114
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
115
+ visual = ModifiedResNet(
116
+ layers=vision_cfg.layers,
117
+ output_dim=embed_dim,
118
+ heads=vision_heads,
119
+ image_size=vision_cfg.image_size,
120
+ width=vision_cfg.width,
121
+ )
122
+ else:
123
+ vision_heads = vision_cfg.width // vision_cfg.head_width
124
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
125
+ visual = VisionTransformer(
126
+ image_size=vision_cfg.image_size,
127
+ patch_size=vision_cfg.patch_size,
128
+ width=vision_cfg.width,
129
+ layers=vision_cfg.layers,
130
+ heads=vision_heads,
131
+ mlp_ratio=vision_cfg.mlp_ratio,
132
+ ls_init_value=vision_cfg.ls_init_value,
133
+ patch_dropout=vision_cfg.patch_dropout,
134
+ input_patchnorm=vision_cfg.input_patchnorm,
135
+ global_average_pool=vision_cfg.global_average_pool,
136
+ attentional_pool=vision_cfg.attentional_pool,
137
+ n_queries=vision_cfg.n_queries,
138
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
139
+ output_tokens=vision_cfg.output_tokens,
140
+ output_dim=embed_dim,
141
+ act_layer=act_layer,
142
+ norm_layer=norm_layer,
143
+ )
144
+
145
+ return visual
146
+
147
+
148
+ def _build_text_tower(
149
+ embed_dim: int,
150
+ text_cfg: CLIPTextCfg,
151
+ quick_gelu: bool = False,
152
+ cast_dtype: Optional[torch.dtype] = None,
153
+ ):
154
+ if isinstance(text_cfg, dict):
155
+ text_cfg = CLIPTextCfg(**text_cfg)
156
+
157
+ if text_cfg.hf_model_name:
158
+ text = HFTextEncoder(
159
+ text_cfg.hf_model_name,
160
+ output_dim=embed_dim,
161
+ proj=text_cfg.proj,
162
+ pooler_type=text_cfg.pooler_type,
163
+ pretrained=text_cfg.hf_model_pretrained,
164
+ output_tokens=text_cfg.output_tokens,
165
+ )
166
+ else:
167
+ act_layer = QuickGELU if quick_gelu else nn.GELU
168
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
169
+
170
+ text = TextTransformer(
171
+ context_length=text_cfg.context_length,
172
+ vocab_size=text_cfg.vocab_size,
173
+ width=text_cfg.width,
174
+ heads=text_cfg.heads,
175
+ layers=text_cfg.layers,
176
+ ls_init_value=text_cfg.ls_init_value,
177
+ output_dim=embed_dim,
178
+ embed_cls=text_cfg.embed_cls,
179
+ output_tokens=text_cfg.output_tokens,
180
+ pad_id=text_cfg.pad_id,
181
+ act_layer=act_layer,
182
+ norm_layer=norm_layer,
183
+ )
184
+ return text
185
+
186
+
187
+ class CLIP(nn.Module):
188
+ output_dict: torch.jit.Final[bool]
189
+
190
+ def __init__(
191
+ self,
192
+ embed_dim: int,
193
+ vision_cfg: CLIPVisionCfg,
194
+ text_cfg: CLIPTextCfg,
195
+ quick_gelu: bool = False,
196
+ cast_dtype: Optional[torch.dtype] = None,
197
+ output_dict: bool = False,
198
+ ):
199
+ super().__init__()
200
+ self.output_dict = output_dict
201
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
202
+
203
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
204
+ self.transformer = text.transformer
205
+ self.context_length = text.context_length
206
+ self.vocab_size = text.vocab_size
207
+ self.token_embedding = text.token_embedding
208
+ self.positional_embedding = text.positional_embedding
209
+ self.ln_final = text.ln_final
210
+ self.text_projection = text.text_projection
211
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
212
+
213
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
214
+
215
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
216
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
217
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
218
+
219
+ @torch.jit.ignore
220
+ def set_grad_checkpointing(self, enable=True):
221
+ self.visual.set_grad_checkpointing(enable)
222
+ self.transformer.grad_checkpointing = enable
223
+
224
+ def encode_image(self, image, normalize: bool = False):
225
+ features = self.visual(image)
226
+ return F.normalize(features, dim=-1) if normalize else features
227
+
228
+ def encode_text(self, text, normalize: bool = False):
229
+ cast_dtype = self.transformer.get_cast_dtype()
230
+
231
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
232
+
233
+ x = x + self.positional_embedding.to(cast_dtype)
234
+ x = x.permute(1, 0, 2) # NLD -> LND
235
+ x = self.transformer(x, attn_mask=self.attn_mask)
236
+ x = x.permute(1, 0, 2) # LND -> NLD
237
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
238
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
239
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
240
+ return F.normalize(x, dim=-1) if normalize else x
241
+
242
+ def forward(
243
+ self,
244
+ image: Optional[torch.Tensor] = None,
245
+ text: Optional[torch.Tensor] = None,
246
+ ):
247
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
248
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
249
+ if self.output_dict:
250
+ return {
251
+ "image_features": image_features,
252
+ "text_features": text_features,
253
+ "logit_scale": self.logit_scale.exp()
254
+ }
255
+ return image_features, text_features, self.logit_scale.exp()
256
+
257
+
258
+ class CustomTextCLIP(nn.Module):
259
+ output_dict: torch.jit.Final[bool]
260
+
261
+ def __init__(
262
+ self,
263
+ embed_dim: int,
264
+ vision_cfg: CLIPVisionCfg,
265
+ text_cfg: CLIPTextCfg,
266
+ quick_gelu: bool = False,
267
+ cast_dtype: Optional[torch.dtype] = None,
268
+ output_dict: bool = False,
269
+ ):
270
+ super().__init__()
271
+ self.output_dict = output_dict
272
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
273
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
274
+ self.context_length = self.text.context_length
275
+ self.vocab_size = self.text.vocab_size
276
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
277
+
278
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
279
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
280
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
281
+
282
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
283
+ self.text.lock(unlocked_layers, freeze_layer_norm)
284
+
285
+ @torch.jit.ignore
286
+ def set_grad_checkpointing(self, enable=True):
287
+ self.visual.set_grad_checkpointing(enable)
288
+ self.text.set_grad_checkpointing(enable)
289
+
290
+ def encode_image(self, image, normalize: bool = False):
291
+ features = self.visual(image)
292
+ return F.normalize(features, dim=-1) if normalize else features
293
+
294
+ def encode_text(self, text, normalize: bool = False):
295
+ features = self.text(text)
296
+ return F.normalize(features, dim=-1) if normalize else features
297
+
298
+ def forward(
299
+ self,
300
+ image: Optional[torch.Tensor] = None,
301
+ text: Optional[torch.Tensor] = None,
302
+ ):
303
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
304
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
305
+ if self.output_dict:
306
+ return {
307
+ "image_features": image_features,
308
+ "text_features": text_features,
309
+ "logit_scale": self.logit_scale.exp()
310
+ }
311
+ return image_features, text_features, self.logit_scale.exp()
312
+
313
+
314
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
315
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
316
+
317
+ def _convert_weights(l):
318
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
319
+ l.weight.data = l.weight.data.to(dtype)
320
+ if l.bias is not None:
321
+ l.bias.data = l.bias.data.to(dtype)
322
+
323
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
324
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
325
+ tensor = getattr(l, attr)
326
+ if tensor is not None:
327
+ tensor.data = tensor.data.to(dtype)
328
+
329
+ if isinstance(l, (CLIP, TextTransformer)):
330
+ # convert text nn.Parameter projections
331
+ attr = getattr(l, "text_projection", None)
332
+ if attr is not None:
333
+ attr.data = attr.data.to(dtype)
334
+
335
+ if isinstance(l, VisionTransformer):
336
+ # convert vision nn.Parameter projections
337
+ attr = getattr(l, "proj", None)
338
+ if attr is not None:
339
+ attr.data = attr.data.to(dtype)
340
+
341
+ model.apply(_convert_weights)
342
+
343
+
344
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
345
+
346
+
347
+ # used to maintain checkpoint compatibility
348
+ def convert_to_custom_text_state_dict(state_dict: dict):
349
+ if 'text_projection' in state_dict:
350
+ # old format state_dict, move text tower -> .text
351
+ new_state_dict = {}
352
+ for k, v in state_dict.items():
353
+ if any(k.startswith(p) for p in (
354
+ 'text_projection',
355
+ 'positional_embedding',
356
+ 'token_embedding',
357
+ 'transformer',
358
+ 'ln_final',
359
+ )):
360
+ k = 'text.' + k
361
+ new_state_dict[k] = v
362
+ return new_state_dict
363
+ return state_dict
364
+
365
+
366
+ def build_model_from_openai_state_dict(
367
+ state_dict: dict,
368
+ quick_gelu=True,
369
+ cast_dtype=torch.float16,
370
+ ):
371
+ vit = "visual.proj" in state_dict
372
+
373
+ if vit:
374
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
375
+ vision_layers = len(
376
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
377
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
378
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
379
+ image_size = vision_patch_size * grid_size
380
+ else:
381
+ counts: list = [
382
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
383
+ vision_layers = tuple(counts)
384
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
385
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
386
+ vision_patch_size = None
387
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
388
+ image_size = output_width * 32
389
+
390
+ embed_dim = state_dict["text_projection"].shape[1]
391
+ context_length = state_dict["positional_embedding"].shape[0]
392
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
393
+ transformer_width = state_dict["ln_final.weight"].shape[0]
394
+ transformer_heads = transformer_width // 64
395
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
396
+
397
+ vision_cfg = CLIPVisionCfg(
398
+ layers=vision_layers,
399
+ width=vision_width,
400
+ patch_size=vision_patch_size,
401
+ image_size=image_size,
402
+ )
403
+ text_cfg = CLIPTextCfg(
404
+ context_length=context_length,
405
+ vocab_size=vocab_size,
406
+ width=transformer_width,
407
+ heads=transformer_heads,
408
+ layers=transformer_layers,
409
+ )
410
+ model = CLIP(
411
+ embed_dim,
412
+ vision_cfg=vision_cfg,
413
+ text_cfg=text_cfg,
414
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
415
+ cast_dtype=cast_dtype,
416
+ )
417
+
418
+ for key in ["input_resolution", "context_length", "vocab_size"]:
419
+ state_dict.pop(key, None)
420
+
421
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
422
+ model.load_state_dict(state_dict)
423
+ return model.eval()
424
+
425
+
426
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
427
+ model.eval()
428
+ image_size = model.visual.image_size
429
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
430
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
431
+ model = torch.jit.trace_module(
432
+ model,
433
+ inputs=dict(
434
+ forward=(example_images, example_text),
435
+ encode_text=(example_text,),
436
+ encode_image=(example_images,)
437
+ ))
438
+ model.visual.image_size = image_size
439
+ return model
440
+
441
+
442
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
443
+ # Rescale the grid of position embeddings when loading from state_dict
444
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
445
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
446
+ return
447
+ grid_size = to_2tuple(model.visual.grid_size)
448
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
449
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
450
+ if new_seq_len == old_pos_embed.shape[0]:
451
+ return
452
+
453
+ if extra_tokens:
454
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
455
+ else:
456
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
457
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
458
+
459
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
460
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
461
+ pos_emb_img = F.interpolate(
462
+ pos_emb_img,
463
+ size=grid_size,
464
+ mode=interpolation,
465
+ antialias=antialias,
466
+ align_corners=False,
467
+ )
468
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
469
+ if pos_emb_tok is not None:
470
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
471
+ else:
472
+ new_pos_embed = pos_emb_img
473
+ state_dict['visual.positional_embedding'] = new_pos_embed
ext/open_clip/model_configs/EVA01-g-14-plus.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva_giant_patch14_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA01-g-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva_giant_patch14_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-B-16.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_base_patch16_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-E-14-plus.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_enormous_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1280,
14
+ "heads": 20,
15
+ "layers": 32
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-E-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_enormous_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-L-14-336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "timm_model_name": "eva02_large_patch14_clip_336",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-L-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_large_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
ext/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
ext/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50x64.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": [
6
+ 3,
7
+ 15,
8
+ 36,
9
+ 10
10
+ ],
11
+ "width": 128,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 1024,
18
+ "heads": 16,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/ViT-B-16-plus-240.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 240,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
ext/open_clip/model_configs/ViT-B-16-plus.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
ext/open_clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
ext/open_clip/model_configs/ViT-B-32-plus-256.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 256,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }