Spaces:
Sleeping
Sleeping
Haobo Yuan
commited on
Commit
•
b34d1d6
1
Parent(s):
3a3cc44
add omg code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -4
- .gitignore +125 -0
- app/configs/m2_convl.py +152 -0
- assets/000000000139.jpg +3 -0
- assets/000000000285.jpg +3 -0
- assets/000000000632.jpg +3 -0
- assets/000000000724.jpg +3 -0
- ext/cityscapes_scripts/createPanopticImgs.py +194 -0
- ext/cityscapes_scripts/helpers/__init__.py +1 -0
- ext/cityscapes_scripts/helpers/annotation.py +441 -0
- ext/cityscapes_scripts/helpers/csHelpers.py +129 -0
- ext/cityscapes_scripts/helpers/labels.py +182 -0
- ext/cityscapes_scripts/helpers/labels_cityPersons.py +61 -0
- ext/cityscapes_scripts/helpers/version.py +9 -0
- ext/class_names/VIPSeg.py +261 -0
- ext/davis2017/__init__.py +3 -0
- ext/davis2017/davis.py +122 -0
- ext/davis2017/evaluation.py +110 -0
- ext/davis2017/metrics.py +197 -0
- ext/davis2017/results.py +52 -0
- ext/davis2017/utils.py +174 -0
- ext/meta/sam_meta.py +41 -0
- ext/open_clip/__init__.py +15 -0
- ext/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- ext/open_clip/coca_model.py +458 -0
- ext/open_clip/constants.py +2 -0
- ext/open_clip/factory.py +387 -0
- ext/open_clip/generation_utils.py +0 -0
- ext/open_clip/hf_configs.py +56 -0
- ext/open_clip/hf_model.py +193 -0
- ext/open_clip/loss.py +216 -0
- ext/open_clip/model.py +473 -0
- ext/open_clip/model_configs/EVA01-g-14-plus.json +18 -0
- ext/open_clip/model_configs/EVA01-g-14.json +18 -0
- ext/open_clip/model_configs/EVA02-B-16.json +18 -0
- ext/open_clip/model_configs/EVA02-E-14-plus.json +18 -0
- ext/open_clip/model_configs/EVA02-E-14.json +18 -0
- ext/open_clip/model_configs/EVA02-L-14-336.json +18 -0
- ext/open_clip/model_configs/EVA02-L-14.json +18 -0
- ext/open_clip/model_configs/RN101-quickgelu.json +22 -0
- ext/open_clip/model_configs/RN101.json +21 -0
- ext/open_clip/model_configs/RN50-quickgelu.json +22 -0
- ext/open_clip/model_configs/RN50.json +21 -0
- ext/open_clip/model_configs/RN50x16.json +21 -0
- ext/open_clip/model_configs/RN50x4.json +21 -0
- ext/open_clip/model_configs/RN50x64.json +21 -0
- ext/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
- ext/open_clip/model_configs/ViT-B-16-plus.json +16 -0
- ext/open_clip/model_configs/ViT-B-16.json +16 -0
- ext/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
.gitattributes
CHANGED
@@ -17,10 +17,6 @@
|
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
@@ -33,3 +29,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
20 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
21 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
22 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/en/_build/
|
68 |
+
docs/zh_cn/_build/
|
69 |
+
|
70 |
+
# PyBuilder
|
71 |
+
target/
|
72 |
+
|
73 |
+
# Jupyter Notebook
|
74 |
+
.ipynb_checkpoints
|
75 |
+
|
76 |
+
# pyenv
|
77 |
+
.python-version
|
78 |
+
|
79 |
+
# celery beat schedule file
|
80 |
+
celerybeat-schedule
|
81 |
+
|
82 |
+
# SageMath parsed files
|
83 |
+
*.sage.py
|
84 |
+
|
85 |
+
# Environments
|
86 |
+
.env
|
87 |
+
.venv
|
88 |
+
env/
|
89 |
+
venv/
|
90 |
+
ENV/
|
91 |
+
env.bak/
|
92 |
+
venv.bak/
|
93 |
+
|
94 |
+
# Spyder project settings
|
95 |
+
.spyderproject
|
96 |
+
.spyproject
|
97 |
+
|
98 |
+
# Rope project settings
|
99 |
+
.ropeproject
|
100 |
+
|
101 |
+
# mkdocs documentation
|
102 |
+
/site
|
103 |
+
|
104 |
+
# mypy
|
105 |
+
.mypy_cache/
|
106 |
+
data/
|
107 |
+
data
|
108 |
+
.vscode
|
109 |
+
.idea/
|
110 |
+
.DS_Store
|
111 |
+
|
112 |
+
# custom
|
113 |
+
*.pkl
|
114 |
+
*.pkl.json
|
115 |
+
*.log.json
|
116 |
+
docs/modelzoo_statistics.md
|
117 |
+
mmdet/.mim
|
118 |
+
work_dirs/
|
119 |
+
|
120 |
+
# Pytorch
|
121 |
+
*.py~
|
122 |
+
*.sh~
|
123 |
+
|
124 |
+
# remove tmp folder
|
125 |
+
tmp/
|
app/configs/m2_convl.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import GroupNorm, ReLU
|
2 |
+
|
3 |
+
from mmdet.models import MSDeformAttnPixelDecoder, CrossEntropyLoss, DiceLoss, FocalLoss
|
4 |
+
from mmdet.models.task_modules.assigners import HungarianAssigner, ClassificationCost, CrossEntropyLossCost, DiceCost
|
5 |
+
from mmdet.models.task_modules.samplers import MaskPseudoSampler
|
6 |
+
|
7 |
+
from seg.models.detectors import Mask2formerVideo
|
8 |
+
from seg.models.fusion_head import OMGFusionHead
|
9 |
+
from seg.models.heads import Mask2FormerVideoHead
|
10 |
+
from seg.models.backbones import OpenCLIPBackbone
|
11 |
+
|
12 |
+
num_things_classes = 80
|
13 |
+
num_stuff_classes = 53
|
14 |
+
|
15 |
+
ov_model_name = 'convnext_large_d_320'
|
16 |
+
ov_datasets_name = 'CocoPanopticOVDataset'
|
17 |
+
model = dict(
|
18 |
+
type=Mask2formerVideo,
|
19 |
+
data_preprocessor=None, # to fill
|
20 |
+
backbone=dict(
|
21 |
+
type=OpenCLIPBackbone,
|
22 |
+
model_name='convnext_large_d_320',
|
23 |
+
fix=True,
|
24 |
+
init_cfg=dict(
|
25 |
+
type='clip_pretrain',
|
26 |
+
checkpoint='laion2b_s29b_b131k_ft_soup'
|
27 |
+
)
|
28 |
+
),
|
29 |
+
panoptic_head=dict(
|
30 |
+
init_cfg=dict(
|
31 |
+
type='Pretrained',
|
32 |
+
checkpoint='./models/omg_seg_convl.pth',
|
33 |
+
prefix='panoptic_head.'
|
34 |
+
),
|
35 |
+
type=Mask2FormerVideoHead,
|
36 |
+
sphere_cls=True,
|
37 |
+
ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}',
|
38 |
+
logit=None,
|
39 |
+
in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
|
40 |
+
strides=[4, 8, 16, 32],
|
41 |
+
feat_channels=256,
|
42 |
+
out_channels=256,
|
43 |
+
num_things_classes=num_things_classes,
|
44 |
+
num_stuff_classes=num_stuff_classes,
|
45 |
+
num_queries=300,
|
46 |
+
num_transformer_feat_level=3,
|
47 |
+
pixel_decoder=dict(
|
48 |
+
type=MSDeformAttnPixelDecoder,
|
49 |
+
num_outs=3,
|
50 |
+
norm_cfg=dict(type=GroupNorm, num_groups=32),
|
51 |
+
act_cfg=dict(type=ReLU),
|
52 |
+
encoder=dict( # DeformableDetrTransformerEncoder
|
53 |
+
num_layers=6,
|
54 |
+
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
|
55 |
+
self_attn_cfg=dict( # MultiScaleDeformableAttention
|
56 |
+
embed_dims=256,
|
57 |
+
num_heads=8,
|
58 |
+
num_levels=3,
|
59 |
+
num_points=4,
|
60 |
+
dropout=0.0,
|
61 |
+
batch_first=True),
|
62 |
+
ffn_cfg=dict(
|
63 |
+
embed_dims=256,
|
64 |
+
feedforward_channels=1024,
|
65 |
+
num_fcs=2,
|
66 |
+
ffn_drop=0.0,
|
67 |
+
act_cfg=dict(type=ReLU, inplace=True)))),
|
68 |
+
positional_encoding=dict(num_feats=128, normalize=True)),
|
69 |
+
enforce_decoder_input_project=False,
|
70 |
+
positional_encoding=dict(num_feats=128, normalize=True),
|
71 |
+
transformer_decoder=dict( # Mask2FormerTransformerDecoder
|
72 |
+
return_intermediate=True,
|
73 |
+
num_layers=9,
|
74 |
+
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
|
75 |
+
self_attn_cfg=dict( # MultiheadAttention
|
76 |
+
embed_dims=256,
|
77 |
+
num_heads=8,
|
78 |
+
dropout=0.0,
|
79 |
+
batch_first=True),
|
80 |
+
cross_attn_cfg=dict( # MultiheadAttention
|
81 |
+
embed_dims=256,
|
82 |
+
num_heads=8,
|
83 |
+
dropout=0.0,
|
84 |
+
batch_first=True),
|
85 |
+
ffn_cfg=dict(
|
86 |
+
embed_dims=256,
|
87 |
+
feedforward_channels=2048,
|
88 |
+
num_fcs=2,
|
89 |
+
ffn_drop=0.0,
|
90 |
+
act_cfg=dict(type='ReLU', inplace=True))),
|
91 |
+
init_cfg=None),
|
92 |
+
loss_cls=dict(
|
93 |
+
type=CrossEntropyLoss,
|
94 |
+
use_sigmoid=False,
|
95 |
+
loss_weight=2.0,
|
96 |
+
reduction='mean',
|
97 |
+
class_weight=None # [1.0] * num_classes + [0.1]
|
98 |
+
),
|
99 |
+
loss_mask=dict(
|
100 |
+
type=CrossEntropyLoss,
|
101 |
+
use_sigmoid=True,
|
102 |
+
reduction='mean',
|
103 |
+
loss_weight=5.0),
|
104 |
+
loss_dice=dict(
|
105 |
+
type=DiceLoss,
|
106 |
+
use_sigmoid=True,
|
107 |
+
activate=True,
|
108 |
+
reduction='mean',
|
109 |
+
naive_dice=True,
|
110 |
+
eps=1.0,
|
111 |
+
loss_weight=5.0),
|
112 |
+
loss_iou=dict(
|
113 |
+
type=FocalLoss,
|
114 |
+
use_sigmoid=True,
|
115 |
+
loss_weight=2.0,
|
116 |
+
reduction='mean'
|
117 |
+
)
|
118 |
+
),
|
119 |
+
panoptic_fusion_head=dict(
|
120 |
+
type=OMGFusionHead,
|
121 |
+
num_things_classes=num_things_classes,
|
122 |
+
num_stuff_classes=num_stuff_classes,
|
123 |
+
loss_panoptic=None,
|
124 |
+
init_cfg=None
|
125 |
+
),
|
126 |
+
train_cfg=dict(
|
127 |
+
num_points=12544,
|
128 |
+
oversample_ratio=3.0,
|
129 |
+
importance_sample_ratio=0.75,
|
130 |
+
assigner=dict(
|
131 |
+
type=HungarianAssigner,
|
132 |
+
match_costs=[
|
133 |
+
dict(type=ClassificationCost, weight=2.0),
|
134 |
+
dict(
|
135 |
+
type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
|
136 |
+
dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
|
137 |
+
]),
|
138 |
+
sampler=dict(type=MaskPseudoSampler)),
|
139 |
+
test_cfg=dict(
|
140 |
+
panoptic_on=True,
|
141 |
+
semantic_on=False,
|
142 |
+
instance_on=True,
|
143 |
+
# max_per_image is for instance segmentation.
|
144 |
+
max_per_image=100,
|
145 |
+
iou_thr=0.8,
|
146 |
+
# In Mask2Former's panoptic postprocessing,
|
147 |
+
# it will filter mask area where score is less than 0.5 .
|
148 |
+
filter_low_score=True,
|
149 |
+
object_mask_thr=0.,
|
150 |
+
),
|
151 |
+
init_cfg=None
|
152 |
+
)
|
assets/000000000139.jpg
ADDED
Git LFS Details
|
assets/000000000285.jpg
ADDED
Git LFS Details
|
assets/000000000632.jpg
ADDED
Git LFS Details
|
assets/000000000724.jpg
ADDED
Git LFS Details
|
ext/cityscapes_scripts/createPanopticImgs.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#
|
3 |
+
# Converts the *instanceIds.png annotations of the Cityscapes dataset
|
4 |
+
# to COCO-style panoptic segmentation format (http://cocodataset.org/#format-data).
|
5 |
+
# The convertion is working for 'fine' set of the annotations.
|
6 |
+
#
|
7 |
+
# By default with this tool uses IDs specified in labels.py. You can use flag
|
8 |
+
# --use-train-id to get train ids for categories. 'ignoreInEval' categories are
|
9 |
+
# removed during the conversion.
|
10 |
+
#
|
11 |
+
# In panoptic segmentation format image_id is used to match predictions and ground truth.
|
12 |
+
# For cityscapes image_id has form <city>_123456_123456 and corresponds to the prefix
|
13 |
+
# of cityscapes image files.
|
14 |
+
#
|
15 |
+
|
16 |
+
# python imports
|
17 |
+
from __future__ import print_function, absolute_import, division, unicode_literals
|
18 |
+
import os
|
19 |
+
import glob
|
20 |
+
import sys
|
21 |
+
import argparse
|
22 |
+
import json
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
# Image processing
|
26 |
+
from PIL import Image
|
27 |
+
|
28 |
+
# cityscapes imports
|
29 |
+
from ext.cityscapes_scripts.helpers.csHelpers import printError
|
30 |
+
from ext.cityscapes_scripts.helpers.labels import id2label, labels
|
31 |
+
|
32 |
+
|
33 |
+
import mmengine
|
34 |
+
|
35 |
+
|
36 |
+
# The main method
|
37 |
+
def convert2panoptic(cityscapesPath=None, outputFolder=None, useTrainId=False, setNames=["val", "train", "test"]):
|
38 |
+
# Where to look for Cityscapes
|
39 |
+
if cityscapesPath is None:
|
40 |
+
if 'CITYSCAPES_DATASET' in os.environ:
|
41 |
+
cityscapesPath = os.environ['CITYSCAPES_DATASET']
|
42 |
+
else:
|
43 |
+
cityscapesPath = 'data/cityscapes'
|
44 |
+
cityscapesPath = os.path.join(cityscapesPath, "gtFine")
|
45 |
+
|
46 |
+
if outputFolder is None:
|
47 |
+
outputFolder = cityscapesPath.replace('gtFine', "annotations")
|
48 |
+
|
49 |
+
mmengine.mkdir_or_exist(outputFolder)
|
50 |
+
|
51 |
+
categories = []
|
52 |
+
for label in labels:
|
53 |
+
if label.ignoreInEval:
|
54 |
+
continue
|
55 |
+
categories.append({'id': int(label.trainId) if useTrainId else int(label.id),
|
56 |
+
'name': label.name,
|
57 |
+
'color': label.color,
|
58 |
+
'supercategory': label.category,
|
59 |
+
'isthing': 1 if label.hasInstances else 0})
|
60 |
+
|
61 |
+
categories = sorted(categories, key=lambda x:x['id'])
|
62 |
+
|
63 |
+
for setName in setNames:
|
64 |
+
# how to search for all ground truth
|
65 |
+
searchFine = os.path.join(cityscapesPath, setName, "*", "*_instanceIds.png")
|
66 |
+
# search files
|
67 |
+
filesFine = glob.glob(searchFine)
|
68 |
+
filesFine.sort()
|
69 |
+
|
70 |
+
files = filesFine
|
71 |
+
# quit if we did not find anything
|
72 |
+
if not files:
|
73 |
+
printError(
|
74 |
+
"Did not find any files for {} set using matching pattern {}. Please consult the README.".format(setName, searchFine)
|
75 |
+
)
|
76 |
+
# a bit verbose
|
77 |
+
print("Converting {} annotation files for {} set.".format(len(files), setName))
|
78 |
+
|
79 |
+
trainIfSuffix = "_trainId" if useTrainId else ""
|
80 |
+
outputBaseFile = "cityscapes_panoptic_{}{}".format(setName, trainIfSuffix)
|
81 |
+
outFile = os.path.join(outputFolder, "{}.json".format(outputBaseFile))
|
82 |
+
print("Json file with the annotations in panoptic format will be saved in {}".format(outFile))
|
83 |
+
panopticFolder = os.path.join(outputFolder, outputBaseFile)
|
84 |
+
if not os.path.isdir(panopticFolder):
|
85 |
+
print("Creating folder {} for panoptic segmentation PNGs".format(panopticFolder))
|
86 |
+
os.mkdir(panopticFolder)
|
87 |
+
print("Corresponding segmentations in .png format will be saved in {}".format(panopticFolder))
|
88 |
+
|
89 |
+
images = []
|
90 |
+
annotations = []
|
91 |
+
for progress, f in enumerate(files):
|
92 |
+
|
93 |
+
originalFormat = np.array(Image.open(f))
|
94 |
+
|
95 |
+
fileName = os.path.basename(f)
|
96 |
+
location = fileName.split('_')[0]
|
97 |
+
imageId = fileName.replace("_gtFine_instanceIds.png", "")
|
98 |
+
fileName = os.path.join(location, fileName)
|
99 |
+
inputFileName = fileName.replace("_gtFine_instanceIds.png", "_leftImg8bit.png")
|
100 |
+
outputFileName = fileName.replace("_gtFine_instanceIds.png", "_panoptic.png")
|
101 |
+
# image entry, id for image is its filename without extension
|
102 |
+
images.append({"id": imageId,
|
103 |
+
"width": int(originalFormat.shape[1]),
|
104 |
+
"height": int(originalFormat.shape[0]),
|
105 |
+
"file_name": inputFileName})
|
106 |
+
|
107 |
+
pan_format = np.zeros(
|
108 |
+
(originalFormat.shape[0], originalFormat.shape[1], 3), dtype=np.uint8
|
109 |
+
)
|
110 |
+
|
111 |
+
segmentIds = np.unique(originalFormat)
|
112 |
+
segmInfo = []
|
113 |
+
for segmentId in segmentIds:
|
114 |
+
if segmentId < 1000:
|
115 |
+
semanticId = segmentId
|
116 |
+
isCrowd = 1
|
117 |
+
else:
|
118 |
+
semanticId = segmentId // 1000
|
119 |
+
isCrowd = 0
|
120 |
+
labelInfo = id2label[semanticId]
|
121 |
+
categoryId = labelInfo.trainId if useTrainId else labelInfo.id
|
122 |
+
if labelInfo.ignoreInEval:
|
123 |
+
continue
|
124 |
+
if not labelInfo.hasInstances:
|
125 |
+
isCrowd = 0
|
126 |
+
|
127 |
+
mask = originalFormat == segmentId
|
128 |
+
color = [segmentId % 256, segmentId // 256, segmentId // 256 // 256]
|
129 |
+
pan_format[mask] = color
|
130 |
+
|
131 |
+
area = np.sum(mask) # segment area computation
|
132 |
+
|
133 |
+
# bbox computation for a segment
|
134 |
+
hor = np.sum(mask, axis=0)
|
135 |
+
hor_idx = np.nonzero(hor)[0]
|
136 |
+
x = hor_idx[0]
|
137 |
+
width = hor_idx[-1] - x + 1
|
138 |
+
vert = np.sum(mask, axis=1)
|
139 |
+
vert_idx = np.nonzero(vert)[0]
|
140 |
+
y = vert_idx[0]
|
141 |
+
height = vert_idx[-1] - y + 1
|
142 |
+
bbox = [int(x), int(y), int(width), int(height)]
|
143 |
+
|
144 |
+
segmInfo.append({"id": int(segmentId),
|
145 |
+
"category_id": int(categoryId),
|
146 |
+
"area": int(area),
|
147 |
+
"bbox": bbox,
|
148 |
+
"iscrowd": isCrowd})
|
149 |
+
|
150 |
+
annotations.append({'image_id': imageId,
|
151 |
+
'file_name': outputFileName,
|
152 |
+
"segments_info": segmInfo})
|
153 |
+
|
154 |
+
mmengine.mkdir_or_exist(os.path.dirname(os.path.join(panopticFolder, outputFileName)))
|
155 |
+
Image.fromarray(pan_format).save(os.path.join(panopticFolder, outputFileName))
|
156 |
+
|
157 |
+
print("\rProgress: {:>3.2f} %".format((progress + 1) * 100 / len(files)), end=' ')
|
158 |
+
sys.stdout.flush()
|
159 |
+
|
160 |
+
print("\nSaving the json file {}".format(outFile))
|
161 |
+
d = {'images': images,
|
162 |
+
'annotations': annotations,
|
163 |
+
'categories': categories}
|
164 |
+
with open(outFile, 'w') as f:
|
165 |
+
json.dump(d, f, sort_keys=True, indent=4)
|
166 |
+
|
167 |
+
|
168 |
+
def main():
|
169 |
+
parser = argparse.ArgumentParser()
|
170 |
+
parser.add_argument("--dataset-folder",
|
171 |
+
dest="cityscapesPath",
|
172 |
+
help="path to the Cityscapes dataset 'gtFine' folder",
|
173 |
+
default=None,
|
174 |
+
type=str)
|
175 |
+
parser.add_argument("--output-folder",
|
176 |
+
dest="outputFolder",
|
177 |
+
help="path to the output folder.",
|
178 |
+
default=None,
|
179 |
+
type=str)
|
180 |
+
parser.add_argument("--use-train-id", default=True,action="store_true", dest="useTrainId")
|
181 |
+
parser.add_argument("--set-names",
|
182 |
+
dest="setNames",
|
183 |
+
help="set names to which apply the function to",
|
184 |
+
nargs='+',
|
185 |
+
default=["val", "train"],
|
186 |
+
type=str)
|
187 |
+
args = parser.parse_args()
|
188 |
+
|
189 |
+
convert2panoptic(args.cityscapesPath, args.outputFolder, args.useTrainId, args.setNames)
|
190 |
+
|
191 |
+
|
192 |
+
# call the main
|
193 |
+
if __name__ == "__main__":
|
194 |
+
main()
|
ext/cityscapes_scripts/helpers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# empty
|
ext/cityscapes_scripts/helpers/annotation.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#
|
3 |
+
# Classes to store, read, and write annotations
|
4 |
+
#
|
5 |
+
|
6 |
+
from __future__ import print_function, absolute_import, division
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
from collections import namedtuple
|
11 |
+
|
12 |
+
# get current date and time
|
13 |
+
import datetime
|
14 |
+
import locale
|
15 |
+
|
16 |
+
from abc import ABCMeta, abstractmethod
|
17 |
+
from .box3dImageTransform import Camera
|
18 |
+
|
19 |
+
# A point in a polygon
|
20 |
+
Point = namedtuple('Point', ['x', 'y'])
|
21 |
+
|
22 |
+
|
23 |
+
class CsObjectType():
|
24 |
+
"""Type of an object"""
|
25 |
+
POLY = 1 # polygon
|
26 |
+
BBOX2D = 2 # bounding box
|
27 |
+
BBOX3D = 3 # 3d bounding box
|
28 |
+
IGNORE2D = 4 # 2d ignore region
|
29 |
+
|
30 |
+
|
31 |
+
class CsObject:
|
32 |
+
"""Abstract base class for annotation objects"""
|
33 |
+
__metaclass__ = ABCMeta
|
34 |
+
|
35 |
+
def __init__(self, objType):
|
36 |
+
self.objectType = objType
|
37 |
+
# the label
|
38 |
+
self.label = ""
|
39 |
+
|
40 |
+
# If deleted or not
|
41 |
+
self.deleted = 0
|
42 |
+
# If verified or not
|
43 |
+
self.verified = 0
|
44 |
+
# The date string
|
45 |
+
self.date = ""
|
46 |
+
# The username
|
47 |
+
self.user = ""
|
48 |
+
# Draw the object
|
49 |
+
# Not read from or written to JSON
|
50 |
+
# Set to False if deleted object
|
51 |
+
# Might be set to False by the application for other reasons
|
52 |
+
self.draw = True
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def __str__(self): pass
|
56 |
+
|
57 |
+
@abstractmethod
|
58 |
+
def fromJsonText(self, jsonText, objId=-1): pass
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def toJsonText(self): pass
|
62 |
+
|
63 |
+
def updateDate(self):
|
64 |
+
try:
|
65 |
+
locale.setlocale(locale.LC_ALL, 'en_US.utf8')
|
66 |
+
except locale.Error:
|
67 |
+
locale.setlocale(locale.LC_ALL, 'en_US')
|
68 |
+
except locale.Error:
|
69 |
+
locale.setlocale(locale.LC_ALL, 'us_us.utf8')
|
70 |
+
except locale.Error:
|
71 |
+
locale.setlocale(locale.LC_ALL, 'us_us')
|
72 |
+
except Exception:
|
73 |
+
pass
|
74 |
+
self.date = datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S")
|
75 |
+
|
76 |
+
# Mark the object as deleted
|
77 |
+
def delete(self):
|
78 |
+
self.deleted = 1
|
79 |
+
self.draw = False
|
80 |
+
|
81 |
+
|
82 |
+
class CsPoly(CsObject):
|
83 |
+
"""Class that contains the information of a single annotated object as polygon"""
|
84 |
+
|
85 |
+
# Constructor
|
86 |
+
def __init__(self):
|
87 |
+
CsObject.__init__(self, CsObjectType.POLY)
|
88 |
+
# the polygon as list of points
|
89 |
+
self.polygon = []
|
90 |
+
# the object ID
|
91 |
+
self.id = -1
|
92 |
+
|
93 |
+
def __str__(self):
|
94 |
+
polyText = ""
|
95 |
+
if self.polygon:
|
96 |
+
if len(self.polygon) <= 4:
|
97 |
+
for p in self.polygon:
|
98 |
+
polyText += '({},{}) '.format(p.x, p.y)
|
99 |
+
else:
|
100 |
+
polyText += '({},{}) ({},{}) ... ({},{}) ({},{})'.format(
|
101 |
+
self.polygon[0].x, self.polygon[0].y,
|
102 |
+
self.polygon[1].x, self.polygon[1].y,
|
103 |
+
self.polygon[-2].x, self.polygon[-2].y,
|
104 |
+
self.polygon[-1].x, self.polygon[-1].y)
|
105 |
+
else:
|
106 |
+
polyText = "none"
|
107 |
+
text = "Object: {} - {}".format(self.label, polyText)
|
108 |
+
return text
|
109 |
+
|
110 |
+
def fromJsonText(self, jsonText, objId=-1):
|
111 |
+
self.id = objId
|
112 |
+
self.label = str(jsonText['label'])
|
113 |
+
self.polygon = [Point(p[0], p[1]) for p in jsonText['polygon']]
|
114 |
+
if 'deleted' in jsonText.keys():
|
115 |
+
self.deleted = jsonText['deleted']
|
116 |
+
else:
|
117 |
+
self.deleted = 0
|
118 |
+
if 'verified' in jsonText.keys():
|
119 |
+
self.verified = jsonText['verified']
|
120 |
+
else:
|
121 |
+
self.verified = 1
|
122 |
+
if 'user' in jsonText.keys():
|
123 |
+
self.user = jsonText['user']
|
124 |
+
else:
|
125 |
+
self.user = ''
|
126 |
+
if 'date' in jsonText.keys():
|
127 |
+
self.date = jsonText['date']
|
128 |
+
else:
|
129 |
+
self.date = ''
|
130 |
+
if self.deleted == 1:
|
131 |
+
self.draw = False
|
132 |
+
else:
|
133 |
+
self.draw = True
|
134 |
+
|
135 |
+
def toJsonText(self):
|
136 |
+
objDict = {}
|
137 |
+
objDict['label'] = self.label
|
138 |
+
objDict['id'] = self.id
|
139 |
+
objDict['deleted'] = self.deleted
|
140 |
+
objDict['verified'] = self.verified
|
141 |
+
objDict['user'] = self.user
|
142 |
+
objDict['date'] = self.date
|
143 |
+
objDict['polygon'] = []
|
144 |
+
for pt in self.polygon:
|
145 |
+
objDict['polygon'].append([pt.x, pt.y])
|
146 |
+
|
147 |
+
return objDict
|
148 |
+
|
149 |
+
|
150 |
+
class CsBbox2d(CsObject):
|
151 |
+
"""Class that contains the information of a single annotated object as bounding box"""
|
152 |
+
|
153 |
+
# Constructor
|
154 |
+
def __init__(self):
|
155 |
+
CsObject.__init__(self, CsObjectType.BBOX2D)
|
156 |
+
# the polygon as list of points
|
157 |
+
self.bbox_amodal_xywh = []
|
158 |
+
self.bbox_modal_xywh = []
|
159 |
+
|
160 |
+
# the ID of the corresponding object
|
161 |
+
self.instanceId = -1
|
162 |
+
# the label of the corresponding object
|
163 |
+
self.label = ""
|
164 |
+
|
165 |
+
def __str__(self):
|
166 |
+
bboxAmodalText = ""
|
167 |
+
bboxAmodalText += '[(x1: {}, y1: {}), (w: {}, h: {})]'.format(
|
168 |
+
self.bbox_amodal_xywh[0], self.bbox_amodal_xywh[1], self.bbox_amodal_xywh[2], self.bbox_amodal_xywh[3])
|
169 |
+
|
170 |
+
bboxModalText = ""
|
171 |
+
bboxModalText += '[(x1: {}, y1: {}), (w: {}, h: {})]'.format(
|
172 |
+
self.bbox_modal_xywh[0], self.bbox_modal_xywh[1], self.bbox_modal_xywh[2], self.bbox_modal_xywh[3])
|
173 |
+
|
174 |
+
text = "Object: {}\n - Amodal {}\n - Modal {}".format(
|
175 |
+
self.label, bboxAmodalText, bboxModalText)
|
176 |
+
return text
|
177 |
+
|
178 |
+
def setAmodalBox(self, bbox_amodal):
|
179 |
+
# sets the amodal box if required
|
180 |
+
self.bbox_amodal_xywh = [
|
181 |
+
bbox_amodal[0],
|
182 |
+
bbox_amodal[1],
|
183 |
+
bbox_amodal[2] - bbox_amodal[0],
|
184 |
+
bbox_amodal[3] - bbox_amodal[1]
|
185 |
+
]
|
186 |
+
|
187 |
+
# access 2d boxes in [xmin, ymin, xmax, ymax] format
|
188 |
+
@property
|
189 |
+
def bbox_amodal(self):
|
190 |
+
"""Returns the 2d box as [xmin, ymin, xmax, ymax]"""
|
191 |
+
return [
|
192 |
+
self.bbox_amodal_xywh[0],
|
193 |
+
self.bbox_amodal_xywh[1],
|
194 |
+
self.bbox_amodal_xywh[0] + self.bbox_amodal_xywh[2],
|
195 |
+
self.bbox_amodal_xywh[1] + self.bbox_amodal_xywh[3]
|
196 |
+
]
|
197 |
+
|
198 |
+
@property
|
199 |
+
def bbox_modal(self):
|
200 |
+
"""Returns the 2d box as [xmin, ymin, xmax, ymax]"""
|
201 |
+
return [
|
202 |
+
self.bbox_modal_xywh[0],
|
203 |
+
self.bbox_modal_xywh[1],
|
204 |
+
self.bbox_modal_xywh[0] + self.bbox_modal_xywh[2],
|
205 |
+
self.bbox_modal_xywh[1] + self.bbox_modal_xywh[3]
|
206 |
+
]
|
207 |
+
|
208 |
+
def fromJsonText(self, jsonText, objId=-1):
|
209 |
+
# try to load from cityperson format
|
210 |
+
if 'bbox' in jsonText.keys() and 'bboxVis' in jsonText.keys():
|
211 |
+
self.bbox_amodal_xywh = jsonText['bbox']
|
212 |
+
self.bbox_modal_xywh = jsonText['bboxVis']
|
213 |
+
# both modal and amodal boxes are provided
|
214 |
+
elif "modal" in jsonText.keys() and "amodal" in jsonText.keys():
|
215 |
+
self.bbox_amodal_xywh = jsonText['amodal']
|
216 |
+
self.bbox_modal_xywh = jsonText['modal']
|
217 |
+
# only amodal boxes are provided
|
218 |
+
else:
|
219 |
+
self.bbox_modal_xywh = jsonText['amodal']
|
220 |
+
self.bbox_amodal_xywh = jsonText['amodal']
|
221 |
+
|
222 |
+
# load label and instanceId if available
|
223 |
+
if 'label' in jsonText.keys() and 'instanceId' in jsonText.keys():
|
224 |
+
self.label = str(jsonText['label'])
|
225 |
+
self.instanceId = jsonText['instanceId']
|
226 |
+
|
227 |
+
def toJsonText(self):
|
228 |
+
objDict = {}
|
229 |
+
objDict['label'] = self.label
|
230 |
+
objDict['instanceId'] = self.instanceId
|
231 |
+
objDict['modal'] = self.bbox_modal_xywh
|
232 |
+
objDict['amodal'] = self.bbox_amodal_xywh
|
233 |
+
|
234 |
+
return objDict
|
235 |
+
|
236 |
+
|
237 |
+
class CsBbox3d(CsObject):
|
238 |
+
"""Class that contains the information of a single annotated object as 3D bounding box"""
|
239 |
+
|
240 |
+
# Constructor
|
241 |
+
def __init__(self):
|
242 |
+
CsObject.__init__(self, CsObjectType.BBOX3D)
|
243 |
+
|
244 |
+
self.bbox_2d = None
|
245 |
+
|
246 |
+
self.center = []
|
247 |
+
self.dims = []
|
248 |
+
self.rotation = []
|
249 |
+
self.instanceId = -1
|
250 |
+
self.label = ""
|
251 |
+
self.score = -1.
|
252 |
+
|
253 |
+
def __str__(self):
|
254 |
+
bbox2dText = str(self.bbox_2d)
|
255 |
+
|
256 |
+
bbox3dText = ""
|
257 |
+
bbox3dText += '\n - Center (x/y/z) [m]: {}/{}/{}'.format(
|
258 |
+
self.center[0], self.center[1], self.center[2])
|
259 |
+
bbox3dText += '\n - Dimensions (l/w/h) [m]: {}/{}/{}'.format(
|
260 |
+
self.dims[0], self.dims[1], self.dims[2])
|
261 |
+
bbox3dText += '\n - Rotation: {}/{}/{}/{}'.format(
|
262 |
+
self.rotation[0], self.rotation[1], self.rotation[2], self.rotation[3])
|
263 |
+
|
264 |
+
text = "Object: {}\n2D {}\n - 3D {}".format(
|
265 |
+
self.label, bbox2dText, bbox3dText)
|
266 |
+
return text
|
267 |
+
|
268 |
+
def fromJsonText(self, jsonText, objId=-1):
|
269 |
+
# load 2D box
|
270 |
+
self.bbox_2d = CsBbox2d()
|
271 |
+
self.bbox_2d.fromJsonText(jsonText['2d'])
|
272 |
+
|
273 |
+
self.center = jsonText['3d']['center']
|
274 |
+
self.dims = jsonText['3d']['dimensions']
|
275 |
+
self.rotation = jsonText['3d']['rotation']
|
276 |
+
self.label = jsonText['label']
|
277 |
+
self.score = jsonText['score']
|
278 |
+
|
279 |
+
if 'instanceId' in jsonText.keys():
|
280 |
+
self.instanceId = jsonText['instanceId']
|
281 |
+
|
282 |
+
def toJsonText(self):
|
283 |
+
objDict = {}
|
284 |
+
objDict['label'] = self.label
|
285 |
+
objDict['instanceId'] = self.instanceId
|
286 |
+
objDict['2d']['amodal'] = self.bbox_2d.bbox_amodal_xywh
|
287 |
+
objDict['2d']['modal'] = self.bbox_2d.bbox_modal_xywh
|
288 |
+
objDict['3d']['center'] = self.center
|
289 |
+
objDict['3d']['dimensions'] = self.dims
|
290 |
+
objDict['3d']['rotation'] = self.rotation
|
291 |
+
|
292 |
+
return objDict
|
293 |
+
|
294 |
+
@property
|
295 |
+
def depth(self):
|
296 |
+
# returns the BEV depth
|
297 |
+
return np.sqrt(self.center[0]**2 + self.center[1]**2).astype(int)
|
298 |
+
|
299 |
+
|
300 |
+
class CsIgnore2d(CsObject):
|
301 |
+
"""Class that contains the information of a single annotated 2d ignore region"""
|
302 |
+
|
303 |
+
# Constructor
|
304 |
+
def __init__(self):
|
305 |
+
CsObject.__init__(self, CsObjectType.IGNORE2D)
|
306 |
+
|
307 |
+
self.bbox_xywh = []
|
308 |
+
self.label = ""
|
309 |
+
self.instanceId = -1
|
310 |
+
|
311 |
+
def __str__(self):
|
312 |
+
bbox2dText = ""
|
313 |
+
bbox2dText += 'Ignore Region: (x1: {}, y1: {}), (w: {}, h: {})'.format(
|
314 |
+
self.bbox_xywh[0], self.bbox_xywh[1], self.bbox_xywh[2], self.bbox_xywh[3])
|
315 |
+
|
316 |
+
return bbox2dText
|
317 |
+
|
318 |
+
def fromJsonText(self, jsonText, objId=-1):
|
319 |
+
self.bbox_xywh = jsonText['2d']
|
320 |
+
|
321 |
+
if 'label' in jsonText.keys():
|
322 |
+
self.label = jsonText['label']
|
323 |
+
|
324 |
+
if 'instanceId' in jsonText.keys():
|
325 |
+
self.instanceId = jsonText['instanceId']
|
326 |
+
|
327 |
+
def toJsonText(self):
|
328 |
+
objDict = {}
|
329 |
+
objDict['label'] = self.label
|
330 |
+
objDict['instanceId'] = self.instanceId
|
331 |
+
objDict['2d'] = self.bbox_xywh
|
332 |
+
|
333 |
+
return objDict
|
334 |
+
|
335 |
+
@property
|
336 |
+
def bbox(self):
|
337 |
+
"""Returns the 2d box as [xmin, ymin, xmax, ymax]"""
|
338 |
+
return [
|
339 |
+
self.bbox_xywh[0],
|
340 |
+
self.bbox_xywh[1],
|
341 |
+
self.bbox_xywh[0] + self.bbox_xywh[2],
|
342 |
+
self.bbox_xywh[1] + self.bbox_xywh[3]
|
343 |
+
]
|
344 |
+
|
345 |
+
# Extend api to be compatible to bbox2d
|
346 |
+
@property
|
347 |
+
def bbox_amodal_xywh(self):
|
348 |
+
return self.bbox_xywh
|
349 |
+
|
350 |
+
@property
|
351 |
+
def bbox_modal_xywh(self):
|
352 |
+
return self.bbox_xywh
|
353 |
+
|
354 |
+
|
355 |
+
class Annotation:
|
356 |
+
"""The annotation of a whole image (doesn't support mixed annotations, i.e. combining CsPoly and CsBbox2d)"""
|
357 |
+
|
358 |
+
# Constructor
|
359 |
+
def __init__(self, objType=CsObjectType.POLY):
|
360 |
+
# the width of that image and thus of the label image
|
361 |
+
self.imgWidth = 0
|
362 |
+
# the height of that image and thus of the label image
|
363 |
+
self.imgHeight = 0
|
364 |
+
# the list of objects
|
365 |
+
self.objects = []
|
366 |
+
# the camera calibration
|
367 |
+
self.camera = None
|
368 |
+
assert objType in CsObjectType.__dict__.values()
|
369 |
+
self.objectType = objType
|
370 |
+
|
371 |
+
def toJson(self):
|
372 |
+
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
|
373 |
+
|
374 |
+
def fromJsonText(self, jsonText):
|
375 |
+
jsonDict = json.loads(jsonText)
|
376 |
+
self.imgWidth = int(jsonDict['imgWidth'])
|
377 |
+
self.imgHeight = int(jsonDict['imgHeight'])
|
378 |
+
self.objects = []
|
379 |
+
# load objects
|
380 |
+
if self.objectType != CsObjectType.IGNORE2D:
|
381 |
+
for objId, objIn in enumerate(jsonDict['objects']):
|
382 |
+
if self.objectType == CsObjectType.POLY:
|
383 |
+
obj = CsPoly()
|
384 |
+
elif self.objectType == CsObjectType.BBOX2D:
|
385 |
+
obj = CsBbox2d()
|
386 |
+
elif self.objectType == CsObjectType.BBOX3D:
|
387 |
+
obj = CsBbox3d()
|
388 |
+
obj.fromJsonText(objIn, objId)
|
389 |
+
self.objects.append(obj)
|
390 |
+
|
391 |
+
# load ignores
|
392 |
+
if 'ignore' in jsonDict.keys():
|
393 |
+
for ignoreId, ignoreIn in enumerate(jsonDict['ignore']):
|
394 |
+
obj = CsIgnore2d()
|
395 |
+
obj.fromJsonText(ignoreIn, ignoreId)
|
396 |
+
self.objects.append(obj)
|
397 |
+
|
398 |
+
# load camera calibration
|
399 |
+
if 'sensor' in jsonDict.keys():
|
400 |
+
self.camera = Camera(fx=jsonDict['sensor']['fx'],
|
401 |
+
fy=jsonDict['sensor']['fy'],
|
402 |
+
u0=jsonDict['sensor']['u0'],
|
403 |
+
v0=jsonDict['sensor']['v0'],
|
404 |
+
sensor_T_ISO_8855=jsonDict['sensor']['sensor_T_ISO_8855'])
|
405 |
+
|
406 |
+
def toJsonText(self):
|
407 |
+
jsonDict = {}
|
408 |
+
jsonDict['imgWidth'] = self.imgWidth
|
409 |
+
jsonDict['imgHeight'] = self.imgHeight
|
410 |
+
jsonDict['objects'] = []
|
411 |
+
for obj in self.objects:
|
412 |
+
objDict = obj.toJsonText()
|
413 |
+
jsonDict['objects'].append(objDict)
|
414 |
+
|
415 |
+
return jsonDict
|
416 |
+
|
417 |
+
# Read a json formatted polygon file and return the annotation
|
418 |
+
def fromJsonFile(self, jsonFile):
|
419 |
+
if not os.path.isfile(jsonFile):
|
420 |
+
print('Given json file not found: {}'.format(jsonFile))
|
421 |
+
return
|
422 |
+
with open(jsonFile, 'r') as f:
|
423 |
+
jsonText = f.read()
|
424 |
+
self.fromJsonText(jsonText)
|
425 |
+
|
426 |
+
def toJsonFile(self, jsonFile):
|
427 |
+
with open(jsonFile, 'w') as f:
|
428 |
+
f.write(self.toJson())
|
429 |
+
|
430 |
+
|
431 |
+
# a dummy example
|
432 |
+
if __name__ == "__main__":
|
433 |
+
obj = CsPoly()
|
434 |
+
obj.label = 'car'
|
435 |
+
obj.polygon.append(Point(0, 0))
|
436 |
+
obj.polygon.append(Point(1, 0))
|
437 |
+
obj.polygon.append(Point(1, 1))
|
438 |
+
obj.polygon.append(Point(0, 1))
|
439 |
+
|
440 |
+
print(type(obj).__name__)
|
441 |
+
print(obj)
|
ext/cityscapes_scripts/helpers/csHelpers.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#
|
3 |
+
# Various helper methods and includes for Cityscapes
|
4 |
+
#
|
5 |
+
|
6 |
+
# Python imports
|
7 |
+
from __future__ import print_function, absolute_import, division
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import getopt
|
11 |
+
import glob
|
12 |
+
import math
|
13 |
+
import json
|
14 |
+
from collections import namedtuple
|
15 |
+
import logging
|
16 |
+
import traceback
|
17 |
+
|
18 |
+
# Image processing
|
19 |
+
from PIL import Image
|
20 |
+
from PIL import ImageDraw
|
21 |
+
|
22 |
+
# Numpy for datastructures
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
# Cityscapes modules
|
26 |
+
# from .annotation import Annotation
|
27 |
+
from .labels import labels, name2label, id2label, trainId2label, category2labels
|
28 |
+
|
29 |
+
|
30 |
+
def printError(message):
|
31 |
+
"""Print an error message and quit"""
|
32 |
+
print('ERROR: ' + str(message))
|
33 |
+
sys.exit(-1)
|
34 |
+
|
35 |
+
|
36 |
+
class colors:
|
37 |
+
"""Class for colors"""
|
38 |
+
RED = '\033[31;1m'
|
39 |
+
GREEN = '\033[32;1m'
|
40 |
+
YELLOW = '\033[33;1m'
|
41 |
+
BLUE = '\033[34;1m'
|
42 |
+
MAGENTA = '\033[35;1m'
|
43 |
+
CYAN = '\033[36;1m'
|
44 |
+
BOLD = '\033[1m'
|
45 |
+
UNDERLINE = '\033[4m'
|
46 |
+
ENDC = '\033[0m'
|
47 |
+
|
48 |
+
|
49 |
+
def getColorEntry(val, args):
|
50 |
+
"""Colored value output if colorized flag is activated."""
|
51 |
+
|
52 |
+
if not args.colorized:
|
53 |
+
return ""
|
54 |
+
if not isinstance(val, float) or math.isnan(val):
|
55 |
+
return colors.ENDC
|
56 |
+
if (val < .20):
|
57 |
+
return colors.RED
|
58 |
+
elif (val < .40):
|
59 |
+
return colors.YELLOW
|
60 |
+
elif (val < .60):
|
61 |
+
return colors.BLUE
|
62 |
+
elif (val < .80):
|
63 |
+
return colors.CYAN
|
64 |
+
else:
|
65 |
+
return colors.GREEN
|
66 |
+
|
67 |
+
|
68 |
+
# Cityscapes files have a typical filename structure
|
69 |
+
# <city>_<sequenceNb>_<frameNb>_<type>[_<type2>].<ext>
|
70 |
+
# This class contains the individual elements as members
|
71 |
+
# For the sequence and frame number, the strings are returned, including leading zeros
|
72 |
+
CsFile = namedtuple('csFile', ['city', 'sequenceNb', 'frameNb', 'type', 'type2', 'ext'])
|
73 |
+
|
74 |
+
|
75 |
+
def getCsFileInfo(fileName):
|
76 |
+
"""Returns a CsFile object filled from the info in the given filename"""
|
77 |
+
baseName = os.path.basename(fileName)
|
78 |
+
parts = baseName.split('_')
|
79 |
+
parts = parts[:-1] + parts[-1].split('.')
|
80 |
+
if not parts:
|
81 |
+
printError('Cannot parse given filename ({}). Does not seem to be a valid Cityscapes file.'.format(fileName))
|
82 |
+
if len(parts) == 5:
|
83 |
+
csFile = CsFile(*parts[:-1], type2="", ext=parts[-1])
|
84 |
+
elif len(parts) == 6:
|
85 |
+
csFile = CsFile(*parts)
|
86 |
+
else:
|
87 |
+
printError('Found {} part(s) in given filename ({}). Expected 5 or 6.'.format(len(parts), fileName))
|
88 |
+
|
89 |
+
return csFile
|
90 |
+
|
91 |
+
|
92 |
+
def getCoreImageFileName(filename):
|
93 |
+
"""Returns the part of Cityscapes filenames that is common to all data types
|
94 |
+
|
95 |
+
e.g. for city_123456_123456_gtFine_polygons.json returns city_123456_123456
|
96 |
+
"""
|
97 |
+
csFile = getCsFileInfo(filename)
|
98 |
+
return "{}_{}_{}".format(csFile.city, csFile.sequenceNb, csFile.frameNb)
|
99 |
+
|
100 |
+
|
101 |
+
def getDirectory(fileName):
|
102 |
+
"""Returns the directory name for the given filename
|
103 |
+
|
104 |
+
e.g.
|
105 |
+
fileName = "/foo/bar/foobar.txt"
|
106 |
+
return value is "bar"
|
107 |
+
Not much error checking though
|
108 |
+
"""
|
109 |
+
dirName = os.path.dirname(fileName)
|
110 |
+
return os.path.basename(dirName)
|
111 |
+
|
112 |
+
|
113 |
+
def ensurePath(path):
|
114 |
+
"""Make sure that the given path exists"""
|
115 |
+
if not path:
|
116 |
+
return
|
117 |
+
if not os.path.isdir(path):
|
118 |
+
os.makedirs(path)
|
119 |
+
|
120 |
+
|
121 |
+
def writeDict2JSON(dictName, fileName):
|
122 |
+
"""Write a dictionary as json file"""
|
123 |
+
with open(fileName, 'w') as f:
|
124 |
+
f.write(json.dumps(dictName, default=lambda o: o.__dict__, sort_keys=True, indent=4))
|
125 |
+
|
126 |
+
|
127 |
+
# dummy main
|
128 |
+
if __name__ == "__main__":
|
129 |
+
printError("Only for include, not executable on its own.")
|
ext/cityscapes_scripts/helpers/labels.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#
|
3 |
+
# Cityscapes labels
|
4 |
+
#
|
5 |
+
|
6 |
+
from __future__ import print_function, absolute_import, division
|
7 |
+
from collections import namedtuple
|
8 |
+
|
9 |
+
|
10 |
+
#--------------------------------------------------------------------------------
|
11 |
+
# Definitions
|
12 |
+
#--------------------------------------------------------------------------------
|
13 |
+
|
14 |
+
# a label and all meta information
|
15 |
+
Label = namedtuple( 'Label' , [
|
16 |
+
|
17 |
+
'name' , # The identifier of this label, e.g. 'car', 'person', ... .
|
18 |
+
# We use them to uniquely name a class
|
19 |
+
|
20 |
+
'id' , # An integer ID that is associated with this label.
|
21 |
+
# The IDs are used to represent the label in ground truth images
|
22 |
+
# An ID of -1 means that this label does not have an ID and thus
|
23 |
+
# is ignored when creating ground truth images (e.g. license plate).
|
24 |
+
# Do not modify these IDs, since exactly these IDs are expected by the
|
25 |
+
# evaluation server.
|
26 |
+
|
27 |
+
'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
|
28 |
+
# ground truth images with train IDs, using the tools provided in the
|
29 |
+
# 'preparation' folder. However, make sure to validate or submit results
|
30 |
+
# to our evaluation server using the regular IDs above!
|
31 |
+
# For trainIds, multiple labels might have the same ID. Then, these labels
|
32 |
+
# are mapped to the same class in the ground truth images. For the inverse
|
33 |
+
# mapping, we use the label that is defined first in the list below.
|
34 |
+
# For example, mapping all void-type classes to the same ID in training,
|
35 |
+
# might make sense for some approaches.
|
36 |
+
# Max value is 255!
|
37 |
+
|
38 |
+
'category' , # The name of the category that this label belongs to
|
39 |
+
|
40 |
+
'categoryId' , # The ID of this category. Used to create ground truth images
|
41 |
+
# on category level.
|
42 |
+
|
43 |
+
'hasInstances', # Whether this label distinguishes between single instances or not
|
44 |
+
|
45 |
+
'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
|
46 |
+
# during evaluations or not
|
47 |
+
|
48 |
+
'color' , # The color of this label
|
49 |
+
] )
|
50 |
+
|
51 |
+
|
52 |
+
#--------------------------------------------------------------------------------
|
53 |
+
# A list of all labels
|
54 |
+
#--------------------------------------------------------------------------------
|
55 |
+
|
56 |
+
# Please adapt the train IDs as appropriate for your approach.
|
57 |
+
# Note that you might want to ignore labels with ID 255 during training.
|
58 |
+
# Further note that the current train IDs are only a suggestion. You can use whatever you like.
|
59 |
+
# Make sure to provide your results using the original IDs and not the training IDs.
|
60 |
+
# Note that many IDs are ignored in evaluation and thus you never need to predict these!
|
61 |
+
|
62 |
+
labels = [
|
63 |
+
# name id trainId category catId hasInstances ignoreInEval color
|
64 |
+
Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
|
65 |
+
Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
|
66 |
+
Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
|
67 |
+
Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
|
68 |
+
Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
|
69 |
+
Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
|
70 |
+
Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
|
71 |
+
Label( 'road' , 7 , 0 + 8, 'flat' , 1 , False , False , (128, 64,128) ),
|
72 |
+
Label( 'sidewalk' , 8 , 1 + 8, 'flat' , 1 , False , False , (244, 35,232) ),
|
73 |
+
Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
|
74 |
+
Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
|
75 |
+
Label( 'building' , 11 , 2 + 8, 'construction' , 2 , False , False , ( 70, 70, 70) ),
|
76 |
+
Label( 'wall' , 12 , 3 + 8, 'construction' , 2 , False , False , (102,102,156) ),
|
77 |
+
Label( 'fence' , 13 , 4 + 8, 'construction' , 2 , False , False , (190,153,153) ),
|
78 |
+
Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
|
79 |
+
Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
|
80 |
+
Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
|
81 |
+
Label( 'pole' , 17 , 5 + 8, 'object' , 3 , False , False , (153,153,153) ),
|
82 |
+
Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
|
83 |
+
Label( 'traffic light' , 19 , 6 + 8, 'object' , 3 , False , False , (250,170, 30) ),
|
84 |
+
Label( 'traffic sign' , 20 , 7 + 8, 'object' , 3 , False , False , (220,220, 0) ),
|
85 |
+
Label( 'vegetation' , 21 , 8 + 8, 'nature' , 4 , False , False , (107,142, 35) ),
|
86 |
+
Label( 'terrain' , 22 , 9 + 8, 'nature' , 4 , False , False , (152,251,152) ),
|
87 |
+
Label( 'sky' , 23 , 10 + 8, 'sky' , 5 , False , False , ( 70,130,180) ),
|
88 |
+
Label( 'person' , 24 , 11 - 11 , 'human' , 6 , True , False , (220, 20, 60) ),
|
89 |
+
Label( 'rider' , 25 , 12 - 11 , 'human' , 6 , True , False , (255, 0, 0) ),
|
90 |
+
Label( 'car' , 26 , 13 - 11, 'vehicle' , 7 , True , False , ( 0, 0,142) ),
|
91 |
+
Label( 'truck' , 27 , 14 - 11, 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
|
92 |
+
Label( 'bus' , 28 , 15 - 11, 'vehicle' , 7 , True , False , ( 0, 60,100) ),
|
93 |
+
Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
|
94 |
+
Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
|
95 |
+
Label( 'train' , 31 , 16 - 11, 'vehicle' , 7 , True , False , ( 0, 80,100) ),
|
96 |
+
Label( 'motorcycle' , 32 , 17 - 11, 'vehicle' , 7 , True , False , ( 0, 0,230) ),
|
97 |
+
Label( 'bicycle' , 33 , 18 - 11, 'vehicle' , 7 , True , False , (119, 11, 32) ),
|
98 |
+
Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
|
99 |
+
]
|
100 |
+
|
101 |
+
|
102 |
+
#--------------------------------------------------------------------------------
|
103 |
+
# Create dictionaries for a fast lookup
|
104 |
+
#--------------------------------------------------------------------------------
|
105 |
+
|
106 |
+
# Please refer to the main method below for example usages!
|
107 |
+
|
108 |
+
# name to label object
|
109 |
+
name2label = { label.name : label for label in labels }
|
110 |
+
# id to label object
|
111 |
+
id2label = { label.id : label for label in labels }
|
112 |
+
# trainId to label object
|
113 |
+
trainId2label = { label.trainId : label for label in reversed(labels) }
|
114 |
+
# category to list of label objects
|
115 |
+
category2labels = {}
|
116 |
+
for label in labels:
|
117 |
+
category = label.category
|
118 |
+
if category in category2labels:
|
119 |
+
category2labels[category].append(label)
|
120 |
+
else:
|
121 |
+
category2labels[category] = [label]
|
122 |
+
|
123 |
+
#--------------------------------------------------------------------------------
|
124 |
+
# Assure single instance name
|
125 |
+
#--------------------------------------------------------------------------------
|
126 |
+
|
127 |
+
# returns the label name that describes a single instance (if possible)
|
128 |
+
# e.g. input | output
|
129 |
+
# ----------------------
|
130 |
+
# car | car
|
131 |
+
# cargroup | car
|
132 |
+
# foo | None
|
133 |
+
# foogroup | None
|
134 |
+
# skygroup | None
|
135 |
+
def assureSingleInstanceName( name ):
|
136 |
+
# if the name is known, it is not a group
|
137 |
+
if name in name2label:
|
138 |
+
return name
|
139 |
+
# test if the name actually denotes a group
|
140 |
+
if not name.endswith("group"):
|
141 |
+
return None
|
142 |
+
# remove group
|
143 |
+
name = name[:-len("group")]
|
144 |
+
# test if the new name exists
|
145 |
+
if not name in name2label:
|
146 |
+
return None
|
147 |
+
# test if the new name denotes a label that actually has instances
|
148 |
+
if not name2label[name].hasInstances:
|
149 |
+
return None
|
150 |
+
# all good then
|
151 |
+
return name
|
152 |
+
|
153 |
+
#--------------------------------------------------------------------------------
|
154 |
+
# Main for testing
|
155 |
+
#--------------------------------------------------------------------------------
|
156 |
+
|
157 |
+
# just a dummy main
|
158 |
+
if __name__ == "__main__":
|
159 |
+
# Print all the labels
|
160 |
+
print("List of cityscapes labels:")
|
161 |
+
print("")
|
162 |
+
print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))
|
163 |
+
print(" " + ('-' * 98))
|
164 |
+
for label in labels:
|
165 |
+
print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))
|
166 |
+
print("")
|
167 |
+
|
168 |
+
print("Example usages:")
|
169 |
+
|
170 |
+
# Map from name to label
|
171 |
+
name = 'car'
|
172 |
+
id = name2label[name].id
|
173 |
+
print("ID of label '{name}': {id}".format( name=name, id=id ))
|
174 |
+
|
175 |
+
# Map from ID to label
|
176 |
+
category = id2label[id].category
|
177 |
+
print("Category of label with ID '{id}': {category}".format( id=id, category=category ))
|
178 |
+
|
179 |
+
# Map from trainID to label
|
180 |
+
trainId = 0
|
181 |
+
name = trainId2label[trainId].name
|
182 |
+
print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))
|
ext/cityscapes_scripts/helpers/labels_cityPersons.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
#
|
3 |
+
# CityPersons (cp) labels
|
4 |
+
#
|
5 |
+
|
6 |
+
from __future__ import print_function, absolute_import, division
|
7 |
+
from collections import namedtuple
|
8 |
+
|
9 |
+
|
10 |
+
#--------------------------------------------------------------------------------
|
11 |
+
# Definitions
|
12 |
+
#--------------------------------------------------------------------------------
|
13 |
+
|
14 |
+
# a label and all meta information
|
15 |
+
LabelCp = namedtuple( 'LabelCp' , [
|
16 |
+
|
17 |
+
'name' , # The identifier of this label, e.g. 'pedestrian', 'rider', ... .
|
18 |
+
# We use them to uniquely name a class
|
19 |
+
|
20 |
+
'id' , # An integer ID that is associated with this label.
|
21 |
+
# The IDs are used to represent the label in ground truth
|
22 |
+
|
23 |
+
'hasInstances', # Whether this label distinguishes between single instances or not
|
24 |
+
|
25 |
+
'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
|
26 |
+
# during evaluations or not
|
27 |
+
|
28 |
+
'color' , # The color of this label
|
29 |
+
] )
|
30 |
+
|
31 |
+
|
32 |
+
#--------------------------------------------------------------------------------
|
33 |
+
# A list of all labels
|
34 |
+
#--------------------------------------------------------------------------------
|
35 |
+
|
36 |
+
# The 'ignore' label covers representations of humans, e.g. people on posters, reflections etc.
|
37 |
+
# Each annotation includes both the full bounding box (bbox) as well as a bounding box covering the visible area (bboxVis).
|
38 |
+
# The latter is obtained automatically from the segmentation masks.
|
39 |
+
|
40 |
+
labelsCp = [
|
41 |
+
# name id hasInstances ignoreInEval color
|
42 |
+
LabelCp( 'ignore' , 0 , False , True , (250,170, 30) ),
|
43 |
+
LabelCp( 'pedestrian' , 1 , True , False , (220, 20, 60) ),
|
44 |
+
LabelCp( 'rider' , 2 , True , False , ( 0, 0,142) ),
|
45 |
+
LabelCp( 'sitting person' , 3 , True , False , (107,142, 35) ),
|
46 |
+
LabelCp( 'person (other)' , 4 , True , False , (190,153,153) ),
|
47 |
+
LabelCp( 'person group' , 5 , False , True , (255, 0, 0) ),
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
#--------------------------------------------------------------------------------
|
52 |
+
# Create dictionaries for a fast lookup
|
53 |
+
#--------------------------------------------------------------------------------
|
54 |
+
|
55 |
+
# Please refer to the main method below for example usages!
|
56 |
+
|
57 |
+
# name to label object
|
58 |
+
name2labelCp = { label.name : label for label in labelsCp }
|
59 |
+
# id to label object
|
60 |
+
id2labelCp = { label.id : label for label in labelsCp }
|
61 |
+
|
ext/cityscapes_scripts/helpers/version.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
with open(os.path.join(os.path.dirname(__file__), '..', 'VERSION')) as f:
|
6 |
+
version = f.read().strip()
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
print(version)
|
ext/class_names/VIPSeg.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CLASSES = [
|
2 |
+
{"id": 0, "name": "wall", "isthing": 0, "color": [120, 120, 120]},
|
3 |
+
{"id": 1, "name": "ceiling", "isthing": 0, "color": [180, 120, 120]},
|
4 |
+
{"id": 2, "name": "door", "isthing": 1, "color": [6, 230, 230]},
|
5 |
+
{"id": 3, "name": "stair", "isthing": 0, "color": [80, 50, 50]},
|
6 |
+
{"id": 4, "name": "ladder", "isthing": 1, "color": [4, 200, 3]},
|
7 |
+
{"id": 5, "name": "escalator", "isthing": 0, "color": [120, 120, 80]},
|
8 |
+
{"id": 6, "name": "Playground_slide", "isthing": 0, "color": [140, 140, 140]},
|
9 |
+
{"id": 7, "name": "handrail_or_fence", "isthing": 0, "color": [204, 5, 255]},
|
10 |
+
{"id": 8, "name": "window", "isthing": 1, "color": [230, 230, 230]},
|
11 |
+
{"id": 9, "name": "rail", "isthing": 0, "color": [4, 250, 7]},
|
12 |
+
{"id": 10, "name": "goal", "isthing": 1, "color": [224, 5, 255]},
|
13 |
+
{"id": 11, "name": "pillar", "isthing": 0, "color": [235, 255, 7]},
|
14 |
+
{"id": 12, "name": "pole", "isthing": 0, "color": [150, 5, 61]},
|
15 |
+
{"id": 13, "name": "floor", "isthing": 0, "color": [120, 120, 70]},
|
16 |
+
{"id": 14, "name": "ground", "isthing": 0, "color": [8, 255, 51]},
|
17 |
+
{"id": 15, "name": "grass", "isthing": 0, "color": [255, 6, 82]},
|
18 |
+
{"id": 16, "name": "sand", "isthing": 0, "color": [143, 255, 140]},
|
19 |
+
{"id": 17, "name": "athletic_field", "isthing": 0, "color": [204, 255, 4]},
|
20 |
+
{"id": 18, "name": "road", "isthing": 0, "color": [255, 51, 7]},
|
21 |
+
{"id": 19, "name": "path", "isthing": 0, "color": [204, 70, 3]},
|
22 |
+
{"id": 20, "name": "crosswalk", "isthing": 0, "color": [0, 102, 200]},
|
23 |
+
{"id": 21, "name": "building", "isthing": 0, "color": [61, 230, 250]},
|
24 |
+
{"id": 22, "name": "house", "isthing": 0, "color": [255, 6, 51]},
|
25 |
+
{"id": 23, "name": "bridge", "isthing": 0, "color": [11, 102, 255]},
|
26 |
+
{"id": 24, "name": "tower", "isthing": 0, "color": [255, 7, 71]},
|
27 |
+
{"id": 25, "name": "windmill", "isthing": 0, "color": [255, 9, 224]},
|
28 |
+
{"id": 26, "name": "well_or_well_lid", "isthing": 0, "color": [9, 7, 230]},
|
29 |
+
{"id": 27, "name": "other_construction", "isthing": 0, "color": [220, 220, 220]},
|
30 |
+
{"id": 28, "name": "sky", "isthing": 0, "color": [255, 9, 92]},
|
31 |
+
{"id": 29, "name": "mountain", "isthing": 0, "color": [112, 9, 255]},
|
32 |
+
{"id": 30, "name": "stone", "isthing": 0, "color": [8, 255, 214]},
|
33 |
+
{"id": 31, "name": "wood", "isthing": 0, "color": [7, 255, 224]},
|
34 |
+
{"id": 32, "name": "ice", "isthing": 0, "color": [255, 184, 6]},
|
35 |
+
{"id": 33, "name": "snowfield", "isthing": 0, "color": [10, 255, 71]},
|
36 |
+
{"id": 34, "name": "grandstand", "isthing": 0, "color": [255, 41, 10]},
|
37 |
+
{"id": 35, "name": "sea", "isthing": 0, "color": [7, 255, 255]},
|
38 |
+
{"id": 36, "name": "river", "isthing": 0, "color": [224, 255, 8]},
|
39 |
+
{"id": 37, "name": "lake", "isthing": 0, "color": [102, 8, 255]},
|
40 |
+
{"id": 38, "name": "waterfall", "isthing": 0, "color": [255, 61, 6]},
|
41 |
+
{"id": 39, "name": "water", "isthing": 0, "color": [255, 194, 7]},
|
42 |
+
{"id": 40, "name": "billboard_or_Bulletin_Board", "isthing": 0, "color": [255, 122, 8]},
|
43 |
+
{"id": 41, "name": "sculpture", "isthing": 1, "color": [0, 255, 20]},
|
44 |
+
{"id": 42, "name": "pipeline", "isthing": 0, "color": [255, 8, 41]},
|
45 |
+
{"id": 43, "name": "flag", "isthing": 1, "color": [255, 5, 153]},
|
46 |
+
{"id": 44, "name": "parasol_or_umbrella", "isthing": 1, "color": [6, 51, 255]},
|
47 |
+
{"id": 45, "name": "cushion_or_carpet", "isthing": 0, "color": [235, 12, 255]},
|
48 |
+
{"id": 46, "name": "tent", "isthing": 1, "color": [160, 150, 20]},
|
49 |
+
{"id": 47, "name": "roadblock", "isthing": 1, "color": [0, 163, 255]},
|
50 |
+
{"id": 48, "name": "car", "isthing": 1, "color": [140, 140, 140]},
|
51 |
+
{"id": 49, "name": "bus", "isthing": 1, "color": [250, 10, 15]},
|
52 |
+
{"id": 50, "name": "truck", "isthing": 1, "color": [20, 255, 0]},
|
53 |
+
{"id": 51, "name": "bicycle", "isthing": 1, "color": [31, 255, 0]},
|
54 |
+
{"id": 52, "name": "motorcycle", "isthing": 1, "color": [255, 31, 0]},
|
55 |
+
{"id": 53, "name": "wheeled_machine", "isthing": 0, "color": [255, 224, 0]},
|
56 |
+
{"id": 54, "name": "ship_or_boat", "isthing": 1, "color": [153, 255, 0]},
|
57 |
+
{"id": 55, "name": "raft", "isthing": 1, "color": [0, 0, 255]},
|
58 |
+
{"id": 56, "name": "airplane", "isthing": 1, "color": [255, 71, 0]},
|
59 |
+
{"id": 57, "name": "tyre", "isthing": 0, "color": [0, 235, 255]},
|
60 |
+
{"id": 58, "name": "traffic_light", "isthing": 0, "color": [0, 173, 255]},
|
61 |
+
{"id": 59, "name": "lamp", "isthing": 0, "color": [31, 0, 255]},
|
62 |
+
{"id": 60, "name": "person", "isthing": 1, "color": [11, 200, 200]},
|
63 |
+
{"id": 61, "name": "cat", "isthing": 1, "color": [255, 82, 0]},
|
64 |
+
{"id": 62, "name": "dog", "isthing": 1, "color": [0, 255, 245]},
|
65 |
+
{"id": 63, "name": "horse", "isthing": 1, "color": [0, 61, 255]},
|
66 |
+
{"id": 64, "name": "cattle", "isthing": 1, "color": [0, 255, 112]},
|
67 |
+
{"id": 65, "name": "other_animal", "isthing": 1, "color": [0, 255, 133]},
|
68 |
+
{"id": 66, "name": "tree", "isthing": 0, "color": [255, 0, 0]},
|
69 |
+
{"id": 67, "name": "flower", "isthing": 0, "color": [255, 163, 0]},
|
70 |
+
{"id": 68, "name": "other_plant", "isthing": 0, "color": [255, 102, 0]},
|
71 |
+
{"id": 69, "name": "toy", "isthing": 0, "color": [194, 255, 0]},
|
72 |
+
{"id": 70, "name": "ball_net", "isthing": 0, "color": [0, 143, 255]},
|
73 |
+
{"id": 71, "name": "backboard", "isthing": 0, "color": [51, 255, 0]},
|
74 |
+
{"id": 72, "name": "skateboard", "isthing": 1, "color": [0, 82, 255]},
|
75 |
+
{"id": 73, "name": "bat", "isthing": 0, "color": [0, 255, 41]},
|
76 |
+
{"id": 74, "name": "ball", "isthing": 1, "color": [0, 255, 173]},
|
77 |
+
{"id": 75, "name": "cupboard_or_showcase_or_storage_rack", "isthing": 0, "color": [10, 0, 255]},
|
78 |
+
{"id": 76, "name": "box", "isthing": 1, "color": [173, 255, 0]},
|
79 |
+
{"id": 77, "name": "traveling_case_or_trolley_case", "isthing": 1, "color": [0, 255, 153]},
|
80 |
+
{"id": 78, "name": "basket", "isthing": 1, "color": [255, 92, 0]},
|
81 |
+
{"id": 79, "name": "bag_or_package", "isthing": 1, "color": [255, 0, 255]},
|
82 |
+
{"id": 80, "name": "trash_can", "isthing": 0, "color": [255, 0, 245]},
|
83 |
+
{"id": 81, "name": "cage", "isthing": 0, "color": [255, 0, 102]},
|
84 |
+
{"id": 82, "name": "plate", "isthing": 1, "color": [255, 173, 0]},
|
85 |
+
{"id": 83, "name": "tub_or_bowl_or_pot", "isthing": 1, "color": [255, 0, 20]},
|
86 |
+
{"id": 84, "name": "bottle_or_cup", "isthing": 1, "color": [255, 184, 184]},
|
87 |
+
{"id": 85, "name": "barrel", "isthing": 1, "color": [0, 31, 255]},
|
88 |
+
{"id": 86, "name": "fishbowl", "isthing": 1, "color": [0, 255, 61]},
|
89 |
+
{"id": 87, "name": "bed", "isthing": 1, "color": [0, 71, 255]},
|
90 |
+
{"id": 88, "name": "pillow", "isthing": 1, "color": [255, 0, 204]},
|
91 |
+
{"id": 89, "name": "table_or_desk", "isthing": 1, "color": [0, 255, 194]},
|
92 |
+
{"id": 90, "name": "chair_or_seat", "isthing": 1, "color": [0, 255, 82]},
|
93 |
+
{"id": 91, "name": "bench", "isthing": 1, "color": [0, 10, 255]},
|
94 |
+
{"id": 92, "name": "sofa", "isthing": 1, "color": [0, 112, 255]},
|
95 |
+
{"id": 93, "name": "shelf", "isthing": 0, "color": [51, 0, 255]},
|
96 |
+
{"id": 94, "name": "bathtub", "isthing": 0, "color": [0, 194, 255]},
|
97 |
+
{"id": 95, "name": "gun", "isthing": 1, "color": [0, 122, 255]},
|
98 |
+
{"id": 96, "name": "commode", "isthing": 1, "color": [0, 255, 163]},
|
99 |
+
{"id": 97, "name": "roaster", "isthing": 1, "color": [255, 153, 0]},
|
100 |
+
{"id": 98, "name": "other_machine", "isthing": 0, "color": [0, 255, 10]},
|
101 |
+
{"id": 99, "name": "refrigerator", "isthing": 1, "color": [255, 112, 0]},
|
102 |
+
{"id": 100, "name": "washing_machine", "isthing": 1, "color": [143, 255, 0]},
|
103 |
+
{"id": 101, "name": "Microwave_oven", "isthing": 1, "color": [82, 0, 255]},
|
104 |
+
{"id": 102, "name": "fan", "isthing": 1, "color": [163, 255, 0]},
|
105 |
+
{"id": 103, "name": "curtain", "isthing": 0, "color": [255, 235, 0]},
|
106 |
+
{"id": 104, "name": "textiles", "isthing": 0, "color": [8, 184, 170]},
|
107 |
+
{"id": 105, "name": "clothes", "isthing": 0, "color": [133, 0, 255]},
|
108 |
+
{"id": 106, "name": "painting_or_poster", "isthing": 1, "color": [0, 255, 92]},
|
109 |
+
{"id": 107, "name": "mirror", "isthing": 1, "color": [184, 0, 255]},
|
110 |
+
{"id": 108, "name": "flower_pot_or_vase", "isthing": 1, "color": [255, 0, 31]},
|
111 |
+
{"id": 109, "name": "clock", "isthing": 1, "color": [0, 184, 255]},
|
112 |
+
{"id": 110, "name": "book", "isthing": 0, "color": [0, 214, 255]},
|
113 |
+
{"id": 111, "name": "tool", "isthing": 0, "color": [255, 0, 112]},
|
114 |
+
{"id": 112, "name": "blackboard", "isthing": 0, "color": [92, 255, 0]},
|
115 |
+
{"id": 113, "name": "tissue", "isthing": 0, "color": [0, 224, 255]},
|
116 |
+
{"id": 114, "name": "screen_or_television", "isthing": 1, "color": [112, 224, 255]},
|
117 |
+
{"id": 115, "name": "computer", "isthing": 1, "color": [70, 184, 160]},
|
118 |
+
{"id": 116, "name": "printer", "isthing": 1, "color": [163, 0, 255]},
|
119 |
+
{"id": 117, "name": "Mobile_phone", "isthing": 1, "color": [153, 0, 255]},
|
120 |
+
{"id": 118, "name": "keyboard", "isthing": 1, "color": [71, 255, 0]},
|
121 |
+
{"id": 119, "name": "other_electronic_product", "isthing": 0, "color": [255, 0, 163]},
|
122 |
+
{"id": 120, "name": "fruit", "isthing": 0, "color": [255, 204, 0]},
|
123 |
+
{"id": 121, "name": "food", "isthing": 0, "color": [255, 0, 143]},
|
124 |
+
{"id": 122, "name": "instrument", "isthing": 1, "color": [0, 255, 235]},
|
125 |
+
{"id": 123, "name": "train", "isthing": 1, "color": [133, 255, 0]}
|
126 |
+
]
|
127 |
+
|
128 |
+
CLASSES_THING = [
|
129 |
+
{'id': 2, 'name': 'door', 'isthing': 1, 'color': [6, 230, 230]},
|
130 |
+
{'id': 4, 'name': 'ladder', 'isthing': 1, 'color': [4, 200, 3]},
|
131 |
+
{'id': 8, 'name': 'window', 'isthing': 1, 'color': [230, 230, 230]},
|
132 |
+
{'id': 10, 'name': 'goal', 'isthing': 1, 'color': [224, 5, 255]},
|
133 |
+
{'id': 41, 'name': 'sculpture', 'isthing': 1, 'color': [0, 255, 20]},
|
134 |
+
{'id': 43, 'name': 'flag', 'isthing': 1, 'color': [255, 5, 153]},
|
135 |
+
{'id': 44, 'name': 'parasol_or_umbrella', 'isthing': 1, 'color': [6, 51, 255]},
|
136 |
+
{'id': 46, 'name': 'tent', 'isthing': 1, 'color': [160, 150, 20]},
|
137 |
+
{'id': 47, 'name': 'roadblock', 'isthing': 1, 'color': [0, 163, 255]},
|
138 |
+
{'id': 48, 'name': 'car', 'isthing': 1, 'color': [140, 140, 140]},
|
139 |
+
{'id': 49, 'name': 'bus', 'isthing': 1, 'color': [250, 10, 15]},
|
140 |
+
{'id': 50, 'name': 'truck', 'isthing': 1, 'color': [20, 255, 0]},
|
141 |
+
{'id': 51, 'name': 'bicycle', 'isthing': 1, 'color': [31, 255, 0]},
|
142 |
+
{'id': 52, 'name': 'motorcycle', 'isthing': 1, 'color': [255, 31, 0]},
|
143 |
+
{'id': 54, 'name': 'ship_or_boat', 'isthing': 1, 'color': [153, 255, 0]},
|
144 |
+
{'id': 55, 'name': 'raft', 'isthing': 1, 'color': [0, 0, 255]},
|
145 |
+
{'id': 56, 'name': 'airplane', 'isthing': 1, 'color': [255, 71, 0]},
|
146 |
+
{'id': 60, 'name': 'person', 'isthing': 1, 'color': [11, 200, 200]},
|
147 |
+
{'id': 61, 'name': 'cat', 'isthing': 1, 'color': [255, 82, 0]},
|
148 |
+
{'id': 62, 'name': 'dog', 'isthing': 1, 'color': [0, 255, 245]},
|
149 |
+
{'id': 63, 'name': 'horse', 'isthing': 1, 'color': [0, 61, 255]},
|
150 |
+
{'id': 64, 'name': 'cattle', 'isthing': 1, 'color': [0, 255, 112]},
|
151 |
+
{'id': 65, 'name': 'other_animal', 'isthing': 1, 'color': [0, 255, 133]},
|
152 |
+
{'id': 72, 'name': 'skateboard', 'isthing': 1, 'color': [0, 82, 255]},
|
153 |
+
{'id': 74, 'name': 'ball', 'isthing': 1, 'color': [0, 255, 173]},
|
154 |
+
{'id': 76, 'name': 'box', 'isthing': 1, 'color': [173, 255, 0]},
|
155 |
+
{'id': 77, 'name': 'traveling_case_or_trolley_case', 'isthing': 1, 'color': [0, 255, 153]},
|
156 |
+
{'id': 78, 'name': 'basket', 'isthing': 1, 'color': [255, 92, 0]},
|
157 |
+
{'id': 79, 'name': 'bag_or_package', 'isthing': 1, 'color': [255, 0, 255]},
|
158 |
+
{'id': 82, 'name': 'plate', 'isthing': 1, 'color': [255, 173, 0]},
|
159 |
+
{'id': 83, 'name': 'tub_or_bowl_or_pot', 'isthing': 1, 'color': [255, 0, 20]},
|
160 |
+
{'id': 84, 'name': 'bottle_or_cup', 'isthing': 1, 'color': [255, 184, 184]},
|
161 |
+
{'id': 85, 'name': 'barrel', 'isthing': 1, 'color': [0, 31, 255]},
|
162 |
+
{'id': 86, 'name': 'fishbowl', 'isthing': 1, 'color': [0, 255, 61]},
|
163 |
+
{'id': 87, 'name': 'bed', 'isthing': 1, 'color': [0, 71, 255]},
|
164 |
+
{'id': 88, 'name': 'pillow', 'isthing': 1, 'color': [255, 0, 204]},
|
165 |
+
{'id': 89, 'name': 'table_or_desk', 'isthing': 1, 'color': [0, 255, 194]},
|
166 |
+
{'id': 90, 'name': 'chair_or_seat', 'isthing': 1, 'color': [0, 255, 82]},
|
167 |
+
{'id': 91, 'name': 'bench', 'isthing': 1, 'color': [0, 10, 255]},
|
168 |
+
{'id': 92, 'name': 'sofa', 'isthing': 1, 'color': [0, 112, 255]},
|
169 |
+
{'id': 95, 'name': 'gun', 'isthing': 1, 'color': [0, 122, 255]},
|
170 |
+
{'id': 96, 'name': 'commode', 'isthing': 1, 'color': [0, 255, 163]},
|
171 |
+
{'id': 97, 'name': 'roaster', 'isthing': 1, 'color': [255, 153, 0]},
|
172 |
+
{'id': 99, 'name': 'refrigerator', 'isthing': 1, 'color': [255, 112, 0]},
|
173 |
+
{'id': 100, 'name': 'washing_machine', 'isthing': 1, 'color': [143, 255, 0]},
|
174 |
+
{'id': 101, 'name': 'Microwave_oven', 'isthing': 1, 'color': [82, 0, 255]},
|
175 |
+
{'id': 102, 'name': 'fan', 'isthing': 1, 'color': [163, 255, 0]},
|
176 |
+
{'id': 106, 'name': 'painting_or_poster', 'isthing': 1, 'color': [0, 255, 92]},
|
177 |
+
{'id': 107, 'name': 'mirror', 'isthing': 1, 'color': [184, 0, 255]},
|
178 |
+
{'id': 108, 'name': 'flower_pot_or_vase', 'isthing': 1, 'color': [255, 0, 31]},
|
179 |
+
{'id': 109, 'name': 'clock', 'isthing': 1, 'color': [0, 184, 255]},
|
180 |
+
{'id': 114, 'name': 'screen_or_television', 'isthing': 1, 'color': [112, 224, 255]},
|
181 |
+
{'id': 115, 'name': 'computer', 'isthing': 1, 'color': [70, 184, 160]},
|
182 |
+
{'id': 116, 'name': 'printer', 'isthing': 1, 'color': [163, 0, 255]},
|
183 |
+
{'id': 117, 'name': 'Mobile_phone', 'isthing': 1, 'color': [153, 0, 255]},
|
184 |
+
{'id': 118, 'name': 'keyboard', 'isthing': 1, 'color': [71, 255, 0]},
|
185 |
+
{'id': 122, 'name': 'instrument', 'isthing': 1, 'color': [0, 255, 235]},
|
186 |
+
{'id': 123, 'name': 'train', 'isthing': 1, 'color': [133, 255, 0]}
|
187 |
+
]
|
188 |
+
|
189 |
+
CLASSES_STUFF = [
|
190 |
+
{'id': 0, 'name': 'wall', 'isthing': 0, 'color': [120, 120, 120]},
|
191 |
+
{'id': 1, 'name': 'ceiling', 'isthing': 0, 'color': [180, 120, 120]},
|
192 |
+
{'id': 3, 'name': 'stair', 'isthing': 0, 'color': [80, 50, 50]},
|
193 |
+
{'id': 5, 'name': 'escalator', 'isthing': 0, 'color': [120, 120, 80]},
|
194 |
+
{'id': 6, 'name': 'Playground_slide', 'isthing': 0, 'color': [140, 140, 140]},
|
195 |
+
{'id': 7, 'name': 'handrail_or_fence', 'isthing': 0, 'color': [204, 5, 255]},
|
196 |
+
{'id': 9, 'name': 'rail', 'isthing': 0, 'color': [4, 250, 7]},
|
197 |
+
{'id': 11, 'name': 'pillar', 'isthing': 0, 'color': [235, 255, 7]},
|
198 |
+
{'id': 12, 'name': 'pole', 'isthing': 0, 'color': [150, 5, 61]},
|
199 |
+
{'id': 13, 'name': 'floor', 'isthing': 0, 'color': [120, 120, 70]},
|
200 |
+
{'id': 14, 'name': 'ground', 'isthing': 0, 'color': [8, 255, 51]},
|
201 |
+
{'id': 15, 'name': 'grass', 'isthing': 0, 'color': [255, 6, 82]},
|
202 |
+
{'id': 16, 'name': 'sand', 'isthing': 0, 'color': [143, 255, 140]},
|
203 |
+
{'id': 17, 'name': 'athletic_field', 'isthing': 0, 'color': [204, 255, 4]},
|
204 |
+
{'id': 18, 'name': 'road', 'isthing': 0, 'color': [255, 51, 7]},
|
205 |
+
{'id': 19, 'name': 'path', 'isthing': 0, 'color': [204, 70, 3]},
|
206 |
+
{'id': 20, 'name': 'crosswalk', 'isthing': 0, 'color': [0, 102, 200]},
|
207 |
+
{'id': 21, 'name': 'building', 'isthing': 0, 'color': [61, 230, 250]},
|
208 |
+
{'id': 22, 'name': 'house', 'isthing': 0, 'color': [255, 6, 51]},
|
209 |
+
{'id': 23, 'name': 'bridge', 'isthing': 0, 'color': [11, 102, 255]},
|
210 |
+
{'id': 24, 'name': 'tower', 'isthing': 0, 'color': [255, 7, 71]},
|
211 |
+
{'id': 25, 'name': 'windmill', 'isthing': 0, 'color': [255, 9, 224]},
|
212 |
+
{'id': 26, 'name': 'well_or_well_lid', 'isthing': 0, 'color': [9, 7, 230]},
|
213 |
+
{'id': 27, 'name': 'other_construction', 'isthing': 0, 'color': [220, 220, 220]},
|
214 |
+
{'id': 28, 'name': 'sky', 'isthing': 0, 'color': [255, 9, 92]},
|
215 |
+
{'id': 29, 'name': 'mountain', 'isthing': 0, 'color': [112, 9, 255]},
|
216 |
+
{'id': 30, 'name': 'stone', 'isthing': 0, 'color': [8, 255, 214]},
|
217 |
+
{'id': 31, 'name': 'wood', 'isthing': 0, 'color': [7, 255, 224]},
|
218 |
+
{'id': 32, 'name': 'ice', 'isthing': 0, 'color': [255, 184, 6]},
|
219 |
+
{'id': 33, 'name': 'snowfield', 'isthing': 0, 'color': [10, 255, 71]},
|
220 |
+
{'id': 34, 'name': 'grandstand', 'isthing': 0, 'color': [255, 41, 10]},
|
221 |
+
{'id': 35, 'name': 'sea', 'isthing': 0, 'color': [7, 255, 255]},
|
222 |
+
{'id': 36, 'name': 'river', 'isthing': 0, 'color': [224, 255, 8]},
|
223 |
+
{'id': 37, 'name': 'lake', 'isthing': 0, 'color': [102, 8, 255]},
|
224 |
+
{'id': 38, 'name': 'waterfall', 'isthing': 0, 'color': [255, 61, 6]},
|
225 |
+
{'id': 39, 'name': 'water', 'isthing': 0, 'color': [255, 194, 7]},
|
226 |
+
{'id': 40, 'name': 'billboard_or_Bulletin_Board', 'isthing': 0, 'color': [255, 122, 8]},
|
227 |
+
{'id': 42, 'name': 'pipeline', 'isthing': 0, 'color': [255, 8, 41]},
|
228 |
+
{'id': 45, 'name': 'cushion_or_carpet', 'isthing': 0, 'color': [235, 12, 255]},
|
229 |
+
{'id': 53, 'name': 'wheeled_machine', 'isthing': 0, 'color': [255, 224, 0]},
|
230 |
+
{'id': 57, 'name': 'tyre', 'isthing': 0, 'color': [0, 235, 255]},
|
231 |
+
{'id': 58, 'name': 'traffic_light', 'isthing': 0, 'color': [0, 173, 255]},
|
232 |
+
{'id': 59, 'name': 'lamp', 'isthing': 0, 'color': [31, 0, 255]},
|
233 |
+
{'id': 66, 'name': 'tree', 'isthing': 0, 'color': [255, 0, 0]},
|
234 |
+
{'id': 67, 'name': 'flower', 'isthing': 0, 'color': [255, 163, 0]},
|
235 |
+
{'id': 68, 'name': 'other_plant', 'isthing': 0, 'color': [255, 102, 0]},
|
236 |
+
{'id': 69, 'name': 'toy', 'isthing': 0, 'color': [194, 255, 0]},
|
237 |
+
{'id': 70, 'name': 'ball_net', 'isthing': 0, 'color': [0, 143, 255]},
|
238 |
+
{'id': 71, 'name': 'backboard', 'isthing': 0, 'color': [51, 255, 0]},
|
239 |
+
{'id': 73, 'name': 'bat', 'isthing': 0, 'color': [0, 255, 41]},
|
240 |
+
{'id': 75, 'name': 'cupboard_or_showcase_or_storage_rack', 'isthing': 0, 'color': [10, 0, 255]},
|
241 |
+
{'id': 80, 'name': 'trash_can', 'isthing': 0, 'color': [255, 0, 245]},
|
242 |
+
{'id': 81, 'name': 'cage', 'isthing': 0, 'color': [255, 0, 102]},
|
243 |
+
{'id': 93, 'name': 'shelf', 'isthing': 0, 'color': [51, 0, 255]},
|
244 |
+
{'id': 94, 'name': 'bathtub', 'isthing': 0, 'color': [0, 194, 255]},
|
245 |
+
{'id': 98, 'name': 'other_machine', 'isthing': 0, 'color': [0, 255, 10]},
|
246 |
+
{'id': 103, 'name': 'curtain', 'isthing': 0, 'color': [255, 235, 0]},
|
247 |
+
{'id': 104, 'name': 'textiles', 'isthing': 0, 'color': [8, 184, 170]},
|
248 |
+
{'id': 105, 'name': 'clothes', 'isthing': 0, 'color': [133, 0, 255]},
|
249 |
+
{'id': 110, 'name': 'book', 'isthing': 0, 'color': [0, 214, 255]},
|
250 |
+
{'id': 111, 'name': 'tool', 'isthing': 0, 'color': [255, 0, 112]},
|
251 |
+
{'id': 112, 'name': 'blackboard', 'isthing': 0, 'color': [92, 255, 0]},
|
252 |
+
{'id': 113, 'name': 'tissue', 'isthing': 0, 'color': [0, 224, 255]},
|
253 |
+
{'id': 119, 'name': 'other_electronic_product', 'isthing': 0, 'color': [255, 0, 163]},
|
254 |
+
{'id': 120, 'name': 'fruit', 'isthing': 0, 'color': [255, 204, 0]},
|
255 |
+
{'id': 121, 'name': 'food', 'isthing': 0, 'color': [255, 0, 143]}
|
256 |
+
]
|
257 |
+
|
258 |
+
COCO_THINGS = [itm['name'] for itm in CLASSES_THING]
|
259 |
+
COCO_STUFF = [itm['name'] for itm in CLASSES_STUFF]
|
260 |
+
COCO_CLASSES = [*COCO_THINGS, *COCO_STUFF]
|
261 |
+
PALETTE = [*[itm['color'] for itm in CLASSES_THING], *[itm['color'] for itm in CLASSES_STUFF]]
|
ext/davis2017/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
__version__ = '0.1.0'
|
ext/davis2017/davis.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
from collections import defaultdict
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class DAVIS(object):
|
9 |
+
SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge']
|
10 |
+
TASKS = ['semi-supervised', 'unsupervised']
|
11 |
+
DATASET_WEB = 'https://davischallenge.org/davis2017/code.html'
|
12 |
+
VOID_LABEL = 255
|
13 |
+
|
14 |
+
def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False):
|
15 |
+
"""
|
16 |
+
Class to read the DAVIS dataset
|
17 |
+
:param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders.
|
18 |
+
:param task: Task to load the annotations, choose between semi-supervised or unsupervised.
|
19 |
+
:param subset: Set to load the annotations
|
20 |
+
:param sequences: Sequences to consider, 'all' to use all the sequences in a set.
|
21 |
+
:param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution'
|
22 |
+
"""
|
23 |
+
if subset not in self.SUBSET_OPTIONS:
|
24 |
+
raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}')
|
25 |
+
if task not in self.TASKS:
|
26 |
+
raise ValueError(f'The only tasks that are supported are {self.TASKS}')
|
27 |
+
|
28 |
+
self.task = task
|
29 |
+
self.subset = subset
|
30 |
+
self.root = root
|
31 |
+
self.img_path = os.path.join(self.root, 'JPEGImages', resolution)
|
32 |
+
annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised'
|
33 |
+
self.mask_path = os.path.join(self.root, annotations_folder, resolution)
|
34 |
+
year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017'
|
35 |
+
self.imagesets_path = os.path.join(self.root, 'ImageSets', year)
|
36 |
+
|
37 |
+
self._check_directories()
|
38 |
+
|
39 |
+
if sequences == 'all':
|
40 |
+
with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f:
|
41 |
+
tmp = f.readlines()
|
42 |
+
sequences_names = [x.strip() for x in tmp]
|
43 |
+
else:
|
44 |
+
sequences_names = sequences if isinstance(sequences, list) else [sequences]
|
45 |
+
self.sequences = defaultdict(dict)
|
46 |
+
|
47 |
+
for seq in sequences_names:
|
48 |
+
images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist()
|
49 |
+
if len(images) == 0 and not codalab:
|
50 |
+
raise FileNotFoundError(f'Images for sequence {seq} not found.')
|
51 |
+
self.sequences[seq]['images'] = images
|
52 |
+
masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist()
|
53 |
+
masks.extend([-1] * (len(images) - len(masks)))
|
54 |
+
self.sequences[seq]['masks'] = masks
|
55 |
+
|
56 |
+
def _check_directories(self):
|
57 |
+
if not os.path.exists(self.root):
|
58 |
+
raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}')
|
59 |
+
if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')):
|
60 |
+
raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset '
|
61 |
+
f'for the {self.task} task from {self.DATASET_WEB}')
|
62 |
+
if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path):
|
63 |
+
raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}')
|
64 |
+
|
65 |
+
def get_frames(self, sequence):
|
66 |
+
for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']):
|
67 |
+
image = np.array(Image.open(img))
|
68 |
+
mask = None if msk is None else np.array(Image.open(msk))
|
69 |
+
yield image, mask
|
70 |
+
|
71 |
+
def _get_all_elements(self, sequence, obj_type):
|
72 |
+
obj = np.array(Image.open(self.sequences[sequence][obj_type][0]))
|
73 |
+
all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape))
|
74 |
+
obj_id = []
|
75 |
+
for i, obj in enumerate(self.sequences[sequence][obj_type]):
|
76 |
+
all_objs[i, ...] = np.array(Image.open(obj))
|
77 |
+
obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1]))
|
78 |
+
return all_objs, obj_id
|
79 |
+
|
80 |
+
def get_all_images(self, sequence):
|
81 |
+
return self._get_all_elements(sequence, 'images')
|
82 |
+
|
83 |
+
def get_all_masks(self, sequence, separate_objects_masks=False):
|
84 |
+
masks, masks_id = self._get_all_elements(sequence, 'masks')
|
85 |
+
masks_void = np.zeros_like(masks)
|
86 |
+
|
87 |
+
# Separate void and object masks
|
88 |
+
for i in range(masks.shape[0]):
|
89 |
+
masks_void[i, ...] = masks[i, ...] == 255
|
90 |
+
masks[i, masks[i, ...] == 255] = 0
|
91 |
+
|
92 |
+
if separate_objects_masks:
|
93 |
+
num_objects = int(np.max(masks[0, ...]))
|
94 |
+
tmp = np.ones((num_objects, *masks.shape))
|
95 |
+
tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None]
|
96 |
+
masks = (tmp == masks[None, ...])
|
97 |
+
masks = masks > 0
|
98 |
+
return masks, masks_void, masks_id
|
99 |
+
|
100 |
+
def get_sequences(self):
|
101 |
+
for seq in self.sequences:
|
102 |
+
yield seq
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
from matplotlib import pyplot as plt
|
107 |
+
|
108 |
+
only_first_frame = True
|
109 |
+
subsets = ['train', 'val']
|
110 |
+
|
111 |
+
for s in subsets:
|
112 |
+
dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s)
|
113 |
+
for seq in dataset.get_sequences():
|
114 |
+
g = dataset.get_frames(seq)
|
115 |
+
img, mask = next(g)
|
116 |
+
plt.subplot(2, 1, 1)
|
117 |
+
plt.title(seq)
|
118 |
+
plt.imshow(img)
|
119 |
+
plt.subplot(2, 1, 2)
|
120 |
+
plt.imshow(mask)
|
121 |
+
plt.show(block=True)
|
122 |
+
|
ext/davis2017/evaluation.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from tqdm import tqdm
|
3 |
+
import warnings
|
4 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from ext.davis2017.davis import DAVIS
|
8 |
+
from ext.davis2017.metrics import db_eval_boundary, db_eval_iou
|
9 |
+
from ext.davis2017 import utils
|
10 |
+
from ext.davis2017.results import Results
|
11 |
+
from scipy.optimize import linear_sum_assignment
|
12 |
+
|
13 |
+
|
14 |
+
class DAVISEvaluation(object):
|
15 |
+
def __init__(self, davis_root, task, gt_set, sequences='all', codalab=False):
|
16 |
+
"""
|
17 |
+
Class to evaluate DAVIS sequences from a certain set and for a certain task
|
18 |
+
:param davis_root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders.
|
19 |
+
:param task: Task to compute the evaluation, chose between semi-supervised or unsupervised.
|
20 |
+
:param gt_set: Set to compute the evaluation
|
21 |
+
:param sequences: Sequences to consider for the evaluation, 'all' to use all the sequences in a set.
|
22 |
+
"""
|
23 |
+
self.davis_root = davis_root
|
24 |
+
self.task = task
|
25 |
+
self.dataset = DAVIS(root=davis_root, task=task, subset=gt_set, sequences=sequences, codalab=codalab)
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric):
|
29 |
+
if all_res_masks.shape[0] > all_gt_masks.shape[0]:
|
30 |
+
sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!")
|
31 |
+
sys.exit()
|
32 |
+
elif all_res_masks.shape[0] < all_gt_masks.shape[0]:
|
33 |
+
zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:]))
|
34 |
+
all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0)
|
35 |
+
j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2])
|
36 |
+
for ii in range(all_gt_masks.shape[0]):
|
37 |
+
if 'J' in metric:
|
38 |
+
j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
|
39 |
+
if 'F' in metric:
|
40 |
+
f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
|
41 |
+
return j_metrics_res, f_metrics_res
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def _evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric, max_n_proposals=20):
|
45 |
+
if all_res_masks.shape[0] > max_n_proposals:
|
46 |
+
sys.stdout.write(f"\nIn your PNG files there is an index higher than the maximum number ({max_n_proposals}) of proposals allowed!")
|
47 |
+
sys.exit()
|
48 |
+
elif all_res_masks.shape[0] < all_gt_masks.shape[0]:
|
49 |
+
zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:]))
|
50 |
+
all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0)
|
51 |
+
j_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1]))
|
52 |
+
f_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1]))
|
53 |
+
for ii in range(all_gt_masks.shape[0]):
|
54 |
+
for jj in range(all_res_masks.shape[0]):
|
55 |
+
if 'J' in metric:
|
56 |
+
j_metrics_res[jj, ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks)
|
57 |
+
if 'F' in metric:
|
58 |
+
f_metrics_res[jj, ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks)
|
59 |
+
if 'J' in metric and 'F' in metric:
|
60 |
+
all_metrics = (np.mean(j_metrics_res, axis=2) + np.mean(f_metrics_res, axis=2)) / 2
|
61 |
+
else:
|
62 |
+
all_metrics = np.mean(j_metrics_res, axis=2) if 'J' in metric else np.mean(f_metrics_res, axis=2)
|
63 |
+
row_ind, col_ind = linear_sum_assignment(-all_metrics)
|
64 |
+
return j_metrics_res[row_ind, col_ind, :], f_metrics_res[row_ind, col_ind, :]
|
65 |
+
|
66 |
+
def evaluate(self, res_path, metric=('J', 'F'), debug=False):
|
67 |
+
metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric]
|
68 |
+
if 'T' in metric:
|
69 |
+
raise ValueError('Temporal metric not supported!')
|
70 |
+
if 'J' not in metric and 'F' not in metric:
|
71 |
+
raise ValueError('Metric possible values are J for IoU or F for Boundary')
|
72 |
+
|
73 |
+
# Containers
|
74 |
+
metrics_res = {}
|
75 |
+
if 'J' in metric:
|
76 |
+
metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
|
77 |
+
if 'F' in metric:
|
78 |
+
metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
|
79 |
+
|
80 |
+
# Sweep all sequences
|
81 |
+
results = Results(root_dir=res_path)
|
82 |
+
for seq in tqdm(list(self.dataset.get_sequences())):
|
83 |
+
all_gt_masks, all_void_masks, all_masks_id = self.dataset.get_all_masks(seq, True)
|
84 |
+
if self.task == 'semi-supervised':
|
85 |
+
all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1]
|
86 |
+
all_res_masks = results.read_masks(seq, all_masks_id)
|
87 |
+
if self.task == 'unsupervised':
|
88 |
+
j_metrics_res, f_metrics_res = self._evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric)
|
89 |
+
elif self.task == 'semi-supervised':
|
90 |
+
j_metrics_res, f_metrics_res = self._evaluate_semisupervised(all_gt_masks, all_res_masks, None, metric)
|
91 |
+
for ii in range(all_gt_masks.shape[0]):
|
92 |
+
seq_name = f'{seq}_{ii+1}'
|
93 |
+
if 'J' in metric:
|
94 |
+
[JM, JR, JD] = utils.db_statistics(j_metrics_res[ii])
|
95 |
+
metrics_res['J']["M"].append(JM)
|
96 |
+
metrics_res['J']["R"].append(JR)
|
97 |
+
metrics_res['J']["D"].append(JD)
|
98 |
+
metrics_res['J']["M_per_object"][seq_name] = JM
|
99 |
+
if 'F' in metric:
|
100 |
+
[FM, FR, FD] = utils.db_statistics(f_metrics_res[ii])
|
101 |
+
metrics_res['F']["M"].append(FM)
|
102 |
+
metrics_res['F']["R"].append(FR)
|
103 |
+
metrics_res['F']["D"].append(FD)
|
104 |
+
metrics_res['F']["M_per_object"][seq_name] = FM
|
105 |
+
|
106 |
+
# Show progress
|
107 |
+
if debug:
|
108 |
+
sys.stdout.write(seq + '\n')
|
109 |
+
sys.stdout.flush()
|
110 |
+
return metrics_res
|
ext/davis2017/metrics.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def db_eval_iou(annotation, segmentation, void_pixels=None):
|
7 |
+
""" Compute region similarity as the Jaccard Index.
|
8 |
+
Arguments:
|
9 |
+
annotation (ndarray): binary annotation map.
|
10 |
+
segmentation (ndarray): binary segmentation map.
|
11 |
+
void_pixels (ndarray): optional mask with void pixels
|
12 |
+
|
13 |
+
Return:
|
14 |
+
jaccard (float): region similarity
|
15 |
+
"""
|
16 |
+
assert annotation.shape == segmentation.shape, \
|
17 |
+
f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.'
|
18 |
+
annotation = annotation.astype(bool)
|
19 |
+
segmentation = segmentation.astype(bool)
|
20 |
+
|
21 |
+
if void_pixels is not None:
|
22 |
+
assert annotation.shape == void_pixels.shape, \
|
23 |
+
f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.'
|
24 |
+
void_pixels = void_pixels.astype(bool)
|
25 |
+
else:
|
26 |
+
void_pixels = np.zeros_like(segmentation)
|
27 |
+
|
28 |
+
# Intersection between all sets
|
29 |
+
inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1))
|
30 |
+
union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1))
|
31 |
+
|
32 |
+
j = inters / union
|
33 |
+
if j.ndim == 0:
|
34 |
+
j = 1 if np.isclose(union, 0) else j
|
35 |
+
else:
|
36 |
+
j[np.isclose(union, 0)] = 1
|
37 |
+
return j
|
38 |
+
|
39 |
+
|
40 |
+
def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008):
|
41 |
+
assert annotation.shape == segmentation.shape
|
42 |
+
if void_pixels is not None:
|
43 |
+
assert annotation.shape == void_pixels.shape
|
44 |
+
if annotation.ndim == 3:
|
45 |
+
n_frames = annotation.shape[0]
|
46 |
+
f_res = np.zeros(n_frames)
|
47 |
+
for frame_id in range(n_frames):
|
48 |
+
void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ]
|
49 |
+
f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th)
|
50 |
+
elif annotation.ndim == 2:
|
51 |
+
f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th)
|
52 |
+
else:
|
53 |
+
raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions')
|
54 |
+
return f_res
|
55 |
+
|
56 |
+
|
57 |
+
def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
|
58 |
+
"""
|
59 |
+
Compute mean,recall and decay from per-frame evaluation.
|
60 |
+
Calculates precision/recall for boundaries between foreground_mask and
|
61 |
+
gt_mask using morphological operators to speed it up.
|
62 |
+
|
63 |
+
Arguments:
|
64 |
+
foreground_mask (ndarray): binary segmentation image.
|
65 |
+
gt_mask (ndarray): binary annotated image.
|
66 |
+
void_pixels (ndarray): optional mask with void pixels
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
F (float): boundaries F-measure
|
70 |
+
"""
|
71 |
+
assert np.atleast_3d(foreground_mask).shape[2] == 1
|
72 |
+
if void_pixels is not None:
|
73 |
+
void_pixels = void_pixels.astype(bool)
|
74 |
+
else:
|
75 |
+
void_pixels = np.zeros_like(foreground_mask).astype(bool)
|
76 |
+
|
77 |
+
bound_pix = bound_th if bound_th >= 1 else \
|
78 |
+
np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))
|
79 |
+
|
80 |
+
# Get the pixel boundaries of both masks
|
81 |
+
fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels))
|
82 |
+
gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels))
|
83 |
+
|
84 |
+
from skimage.morphology import disk
|
85 |
+
|
86 |
+
# fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
|
87 |
+
fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
|
88 |
+
# gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
|
89 |
+
gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
|
90 |
+
|
91 |
+
# Get the intersection
|
92 |
+
gt_match = gt_boundary * fg_dil
|
93 |
+
fg_match = fg_boundary * gt_dil
|
94 |
+
|
95 |
+
# Area of the intersection
|
96 |
+
n_fg = np.sum(fg_boundary)
|
97 |
+
n_gt = np.sum(gt_boundary)
|
98 |
+
|
99 |
+
# % Compute precision and recall
|
100 |
+
if n_fg == 0 and n_gt > 0:
|
101 |
+
precision = 1
|
102 |
+
recall = 0
|
103 |
+
elif n_fg > 0 and n_gt == 0:
|
104 |
+
precision = 0
|
105 |
+
recall = 1
|
106 |
+
elif n_fg == 0 and n_gt == 0:
|
107 |
+
precision = 1
|
108 |
+
recall = 1
|
109 |
+
else:
|
110 |
+
precision = np.sum(fg_match) / float(n_fg)
|
111 |
+
recall = np.sum(gt_match) / float(n_gt)
|
112 |
+
|
113 |
+
# Compute F measure
|
114 |
+
if precision + recall == 0:
|
115 |
+
F = 0
|
116 |
+
else:
|
117 |
+
F = 2 * precision * recall / (precision + recall)
|
118 |
+
|
119 |
+
return F
|
120 |
+
|
121 |
+
|
122 |
+
def _seg2bmap(seg, width=None, height=None):
|
123 |
+
"""
|
124 |
+
From a segmentation, compute a binary boundary map with 1 pixel wide
|
125 |
+
boundaries. The boundary pixels are offset by 1/2 pixel towards the
|
126 |
+
origin from the actual segment boundary.
|
127 |
+
Arguments:
|
128 |
+
seg : Segments labeled from 1..k.
|
129 |
+
width : Width of desired bmap <= seg.shape[1]
|
130 |
+
height : Height of desired bmap <= seg.shape[0]
|
131 |
+
Returns:
|
132 |
+
bmap (ndarray): Binary boundary map.
|
133 |
+
David Martin <[email protected]>
|
134 |
+
January 2003
|
135 |
+
"""
|
136 |
+
|
137 |
+
seg = seg.astype(bool)
|
138 |
+
seg[seg > 0] = 1
|
139 |
+
|
140 |
+
assert np.atleast_3d(seg).shape[2] == 1
|
141 |
+
|
142 |
+
width = seg.shape[1] if width is None else width
|
143 |
+
height = seg.shape[0] if height is None else height
|
144 |
+
|
145 |
+
h, w = seg.shape[:2]
|
146 |
+
|
147 |
+
ar1 = float(width) / float(height)
|
148 |
+
ar2 = float(w) / float(h)
|
149 |
+
|
150 |
+
assert not (
|
151 |
+
width > w | height > h | abs(ar1 - ar2) > 0.01
|
152 |
+
), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
|
153 |
+
|
154 |
+
e = np.zeros_like(seg)
|
155 |
+
s = np.zeros_like(seg)
|
156 |
+
se = np.zeros_like(seg)
|
157 |
+
|
158 |
+
e[:, :-1] = seg[:, 1:]
|
159 |
+
s[:-1, :] = seg[1:, :]
|
160 |
+
se[:-1, :-1] = seg[1:, 1:]
|
161 |
+
|
162 |
+
b = seg ^ e | seg ^ s | seg ^ se
|
163 |
+
b[-1, :] = seg[-1, :] ^ e[-1, :]
|
164 |
+
b[:, -1] = seg[:, -1] ^ s[:, -1]
|
165 |
+
b[-1, -1] = 0
|
166 |
+
|
167 |
+
if w == width and h == height:
|
168 |
+
bmap = b
|
169 |
+
else:
|
170 |
+
bmap = np.zeros((height, width))
|
171 |
+
for x in range(w):
|
172 |
+
for y in range(h):
|
173 |
+
if b[y, x]:
|
174 |
+
j = 1 + math.floor((y - 1) + height / h)
|
175 |
+
i = 1 + math.floor((x - 1) + width / h)
|
176 |
+
bmap[j, i] = 1
|
177 |
+
|
178 |
+
return bmap
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == '__main__':
|
182 |
+
from davis2017.davis import DAVIS
|
183 |
+
from davis2017.results import Results
|
184 |
+
|
185 |
+
dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics')
|
186 |
+
results = Results(root_dir='examples/osvos')
|
187 |
+
# Test timing F measure
|
188 |
+
for seq in dataset.get_sequences():
|
189 |
+
all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True)
|
190 |
+
all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1]
|
191 |
+
all_res_masks = results.read_masks(seq, all_masks_id)
|
192 |
+
f_metrics_res = np.zeros(all_gt_masks.shape[:2])
|
193 |
+
for ii in range(all_gt_masks.shape[0]):
|
194 |
+
f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...])
|
195 |
+
|
196 |
+
# Run using to profile code: python -m cProfile -o f_measure.prof metrics.py
|
197 |
+
# snakeviz f_measure.prof
|
ext/davis2017/results.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImagePalette
|
4 |
+
import sys
|
5 |
+
|
6 |
+
|
7 |
+
davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
|
8 |
+
mose_palette = b'\x00\x00\x00\xe4\x1a\x1c7~\xb8M\xafJ\x98N\xa3\xff\x7f\x00\xff\xff3\xa6V(\xf7\x81\xbf\x99\x99\x99f\xc2\xa5\xfc\x8db\x8d\xa0\xcb\xe7\x8a\xc3\xa6\xd8T\xff\xd9/\xe5\xc4\x94\xb3\xb3\xb3\x8d\xd3\xc7\xff\xff\xb3\xbe\xba\xda\xfb\x80r\x80\xb1\xd3\xfd\xb4b\xb3\xdei\xfc\xcd\xe5\xd9\xd9\xd9\xbc\x80\xbd\xcc\xeb\xc5\xff\xedo'
|
9 |
+
|
10 |
+
class Results(object):
|
11 |
+
def __init__(self, root_dir):
|
12 |
+
self.root_dir = root_dir
|
13 |
+
|
14 |
+
def _read_mask(self, sequence, frame_id):
|
15 |
+
try:
|
16 |
+
mask_path = os.path.join(self.root_dir, sequence, f'{frame_id}.png')
|
17 |
+
# BUGFIX
|
18 |
+
# There is a bug in the codebase
|
19 |
+
# Here is a compensation.
|
20 |
+
img = Image.open(mask_path)
|
21 |
+
if img.mode != 'P':
|
22 |
+
img_color = np.array(img)
|
23 |
+
h, w, three = img_color.shape
|
24 |
+
assert three == 3
|
25 |
+
|
26 |
+
img_new = np.ones((h, w), dtype=np.uint8) * 255
|
27 |
+
color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
|
28 |
+
for i in range(10):
|
29 |
+
cur_color = color_map_np[i]
|
30 |
+
mask = np.all(img_color == cur_color, axis=-1)
|
31 |
+
img_new[mask] = i
|
32 |
+
assert not np.all(img_new == 255).any()
|
33 |
+
img = img_new
|
34 |
+
# BUGFIX
|
35 |
+
return np.array(img)
|
36 |
+
except IOError as err:
|
37 |
+
sys.stdout.write(sequence + " frame %s not found!\n" % frame_id)
|
38 |
+
sys.stdout.write("The frames have to be indexed PNG files placed inside the corespondent sequence "
|
39 |
+
"folder.\nThe indexes have to match with the initial frame.\n")
|
40 |
+
sys.stderr.write("IOError: " + err.strerror + "\n")
|
41 |
+
sys.exit()
|
42 |
+
|
43 |
+
def read_masks(self, sequence, masks_id):
|
44 |
+
mask_0 = self._read_mask(sequence, masks_id[0])
|
45 |
+
masks = np.zeros((len(masks_id), *mask_0.shape))
|
46 |
+
for ii, m in enumerate(masks_id):
|
47 |
+
masks[ii, ...] = self._read_mask(sequence, m)
|
48 |
+
num_objects = int(np.max(masks))
|
49 |
+
tmp = np.ones((num_objects, *masks.shape))
|
50 |
+
tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None]
|
51 |
+
masks = (tmp == masks[None, ...]) > 0
|
52 |
+
return masks
|
ext/davis2017/utils.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import errno
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import warnings
|
6 |
+
from ext.davis2017.davis import DAVIS
|
7 |
+
|
8 |
+
|
9 |
+
def _pascal_color_map(N=256, normalized=False):
|
10 |
+
"""
|
11 |
+
Python implementation of the color map function for the PASCAL VOC data set.
|
12 |
+
Official Matlab version can be found in the PASCAL VOC devkit
|
13 |
+
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
|
14 |
+
"""
|
15 |
+
|
16 |
+
def bitget(byteval, idx):
|
17 |
+
return (byteval & (1 << idx)) != 0
|
18 |
+
|
19 |
+
dtype = 'float32' if normalized else 'uint8'
|
20 |
+
cmap = np.zeros((N, 3), dtype=dtype)
|
21 |
+
for i in range(N):
|
22 |
+
r = g = b = 0
|
23 |
+
c = i
|
24 |
+
for j in range(8):
|
25 |
+
r = r | (bitget(c, 0) << 7 - j)
|
26 |
+
g = g | (bitget(c, 1) << 7 - j)
|
27 |
+
b = b | (bitget(c, 2) << 7 - j)
|
28 |
+
c = c >> 3
|
29 |
+
|
30 |
+
cmap[i] = np.array([r, g, b])
|
31 |
+
|
32 |
+
cmap = cmap / 255 if normalized else cmap
|
33 |
+
return cmap
|
34 |
+
|
35 |
+
|
36 |
+
def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None):
|
37 |
+
im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int)
|
38 |
+
if im.shape[:-1] != ann.shape:
|
39 |
+
raise ValueError('First two dimensions of `im` and `ann` must match')
|
40 |
+
if im.shape[-1] != 3:
|
41 |
+
raise ValueError('im must have three channels at the 3 dimension')
|
42 |
+
|
43 |
+
colors = colors or _pascal_color_map()
|
44 |
+
colors = np.asarray(colors, dtype=np.uint8)
|
45 |
+
|
46 |
+
mask = colors[ann]
|
47 |
+
fg = im * alpha + (1 - alpha) * mask
|
48 |
+
|
49 |
+
img = im.copy()
|
50 |
+
img[ann > 0] = fg[ann > 0]
|
51 |
+
|
52 |
+
if contour_thickness: # pragma: no cover
|
53 |
+
import cv2
|
54 |
+
for obj_id in np.unique(ann[ann > 0]):
|
55 |
+
contours = cv2.findContours((ann == obj_id).astype(
|
56 |
+
np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
|
57 |
+
cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(),
|
58 |
+
contour_thickness)
|
59 |
+
return img
|
60 |
+
|
61 |
+
|
62 |
+
def generate_obj_proposals(davis_root, subset, num_proposals, save_path):
|
63 |
+
dataset = DAVIS(davis_root, subset=subset, codalab=True)
|
64 |
+
for seq in dataset.get_sequences():
|
65 |
+
save_dir = os.path.join(save_path, seq)
|
66 |
+
if os.path.exists(save_dir):
|
67 |
+
continue
|
68 |
+
all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True)
|
69 |
+
img_size = all_gt_masks.shape[2:]
|
70 |
+
num_rows = int(np.ceil(np.sqrt(num_proposals)))
|
71 |
+
proposals = np.zeros((num_proposals, len(all_masks_id), *img_size))
|
72 |
+
height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist()
|
73 |
+
width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist()
|
74 |
+
ii = 0
|
75 |
+
prev_h, prev_w = 0, 0
|
76 |
+
for h in height_slices[1:]:
|
77 |
+
for w in width_slices[1:]:
|
78 |
+
proposals[ii, :, prev_h:h, prev_w:w] = 1
|
79 |
+
prev_w = w
|
80 |
+
ii += 1
|
81 |
+
if ii == num_proposals:
|
82 |
+
break
|
83 |
+
prev_h, prev_w = h, 0
|
84 |
+
if ii == num_proposals:
|
85 |
+
break
|
86 |
+
|
87 |
+
os.makedirs(save_dir, exist_ok=True)
|
88 |
+
for i, mask_id in enumerate(all_masks_id):
|
89 |
+
mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0)
|
90 |
+
save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))
|
91 |
+
|
92 |
+
|
93 |
+
def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path):
|
94 |
+
dataset = DAVIS(davis_root, subset=subset, codalab=True)
|
95 |
+
for seq in dataset.get_sequences():
|
96 |
+
gt_masks, all_masks_id = dataset.get_all_masks(seq, True)
|
97 |
+
obj_swap = np.random.permutation(np.arange(gt_masks.shape[0]))
|
98 |
+
gt_masks = gt_masks[obj_swap, ...]
|
99 |
+
save_dir = os.path.join(save_path, seq)
|
100 |
+
os.makedirs(save_dir, exist_ok=True)
|
101 |
+
for i, mask_id in enumerate(all_masks_id):
|
102 |
+
mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0)
|
103 |
+
save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))
|
104 |
+
|
105 |
+
|
106 |
+
def color_map(N=256, normalized=False):
|
107 |
+
def bitget(byteval, idx):
|
108 |
+
return ((byteval & (1 << idx)) != 0)
|
109 |
+
|
110 |
+
dtype = 'float32' if normalized else 'uint8'
|
111 |
+
cmap = np.zeros((N, 3), dtype=dtype)
|
112 |
+
for i in range(N):
|
113 |
+
r = g = b = 0
|
114 |
+
c = i
|
115 |
+
for j in range(8):
|
116 |
+
r = r | (bitget(c, 0) << 7-j)
|
117 |
+
g = g | (bitget(c, 1) << 7-j)
|
118 |
+
b = b | (bitget(c, 2) << 7-j)
|
119 |
+
c = c >> 3
|
120 |
+
|
121 |
+
cmap[i] = np.array([r, g, b])
|
122 |
+
|
123 |
+
cmap = cmap/255 if normalized else cmap
|
124 |
+
return cmap
|
125 |
+
|
126 |
+
|
127 |
+
def save_mask(mask, img_path):
|
128 |
+
if np.max(mask) > 255:
|
129 |
+
raise ValueError('Maximum id pixel value is 255')
|
130 |
+
mask_img = Image.fromarray(mask.astype(np.uint8))
|
131 |
+
mask_img.putpalette(color_map().flatten().tolist())
|
132 |
+
mask_img.save(img_path)
|
133 |
+
|
134 |
+
|
135 |
+
def db_statistics(per_frame_values):
|
136 |
+
""" Compute mean,recall and decay from per-frame evaluation.
|
137 |
+
Arguments:
|
138 |
+
per_frame_values (ndarray): per-frame evaluation
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
M,O,D (float,float,float):
|
142 |
+
return evaluation statistics: mean,recall,decay.
|
143 |
+
"""
|
144 |
+
|
145 |
+
# strip off nan values
|
146 |
+
with warnings.catch_warnings():
|
147 |
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
148 |
+
M = np.nanmean(per_frame_values)
|
149 |
+
O = np.nanmean(per_frame_values > 0.5)
|
150 |
+
|
151 |
+
N_bins = 4
|
152 |
+
ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1
|
153 |
+
ids = ids.astype(np.uint8)
|
154 |
+
|
155 |
+
D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)]
|
156 |
+
|
157 |
+
with warnings.catch_warnings():
|
158 |
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
159 |
+
D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3])
|
160 |
+
|
161 |
+
return M, O, D
|
162 |
+
|
163 |
+
|
164 |
+
def list_files(dir, extension=".png"):
|
165 |
+
return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)]
|
166 |
+
|
167 |
+
|
168 |
+
def force_symlink(file1, file2):
|
169 |
+
try:
|
170 |
+
os.symlink(file1, file2)
|
171 |
+
except OSError as e:
|
172 |
+
if e.errno == errno.EEXIST:
|
173 |
+
os.remove(file2)
|
174 |
+
os.symlink(file1, file2)
|
ext/meta/sam_meta.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
meta_dict = {
|
2 |
+
'vit_h': dict(
|
3 |
+
encoder_embed_dim=1280,
|
4 |
+
encoder_depth=32,
|
5 |
+
encoder_num_heads=16,
|
6 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
7 |
+
# common
|
8 |
+
prompt_embed_dim=256,
|
9 |
+
image_size=1024,
|
10 |
+
vit_patch_size=16,
|
11 |
+
image_embedding_size=64
|
12 |
+
),
|
13 |
+
'vit_l': dict(
|
14 |
+
encoder_embed_dim=1024,
|
15 |
+
encoder_depth=24,
|
16 |
+
encoder_num_heads=16,
|
17 |
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
18 |
+
# common
|
19 |
+
prompt_embed_dim=256,
|
20 |
+
image_size=1024,
|
21 |
+
vit_patch_size=16,
|
22 |
+
image_embedding_size=64
|
23 |
+
),
|
24 |
+
'vit_b': dict(
|
25 |
+
encoder_embed_dim=768,
|
26 |
+
encoder_depth=12,
|
27 |
+
encoder_num_heads=12,
|
28 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
29 |
+
# common
|
30 |
+
prompt_embed_dim=256,
|
31 |
+
image_size=1024,
|
32 |
+
vit_patch_size=16,
|
33 |
+
image_embedding_size=64
|
34 |
+
)
|
35 |
+
}
|
36 |
+
|
37 |
+
checkpoint_dict = {
|
38 |
+
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
|
39 |
+
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
40 |
+
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
41 |
+
}
|
ext/open_clip/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .coca_model import CoCa
|
2 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
3 |
+
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
4 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
5 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
6 |
+
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
7 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype
|
8 |
+
from .openai import load_openai_model, list_openai_models
|
9 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
10 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
11 |
+
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
12 |
+
from .tokenizer import SimpleTokenizer, tokenize, decode
|
13 |
+
from .transform import image_transform, AugmentationCfg
|
14 |
+
from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
|
15 |
+
from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
|
ext/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
ext/open_clip/coca_model.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
from .transformer import (
|
10 |
+
LayerNormFp32,
|
11 |
+
LayerNorm,
|
12 |
+
QuickGELU,
|
13 |
+
MultimodalTransformer,
|
14 |
+
)
|
15 |
+
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
16 |
+
|
17 |
+
try:
|
18 |
+
from transformers import (
|
19 |
+
BeamSearchScorer,
|
20 |
+
LogitsProcessorList,
|
21 |
+
TopPLogitsWarper,
|
22 |
+
TopKLogitsWarper,
|
23 |
+
RepetitionPenaltyLogitsProcessor,
|
24 |
+
MinLengthLogitsProcessor,
|
25 |
+
MaxLengthCriteria,
|
26 |
+
StoppingCriteriaList
|
27 |
+
)
|
28 |
+
|
29 |
+
GENERATION_TYPES = {
|
30 |
+
"top_k": TopKLogitsWarper,
|
31 |
+
"top_p": TopPLogitsWarper,
|
32 |
+
"beam_search": "beam_search"
|
33 |
+
}
|
34 |
+
_has_transformers = True
|
35 |
+
except ImportError as e:
|
36 |
+
GENERATION_TYPES = {
|
37 |
+
"top_k": None,
|
38 |
+
"top_p": None,
|
39 |
+
"beam_search": "beam_search"
|
40 |
+
}
|
41 |
+
_has_transformers = False
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class MultimodalCfg(CLIPTextCfg):
|
46 |
+
mlp_ratio: int = 4
|
47 |
+
dim_head: int = 64
|
48 |
+
heads: int = 8
|
49 |
+
n_queries: int = 256
|
50 |
+
attn_pooler_heads: int = 8
|
51 |
+
|
52 |
+
|
53 |
+
def _build_text_decoder_tower(
|
54 |
+
embed_dim,
|
55 |
+
multimodal_cfg,
|
56 |
+
quick_gelu: bool = False,
|
57 |
+
cast_dtype: Optional[torch.dtype] = None,
|
58 |
+
):
|
59 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
60 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
61 |
+
norm_layer = (
|
62 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
63 |
+
)
|
64 |
+
|
65 |
+
decoder = MultimodalTransformer(
|
66 |
+
context_length=multimodal_cfg.context_length,
|
67 |
+
width=multimodal_cfg.width,
|
68 |
+
heads=multimodal_cfg.heads,
|
69 |
+
layers=multimodal_cfg.layers,
|
70 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
71 |
+
output_dim=embed_dim,
|
72 |
+
act_layer=act_layer,
|
73 |
+
norm_layer=norm_layer,
|
74 |
+
)
|
75 |
+
|
76 |
+
return decoder
|
77 |
+
|
78 |
+
|
79 |
+
class CoCa(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
embed_dim,
|
83 |
+
multimodal_cfg: MultimodalCfg,
|
84 |
+
text_cfg: CLIPTextCfg,
|
85 |
+
vision_cfg: CLIPVisionCfg,
|
86 |
+
quick_gelu: bool = False,
|
87 |
+
cast_dtype: Optional[torch.dtype] = None,
|
88 |
+
pad_id: int = 0,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
92 |
+
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
93 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
94 |
+
|
95 |
+
self.text = _build_text_tower(
|
96 |
+
embed_dim=embed_dim,
|
97 |
+
text_cfg=text_cfg,
|
98 |
+
quick_gelu=quick_gelu,
|
99 |
+
cast_dtype=cast_dtype,
|
100 |
+
)
|
101 |
+
|
102 |
+
vocab_size = (
|
103 |
+
text_cfg.vocab_size # for hf models
|
104 |
+
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
105 |
+
else text_cfg.vocab_size
|
106 |
+
)
|
107 |
+
|
108 |
+
self.visual = _build_vision_tower(
|
109 |
+
embed_dim=embed_dim,
|
110 |
+
vision_cfg=vision_cfg,
|
111 |
+
quick_gelu=quick_gelu,
|
112 |
+
cast_dtype=cast_dtype,
|
113 |
+
)
|
114 |
+
|
115 |
+
self.text_decoder = _build_text_decoder_tower(
|
116 |
+
vocab_size,
|
117 |
+
multimodal_cfg=multimodal_cfg,
|
118 |
+
quick_gelu=quick_gelu,
|
119 |
+
cast_dtype=cast_dtype,
|
120 |
+
)
|
121 |
+
|
122 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
123 |
+
self.pad_id = pad_id
|
124 |
+
|
125 |
+
@torch.jit.ignore
|
126 |
+
def set_grad_checkpointing(self, enable=True):
|
127 |
+
self.visual.set_grad_checkpointing(enable)
|
128 |
+
self.text.set_grad_checkpointing(enable)
|
129 |
+
self.text_decoder.set_grad_checkpointing(enable)
|
130 |
+
|
131 |
+
def _encode_image(self, images, normalize=True):
|
132 |
+
image_latent, tokens_embs = self.visual(images)
|
133 |
+
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
134 |
+
return image_latent, tokens_embs
|
135 |
+
|
136 |
+
def _encode_text(self, text, normalize=True, embed_cls=True):
|
137 |
+
text = text[:, :-1] if embed_cls else text # make space for CLS token
|
138 |
+
text_latent, token_emb = self.text(text)
|
139 |
+
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
140 |
+
return text_latent, token_emb
|
141 |
+
|
142 |
+
def encode_image(self, images, normalize=True):
|
143 |
+
image_latent, _ = self._encode_image(images, normalize=normalize)
|
144 |
+
return image_latent
|
145 |
+
|
146 |
+
def encode_text(self, text, normalize=True, embed_cls=True):
|
147 |
+
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
|
148 |
+
return text_latent
|
149 |
+
|
150 |
+
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
|
151 |
+
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
|
152 |
+
if image_latent is None or image_embs is None:
|
153 |
+
image_latent, image_embs = self._encode_image(image)
|
154 |
+
|
155 |
+
# TODO: add assertion to avoid bugs?
|
156 |
+
labels = text[:, -token_embs.shape[1]:]
|
157 |
+
|
158 |
+
logits = self.text_decoder(image_embs, token_embs)
|
159 |
+
return {
|
160 |
+
"image_features": image_latent,
|
161 |
+
"text_features": text_latent,
|
162 |
+
"logits": logits,
|
163 |
+
"labels": labels,
|
164 |
+
"logit_scale": self.logit_scale.exp()
|
165 |
+
}
|
166 |
+
|
167 |
+
def generate(
|
168 |
+
self,
|
169 |
+
image,
|
170 |
+
text=None,
|
171 |
+
seq_len=30,
|
172 |
+
max_seq_len=77,
|
173 |
+
temperature=1.,
|
174 |
+
generation_type="beam_search",
|
175 |
+
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
176 |
+
top_k=1, # keeps the top_k most probable tokens
|
177 |
+
pad_token_id=None,
|
178 |
+
eos_token_id=None,
|
179 |
+
sot_token_id=None,
|
180 |
+
num_beams=6,
|
181 |
+
num_beam_groups=3,
|
182 |
+
min_seq_len=5,
|
183 |
+
stopping_criteria=None,
|
184 |
+
repetition_penalty=1.0,
|
185 |
+
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
186 |
+
):
|
187 |
+
# taking many ideas and components from HuggingFace GenerationMixin
|
188 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
189 |
+
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
190 |
+
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
sot_token_id = 49406 if sot_token_id is None else sot_token_id
|
194 |
+
eos_token_id = 49407 if eos_token_id is None else eos_token_id
|
195 |
+
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
196 |
+
logit_processor = LogitsProcessorList(
|
197 |
+
[
|
198 |
+
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
199 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
200 |
+
]
|
201 |
+
)
|
202 |
+
|
203 |
+
if stopping_criteria is None:
|
204 |
+
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
205 |
+
|
206 |
+
stopping_criteria = StoppingCriteriaList(
|
207 |
+
stopping_criteria
|
208 |
+
)
|
209 |
+
|
210 |
+
device = image.device
|
211 |
+
|
212 |
+
if generation_type == "beam_search":
|
213 |
+
output = self._generate_beamsearch(
|
214 |
+
image_inputs = image,
|
215 |
+
pad_token_id=pad_token_id,
|
216 |
+
eos_token_id=eos_token_id,
|
217 |
+
sot_token_id=sot_token_id,
|
218 |
+
num_beams=num_beams,
|
219 |
+
num_beam_groups=num_beam_groups,
|
220 |
+
min_seq_len=min_seq_len,
|
221 |
+
stopping_criteria=stopping_criteria,
|
222 |
+
logit_processor=logit_processor,
|
223 |
+
)
|
224 |
+
if fixed_output_length and output.shape[1] < seq_len:
|
225 |
+
return torch.cat(
|
226 |
+
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
|
227 |
+
dim=1
|
228 |
+
)
|
229 |
+
return output
|
230 |
+
|
231 |
+
elif generation_type == "top_p":
|
232 |
+
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
233 |
+
elif generation_type == "top_k":
|
234 |
+
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
235 |
+
else:
|
236 |
+
raise ValueError(
|
237 |
+
f"generation_type has to be one of "
|
238 |
+
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
239 |
+
)
|
240 |
+
|
241 |
+
image_latent, image_embs = self._encode_image(image)
|
242 |
+
|
243 |
+
if text is None:
|
244 |
+
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
245 |
+
|
246 |
+
was_training = self.training
|
247 |
+
num_dims = len(text.shape)
|
248 |
+
|
249 |
+
if num_dims == 1:
|
250 |
+
text = text[None, :]
|
251 |
+
|
252 |
+
cur_len = text.shape[1]
|
253 |
+
self.eval()
|
254 |
+
out = text
|
255 |
+
|
256 |
+
while True:
|
257 |
+
x = out[:, -max_seq_len:]
|
258 |
+
cur_len = x.shape[1]
|
259 |
+
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
|
260 |
+
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
261 |
+
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
262 |
+
|
263 |
+
if mask.all():
|
264 |
+
if not fixed_output_length:
|
265 |
+
break
|
266 |
+
else:
|
267 |
+
logits = logits[~mask, :]
|
268 |
+
filtered_logits = logit_processor(x[~mask, :], logits)
|
269 |
+
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
270 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
271 |
+
|
272 |
+
if (cur_len + 1 == seq_len):
|
273 |
+
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
274 |
+
else:
|
275 |
+
sample[~mask, :] = torch.multinomial(probs, 1)
|
276 |
+
|
277 |
+
out = torch.cat((out, sample), dim=-1)
|
278 |
+
|
279 |
+
cur_len += 1
|
280 |
+
|
281 |
+
if stopping_criteria(out, None):
|
282 |
+
break
|
283 |
+
|
284 |
+
if num_dims == 1:
|
285 |
+
out = out.squeeze(0)
|
286 |
+
|
287 |
+
self.train(was_training)
|
288 |
+
return out
|
289 |
+
|
290 |
+
def _generate_beamsearch(
|
291 |
+
self,
|
292 |
+
image_inputs,
|
293 |
+
pad_token_id=None,
|
294 |
+
eos_token_id=None,
|
295 |
+
sot_token_id=None,
|
296 |
+
num_beams=6,
|
297 |
+
num_beam_groups=3,
|
298 |
+
min_seq_len=5,
|
299 |
+
stopping_criteria=None,
|
300 |
+
logit_processor=None,
|
301 |
+
logit_warper=None,
|
302 |
+
):
|
303 |
+
device = image_inputs.device
|
304 |
+
batch_size = image_inputs.shape[0]
|
305 |
+
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
306 |
+
image_latent, image_embs = self._encode_image(image_inputs)
|
307 |
+
|
308 |
+
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
309 |
+
input_ids = input_ids * sot_token_id
|
310 |
+
beam_scorer = BeamSearchScorer(
|
311 |
+
batch_size=batch_size,
|
312 |
+
num_beams=num_beams,
|
313 |
+
device=device,
|
314 |
+
num_beam_groups=num_beam_groups,
|
315 |
+
)
|
316 |
+
# instantiate logits processors
|
317 |
+
logits_processor = (
|
318 |
+
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
319 |
+
if logit_processor is None
|
320 |
+
else logit_processor
|
321 |
+
)
|
322 |
+
|
323 |
+
batch_size = len(beam_scorer._beam_hyps)
|
324 |
+
num_beams = beam_scorer.num_beams
|
325 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
326 |
+
num_sub_beams = num_beams // num_beam_groups
|
327 |
+
batch_beam_size, cur_len = input_ids.shape
|
328 |
+
beam_indices = None
|
329 |
+
|
330 |
+
if num_beams * batch_size != batch_beam_size:
|
331 |
+
raise ValueError(
|
332 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
333 |
+
)
|
334 |
+
|
335 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
336 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
337 |
+
# the same group don't produce same tokens everytime.
|
338 |
+
beam_scores[:, ::num_sub_beams] = 0
|
339 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
340 |
+
|
341 |
+
while True:
|
342 |
+
|
343 |
+
# predicted tokens in cur_len step
|
344 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
345 |
+
|
346 |
+
# indices which will form the beams in the next time step
|
347 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
348 |
+
|
349 |
+
# do one decoder step on all beams of all sentences in batch
|
350 |
+
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
351 |
+
outputs = self(
|
352 |
+
model_inputs['images'],
|
353 |
+
model_inputs['text'],
|
354 |
+
embed_cls=False,
|
355 |
+
image_latent=image_latent,
|
356 |
+
image_embs=image_embs
|
357 |
+
)
|
358 |
+
|
359 |
+
for beam_group_idx in range(num_beam_groups):
|
360 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
361 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
362 |
+
group_size = group_end_idx - group_start_idx
|
363 |
+
|
364 |
+
# indices of beams of current group among all sentences in batch
|
365 |
+
batch_group_indices = []
|
366 |
+
|
367 |
+
for batch_idx in range(batch_size):
|
368 |
+
batch_group_indices.extend(
|
369 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
370 |
+
)
|
371 |
+
group_input_ids = input_ids[batch_group_indices]
|
372 |
+
|
373 |
+
# select outputs of beams of currentg group only
|
374 |
+
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
375 |
+
vocab_size = next_token_logits.shape[-1]
|
376 |
+
|
377 |
+
next_token_scores_processed = logits_processor(
|
378 |
+
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
379 |
+
)
|
380 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
381 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
382 |
+
|
383 |
+
# reshape for beam search
|
384 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
385 |
+
|
386 |
+
next_token_scores, next_tokens = torch.topk(
|
387 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
388 |
+
)
|
389 |
+
|
390 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
391 |
+
next_tokens = next_tokens % vocab_size
|
392 |
+
|
393 |
+
# stateless
|
394 |
+
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
395 |
+
beam_outputs = beam_scorer.process(
|
396 |
+
group_input_ids,
|
397 |
+
next_token_scores,
|
398 |
+
next_tokens,
|
399 |
+
next_indices,
|
400 |
+
pad_token_id=pad_token_id,
|
401 |
+
eos_token_id=eos_token_id,
|
402 |
+
beam_indices=process_beam_indices,
|
403 |
+
)
|
404 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
405 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
406 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
407 |
+
|
408 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
409 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
410 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
411 |
+
|
412 |
+
# (beam_idx // group_size) -> batch_idx
|
413 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
414 |
+
reordering_indices[batch_group_indices] = (
|
415 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
416 |
+
)
|
417 |
+
|
418 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
419 |
+
|
420 |
+
# increase cur_len
|
421 |
+
cur_len = cur_len + 1
|
422 |
+
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
423 |
+
break
|
424 |
+
|
425 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
426 |
+
sequence_outputs = beam_scorer.finalize(
|
427 |
+
input_ids,
|
428 |
+
beam_scores,
|
429 |
+
next_tokens,
|
430 |
+
next_indices,
|
431 |
+
pad_token_id=pad_token_id,
|
432 |
+
eos_token_id=eos_token_id,
|
433 |
+
max_length=stopping_criteria.max_length,
|
434 |
+
beam_indices=final_beam_indices,
|
435 |
+
)
|
436 |
+
return sequence_outputs['sequences']
|
437 |
+
|
438 |
+
|
439 |
+
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
440 |
+
if past:
|
441 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
442 |
+
|
443 |
+
attention_mask = kwargs.get("attention_mask", None)
|
444 |
+
position_ids = kwargs.get("position_ids", None)
|
445 |
+
|
446 |
+
if attention_mask is not None and position_ids is None:
|
447 |
+
# create position_ids on the fly for batch generation
|
448 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
449 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
450 |
+
else:
|
451 |
+
position_ids = None
|
452 |
+
return {
|
453 |
+
"text": input_ids,
|
454 |
+
"images": image_inputs,
|
455 |
+
"past_key_values": past,
|
456 |
+
"position_ids": position_ids,
|
457 |
+
"attention_mask": attention_mask,
|
458 |
+
}
|
ext/open_clip/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
ext/open_clip/factory.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
13 |
+
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
14 |
+
resize_pos_embed, get_cast_dtype
|
15 |
+
from .coca_model import CoCa
|
16 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
17 |
+
from .openai import load_openai_model
|
18 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
|
19 |
+
list_pretrained_tags_by_model, download_pretrained_from_hf
|
20 |
+
from .transform import image_transform, AugmentationCfg
|
21 |
+
from .tokenizer import HFTokenizer, tokenize
|
22 |
+
|
23 |
+
|
24 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
25 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
26 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
27 |
+
|
28 |
+
|
29 |
+
def _natural_key(string_):
|
30 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
31 |
+
|
32 |
+
|
33 |
+
def _rescan_model_configs():
|
34 |
+
global _MODEL_CONFIGS
|
35 |
+
|
36 |
+
config_ext = ('.json',)
|
37 |
+
config_files = []
|
38 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
39 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
40 |
+
config_files.append(config_path)
|
41 |
+
elif config_path.is_dir():
|
42 |
+
for ext in config_ext:
|
43 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
44 |
+
|
45 |
+
for cf in config_files:
|
46 |
+
with open(cf, 'r') as f:
|
47 |
+
model_cfg = json.load(f)
|
48 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
49 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
50 |
+
|
51 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
52 |
+
|
53 |
+
|
54 |
+
_rescan_model_configs() # initial populate of model config registry
|
55 |
+
|
56 |
+
|
57 |
+
def list_models():
|
58 |
+
""" enumerate available model architectures based on config files """
|
59 |
+
return list(_MODEL_CONFIGS.keys())
|
60 |
+
|
61 |
+
|
62 |
+
def add_model_config(path):
|
63 |
+
""" add model config path or file and update registry """
|
64 |
+
if not isinstance(path, Path):
|
65 |
+
path = Path(path)
|
66 |
+
_MODEL_CONFIG_PATHS.append(path)
|
67 |
+
_rescan_model_configs()
|
68 |
+
|
69 |
+
|
70 |
+
def get_model_config(model_name):
|
71 |
+
if model_name in _MODEL_CONFIGS:
|
72 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
73 |
+
else:
|
74 |
+
return None
|
75 |
+
|
76 |
+
|
77 |
+
def get_tokenizer(model_name):
|
78 |
+
if model_name.startswith(HF_HUB_PREFIX):
|
79 |
+
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
80 |
+
else:
|
81 |
+
config = get_model_config(model_name)
|
82 |
+
tokenizer = HFTokenizer(
|
83 |
+
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
84 |
+
return tokenizer
|
85 |
+
|
86 |
+
|
87 |
+
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
88 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
89 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
90 |
+
state_dict = checkpoint['state_dict']
|
91 |
+
else:
|
92 |
+
state_dict = checkpoint
|
93 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
94 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
95 |
+
return state_dict
|
96 |
+
|
97 |
+
|
98 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
99 |
+
state_dict = load_state_dict(checkpoint_path)
|
100 |
+
# detect old format and make compatible with new format
|
101 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
102 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
103 |
+
resize_pos_embed(state_dict, model)
|
104 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
105 |
+
return incompatible_keys
|
106 |
+
|
107 |
+
|
108 |
+
def create_model(
|
109 |
+
model_name: str,
|
110 |
+
pretrained: Optional[str] = None,
|
111 |
+
precision: str = 'fp32',
|
112 |
+
device: Union[str, torch.device] = 'cpu',
|
113 |
+
jit: bool = False,
|
114 |
+
force_quick_gelu: bool = False,
|
115 |
+
force_custom_text: bool = False,
|
116 |
+
force_patch_dropout: Optional[float] = None,
|
117 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
118 |
+
pretrained_image: bool = False,
|
119 |
+
pretrained_hf: bool = True,
|
120 |
+
cache_dir: Optional[str] = None,
|
121 |
+
output_dict: Optional[bool] = None,
|
122 |
+
require_pretrained: bool = False,
|
123 |
+
logger: logging.Logger = logging,
|
124 |
+
):
|
125 |
+
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
126 |
+
if has_hf_hub_prefix:
|
127 |
+
model_id = model_name[len(HF_HUB_PREFIX):]
|
128 |
+
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
129 |
+
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
130 |
+
|
131 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
132 |
+
config = json.load(f)
|
133 |
+
pretrained_cfg = config['preprocess_cfg']
|
134 |
+
model_cfg = config['model_cfg']
|
135 |
+
else:
|
136 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
137 |
+
checkpoint_path = None
|
138 |
+
pretrained_cfg = {}
|
139 |
+
model_cfg = None
|
140 |
+
|
141 |
+
if isinstance(device, str):
|
142 |
+
device = torch.device(device)
|
143 |
+
|
144 |
+
if pretrained and pretrained.lower() == 'openai':
|
145 |
+
logger.info(f'Loading pretrained {model_name} from OpenAI.')
|
146 |
+
model = load_openai_model(
|
147 |
+
model_name,
|
148 |
+
precision=precision,
|
149 |
+
device=device,
|
150 |
+
cache_dir=cache_dir,
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
model_cfg = model_cfg or get_model_config(model_name)
|
154 |
+
if model_cfg is not None:
|
155 |
+
logger.info(f'Loaded {model_name} model config.')
|
156 |
+
else:
|
157 |
+
logger.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
158 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
159 |
+
|
160 |
+
if force_quick_gelu:
|
161 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
162 |
+
model_cfg["quick_gelu"] = True
|
163 |
+
|
164 |
+
if force_patch_dropout is not None:
|
165 |
+
# override the default patch dropout value
|
166 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
167 |
+
|
168 |
+
if force_image_size is not None:
|
169 |
+
# override model config's image size
|
170 |
+
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
171 |
+
|
172 |
+
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
|
173 |
+
if pretrained_image:
|
174 |
+
if is_timm_model:
|
175 |
+
# pretrained weight loading for timm models set via vision_cfg
|
176 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
177 |
+
else:
|
178 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
179 |
+
|
180 |
+
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
|
181 |
+
cast_dtype = get_cast_dtype(precision)
|
182 |
+
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
183 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
184 |
+
|
185 |
+
if custom_text:
|
186 |
+
if is_hf_model:
|
187 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
188 |
+
if "coca" in model_name:
|
189 |
+
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
190 |
+
else:
|
191 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
192 |
+
else:
|
193 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
194 |
+
|
195 |
+
if precision in ("fp16", "bf16"):
|
196 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
197 |
+
# manual mixed precision that matches original OpenAI behaviour
|
198 |
+
if is_timm_model:
|
199 |
+
# FIXME this is a bit janky, create timm based model in low-precision and
|
200 |
+
# then cast only LayerNormFp32 instances back to float32 so they don't break.
|
201 |
+
# Why? The convert_weights_to_lp fn only works with native models.
|
202 |
+
model.to(device=device, dtype=dtype)
|
203 |
+
from .transformer import LayerNormFp32
|
204 |
+
def _convert_ln(m):
|
205 |
+
if isinstance(m, LayerNormFp32):
|
206 |
+
m.weight.data = m.weight.data.to(torch.float32)
|
207 |
+
m.bias.data = m.bias.data.to(torch.float32)
|
208 |
+
model.apply(_convert_ln)
|
209 |
+
else:
|
210 |
+
model.to(device=device)
|
211 |
+
convert_weights_to_lp(model, dtype=dtype)
|
212 |
+
elif precision in ("pure_fp16", "pure_bf16"):
|
213 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
214 |
+
model.to(device=device, dtype=dtype)
|
215 |
+
else:
|
216 |
+
model.to(device=device)
|
217 |
+
|
218 |
+
pretrained_loaded = False
|
219 |
+
if pretrained:
|
220 |
+
checkpoint_path = ''
|
221 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
222 |
+
if pretrained_cfg:
|
223 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
224 |
+
elif os.path.exists(pretrained):
|
225 |
+
checkpoint_path = pretrained
|
226 |
+
|
227 |
+
if checkpoint_path:
|
228 |
+
logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
229 |
+
load_checkpoint(model, checkpoint_path)
|
230 |
+
else:
|
231 |
+
error_str = (
|
232 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
233 |
+
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
234 |
+
logger.warning(error_str)
|
235 |
+
raise RuntimeError(error_str)
|
236 |
+
pretrained_loaded = True
|
237 |
+
elif has_hf_hub_prefix:
|
238 |
+
logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
239 |
+
load_checkpoint(model, checkpoint_path)
|
240 |
+
pretrained_loaded = True
|
241 |
+
|
242 |
+
if require_pretrained and not pretrained_loaded:
|
243 |
+
# callers of create_model_from_pretrained always expect pretrained weights
|
244 |
+
raise RuntimeError(
|
245 |
+
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
246 |
+
|
247 |
+
# set image / mean metadata from pretrained_cfg if available, or use default
|
248 |
+
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
249 |
+
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
250 |
+
|
251 |
+
if output_dict and hasattr(model, "output_dict"):
|
252 |
+
model.output_dict = True
|
253 |
+
|
254 |
+
if jit:
|
255 |
+
model = torch.jit.script(model)
|
256 |
+
|
257 |
+
return model
|
258 |
+
|
259 |
+
|
260 |
+
def create_loss(args):
|
261 |
+
if args.distill:
|
262 |
+
return DistillClipLoss(
|
263 |
+
local_loss=args.local_loss,
|
264 |
+
gather_with_grad=args.gather_with_grad,
|
265 |
+
cache_labels=True,
|
266 |
+
rank=args.rank,
|
267 |
+
world_size=args.world_size,
|
268 |
+
use_horovod=args.horovod,
|
269 |
+
)
|
270 |
+
elif "coca" in args.model.lower():
|
271 |
+
return CoCaLoss(
|
272 |
+
caption_loss_weight=args.coca_caption_loss_weight,
|
273 |
+
clip_loss_weight=args.coca_contrastive_loss_weight,
|
274 |
+
local_loss=args.local_loss,
|
275 |
+
gather_with_grad=args.gather_with_grad,
|
276 |
+
cache_labels=True,
|
277 |
+
rank=args.rank,
|
278 |
+
world_size=args.world_size,
|
279 |
+
use_horovod=args.horovod,
|
280 |
+
)
|
281 |
+
return ClipLoss(
|
282 |
+
local_loss=args.local_loss,
|
283 |
+
gather_with_grad=args.gather_with_grad,
|
284 |
+
cache_labels=True,
|
285 |
+
rank=args.rank,
|
286 |
+
world_size=args.world_size,
|
287 |
+
use_horovod=args.horovod,
|
288 |
+
)
|
289 |
+
|
290 |
+
|
291 |
+
def create_model_and_transforms(
|
292 |
+
model_name: str,
|
293 |
+
pretrained: Optional[str] = None,
|
294 |
+
precision: str = 'fp32',
|
295 |
+
device: Union[str, torch.device] = 'cpu',
|
296 |
+
jit: bool = False,
|
297 |
+
force_quick_gelu: bool = False,
|
298 |
+
force_custom_text: bool = False,
|
299 |
+
force_patch_dropout: Optional[float] = None,
|
300 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
301 |
+
pretrained_image: bool = False,
|
302 |
+
pretrained_hf: bool = True,
|
303 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
304 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
305 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
306 |
+
cache_dir: Optional[str] = None,
|
307 |
+
output_dict: Optional[bool] = None,
|
308 |
+
logger: logging.Logger = logging,
|
309 |
+
):
|
310 |
+
model = create_model(
|
311 |
+
model_name,
|
312 |
+
pretrained,
|
313 |
+
precision=precision,
|
314 |
+
device=device,
|
315 |
+
jit=jit,
|
316 |
+
force_quick_gelu=force_quick_gelu,
|
317 |
+
force_custom_text=force_custom_text,
|
318 |
+
force_patch_dropout=force_patch_dropout,
|
319 |
+
force_image_size=force_image_size,
|
320 |
+
pretrained_image=pretrained_image,
|
321 |
+
pretrained_hf=pretrained_hf,
|
322 |
+
cache_dir=cache_dir,
|
323 |
+
output_dict=output_dict,
|
324 |
+
logger=logger,
|
325 |
+
)
|
326 |
+
|
327 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
328 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
329 |
+
preprocess_train = image_transform(
|
330 |
+
model.visual.image_size,
|
331 |
+
is_train=True,
|
332 |
+
mean=image_mean,
|
333 |
+
std=image_std,
|
334 |
+
aug_cfg=aug_cfg,
|
335 |
+
)
|
336 |
+
preprocess_val = image_transform(
|
337 |
+
model.visual.image_size,
|
338 |
+
is_train=False,
|
339 |
+
mean=image_mean,
|
340 |
+
std=image_std,
|
341 |
+
)
|
342 |
+
|
343 |
+
return model, preprocess_train, preprocess_val
|
344 |
+
|
345 |
+
|
346 |
+
def create_model_from_pretrained(
|
347 |
+
model_name: str,
|
348 |
+
pretrained: Optional[str] = None,
|
349 |
+
precision: str = 'fp32',
|
350 |
+
device: Union[str, torch.device] = 'cpu',
|
351 |
+
jit: bool = False,
|
352 |
+
force_quick_gelu: bool = False,
|
353 |
+
force_custom_text: bool = False,
|
354 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
355 |
+
return_transform: bool = True,
|
356 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
357 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
358 |
+
cache_dir: Optional[str] = None,
|
359 |
+
logger: logging.Logger = logging,
|
360 |
+
):
|
361 |
+
model = create_model(
|
362 |
+
model_name,
|
363 |
+
pretrained,
|
364 |
+
precision=precision,
|
365 |
+
device=device,
|
366 |
+
jit=jit,
|
367 |
+
force_quick_gelu=force_quick_gelu,
|
368 |
+
force_custom_text=force_custom_text,
|
369 |
+
force_image_size=force_image_size,
|
370 |
+
cache_dir=cache_dir,
|
371 |
+
require_pretrained=True,
|
372 |
+
logger=logger,
|
373 |
+
)
|
374 |
+
|
375 |
+
if not return_transform:
|
376 |
+
return model
|
377 |
+
|
378 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
379 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
380 |
+
preprocess = image_transform(
|
381 |
+
model.visual.image_size,
|
382 |
+
is_train=False,
|
383 |
+
mean=image_mean,
|
384 |
+
std=image_std,
|
385 |
+
)
|
386 |
+
|
387 |
+
return model, preprocess
|
ext/open_clip/generation_utils.py
ADDED
File without changes
|
ext/open_clip/hf_configs.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF architecture dict:
|
2 |
+
arch_dict = {
|
3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
4 |
+
"roberta": {
|
5 |
+
"config_names": {
|
6 |
+
"context_length": "max_position_embeddings",
|
7 |
+
"vocab_size": "vocab_size",
|
8 |
+
"width": "hidden_size",
|
9 |
+
"heads": "num_attention_heads",
|
10 |
+
"layers": "num_hidden_layers",
|
11 |
+
"layer_attr": "layer",
|
12 |
+
"token_embeddings_attr": "embeddings"
|
13 |
+
},
|
14 |
+
"pooler": "mean_pooler",
|
15 |
+
},
|
16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
17 |
+
"xlm-roberta": {
|
18 |
+
"config_names": {
|
19 |
+
"context_length": "max_position_embeddings",
|
20 |
+
"vocab_size": "vocab_size",
|
21 |
+
"width": "hidden_size",
|
22 |
+
"heads": "num_attention_heads",
|
23 |
+
"layers": "num_hidden_layers",
|
24 |
+
"layer_attr": "layer",
|
25 |
+
"token_embeddings_attr": "embeddings"
|
26 |
+
},
|
27 |
+
"pooler": "mean_pooler",
|
28 |
+
},
|
29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
30 |
+
"mt5": {
|
31 |
+
"config_names": {
|
32 |
+
# unlimited seqlen
|
33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
35 |
+
"context_length": "",
|
36 |
+
"vocab_size": "vocab_size",
|
37 |
+
"width": "d_model",
|
38 |
+
"heads": "num_heads",
|
39 |
+
"layers": "num_layers",
|
40 |
+
"layer_attr": "block",
|
41 |
+
"token_embeddings_attr": "embed_tokens"
|
42 |
+
},
|
43 |
+
"pooler": "mean_pooler",
|
44 |
+
},
|
45 |
+
# https://huggingface.co/docs/transformers/model_doc/bert
|
46 |
+
"bert": {
|
47 |
+
"config_names": {
|
48 |
+
"context_length": "max_position_embeddings",
|
49 |
+
"vocab_size": "vocab_size",
|
50 |
+
"width": "hidden_size",
|
51 |
+
"heads": "num_attention_heads",
|
52 |
+
"layers": "num_hidden_layers",
|
53 |
+
},
|
54 |
+
"pooler": "cls_pooler",
|
55 |
+
},
|
56 |
+
}
|
ext/open_clip/hf_model.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
import re
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import TensorType
|
10 |
+
|
11 |
+
try:
|
12 |
+
import transformers
|
13 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
14 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
15 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
16 |
+
except ImportError as e:
|
17 |
+
transformers = None
|
18 |
+
|
19 |
+
|
20 |
+
class BaseModelOutput:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class PretrainedConfig:
|
25 |
+
pass
|
26 |
+
|
27 |
+
from .hf_configs import arch_dict
|
28 |
+
|
29 |
+
|
30 |
+
# utils
|
31 |
+
def _camel2snake(s):
|
32 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
33 |
+
|
34 |
+
|
35 |
+
# TODO: ?last - for gpt-like models
|
36 |
+
_POOLERS = {}
|
37 |
+
|
38 |
+
|
39 |
+
def register_pooler(cls):
|
40 |
+
"""Decorator registering pooler class"""
|
41 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
42 |
+
return cls
|
43 |
+
|
44 |
+
|
45 |
+
@register_pooler
|
46 |
+
class MeanPooler(nn.Module):
|
47 |
+
"""Mean pooling"""
|
48 |
+
|
49 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
50 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
51 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
52 |
+
|
53 |
+
|
54 |
+
@register_pooler
|
55 |
+
class MaxPooler(nn.Module):
|
56 |
+
"""Max pooling"""
|
57 |
+
|
58 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
59 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
60 |
+
return masked_output.max(1).values
|
61 |
+
|
62 |
+
|
63 |
+
@register_pooler
|
64 |
+
class ClsPooler(nn.Module):
|
65 |
+
"""CLS token pooling"""
|
66 |
+
|
67 |
+
def __init__(self, use_pooler_output=True):
|
68 |
+
super().__init__()
|
69 |
+
self.cls_token_position = 0
|
70 |
+
self.use_pooler_output = use_pooler_output
|
71 |
+
|
72 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
73 |
+
if (self.use_pooler_output and
|
74 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
75 |
+
(x.pooler_output is not None)
|
76 |
+
):
|
77 |
+
return x.pooler_output
|
78 |
+
|
79 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
80 |
+
|
81 |
+
|
82 |
+
@register_pooler
|
83 |
+
class ClsLastHiddenStatePooler(nn.Module):
|
84 |
+
"""CLS token pooling
|
85 |
+
NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self):
|
89 |
+
super().__init__()
|
90 |
+
self.cls_token_position = 0
|
91 |
+
|
92 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
93 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
94 |
+
|
95 |
+
|
96 |
+
class HFTextEncoder(nn.Module):
|
97 |
+
"""HuggingFace model adapter"""
|
98 |
+
output_tokens: torch.jit.Final[bool]
|
99 |
+
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
model_name_or_path: str,
|
103 |
+
output_dim: int,
|
104 |
+
config: PretrainedConfig = None,
|
105 |
+
pooler_type: str = None,
|
106 |
+
proj: str = None,
|
107 |
+
pretrained: bool = True,
|
108 |
+
output_tokens: bool = False,
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.output_tokens = output_tokens
|
112 |
+
self.output_dim = output_dim
|
113 |
+
|
114 |
+
# TODO: find better way to get this information
|
115 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
116 |
+
|
117 |
+
if transformers is None:
|
118 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
119 |
+
if config is None:
|
120 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
121 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
122 |
+
AutoModel.from_config, self.config)
|
123 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
124 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
125 |
+
self.transformer = create_func(model_args)
|
126 |
+
self.transformer = self.transformer.encoder
|
127 |
+
else:
|
128 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
129 |
+
else:
|
130 |
+
self.config = config
|
131 |
+
self.transformer = AutoModel.from_config(config)
|
132 |
+
if pooler_type is None: # get default arch pooler
|
133 |
+
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
134 |
+
|
135 |
+
# FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
|
136 |
+
self.vocab_size = getattr(self.config, 'vocab_size', 0)
|
137 |
+
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
|
138 |
+
|
139 |
+
self.pooler = _POOLERS[pooler_type]()
|
140 |
+
|
141 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
142 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
143 |
+
self.proj = nn.Identity()
|
144 |
+
elif proj == 'linear':
|
145 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
146 |
+
elif proj == 'mlp':
|
147 |
+
hidden_size = (d_model + output_dim) // 2
|
148 |
+
self.proj = nn.Sequential(
|
149 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
150 |
+
nn.GELU(),
|
151 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
152 |
+
)
|
153 |
+
|
154 |
+
def forward(self, x: TensorType):
|
155 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
156 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
157 |
+
pooled_out = self.pooler(out, attn_mask)
|
158 |
+
projected = self.proj(pooled_out)
|
159 |
+
|
160 |
+
seq_len = out.last_hidden_state.shape[1]
|
161 |
+
tokens = (
|
162 |
+
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
163 |
+
if type(self.pooler) == ClsPooler
|
164 |
+
else out.last_hidden_state
|
165 |
+
)
|
166 |
+
|
167 |
+
if self.output_tokens:
|
168 |
+
return projected, tokens
|
169 |
+
return projected
|
170 |
+
|
171 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
172 |
+
if not unlocked_layers: # full freezing
|
173 |
+
for n, p in self.transformer.named_parameters():
|
174 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
175 |
+
return
|
176 |
+
|
177 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
178 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
179 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
180 |
+
embeddings = getattr(
|
181 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
182 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
183 |
+
# freeze layers
|
184 |
+
for module in modules:
|
185 |
+
for n, p in module.named_parameters():
|
186 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
187 |
+
|
188 |
+
@torch.jit.ignore
|
189 |
+
def set_grad_checkpointing(self, enable=True):
|
190 |
+
self.transformer.gradient_checkpointing_enable()
|
191 |
+
|
192 |
+
def init_parameters(self):
|
193 |
+
pass
|
ext/open_clip/loss.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
try:
|
6 |
+
import torch.distributed.nn
|
7 |
+
from torch import distributed as dist
|
8 |
+
|
9 |
+
has_distributed = True
|
10 |
+
except ImportError:
|
11 |
+
has_distributed = False
|
12 |
+
|
13 |
+
try:
|
14 |
+
import horovod.torch as hvd
|
15 |
+
except ImportError:
|
16 |
+
hvd = None
|
17 |
+
|
18 |
+
|
19 |
+
def gather_features(
|
20 |
+
image_features,
|
21 |
+
text_features,
|
22 |
+
local_loss=False,
|
23 |
+
gather_with_grad=False,
|
24 |
+
rank=0,
|
25 |
+
world_size=1,
|
26 |
+
use_horovod=False
|
27 |
+
):
|
28 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
29 |
+
if use_horovod:
|
30 |
+
assert hvd is not None, 'Please install horovod'
|
31 |
+
if gather_with_grad:
|
32 |
+
all_image_features = hvd.allgather(image_features)
|
33 |
+
all_text_features = hvd.allgather(text_features)
|
34 |
+
else:
|
35 |
+
with torch.no_grad():
|
36 |
+
all_image_features = hvd.allgather(image_features)
|
37 |
+
all_text_features = hvd.allgather(text_features)
|
38 |
+
if not local_loss:
|
39 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
40 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
41 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
42 |
+
gathered_image_features[rank] = image_features
|
43 |
+
gathered_text_features[rank] = text_features
|
44 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
45 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
46 |
+
else:
|
47 |
+
# We gather tensors from all gpus
|
48 |
+
if gather_with_grad:
|
49 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
50 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
51 |
+
else:
|
52 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
53 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
54 |
+
dist.all_gather(gathered_image_features, image_features)
|
55 |
+
dist.all_gather(gathered_text_features, text_features)
|
56 |
+
if not local_loss:
|
57 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
58 |
+
gathered_image_features[rank] = image_features
|
59 |
+
gathered_text_features[rank] = text_features
|
60 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
61 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
62 |
+
|
63 |
+
return all_image_features, all_text_features
|
64 |
+
|
65 |
+
|
66 |
+
class ClipLoss(nn.Module):
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
local_loss=False,
|
71 |
+
gather_with_grad=False,
|
72 |
+
cache_labels=False,
|
73 |
+
rank=0,
|
74 |
+
world_size=1,
|
75 |
+
use_horovod=False,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.local_loss = local_loss
|
79 |
+
self.gather_with_grad = gather_with_grad
|
80 |
+
self.cache_labels = cache_labels
|
81 |
+
self.rank = rank
|
82 |
+
self.world_size = world_size
|
83 |
+
self.use_horovod = use_horovod
|
84 |
+
|
85 |
+
# cache state
|
86 |
+
self.prev_num_logits = 0
|
87 |
+
self.labels = {}
|
88 |
+
|
89 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
90 |
+
# calculated ground-truth and cache if enabled
|
91 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
92 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
93 |
+
if self.world_size > 1 and self.local_loss:
|
94 |
+
labels = labels + num_logits * self.rank
|
95 |
+
if self.cache_labels:
|
96 |
+
self.labels[device] = labels
|
97 |
+
self.prev_num_logits = num_logits
|
98 |
+
else:
|
99 |
+
labels = self.labels[device]
|
100 |
+
return labels
|
101 |
+
|
102 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
103 |
+
if self.world_size > 1:
|
104 |
+
all_image_features, all_text_features = gather_features(
|
105 |
+
image_features, text_features,
|
106 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
107 |
+
|
108 |
+
if self.local_loss:
|
109 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
110 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
111 |
+
else:
|
112 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
113 |
+
logits_per_text = logits_per_image.T
|
114 |
+
else:
|
115 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
116 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
117 |
+
|
118 |
+
return logits_per_image, logits_per_text
|
119 |
+
|
120 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
121 |
+
device = image_features.device
|
122 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
123 |
+
|
124 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
125 |
+
|
126 |
+
total_loss = (
|
127 |
+
F.cross_entropy(logits_per_image, labels) +
|
128 |
+
F.cross_entropy(logits_per_text, labels)
|
129 |
+
) / 2
|
130 |
+
|
131 |
+
return {"contrastive_loss": total_loss} if output_dict else total_loss
|
132 |
+
|
133 |
+
|
134 |
+
class CoCaLoss(ClipLoss):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
caption_loss_weight,
|
138 |
+
clip_loss_weight,
|
139 |
+
pad_id=0, # pad_token for open_clip custom tokenizer
|
140 |
+
local_loss=False,
|
141 |
+
gather_with_grad=False,
|
142 |
+
cache_labels=False,
|
143 |
+
rank=0,
|
144 |
+
world_size=1,
|
145 |
+
use_horovod=False,
|
146 |
+
):
|
147 |
+
super().__init__(
|
148 |
+
local_loss=local_loss,
|
149 |
+
gather_with_grad=gather_with_grad,
|
150 |
+
cache_labels=cache_labels,
|
151 |
+
rank=rank,
|
152 |
+
world_size=world_size,
|
153 |
+
use_horovod=use_horovod
|
154 |
+
)
|
155 |
+
|
156 |
+
self.clip_loss_weight = clip_loss_weight
|
157 |
+
self.caption_loss_weight = caption_loss_weight
|
158 |
+
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
159 |
+
|
160 |
+
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
161 |
+
|
162 |
+
clip_loss = torch.tensor(0)
|
163 |
+
|
164 |
+
if self.clip_loss_weight:
|
165 |
+
clip_loss = super().forward(image_features, text_features, logit_scale)
|
166 |
+
clip_loss = self.clip_loss_weight * clip_loss
|
167 |
+
|
168 |
+
caption_loss = self.caption_loss(
|
169 |
+
logits.permute(0, 2, 1),
|
170 |
+
labels,
|
171 |
+
)
|
172 |
+
caption_loss = caption_loss * self.caption_loss_weight
|
173 |
+
|
174 |
+
if output_dict:
|
175 |
+
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
176 |
+
|
177 |
+
return clip_loss, caption_loss
|
178 |
+
|
179 |
+
|
180 |
+
class DistillClipLoss(ClipLoss):
|
181 |
+
|
182 |
+
def dist_loss(self, teacher_logits, student_logits):
|
183 |
+
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
184 |
+
|
185 |
+
def forward(
|
186 |
+
self,
|
187 |
+
image_features,
|
188 |
+
text_features,
|
189 |
+
logit_scale,
|
190 |
+
dist_image_features,
|
191 |
+
dist_text_features,
|
192 |
+
dist_logit_scale,
|
193 |
+
output_dict=False,
|
194 |
+
):
|
195 |
+
logits_per_image, logits_per_text = \
|
196 |
+
self.get_logits(image_features, text_features, logit_scale)
|
197 |
+
|
198 |
+
dist_logits_per_image, dist_logits_per_text = \
|
199 |
+
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
200 |
+
|
201 |
+
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
202 |
+
|
203 |
+
contrastive_loss = (
|
204 |
+
F.cross_entropy(logits_per_image, labels) +
|
205 |
+
F.cross_entropy(logits_per_text, labels)
|
206 |
+
) / 2
|
207 |
+
|
208 |
+
distill_loss = (
|
209 |
+
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
210 |
+
self.dist_loss(dist_logits_per_text, logits_per_text)
|
211 |
+
) / 2
|
212 |
+
|
213 |
+
if output_dict:
|
214 |
+
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
215 |
+
|
216 |
+
return contrastive_loss, distill_loss
|
ext/open_clip/model.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
from torch.utils.checkpoint import checkpoint
|
15 |
+
|
16 |
+
from .hf_model import HFTextEncoder
|
17 |
+
from .modified_resnet import ModifiedResNet
|
18 |
+
from .timm_model import TimmModel
|
19 |
+
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
20 |
+
from .utils import to_2tuple
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class CLIPVisionCfg:
|
25 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
26 |
+
width: int = 768
|
27 |
+
head_width: int = 64
|
28 |
+
mlp_ratio: float = 4.0
|
29 |
+
patch_size: int = 16
|
30 |
+
image_size: Union[Tuple[int, int], int] = 224
|
31 |
+
|
32 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
33 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
34 |
+
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
35 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
36 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
37 |
+
n_queries: int = 256 # n_queries for attentional pooler
|
38 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
39 |
+
output_tokens: bool = False
|
40 |
+
|
41 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
42 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
43 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
44 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
45 |
+
timm_proj_bias: bool = False # enable bias final projection
|
46 |
+
timm_drop: float = 0. # head dropout
|
47 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class CLIPTextCfg:
|
52 |
+
context_length: int = 77
|
53 |
+
vocab_size: int = 49408
|
54 |
+
width: int = 512
|
55 |
+
heads: int = 8
|
56 |
+
layers: int = 12
|
57 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
58 |
+
hf_model_name: str = None
|
59 |
+
hf_tokenizer_name: str = None
|
60 |
+
hf_model_pretrained: bool = True
|
61 |
+
proj: str = 'mlp'
|
62 |
+
pooler_type: str = 'mean_pooler'
|
63 |
+
embed_cls: bool = False
|
64 |
+
pad_id: int = 0
|
65 |
+
output_tokens: bool = False
|
66 |
+
|
67 |
+
|
68 |
+
def get_cast_dtype(precision: str):
|
69 |
+
cast_dtype = None
|
70 |
+
if precision == 'bf16':
|
71 |
+
cast_dtype = torch.bfloat16
|
72 |
+
elif precision == 'fp16':
|
73 |
+
cast_dtype = torch.float16
|
74 |
+
return cast_dtype
|
75 |
+
|
76 |
+
|
77 |
+
def get_input_dtype(precision: str):
|
78 |
+
input_dtype = None
|
79 |
+
if precision in ('bf16', 'pure_bf16'):
|
80 |
+
input_dtype = torch.bfloat16
|
81 |
+
elif precision in ('fp16', 'pure_fp16'):
|
82 |
+
input_dtype = torch.float16
|
83 |
+
return input_dtype
|
84 |
+
|
85 |
+
|
86 |
+
def _build_vision_tower(
|
87 |
+
embed_dim: int,
|
88 |
+
vision_cfg: CLIPVisionCfg,
|
89 |
+
quick_gelu: bool = False,
|
90 |
+
cast_dtype: Optional[torch.dtype] = None
|
91 |
+
):
|
92 |
+
if isinstance(vision_cfg, dict):
|
93 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
94 |
+
|
95 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
96 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
97 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
98 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
99 |
+
|
100 |
+
if vision_cfg.timm_model_name:
|
101 |
+
visual = TimmModel(
|
102 |
+
vision_cfg.timm_model_name,
|
103 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
104 |
+
pool=vision_cfg.timm_pool,
|
105 |
+
proj=vision_cfg.timm_proj,
|
106 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
107 |
+
drop=vision_cfg.timm_drop,
|
108 |
+
drop_path=vision_cfg.timm_drop_path,
|
109 |
+
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
|
110 |
+
embed_dim=embed_dim,
|
111 |
+
image_size=vision_cfg.image_size,
|
112 |
+
)
|
113 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
114 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
115 |
+
visual = ModifiedResNet(
|
116 |
+
layers=vision_cfg.layers,
|
117 |
+
output_dim=embed_dim,
|
118 |
+
heads=vision_heads,
|
119 |
+
image_size=vision_cfg.image_size,
|
120 |
+
width=vision_cfg.width,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
124 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
125 |
+
visual = VisionTransformer(
|
126 |
+
image_size=vision_cfg.image_size,
|
127 |
+
patch_size=vision_cfg.patch_size,
|
128 |
+
width=vision_cfg.width,
|
129 |
+
layers=vision_cfg.layers,
|
130 |
+
heads=vision_heads,
|
131 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
132 |
+
ls_init_value=vision_cfg.ls_init_value,
|
133 |
+
patch_dropout=vision_cfg.patch_dropout,
|
134 |
+
input_patchnorm=vision_cfg.input_patchnorm,
|
135 |
+
global_average_pool=vision_cfg.global_average_pool,
|
136 |
+
attentional_pool=vision_cfg.attentional_pool,
|
137 |
+
n_queries=vision_cfg.n_queries,
|
138 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
139 |
+
output_tokens=vision_cfg.output_tokens,
|
140 |
+
output_dim=embed_dim,
|
141 |
+
act_layer=act_layer,
|
142 |
+
norm_layer=norm_layer,
|
143 |
+
)
|
144 |
+
|
145 |
+
return visual
|
146 |
+
|
147 |
+
|
148 |
+
def _build_text_tower(
|
149 |
+
embed_dim: int,
|
150 |
+
text_cfg: CLIPTextCfg,
|
151 |
+
quick_gelu: bool = False,
|
152 |
+
cast_dtype: Optional[torch.dtype] = None,
|
153 |
+
):
|
154 |
+
if isinstance(text_cfg, dict):
|
155 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
156 |
+
|
157 |
+
if text_cfg.hf_model_name:
|
158 |
+
text = HFTextEncoder(
|
159 |
+
text_cfg.hf_model_name,
|
160 |
+
output_dim=embed_dim,
|
161 |
+
proj=text_cfg.proj,
|
162 |
+
pooler_type=text_cfg.pooler_type,
|
163 |
+
pretrained=text_cfg.hf_model_pretrained,
|
164 |
+
output_tokens=text_cfg.output_tokens,
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
168 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
169 |
+
|
170 |
+
text = TextTransformer(
|
171 |
+
context_length=text_cfg.context_length,
|
172 |
+
vocab_size=text_cfg.vocab_size,
|
173 |
+
width=text_cfg.width,
|
174 |
+
heads=text_cfg.heads,
|
175 |
+
layers=text_cfg.layers,
|
176 |
+
ls_init_value=text_cfg.ls_init_value,
|
177 |
+
output_dim=embed_dim,
|
178 |
+
embed_cls=text_cfg.embed_cls,
|
179 |
+
output_tokens=text_cfg.output_tokens,
|
180 |
+
pad_id=text_cfg.pad_id,
|
181 |
+
act_layer=act_layer,
|
182 |
+
norm_layer=norm_layer,
|
183 |
+
)
|
184 |
+
return text
|
185 |
+
|
186 |
+
|
187 |
+
class CLIP(nn.Module):
|
188 |
+
output_dict: torch.jit.Final[bool]
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
embed_dim: int,
|
193 |
+
vision_cfg: CLIPVisionCfg,
|
194 |
+
text_cfg: CLIPTextCfg,
|
195 |
+
quick_gelu: bool = False,
|
196 |
+
cast_dtype: Optional[torch.dtype] = None,
|
197 |
+
output_dict: bool = False,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
self.output_dict = output_dict
|
201 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
202 |
+
|
203 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
204 |
+
self.transformer = text.transformer
|
205 |
+
self.context_length = text.context_length
|
206 |
+
self.vocab_size = text.vocab_size
|
207 |
+
self.token_embedding = text.token_embedding
|
208 |
+
self.positional_embedding = text.positional_embedding
|
209 |
+
self.ln_final = text.ln_final
|
210 |
+
self.text_projection = text.text_projection
|
211 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
212 |
+
|
213 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
214 |
+
|
215 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
216 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
217 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
218 |
+
|
219 |
+
@torch.jit.ignore
|
220 |
+
def set_grad_checkpointing(self, enable=True):
|
221 |
+
self.visual.set_grad_checkpointing(enable)
|
222 |
+
self.transformer.grad_checkpointing = enable
|
223 |
+
|
224 |
+
def encode_image(self, image, normalize: bool = False):
|
225 |
+
features = self.visual(image)
|
226 |
+
return F.normalize(features, dim=-1) if normalize else features
|
227 |
+
|
228 |
+
def encode_text(self, text, normalize: bool = False):
|
229 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
230 |
+
|
231 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
232 |
+
|
233 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
234 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
235 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
236 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
237 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
238 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
239 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
240 |
+
return F.normalize(x, dim=-1) if normalize else x
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self,
|
244 |
+
image: Optional[torch.Tensor] = None,
|
245 |
+
text: Optional[torch.Tensor] = None,
|
246 |
+
):
|
247 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
248 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
249 |
+
if self.output_dict:
|
250 |
+
return {
|
251 |
+
"image_features": image_features,
|
252 |
+
"text_features": text_features,
|
253 |
+
"logit_scale": self.logit_scale.exp()
|
254 |
+
}
|
255 |
+
return image_features, text_features, self.logit_scale.exp()
|
256 |
+
|
257 |
+
|
258 |
+
class CustomTextCLIP(nn.Module):
|
259 |
+
output_dict: torch.jit.Final[bool]
|
260 |
+
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
embed_dim: int,
|
264 |
+
vision_cfg: CLIPVisionCfg,
|
265 |
+
text_cfg: CLIPTextCfg,
|
266 |
+
quick_gelu: bool = False,
|
267 |
+
cast_dtype: Optional[torch.dtype] = None,
|
268 |
+
output_dict: bool = False,
|
269 |
+
):
|
270 |
+
super().__init__()
|
271 |
+
self.output_dict = output_dict
|
272 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
273 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
274 |
+
self.context_length = self.text.context_length
|
275 |
+
self.vocab_size = self.text.vocab_size
|
276 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
277 |
+
|
278 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
279 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
280 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
281 |
+
|
282 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
283 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
284 |
+
|
285 |
+
@torch.jit.ignore
|
286 |
+
def set_grad_checkpointing(self, enable=True):
|
287 |
+
self.visual.set_grad_checkpointing(enable)
|
288 |
+
self.text.set_grad_checkpointing(enable)
|
289 |
+
|
290 |
+
def encode_image(self, image, normalize: bool = False):
|
291 |
+
features = self.visual(image)
|
292 |
+
return F.normalize(features, dim=-1) if normalize else features
|
293 |
+
|
294 |
+
def encode_text(self, text, normalize: bool = False):
|
295 |
+
features = self.text(text)
|
296 |
+
return F.normalize(features, dim=-1) if normalize else features
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
image: Optional[torch.Tensor] = None,
|
301 |
+
text: Optional[torch.Tensor] = None,
|
302 |
+
):
|
303 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
304 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
305 |
+
if self.output_dict:
|
306 |
+
return {
|
307 |
+
"image_features": image_features,
|
308 |
+
"text_features": text_features,
|
309 |
+
"logit_scale": self.logit_scale.exp()
|
310 |
+
}
|
311 |
+
return image_features, text_features, self.logit_scale.exp()
|
312 |
+
|
313 |
+
|
314 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
315 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
316 |
+
|
317 |
+
def _convert_weights(l):
|
318 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
319 |
+
l.weight.data = l.weight.data.to(dtype)
|
320 |
+
if l.bias is not None:
|
321 |
+
l.bias.data = l.bias.data.to(dtype)
|
322 |
+
|
323 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
324 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
325 |
+
tensor = getattr(l, attr)
|
326 |
+
if tensor is not None:
|
327 |
+
tensor.data = tensor.data.to(dtype)
|
328 |
+
|
329 |
+
if isinstance(l, (CLIP, TextTransformer)):
|
330 |
+
# convert text nn.Parameter projections
|
331 |
+
attr = getattr(l, "text_projection", None)
|
332 |
+
if attr is not None:
|
333 |
+
attr.data = attr.data.to(dtype)
|
334 |
+
|
335 |
+
if isinstance(l, VisionTransformer):
|
336 |
+
# convert vision nn.Parameter projections
|
337 |
+
attr = getattr(l, "proj", None)
|
338 |
+
if attr is not None:
|
339 |
+
attr.data = attr.data.to(dtype)
|
340 |
+
|
341 |
+
model.apply(_convert_weights)
|
342 |
+
|
343 |
+
|
344 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
345 |
+
|
346 |
+
|
347 |
+
# used to maintain checkpoint compatibility
|
348 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
349 |
+
if 'text_projection' in state_dict:
|
350 |
+
# old format state_dict, move text tower -> .text
|
351 |
+
new_state_dict = {}
|
352 |
+
for k, v in state_dict.items():
|
353 |
+
if any(k.startswith(p) for p in (
|
354 |
+
'text_projection',
|
355 |
+
'positional_embedding',
|
356 |
+
'token_embedding',
|
357 |
+
'transformer',
|
358 |
+
'ln_final',
|
359 |
+
)):
|
360 |
+
k = 'text.' + k
|
361 |
+
new_state_dict[k] = v
|
362 |
+
return new_state_dict
|
363 |
+
return state_dict
|
364 |
+
|
365 |
+
|
366 |
+
def build_model_from_openai_state_dict(
|
367 |
+
state_dict: dict,
|
368 |
+
quick_gelu=True,
|
369 |
+
cast_dtype=torch.float16,
|
370 |
+
):
|
371 |
+
vit = "visual.proj" in state_dict
|
372 |
+
|
373 |
+
if vit:
|
374 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
375 |
+
vision_layers = len(
|
376 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
377 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
378 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
379 |
+
image_size = vision_patch_size * grid_size
|
380 |
+
else:
|
381 |
+
counts: list = [
|
382 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
383 |
+
vision_layers = tuple(counts)
|
384 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
385 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
386 |
+
vision_patch_size = None
|
387 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
388 |
+
image_size = output_width * 32
|
389 |
+
|
390 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
391 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
392 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
393 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
394 |
+
transformer_heads = transformer_width // 64
|
395 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
396 |
+
|
397 |
+
vision_cfg = CLIPVisionCfg(
|
398 |
+
layers=vision_layers,
|
399 |
+
width=vision_width,
|
400 |
+
patch_size=vision_patch_size,
|
401 |
+
image_size=image_size,
|
402 |
+
)
|
403 |
+
text_cfg = CLIPTextCfg(
|
404 |
+
context_length=context_length,
|
405 |
+
vocab_size=vocab_size,
|
406 |
+
width=transformer_width,
|
407 |
+
heads=transformer_heads,
|
408 |
+
layers=transformer_layers,
|
409 |
+
)
|
410 |
+
model = CLIP(
|
411 |
+
embed_dim,
|
412 |
+
vision_cfg=vision_cfg,
|
413 |
+
text_cfg=text_cfg,
|
414 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
415 |
+
cast_dtype=cast_dtype,
|
416 |
+
)
|
417 |
+
|
418 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
419 |
+
state_dict.pop(key, None)
|
420 |
+
|
421 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
422 |
+
model.load_state_dict(state_dict)
|
423 |
+
return model.eval()
|
424 |
+
|
425 |
+
|
426 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
427 |
+
model.eval()
|
428 |
+
image_size = model.visual.image_size
|
429 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
430 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
431 |
+
model = torch.jit.trace_module(
|
432 |
+
model,
|
433 |
+
inputs=dict(
|
434 |
+
forward=(example_images, example_text),
|
435 |
+
encode_text=(example_text,),
|
436 |
+
encode_image=(example_images,)
|
437 |
+
))
|
438 |
+
model.visual.image_size = image_size
|
439 |
+
return model
|
440 |
+
|
441 |
+
|
442 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
443 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
444 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
445 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
446 |
+
return
|
447 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
448 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
449 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
450 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
451 |
+
return
|
452 |
+
|
453 |
+
if extra_tokens:
|
454 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
455 |
+
else:
|
456 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
457 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
458 |
+
|
459 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
460 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
461 |
+
pos_emb_img = F.interpolate(
|
462 |
+
pos_emb_img,
|
463 |
+
size=grid_size,
|
464 |
+
mode=interpolation,
|
465 |
+
antialias=antialias,
|
466 |
+
align_corners=False,
|
467 |
+
)
|
468 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
469 |
+
if pos_emb_tok is not None:
|
470 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
471 |
+
else:
|
472 |
+
new_pos_embed = pos_emb_img
|
473 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
ext/open_clip/model_configs/EVA01-g-14-plus.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva_giant_patch14_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA01-g-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva_giant_patch14_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-B-16.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_base_patch16_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-E-14-plus.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_enormous_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1280,
|
14 |
+
"heads": 20,
|
15 |
+
"layers": 32
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-E-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_enormous_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-L-14-336.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"timm_model_name": "eva02_large_patch14_clip_336",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-L-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_large_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/RN101-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
23,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
ext/open_clip/model_configs/RN101.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
23,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/RN50-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
6,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
ext/open_clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/RN50x16.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 384,
|
5 |
+
"layers": [
|
6 |
+
6,
|
7 |
+
8,
|
8 |
+
18,
|
9 |
+
8
|
10 |
+
],
|
11 |
+
"width": 96,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 768,
|
18 |
+
"heads": 12,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/RN50x4.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 288,
|
5 |
+
"layers": [
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"width": 80,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 640,
|
18 |
+
"heads": 10,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/RN50x64.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 448,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
15,
|
8 |
+
36,
|
9 |
+
10
|
10 |
+
],
|
11 |
+
"width": 128,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 1024,
|
18 |
+
"heads": 16,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/ViT-B-16-plus-240.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 240,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
ext/open_clip/model_configs/ViT-B-16-plus.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
ext/open_clip/model_configs/ViT-B-16.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
ext/open_clip/model_configs/ViT-B-32-plus-256.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 256,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|