Spaces:
Runtime error
Runtime error
jwyang
commited on
Commit
·
4121bec
1
Parent(s):
520e34b
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +123 -0
- config.py +245 -0
- configs/Base-RCNN-C4.yaml +18 -0
- configs/Base-RCNN-FPN.yaml +42 -0
- configs/CLIP_fast_rcnn_R_50_C4.yaml +71 -0
- configs/CLIP_fast_rcnn_swin_base_C4.yaml +74 -0
- configs/mask_rcnn_CLIP_R_50_C4_1x.yaml +55 -0
- configs/mask_rcnn_R_50_C4_1x.yaml +23 -0
- configs/mask_rcnn_R_50_FPN_1x.yaml +23 -0
- datasets/README.md +140 -0
- datasets/custom_images/dog_and_cat.jfif +0 -0
- datasets/prepare_ade20k_sem_seg.py +26 -0
- datasets/prepare_cocofied_lvis.py +176 -0
- datasets/prepare_for_tests.sh +22 -0
- datasets/prepare_panoptic_fpn.py +116 -0
- detectron2/__init__.py +10 -0
- detectron2/__pycache__/__init__.cpython-39.pyc +0 -0
- detectron2/checkpoint/__init__.py +10 -0
- detectron2/checkpoint/__pycache__/__init__.cpython-39.pyc +0 -0
- detectron2/checkpoint/__pycache__/c2_model_loading.cpython-39.pyc +0 -0
- detectron2/checkpoint/__pycache__/catalog.cpython-39.pyc +0 -0
- detectron2/checkpoint/__pycache__/clip_model_loading.cpython-39.pyc +0 -0
- detectron2/checkpoint/__pycache__/detection_checkpoint.cpython-39.pyc +0 -0
- detectron2/checkpoint/c2_model_loading.py +407 -0
- detectron2/checkpoint/catalog.py +115 -0
- detectron2/checkpoint/clip_model_loading.py +415 -0
- detectron2/checkpoint/detection_checkpoint.py +134 -0
- detectron2/config/__init__.py +24 -0
- detectron2/config/__pycache__/__init__.cpython-39.pyc +0 -0
- detectron2/config/__pycache__/compat.cpython-39.pyc +0 -0
- detectron2/config/__pycache__/config.cpython-39.pyc +0 -0
- detectron2/config/__pycache__/defaults.cpython-39.pyc +0 -0
- detectron2/config/__pycache__/instantiate.cpython-39.pyc +0 -0
- detectron2/config/__pycache__/lazy.cpython-39.pyc +0 -0
- detectron2/config/compat.py +229 -0
- detectron2/config/config.py +249 -0
- detectron2/config/defaults.py +786 -0
- detectron2/config/instantiate.py +82 -0
- detectron2/config/lazy.py +370 -0
- detectron2/data/__init__.py +19 -0
- detectron2/data/__pycache__/__init__.cpython-39.pyc +0 -0
- detectron2/data/__pycache__/build.cpython-39.pyc +0 -0
- detectron2/data/__pycache__/catalog.cpython-39.pyc +0 -0
- detectron2/data/__pycache__/clip_build.cpython-39.pyc +0 -0
- detectron2/data/__pycache__/common.cpython-39.pyc +0 -0
- detectron2/data/__pycache__/dataset_mapper.cpython-39.pyc +0 -0
- detectron2/data/__pycache__/detection_utils.cpython-39.pyc +0 -0
- detectron2/data/build.py +536 -0
- detectron2/data/catalog.py +236 -0
- detectron2/data/clip_build.py +158 -0
app.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import requests
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from PIL import Image
|
11 |
+
from torchvision import transforms
|
12 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
13 |
+
from timm.data import create_transform
|
14 |
+
from config import get_config
|
15 |
+
|
16 |
+
from collections import OrderedDict
|
17 |
+
|
18 |
+
import detectron2.utils.comm as comm
|
19 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
20 |
+
from detectron2.config import get_cfg
|
21 |
+
from detectron2.data import MetadataCatalog
|
22 |
+
from detectron2.engine import DefaultTrainer as Trainer
|
23 |
+
from detectron2.engine import default_argument_parser, default_setup, hooks, launch
|
24 |
+
from detectron2.evaluation import (
|
25 |
+
CityscapesInstanceEvaluator,
|
26 |
+
CityscapesSemSegEvaluator,
|
27 |
+
COCOEvaluator,
|
28 |
+
COCOPanopticEvaluator,
|
29 |
+
DatasetEvaluators,
|
30 |
+
LVISEvaluator,
|
31 |
+
PascalVOCDetectionEvaluator,
|
32 |
+
SemSegEvaluator,
|
33 |
+
verify_results,
|
34 |
+
FLICKR30KEvaluator,
|
35 |
+
)
|
36 |
+
from detectron2.modeling import GeneralizedRCNNWithTTA
|
37 |
+
|
38 |
+
def parse_option():
|
39 |
+
parser = argparse.ArgumentParser('RegionCLIP demo script', add_help=False)
|
40 |
+
parser.add_argument('--config-file', type=str, default="configs/CLIP_fast_rcnn_R_50_C4.yaml", metavar="FILE", help='path to config file', )
|
41 |
+
args, unparsed = parser.parse_known_args()
|
42 |
+
|
43 |
+
return args
|
44 |
+
|
45 |
+
def build_transforms(img_size, center_crop=True):
|
46 |
+
t = []
|
47 |
+
if center_crop:
|
48 |
+
size = int((256 / 224) * img_size)
|
49 |
+
t.append(
|
50 |
+
transforms.Resize(size)
|
51 |
+
)
|
52 |
+
t.append(
|
53 |
+
transforms.CenterCrop(img_size)
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
t.append(
|
57 |
+
transforms.Resize(img_size)
|
58 |
+
)
|
59 |
+
t.append(transforms.ToTensor())
|
60 |
+
return transforms.Compose(t)
|
61 |
+
|
62 |
+
def setup(args):
|
63 |
+
"""
|
64 |
+
Create configs and perform basic setups.
|
65 |
+
"""
|
66 |
+
cfg = get_cfg()
|
67 |
+
cfg.merge_from_file(args.config_file)
|
68 |
+
cfg.freeze()
|
69 |
+
default_setup(cfg, args)
|
70 |
+
return cfg
|
71 |
+
|
72 |
+
'''
|
73 |
+
build model
|
74 |
+
'''
|
75 |
+
args = parse_option()
|
76 |
+
cfg = setup(args)
|
77 |
+
|
78 |
+
model = Trainer.build_model(cfg)
|
79 |
+
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
80 |
+
cfg.MODEL.WEIGHTS, resume=False
|
81 |
+
)
|
82 |
+
if cfg.MODEL.META_ARCHITECTURE in ['CLIPRCNN', 'CLIPFastRCNN', 'PretrainFastRCNN'] \
|
83 |
+
and cfg.MODEL.CLIP.BB_RPN_WEIGHTS is not None\
|
84 |
+
and cfg.MODEL.CLIP.CROP_REGION_TYPE == 'RPN': # load 2nd pretrained model
|
85 |
+
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, bb_rpn_weights=True).resume_or_load(
|
86 |
+
cfg.MODEL.CLIP.BB_RPN_WEIGHTS, resume=False
|
87 |
+
)
|
88 |
+
|
89 |
+
'''
|
90 |
+
build data transform
|
91 |
+
'''
|
92 |
+
eval_transforms = build_transforms(800, center_crop=False)
|
93 |
+
# display_transforms = build_transforms4display(960, center_crop=False)
|
94 |
+
|
95 |
+
def localize_object(image, texts):
|
96 |
+
print(texts)
|
97 |
+
img_t = eval_transforms(Image.fromarray(image).convert("RGB")) * 255
|
98 |
+
|
99 |
+
print(img_t.shape)
|
100 |
+
model.eval()
|
101 |
+
with torch.no_grad():
|
102 |
+
print(img_t[0][:10, :10])
|
103 |
+
res = model(texts, [{"image": img_t}])
|
104 |
+
|
105 |
+
return res
|
106 |
+
|
107 |
+
|
108 |
+
image = gr.inputs.Image()
|
109 |
+
|
110 |
+
gr.Interface(
|
111 |
+
description="RegionCLIP for Open-Vocabulary Object Detection",
|
112 |
+
fn=localize_object,
|
113 |
+
inputs=["image", "text"],
|
114 |
+
outputs=[
|
115 |
+
gr.outputs.Image(
|
116 |
+
type="pil",
|
117 |
+
label="grounding results"),
|
118 |
+
],
|
119 |
+
examples=[
|
120 |
+
["./elephants.png", "an elephant"],
|
121 |
+
["./apple_with_ipod.jpg", "an apple"],
|
122 |
+
],
|
123 |
+
).launch()
|
config.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Unified Contrastive Learning (UniCL)
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Jianwei Yang ([email protected])
|
6 |
+
# Based on Swin Transformer written by Zhe Liu
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import os
|
10 |
+
import yaml
|
11 |
+
from yacs.config import CfgNode as CN
|
12 |
+
|
13 |
+
_C = CN()
|
14 |
+
_C.VERBOSE = False
|
15 |
+
|
16 |
+
# Base config files
|
17 |
+
_C.BASE = ['']
|
18 |
+
|
19 |
+
# -----------------------------------------------------------------------------
|
20 |
+
# Data settings
|
21 |
+
# -----------------------------------------------------------------------------
|
22 |
+
_C.DATA = CN()
|
23 |
+
# Batch size for a single GPU, could be overwritten by command line argument
|
24 |
+
_C.DATA.BATCH_SIZE = 128
|
25 |
+
# Path to dataset, could be overwritten by command line argument
|
26 |
+
_C.DATA.DATA_PATH = ''
|
27 |
+
# Dataset name
|
28 |
+
_C.DATA.DATASET = 'imagenet'
|
29 |
+
# Input image size
|
30 |
+
_C.DATA.IMG_SIZE = 224
|
31 |
+
# Interpolation to resize image (random, bilinear, bicubic)
|
32 |
+
_C.DATA.INTERPOLATION = 'bicubic'
|
33 |
+
# Use zipped dataset instead of folder dataset
|
34 |
+
# could be overwritten by command line argument
|
35 |
+
_C.DATA.ZIP_MODE = False
|
36 |
+
# Cache Data in Memory, could be overwritten by command line argument
|
37 |
+
_C.DATA.CACHE_MODE = 'part'
|
38 |
+
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
|
39 |
+
_C.DATA.PIN_MEMORY = True
|
40 |
+
# Number of data loading threads
|
41 |
+
_C.DATA.NUM_WORKERS = 8
|
42 |
+
|
43 |
+
# -----------------------------------------------------------------------------
|
44 |
+
# Model settings
|
45 |
+
# -----------------------------------------------------------------------------
|
46 |
+
_C.MODEL = CN()
|
47 |
+
# Model name
|
48 |
+
_C.MODEL.NAME = ''
|
49 |
+
# Checkpoint to resume, could be overwritten by command line argument
|
50 |
+
_C.MODEL.RESUME = ''
|
51 |
+
# Number of classes, overwritten in data preparation
|
52 |
+
_C.MODEL.NUM_CLASSES = 0
|
53 |
+
# Label Smoothing
|
54 |
+
_C.MODEL.LABEL_SMOOTHING = 0.1
|
55 |
+
# Whether load pretrained model
|
56 |
+
_C.MODEL.PRETRAINED = ''
|
57 |
+
# Projection dimension
|
58 |
+
_C.MODEL.DIM_PROJECTION = 512
|
59 |
+
# Mode specific
|
60 |
+
_C.MODEL.SPEC = CN(new_allowed=True)
|
61 |
+
# -----------------------------------------------------------------------------
|
62 |
+
# Build Image Encoder
|
63 |
+
# -----------------------------------------------------------------------------
|
64 |
+
_C.MODEL.IMAGE_ENCODER = CN()
|
65 |
+
# Image encoder type
|
66 |
+
_C.MODEL.IMAGE_ENCODER.TYPE = 'swin'
|
67 |
+
# Input image size
|
68 |
+
_C.MODEL.IMAGE_ENCODER.IMG_SIZE = 224
|
69 |
+
# Dropout rate
|
70 |
+
_C.MODEL.IMAGE_ENCODER.DROP_RATE = 0.0
|
71 |
+
# Drop path rate
|
72 |
+
_C.MODEL.IMAGE_ENCODER.DROP_PATH_RATE = 0.1
|
73 |
+
|
74 |
+
# Swin Transformer parameters
|
75 |
+
_C.MODEL.IMAGE_ENCODER.SWIN = CN()
|
76 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.PATCH_SIZE = 4
|
77 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.IN_CHANS = 3
|
78 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.EMBED_DIM = 96
|
79 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.DEPTHS = [2, 2, 6, 2]
|
80 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
81 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.WINDOW_SIZE = 7
|
82 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.MLP_RATIO = 4.
|
83 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.QKV_BIAS = True
|
84 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.QK_SCALE = None
|
85 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.APE = False
|
86 |
+
_C.MODEL.IMAGE_ENCODER.SWIN.PATCH_NORM = True
|
87 |
+
|
88 |
+
# FocalNet parameters
|
89 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL = CN()
|
90 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_SIZE = 4
|
91 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.IN_CHANS = 3
|
92 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.EMBED_DIM = 96
|
93 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.DEPTHS = [2, 2, 6, 2]
|
94 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.MLP_RATIO = 4.
|
95 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_NORM = True
|
96 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_LEVELS = [2, 2, 2, 2]
|
97 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_WINDOWS = [3, 3, 3, 3]
|
98 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_FACTORS = [2, 2, 2, 2]
|
99 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.USE_CONV_EMBED = False
|
100 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.USE_LAYERSCALE = False
|
101 |
+
_C.MODEL.IMAGE_ENCODER.FOCAL.USE_POSTLN = False
|
102 |
+
|
103 |
+
# -----------------------------------------------------------------------------
|
104 |
+
# Build Text Encoder
|
105 |
+
# -----------------------------------------------------------------------------
|
106 |
+
_C.MODEL.TEXT_ENCODER = CN()
|
107 |
+
|
108 |
+
_C.MODEL.TEXT_ENCODER.NAME = 'transformer'
|
109 |
+
_C.MODEL.TEXT_ENCODER.LOAD_PRETRAINED = False
|
110 |
+
_C.MODEL.TEXT_ENCODER.PRETRAINED = ''
|
111 |
+
_C.MODEL.TEXT_ENCODER.TOKENIZER = 'clip'
|
112 |
+
_C.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
|
113 |
+
_C.MODEL.TEXT_ENCODER.WIDTH = 1024
|
114 |
+
_C.MODEL.TEXT_ENCODER.HEADS = 16
|
115 |
+
_C.MODEL.TEXT_ENCODER.LAYERS = 12
|
116 |
+
_C.MODEL.TEXT_ENCODER.AUTOGRESSIVE = True
|
117 |
+
|
118 |
+
# -----------------------------------------------------------------------------
|
119 |
+
# Training settings
|
120 |
+
# -----------------------------------------------------------------------------
|
121 |
+
_C.TRAIN = CN()
|
122 |
+
_C.TRAIN.START_EPOCH = 0
|
123 |
+
_C.TRAIN.EPOCHS = 32
|
124 |
+
_C.TRAIN.WARMUP_EPOCHS = 5
|
125 |
+
_C.TRAIN.WEIGHT_DECAY = 0.1
|
126 |
+
_C.TRAIN.BASE_LR = 5e-4
|
127 |
+
_C.TRAIN.WARMUP_LR = 5e-7
|
128 |
+
_C.TRAIN.MIN_LR = 5e-6
|
129 |
+
# Clip gradient norm
|
130 |
+
_C.TRAIN.CLIP_GRAD = 5.0
|
131 |
+
# Auto resume from latest checkpoint
|
132 |
+
_C.TRAIN.AUTO_RESUME = True
|
133 |
+
# Gradient accumulation steps
|
134 |
+
# could be overwritten by command line argument
|
135 |
+
_C.TRAIN.ACCUMULATION_STEPS = 0
|
136 |
+
# Whether to use gradient checkpointing to save memory
|
137 |
+
# could be overwritten by command line argument
|
138 |
+
_C.TRAIN.USE_CHECKPOINT = False
|
139 |
+
|
140 |
+
# LR scheduler
|
141 |
+
_C.TRAIN.LR_SCHEDULER = CN()
|
142 |
+
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
|
143 |
+
# Epoch interval to decay LR, used in StepLRScheduler
|
144 |
+
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
|
145 |
+
# LR decay rate, used in StepLRScheduler
|
146 |
+
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
|
147 |
+
|
148 |
+
# Optimizer
|
149 |
+
_C.TRAIN.OPTIMIZER = CN()
|
150 |
+
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
|
151 |
+
# Optimizer Epsilon
|
152 |
+
_C.TRAIN.OPTIMIZER.EPS = 1e-8
|
153 |
+
# Optimizer Betas
|
154 |
+
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
|
155 |
+
# SGD momentum
|
156 |
+
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
|
157 |
+
|
158 |
+
# -----------------------------------------------------------------------------
|
159 |
+
# Augmentation settings
|
160 |
+
# -----------------------------------------------------------------------------
|
161 |
+
_C.AUG = CN()
|
162 |
+
# Color jitter factor
|
163 |
+
_C.AUG.COLOR_JITTER = 0.4
|
164 |
+
# Use AutoAugment policy. "v0" or "original"
|
165 |
+
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
|
166 |
+
# Random erase prob
|
167 |
+
_C.AUG.REPROB = 0.25
|
168 |
+
# Random erase mode
|
169 |
+
_C.AUG.REMODE = 'pixel'
|
170 |
+
# Random erase count
|
171 |
+
_C.AUG.RECOUNT = 1
|
172 |
+
# Mixup alpha, mixup enabled if > 0
|
173 |
+
_C.AUG.MIXUP = 0.8
|
174 |
+
# Cutmix alpha, cutmix enabled if > 0
|
175 |
+
_C.AUG.CUTMIX = 1.0
|
176 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
177 |
+
_C.AUG.CUTMIX_MINMAX = None
|
178 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
179 |
+
_C.AUG.MIXUP_PROB = 1.0
|
180 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
181 |
+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
182 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
183 |
+
_C.AUG.MIXUP_MODE = 'batch'
|
184 |
+
|
185 |
+
# -----------------------------------------------------------------------------
|
186 |
+
# Testing settings
|
187 |
+
# -----------------------------------------------------------------------------
|
188 |
+
_C.TEST = CN()
|
189 |
+
# Whether to use center crop when testing
|
190 |
+
_C.TEST.CROP = True
|
191 |
+
|
192 |
+
# -----------------------------------------------------------------------------
|
193 |
+
# Misc
|
194 |
+
# -----------------------------------------------------------------------------
|
195 |
+
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
|
196 |
+
# overwritten by command line argument
|
197 |
+
_C.AMP_OPT_LEVEL = ''
|
198 |
+
# Path to output folder, overwritten by command line argument
|
199 |
+
_C.OUTPUT = ''
|
200 |
+
# Tag of experiment, overwritten by command line argument
|
201 |
+
_C.TAG = 'default'
|
202 |
+
# Frequency to save checkpoint
|
203 |
+
_C.SAVE_FREQ = 1
|
204 |
+
# Frequency to logging info
|
205 |
+
_C.PRINT_FREQ = 100
|
206 |
+
# Fixed random seed
|
207 |
+
_C.SEED = 0
|
208 |
+
# Perform evaluation only, overwritten by command line argument
|
209 |
+
_C.EVAL_MODE = False
|
210 |
+
# Test throughput only, overwritten by command line argument
|
211 |
+
_C.THROUGHPUT_MODE = False
|
212 |
+
# Debug only so that skip dataloader initialization, overwritten by command line argument
|
213 |
+
_C.DEBUG_MODE = False
|
214 |
+
# local rank for DistributedDataParallel, given by command line argument
|
215 |
+
_C.LOCAL_RANK = 0
|
216 |
+
|
217 |
+
|
218 |
+
def _update_config_from_file(config, cfg_file):
|
219 |
+
config.defrost()
|
220 |
+
with open(cfg_file, 'r') as f:
|
221 |
+
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
222 |
+
|
223 |
+
for cfg in yaml_cfg.setdefault('BASE', ['']):
|
224 |
+
if cfg:
|
225 |
+
_update_config_from_file(
|
226 |
+
config, os.path.join(os.path.dirname(cfg_file), cfg)
|
227 |
+
)
|
228 |
+
print('=> merge config from {}'.format(cfg_file))
|
229 |
+
config.merge_from_file(cfg_file)
|
230 |
+
config.freeze()
|
231 |
+
|
232 |
+
|
233 |
+
def update_config(config, args):
|
234 |
+
_update_config_from_file(config, args.cfg)
|
235 |
+
config.freeze()
|
236 |
+
|
237 |
+
|
238 |
+
def get_config(args):
|
239 |
+
"""Get a yacs CfgNode object with default values."""
|
240 |
+
# Return a clone so that the defaults will not be altered
|
241 |
+
# This is for the "local variable" use pattern
|
242 |
+
config = _C.clone()
|
243 |
+
update_config(config, args)
|
244 |
+
|
245 |
+
return config
|
configs/Base-RCNN-C4.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GeneralizedRCNN"
|
3 |
+
RPN:
|
4 |
+
PRE_NMS_TOPK_TEST: 6000
|
5 |
+
POST_NMS_TOPK_TEST: 1000
|
6 |
+
ROI_HEADS:
|
7 |
+
NAME: "Res5ROIHeads"
|
8 |
+
DATASETS:
|
9 |
+
TRAIN: ("coco_2017_train",)
|
10 |
+
TEST: ("coco_2017_val",)
|
11 |
+
SOLVER:
|
12 |
+
IMS_PER_BATCH: 16
|
13 |
+
BASE_LR: 0.02
|
14 |
+
STEPS: (60000, 80000)
|
15 |
+
MAX_ITER: 90000
|
16 |
+
INPUT:
|
17 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
18 |
+
VERSION: 2
|
configs/Base-RCNN-FPN.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GeneralizedRCNN"
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "build_resnet_fpn_backbone"
|
5 |
+
RESNETS:
|
6 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
7 |
+
FPN:
|
8 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
9 |
+
ANCHOR_GENERATOR:
|
10 |
+
SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
|
11 |
+
ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
|
12 |
+
RPN:
|
13 |
+
IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
|
14 |
+
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
|
15 |
+
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
|
16 |
+
# Detectron1 uses 2000 proposals per-batch,
|
17 |
+
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
|
18 |
+
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
|
19 |
+
POST_NMS_TOPK_TRAIN: 1000
|
20 |
+
POST_NMS_TOPK_TEST: 1000
|
21 |
+
ROI_HEADS:
|
22 |
+
NAME: "StandardROIHeads"
|
23 |
+
IN_FEATURES: ["p2", "p3", "p4", "p5"]
|
24 |
+
ROI_BOX_HEAD:
|
25 |
+
NAME: "FastRCNNConvFCHead"
|
26 |
+
NUM_FC: 2
|
27 |
+
POOLER_RESOLUTION: 7
|
28 |
+
ROI_MASK_HEAD:
|
29 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
30 |
+
NUM_CONV: 4
|
31 |
+
POOLER_RESOLUTION: 14
|
32 |
+
DATASETS:
|
33 |
+
TRAIN: ("coco_2017_train",)
|
34 |
+
TEST: ("coco_2017_val",)
|
35 |
+
SOLVER:
|
36 |
+
IMS_PER_BATCH: 16
|
37 |
+
BASE_LR: 0.02
|
38 |
+
STEPS: (60000, 80000)
|
39 |
+
MAX_ITER: 90000
|
40 |
+
INPUT:
|
41 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
42 |
+
VERSION: 2
|
configs/CLIP_fast_rcnn_R_50_C4.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-C4.yaml"
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "CLIPFastRCNN" # "CLIPRCNN" # "GeneralizedRCNN"
|
4 |
+
MASK_ON: False
|
5 |
+
WEIGHTS: "./model_final.pth"
|
6 |
+
BACKBONE:
|
7 |
+
NAME: "build_clip_resnet_backbone" # "build_resnet_fpn_backbone"
|
8 |
+
FREEZE_AT: 2
|
9 |
+
TEXT_BACKBONE:
|
10 |
+
NAME: "build_clip_language_encoder"
|
11 |
+
CLIP:
|
12 |
+
CROP_REGION_TYPE: "RPN"
|
13 |
+
OFFLINE_RPN_CONFIG: "./configs/mask_rcnn_R_50_FPN_1x.yaml"
|
14 |
+
USE_TEXT_EMB_CLASSIFIER: True
|
15 |
+
TEXT_EMB_PATH: "./lvis_1203_cls_emb_notnorm_rn50x4.pth"
|
16 |
+
NO_BOX_DELTA: True
|
17 |
+
OFFLINE_RPN_NMS_THRESH: 0.7
|
18 |
+
CLSS_TEMP: 0.01
|
19 |
+
MULTIPLY_RPN_SCORE: True
|
20 |
+
TEXT_EMB_DIM: 640
|
21 |
+
RESNETS:
|
22 |
+
DEPTH: 200
|
23 |
+
OUT_FEATURES: ["res4"]
|
24 |
+
NORM: FrozenBN
|
25 |
+
STEM_OUT_CHANNELS: 64
|
26 |
+
RES2_OUT_CHANNELS: 256
|
27 |
+
RPN:
|
28 |
+
HEAD_NAME: StandardRPNHead
|
29 |
+
IN_FEATURES: ["res4"]
|
30 |
+
POST_NMS_TOPK_TEST: 1000
|
31 |
+
NMS_THRESH:
|
32 |
+
ROI_HEADS:
|
33 |
+
NAME: "CLIPRes5ROIHeads" # "Res5ROIHeads" # "StandardROIHeads"
|
34 |
+
IN_FEATURES: ["res4"]
|
35 |
+
NUM_CLASSES: 1203
|
36 |
+
NMS_THRESH_TEST: 0.3
|
37 |
+
SCORE_THRESH_TEST: 0.0
|
38 |
+
ROI_BOX_HEAD:
|
39 |
+
NAME: ""
|
40 |
+
NUM_FC: 0
|
41 |
+
CLS_AGNOSTIC_BBOX_REG: True
|
42 |
+
POOLER_RESOLUTION: 18
|
43 |
+
ROI_MASK_HEAD:
|
44 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
45 |
+
NUM_CONV: 0
|
46 |
+
POOLER_RESOLUTION: 14
|
47 |
+
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
48 |
+
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
49 |
+
INPUT:
|
50 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
51 |
+
DATASETS:
|
52 |
+
TRAIN: ("lvis_v1_train",)
|
53 |
+
TEST: ("lvis_v1_val",)
|
54 |
+
TEST:
|
55 |
+
DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
|
56 |
+
EVAL_PERIOD: 25000
|
57 |
+
SOLVER:
|
58 |
+
IMS_PER_BATCH: 16
|
59 |
+
BASE_LR: 0.02
|
60 |
+
STEPS: (120000, 160000)
|
61 |
+
MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
|
62 |
+
DATALOADER:
|
63 |
+
SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
|
64 |
+
REPEAT_THRESHOLD: 0.001
|
65 |
+
INPUT:
|
66 |
+
MIN_SIZE_TRAIN_SAMPLING: choice
|
67 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
68 |
+
MAX_SIZE_TRAIN: 1333
|
69 |
+
MIN_SIZE_TEST: 800
|
70 |
+
MAX_SIZE_TEST: 1333
|
71 |
+
FORMAT: "RGB"
|
configs/CLIP_fast_rcnn_swin_base_C4.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-C4.yaml"
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "CLIPFastRCNN" # "CLIPRCNN" # "GeneralizedRCNN"
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "build_clip_swin" # "build_resnet_fpn_backbone"
|
6 |
+
FREEZE_AT: 2
|
7 |
+
TEXT_BACKBONE:
|
8 |
+
NAME: "build_clip_swin_text_backbone"
|
9 |
+
SPEC:
|
10 |
+
EMBED_DIM: 512
|
11 |
+
VISION:
|
12 |
+
PATCH_SIZE: 4
|
13 |
+
IN_CHANS: 3
|
14 |
+
EMBED_DIM: 128
|
15 |
+
DEPTHS: [ 2, 2, 18, 2 ]
|
16 |
+
NUM_HEADS: [ 4, 8, 16, 32 ]
|
17 |
+
WINDOW_SIZE: 7
|
18 |
+
MLP_RATIO: 4.
|
19 |
+
QKV_BIAS: True
|
20 |
+
APE: False
|
21 |
+
PATCH_NORM: True
|
22 |
+
DROP_RATE: 0.0
|
23 |
+
DROP_PATH_RATE: 0.2
|
24 |
+
OUT_FEATURES: ["stage2", "stage3", "stage4", "stage5"]
|
25 |
+
TEXT:
|
26 |
+
NAME: 'transformer'
|
27 |
+
TOKENIZER: clip
|
28 |
+
CONTEXT_LENGTH: 77
|
29 |
+
WIDTH: 512
|
30 |
+
HEADS: 8
|
31 |
+
LAYERS: 12
|
32 |
+
WEIGHTS: "" # "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
33 |
+
MASK_ON: True
|
34 |
+
RPN:
|
35 |
+
HEAD_NAME: StandardRPNHead
|
36 |
+
IN_FEATURES: ["stage4"]
|
37 |
+
ROI_HEADS:
|
38 |
+
NAME: "CLIPSwinROIHeads" # "Res5ROIHeads" # "StandardROIHeads"
|
39 |
+
IN_FEATURES: ["stage4"]
|
40 |
+
NUM_CLASSES: 1203
|
41 |
+
SCORE_THRESH_TEST: 0.0001
|
42 |
+
ROI_BOX_HEAD:
|
43 |
+
NAME: ""
|
44 |
+
NUM_FC: 0
|
45 |
+
POOLER_RESOLUTION: 14
|
46 |
+
ROI_MASK_HEAD:
|
47 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
48 |
+
NUM_CONV: 0
|
49 |
+
POOLER_RESOLUTION: 14
|
50 |
+
PIXEL_MEAN: [0.485, 0.456, 0.406]
|
51 |
+
PIXEL_STD: [0.229, 0.224, 0.225]
|
52 |
+
INPUT:
|
53 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
54 |
+
DATASETS:
|
55 |
+
TRAIN: ("lvis_v1_train",)
|
56 |
+
TEST: ("lvis_v1_val",)
|
57 |
+
TEST:
|
58 |
+
DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
|
59 |
+
EVAL_PERIOD: 25000
|
60 |
+
SOLVER:
|
61 |
+
IMS_PER_BATCH: 16
|
62 |
+
BASE_LR: 0.02
|
63 |
+
STEPS: (120000, 160000)
|
64 |
+
MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
|
65 |
+
DATALOADER:
|
66 |
+
SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
|
67 |
+
REPEAT_THRESHOLD: 0.001
|
68 |
+
INPUT:
|
69 |
+
MIN_SIZE_TRAIN_SAMPLING: choice
|
70 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
71 |
+
MAX_SIZE_TRAIN: 1333
|
72 |
+
MIN_SIZE_TEST: 800
|
73 |
+
MAX_SIZE_TEST: 1333
|
74 |
+
FORMAT: "RGB"
|
configs/mask_rcnn_CLIP_R_50_C4_1x.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-C4.yaml"
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "GeneralizedRCNN"
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "build_clip_resnet_backbone" #"build_clip_resnet_fpn_backbone" # "build_resnet_fpn_backbone"
|
6 |
+
FREEZE_AT: 2
|
7 |
+
WEIGHTS: "" # "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
8 |
+
MASK_ON: True
|
9 |
+
RESNETS:
|
10 |
+
DEPTH: 50
|
11 |
+
OUT_FEATURES: ["res4"]
|
12 |
+
NORM: FrozenBN
|
13 |
+
STEM_OUT_CHANNELS: 64
|
14 |
+
RES2_OUT_CHANNELS: 256
|
15 |
+
RPN:
|
16 |
+
HEAD_NAME: StandardRPNHead
|
17 |
+
IN_FEATURES: ["res4"]
|
18 |
+
ROI_HEADS:
|
19 |
+
NAME: "CLIPRes5ROIHeads" # "Res5ROIHeads" # "StandardROIHeads"
|
20 |
+
IN_FEATURES: ["res4"]
|
21 |
+
NUM_CLASSES: 1203
|
22 |
+
SCORE_THRESH_TEST: 0.0001
|
23 |
+
ROI_BOX_HEAD:
|
24 |
+
NAME: ""
|
25 |
+
NUM_FC: 0
|
26 |
+
POOLER_RESOLUTION: 14
|
27 |
+
ROI_MASK_HEAD:
|
28 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
29 |
+
NUM_CONV: 0
|
30 |
+
POOLER_RESOLUTION: 14
|
31 |
+
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] # [103.530, 116.280, 123.675] #
|
32 |
+
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] # [1.0, 1.0, 1.0] #
|
33 |
+
INPUT:
|
34 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
35 |
+
DATASETS:
|
36 |
+
TRAIN: ("lvis_v1_train",)
|
37 |
+
TEST: ("lvis_v1_val",)
|
38 |
+
TEST:
|
39 |
+
DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
|
40 |
+
EVAL_PERIOD: 25000
|
41 |
+
SOLVER:
|
42 |
+
IMS_PER_BATCH: 16
|
43 |
+
BASE_LR: 0.02
|
44 |
+
STEPS: (120000, 160000) # (140000,) #
|
45 |
+
MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
|
46 |
+
DATALOADER:
|
47 |
+
SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
|
48 |
+
REPEAT_THRESHOLD: 0.001
|
49 |
+
INPUT:
|
50 |
+
MIN_SIZE_TRAIN_SAMPLING: choice
|
51 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
52 |
+
MAX_SIZE_TRAIN: 1333
|
53 |
+
MIN_SIZE_TEST: 800
|
54 |
+
MAX_SIZE_TEST: 1333
|
55 |
+
FORMAT: "RGB" # "BGR"
|
configs/mask_rcnn_R_50_C4_1x.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-C4.yaml"
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
4 |
+
MASK_ON: True
|
5 |
+
RESNETS:
|
6 |
+
DEPTH: 50
|
7 |
+
ROI_HEADS:
|
8 |
+
NUM_CLASSES: 1203
|
9 |
+
SCORE_THRESH_TEST: 0.0001
|
10 |
+
INPUT:
|
11 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
12 |
+
DATASETS:
|
13 |
+
TRAIN: ("lvis_v1_train",)
|
14 |
+
TEST: ("lvis_v1_val",)
|
15 |
+
TEST:
|
16 |
+
DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
|
17 |
+
EVAL_PERIOD: 50000
|
18 |
+
SOLVER:
|
19 |
+
STEPS: (120000, 160000)
|
20 |
+
MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
|
21 |
+
DATALOADER:
|
22 |
+
SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
|
23 |
+
REPEAT_THRESHOLD: 0.001
|
configs/mask_rcnn_R_50_FPN_1x.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "./Base-RCNN-FPN.yaml"
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
4 |
+
MASK_ON: True
|
5 |
+
RESNETS:
|
6 |
+
DEPTH: 50
|
7 |
+
ROI_HEADS:
|
8 |
+
NUM_CLASSES: 1203
|
9 |
+
SCORE_THRESH_TEST: 0.0001
|
10 |
+
INPUT:
|
11 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
12 |
+
DATASETS:
|
13 |
+
TRAIN: ("lvis_v1_train",)
|
14 |
+
TEST: ("lvis_v1_val",)
|
15 |
+
TEST:
|
16 |
+
DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
|
17 |
+
EVAL_PERIOD: 50000
|
18 |
+
SOLVER:
|
19 |
+
STEPS: (120000, 160000)
|
20 |
+
MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
|
21 |
+
DATALOADER:
|
22 |
+
SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
|
23 |
+
REPEAT_THRESHOLD: 0.001
|
datasets/README.md
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use Builtin Datasets
|
2 |
+
|
3 |
+
A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog)
|
4 |
+
for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc).
|
5 |
+
This document explains how to setup the builtin datasets so they can be used by the above APIs.
|
6 |
+
[Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`,
|
7 |
+
and how to add new datasets to them.
|
8 |
+
|
9 |
+
Detectron2 has builtin support for a few datasets.
|
10 |
+
The datasets are assumed to exist in a directory specified by the environment variable
|
11 |
+
`DETECTRON2_DATASETS`.
|
12 |
+
Under this directory, detectron2 will look for datasets in the structure described below, if needed.
|
13 |
+
```
|
14 |
+
$DETECTRON2_DATASETS/
|
15 |
+
coco/
|
16 |
+
lvis/
|
17 |
+
cityscapes/
|
18 |
+
VOC20{07,12}/
|
19 |
+
```
|
20 |
+
|
21 |
+
You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`.
|
22 |
+
If left unset, the default is `./datasets` relative to your current working directory.
|
23 |
+
|
24 |
+
The [model zoo](https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md)
|
25 |
+
contains configs and models that use these builtin datasets.
|
26 |
+
|
27 |
+
## Expected dataset structure for [COCO instance/keypoint detection](https://cocodataset.org/#download):
|
28 |
+
|
29 |
+
```
|
30 |
+
coco/
|
31 |
+
annotations/
|
32 |
+
instances_{train,val}2017.json
|
33 |
+
person_keypoints_{train,val}2017.json
|
34 |
+
{train,val}2017/
|
35 |
+
# image files that are mentioned in the corresponding json
|
36 |
+
```
|
37 |
+
|
38 |
+
You can use the 2014 version of the dataset as well.
|
39 |
+
|
40 |
+
Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset,
|
41 |
+
which you can download with `./datasets/prepare_for_tests.sh`.
|
42 |
+
|
43 |
+
## Expected dataset structure for PanopticFPN:
|
44 |
+
|
45 |
+
Extract panoptic annotations from [COCO website](https://cocodataset.org/#download)
|
46 |
+
into the following structure:
|
47 |
+
```
|
48 |
+
coco/
|
49 |
+
annotations/
|
50 |
+
panoptic_{train,val}2017.json
|
51 |
+
panoptic_{train,val}2017/ # png annotations
|
52 |
+
panoptic_stuff_{train,val}2017/ # generated by the script mentioned below
|
53 |
+
```
|
54 |
+
|
55 |
+
Install panopticapi by:
|
56 |
+
```
|
57 |
+
pip install git+https://github.com/cocodataset/panopticapi.git
|
58 |
+
```
|
59 |
+
Then, run `python datasets/prepare_panoptic_fpn.py`, to extract semantic annotations from panoptic annotations.
|
60 |
+
|
61 |
+
## Expected dataset structure for [LVIS instance segmentation](https://www.lvisdataset.org/dataset):
|
62 |
+
```
|
63 |
+
coco/
|
64 |
+
{train,val,test}2017/
|
65 |
+
lvis/
|
66 |
+
lvis_v0.5_{train,val}.json
|
67 |
+
lvis_v0.5_image_info_test.json
|
68 |
+
lvis_v1_{train,val}.json
|
69 |
+
lvis_v1_image_info_test{,_challenge}.json
|
70 |
+
```
|
71 |
+
|
72 |
+
Install lvis-api by:
|
73 |
+
```
|
74 |
+
pip install git+https://github.com/lvis-dataset/lvis-api.git
|
75 |
+
```
|
76 |
+
|
77 |
+
To evaluate models trained on the COCO dataset using LVIS annotations,
|
78 |
+
run `python datasets/prepare_cocofied_lvis.py` to prepare "cocofied" LVIS annotations.
|
79 |
+
|
80 |
+
## Expected dataset structure for [cityscapes](https://www.cityscapes-dataset.com/downloads/):
|
81 |
+
```
|
82 |
+
cityscapes/
|
83 |
+
gtFine/
|
84 |
+
train/
|
85 |
+
aachen/
|
86 |
+
color.png, instanceIds.png, labelIds.png, polygons.json,
|
87 |
+
labelTrainIds.png
|
88 |
+
...
|
89 |
+
val/
|
90 |
+
test/
|
91 |
+
# below are generated Cityscapes panoptic annotation
|
92 |
+
cityscapes_panoptic_train.json
|
93 |
+
cityscapes_panoptic_train/
|
94 |
+
cityscapes_panoptic_val.json
|
95 |
+
cityscapes_panoptic_val/
|
96 |
+
cityscapes_panoptic_test.json
|
97 |
+
cityscapes_panoptic_test/
|
98 |
+
leftImg8bit/
|
99 |
+
train/
|
100 |
+
val/
|
101 |
+
test/
|
102 |
+
```
|
103 |
+
Install cityscapes scripts by:
|
104 |
+
```
|
105 |
+
pip install git+https://github.com/mcordts/cityscapesScripts.git
|
106 |
+
```
|
107 |
+
|
108 |
+
Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with:
|
109 |
+
```
|
110 |
+
CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py
|
111 |
+
```
|
112 |
+
These files are not needed for instance segmentation.
|
113 |
+
|
114 |
+
Note: to generate Cityscapes panoptic dataset, run cityscapesescript with:
|
115 |
+
```
|
116 |
+
CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py
|
117 |
+
```
|
118 |
+
These files are not needed for semantic and instance segmentation.
|
119 |
+
|
120 |
+
## Expected dataset structure for [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html):
|
121 |
+
```
|
122 |
+
VOC20{07,12}/
|
123 |
+
Annotations/
|
124 |
+
ImageSets/
|
125 |
+
Main/
|
126 |
+
trainval.txt
|
127 |
+
test.txt
|
128 |
+
# train.txt or val.txt, if you use these splits
|
129 |
+
JPEGImages/
|
130 |
+
```
|
131 |
+
|
132 |
+
## Expected dataset structure for [ADE20k Scene Parsing](http://sceneparsing.csail.mit.edu/):
|
133 |
+
```
|
134 |
+
ADEChallengeData2016/
|
135 |
+
annotations/
|
136 |
+
annotations_detectron2/
|
137 |
+
images/
|
138 |
+
objectInfo150.txt
|
139 |
+
```
|
140 |
+
The directory `annotations_detectron2` is generated by running `python datasets/prepare_ade20k_sem_seg.py`.
|
datasets/custom_images/dog_and_cat.jfif
ADDED
Binary file (121 kB). View file
|
|
datasets/prepare_ade20k_sem_seg.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import tqdm
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def convert(input, output):
|
12 |
+
img = np.asarray(Image.open(input))
|
13 |
+
assert img.dtype == np.uint8
|
14 |
+
img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1
|
15 |
+
Image.fromarray(img).save(output)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016"
|
20 |
+
for name in ["training", "validation"]:
|
21 |
+
annotation_dir = dataset_dir / "annotations" / name
|
22 |
+
output_dir = dataset_dir / "annotations_detectron2" / name
|
23 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
24 |
+
for file in tqdm.tqdm(list(annotation_dir.iterdir())):
|
25 |
+
output_file = output_dir / file.name
|
26 |
+
convert(file, output_file)
|
datasets/prepare_cocofied_lvis.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
+
# This mapping is extracted from the official LVIS mapping:
|
11 |
+
# https://github.com/lvis-dataset/lvis-api/blob/master/data/coco_to_synset.json
|
12 |
+
COCO_SYNSET_CATEGORIES = [
|
13 |
+
{"synset": "person.n.01", "coco_cat_id": 1},
|
14 |
+
{"synset": "bicycle.n.01", "coco_cat_id": 2},
|
15 |
+
{"synset": "car.n.01", "coco_cat_id": 3},
|
16 |
+
{"synset": "motorcycle.n.01", "coco_cat_id": 4},
|
17 |
+
{"synset": "airplane.n.01", "coco_cat_id": 5},
|
18 |
+
{"synset": "bus.n.01", "coco_cat_id": 6},
|
19 |
+
{"synset": "train.n.01", "coco_cat_id": 7},
|
20 |
+
{"synset": "truck.n.01", "coco_cat_id": 8},
|
21 |
+
{"synset": "boat.n.01", "coco_cat_id": 9},
|
22 |
+
{"synset": "traffic_light.n.01", "coco_cat_id": 10},
|
23 |
+
{"synset": "fireplug.n.01", "coco_cat_id": 11},
|
24 |
+
{"synset": "stop_sign.n.01", "coco_cat_id": 13},
|
25 |
+
{"synset": "parking_meter.n.01", "coco_cat_id": 14},
|
26 |
+
{"synset": "bench.n.01", "coco_cat_id": 15},
|
27 |
+
{"synset": "bird.n.01", "coco_cat_id": 16},
|
28 |
+
{"synset": "cat.n.01", "coco_cat_id": 17},
|
29 |
+
{"synset": "dog.n.01", "coco_cat_id": 18},
|
30 |
+
{"synset": "horse.n.01", "coco_cat_id": 19},
|
31 |
+
{"synset": "sheep.n.01", "coco_cat_id": 20},
|
32 |
+
{"synset": "beef.n.01", "coco_cat_id": 21},
|
33 |
+
{"synset": "elephant.n.01", "coco_cat_id": 22},
|
34 |
+
{"synset": "bear.n.01", "coco_cat_id": 23},
|
35 |
+
{"synset": "zebra.n.01", "coco_cat_id": 24},
|
36 |
+
{"synset": "giraffe.n.01", "coco_cat_id": 25},
|
37 |
+
{"synset": "backpack.n.01", "coco_cat_id": 27},
|
38 |
+
{"synset": "umbrella.n.01", "coco_cat_id": 28},
|
39 |
+
{"synset": "bag.n.04", "coco_cat_id": 31},
|
40 |
+
{"synset": "necktie.n.01", "coco_cat_id": 32},
|
41 |
+
{"synset": "bag.n.06", "coco_cat_id": 33},
|
42 |
+
{"synset": "frisbee.n.01", "coco_cat_id": 34},
|
43 |
+
{"synset": "ski.n.01", "coco_cat_id": 35},
|
44 |
+
{"synset": "snowboard.n.01", "coco_cat_id": 36},
|
45 |
+
{"synset": "ball.n.06", "coco_cat_id": 37},
|
46 |
+
{"synset": "kite.n.03", "coco_cat_id": 38},
|
47 |
+
{"synset": "baseball_bat.n.01", "coco_cat_id": 39},
|
48 |
+
{"synset": "baseball_glove.n.01", "coco_cat_id": 40},
|
49 |
+
{"synset": "skateboard.n.01", "coco_cat_id": 41},
|
50 |
+
{"synset": "surfboard.n.01", "coco_cat_id": 42},
|
51 |
+
{"synset": "tennis_racket.n.01", "coco_cat_id": 43},
|
52 |
+
{"synset": "bottle.n.01", "coco_cat_id": 44},
|
53 |
+
{"synset": "wineglass.n.01", "coco_cat_id": 46},
|
54 |
+
{"synset": "cup.n.01", "coco_cat_id": 47},
|
55 |
+
{"synset": "fork.n.01", "coco_cat_id": 48},
|
56 |
+
{"synset": "knife.n.01", "coco_cat_id": 49},
|
57 |
+
{"synset": "spoon.n.01", "coco_cat_id": 50},
|
58 |
+
{"synset": "bowl.n.03", "coco_cat_id": 51},
|
59 |
+
{"synset": "banana.n.02", "coco_cat_id": 52},
|
60 |
+
{"synset": "apple.n.01", "coco_cat_id": 53},
|
61 |
+
{"synset": "sandwich.n.01", "coco_cat_id": 54},
|
62 |
+
{"synset": "orange.n.01", "coco_cat_id": 55},
|
63 |
+
{"synset": "broccoli.n.01", "coco_cat_id": 56},
|
64 |
+
{"synset": "carrot.n.01", "coco_cat_id": 57},
|
65 |
+
{"synset": "frank.n.02", "coco_cat_id": 58},
|
66 |
+
{"synset": "pizza.n.01", "coco_cat_id": 59},
|
67 |
+
{"synset": "doughnut.n.02", "coco_cat_id": 60},
|
68 |
+
{"synset": "cake.n.03", "coco_cat_id": 61},
|
69 |
+
{"synset": "chair.n.01", "coco_cat_id": 62},
|
70 |
+
{"synset": "sofa.n.01", "coco_cat_id": 63},
|
71 |
+
{"synset": "pot.n.04", "coco_cat_id": 64},
|
72 |
+
{"synset": "bed.n.01", "coco_cat_id": 65},
|
73 |
+
{"synset": "dining_table.n.01", "coco_cat_id": 67},
|
74 |
+
{"synset": "toilet.n.02", "coco_cat_id": 70},
|
75 |
+
{"synset": "television_receiver.n.01", "coco_cat_id": 72},
|
76 |
+
{"synset": "laptop.n.01", "coco_cat_id": 73},
|
77 |
+
{"synset": "mouse.n.04", "coco_cat_id": 74},
|
78 |
+
{"synset": "remote_control.n.01", "coco_cat_id": 75},
|
79 |
+
{"synset": "computer_keyboard.n.01", "coco_cat_id": 76},
|
80 |
+
{"synset": "cellular_telephone.n.01", "coco_cat_id": 77},
|
81 |
+
{"synset": "microwave.n.02", "coco_cat_id": 78},
|
82 |
+
{"synset": "oven.n.01", "coco_cat_id": 79},
|
83 |
+
{"synset": "toaster.n.02", "coco_cat_id": 80},
|
84 |
+
{"synset": "sink.n.01", "coco_cat_id": 81},
|
85 |
+
{"synset": "electric_refrigerator.n.01", "coco_cat_id": 82},
|
86 |
+
{"synset": "book.n.01", "coco_cat_id": 84},
|
87 |
+
{"synset": "clock.n.01", "coco_cat_id": 85},
|
88 |
+
{"synset": "vase.n.01", "coco_cat_id": 86},
|
89 |
+
{"synset": "scissors.n.01", "coco_cat_id": 87},
|
90 |
+
{"synset": "teddy.n.01", "coco_cat_id": 88},
|
91 |
+
{"synset": "hand_blower.n.01", "coco_cat_id": 89},
|
92 |
+
{"synset": "toothbrush.n.01", "coco_cat_id": 90},
|
93 |
+
]
|
94 |
+
|
95 |
+
|
96 |
+
def cocofy_lvis(input_filename, output_filename):
|
97 |
+
"""
|
98 |
+
Filter LVIS instance segmentation annotations to remove all categories that are not included in
|
99 |
+
COCO. The new json files can be used to evaluate COCO AP using `lvis-api`. The category ids in
|
100 |
+
the output json are the incontiguous COCO dataset ids.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
input_filename (str): path to the LVIS json file.
|
104 |
+
output_filename (str): path to the COCOfied json file.
|
105 |
+
"""
|
106 |
+
|
107 |
+
with open(input_filename, "r") as f:
|
108 |
+
lvis_json = json.load(f)
|
109 |
+
|
110 |
+
lvis_annos = lvis_json.pop("annotations")
|
111 |
+
cocofied_lvis = copy.deepcopy(lvis_json)
|
112 |
+
lvis_json["annotations"] = lvis_annos
|
113 |
+
|
114 |
+
# Mapping from lvis cat id to coco cat id via synset
|
115 |
+
lvis_cat_id_to_synset = {cat["id"]: cat["synset"] for cat in lvis_json["categories"]}
|
116 |
+
synset_to_coco_cat_id = {x["synset"]: x["coco_cat_id"] for x in COCO_SYNSET_CATEGORIES}
|
117 |
+
# Synsets that we will keep in the dataset
|
118 |
+
synsets_to_keep = set(synset_to_coco_cat_id.keys())
|
119 |
+
coco_cat_id_with_instances = defaultdict(int)
|
120 |
+
|
121 |
+
new_annos = []
|
122 |
+
ann_id = 1
|
123 |
+
for ann in lvis_annos:
|
124 |
+
lvis_cat_id = ann["category_id"]
|
125 |
+
synset = lvis_cat_id_to_synset[lvis_cat_id]
|
126 |
+
if synset not in synsets_to_keep:
|
127 |
+
continue
|
128 |
+
coco_cat_id = synset_to_coco_cat_id[synset]
|
129 |
+
new_ann = copy.deepcopy(ann)
|
130 |
+
new_ann["category_id"] = coco_cat_id
|
131 |
+
new_ann["id"] = ann_id
|
132 |
+
ann_id += 1
|
133 |
+
new_annos.append(new_ann)
|
134 |
+
coco_cat_id_with_instances[coco_cat_id] += 1
|
135 |
+
cocofied_lvis["annotations"] = new_annos
|
136 |
+
|
137 |
+
for image in cocofied_lvis["images"]:
|
138 |
+
for key in ["not_exhaustive_category_ids", "neg_category_ids"]:
|
139 |
+
new_category_list = []
|
140 |
+
for lvis_cat_id in image[key]:
|
141 |
+
synset = lvis_cat_id_to_synset[lvis_cat_id]
|
142 |
+
if synset not in synsets_to_keep:
|
143 |
+
continue
|
144 |
+
coco_cat_id = synset_to_coco_cat_id[synset]
|
145 |
+
new_category_list.append(coco_cat_id)
|
146 |
+
coco_cat_id_with_instances[coco_cat_id] += 1
|
147 |
+
image[key] = new_category_list
|
148 |
+
|
149 |
+
coco_cat_id_with_instances = set(coco_cat_id_with_instances.keys())
|
150 |
+
|
151 |
+
new_categories = []
|
152 |
+
for cat in lvis_json["categories"]:
|
153 |
+
synset = cat["synset"]
|
154 |
+
if synset not in synsets_to_keep:
|
155 |
+
continue
|
156 |
+
coco_cat_id = synset_to_coco_cat_id[synset]
|
157 |
+
if coco_cat_id not in coco_cat_id_with_instances:
|
158 |
+
continue
|
159 |
+
new_cat = copy.deepcopy(cat)
|
160 |
+
new_cat["id"] = coco_cat_id
|
161 |
+
new_categories.append(new_cat)
|
162 |
+
cocofied_lvis["categories"] = new_categories
|
163 |
+
|
164 |
+
with open(output_filename, "w") as f:
|
165 |
+
json.dump(cocofied_lvis, f)
|
166 |
+
print("{} is COCOfied and stored in {}.".format(input_filename, output_filename))
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "lvis")
|
171 |
+
for s in ["lvis_v0.5_train", "lvis_v0.5_val"]:
|
172 |
+
print("Start COCOfing {}.".format(s))
|
173 |
+
cocofy_lvis(
|
174 |
+
os.path.join(dataset_dir, "{}.json".format(s)),
|
175 |
+
os.path.join(dataset_dir, "{}_cocofied.json".format(s)),
|
176 |
+
)
|
datasets/prepare_for_tests.sh
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -e
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
# Download some files needed for running tests.
|
5 |
+
|
6 |
+
cd "${0%/*}"
|
7 |
+
|
8 |
+
BASE=https://dl.fbaipublicfiles.com/detectron2
|
9 |
+
mkdir -p coco/annotations
|
10 |
+
|
11 |
+
for anno in instances_val2017_100 \
|
12 |
+
person_keypoints_val2017_100 \
|
13 |
+
instances_minival2014_100 \
|
14 |
+
person_keypoints_minival2014_100; do
|
15 |
+
|
16 |
+
dest=coco/annotations/$anno.json
|
17 |
+
[[ -s $dest ]] && {
|
18 |
+
echo "$dest exists. Skipping ..."
|
19 |
+
} || {
|
20 |
+
wget $BASE/annotations/coco/$anno.json -O $dest
|
21 |
+
}
|
22 |
+
done
|
datasets/prepare_panoptic_fpn.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import functools
|
6 |
+
import json
|
7 |
+
import multiprocessing as mp
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
from fvcore.common.download import download
|
12 |
+
from panopticapi.utils import rgb2id
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
16 |
+
|
17 |
+
|
18 |
+
def _process_panoptic_to_semantic(input_panoptic, output_semantic, segments, id_map):
|
19 |
+
panoptic = np.asarray(Image.open(input_panoptic), dtype=np.uint32)
|
20 |
+
panoptic = rgb2id(panoptic)
|
21 |
+
output = np.zeros_like(panoptic, dtype=np.uint8) + 255
|
22 |
+
for seg in segments:
|
23 |
+
cat_id = seg["category_id"]
|
24 |
+
new_cat_id = id_map[cat_id]
|
25 |
+
output[panoptic == seg["id"]] = new_cat_id
|
26 |
+
Image.fromarray(output).save(output_semantic)
|
27 |
+
|
28 |
+
|
29 |
+
def separate_coco_semantic_from_panoptic(panoptic_json, panoptic_root, sem_seg_root, categories):
|
30 |
+
"""
|
31 |
+
Create semantic segmentation annotations from panoptic segmentation
|
32 |
+
annotations, to be used by PanopticFPN.
|
33 |
+
|
34 |
+
It maps all thing categories to class 0, and maps all unlabeled pixels to class 255.
|
35 |
+
It maps all stuff categories to contiguous ids starting from 1.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
panoptic_json (str): path to the panoptic json file, in COCO's format.
|
39 |
+
panoptic_root (str): a directory with panoptic annotation files, in COCO's format.
|
40 |
+
sem_seg_root (str): a directory to output semantic annotation files
|
41 |
+
categories (list[dict]): category metadata. Each dict needs to have:
|
42 |
+
"id": corresponds to the "category_id" in the json annotations
|
43 |
+
"isthing": 0 or 1
|
44 |
+
"""
|
45 |
+
os.makedirs(sem_seg_root, exist_ok=True)
|
46 |
+
|
47 |
+
stuff_ids = [k["id"] for k in categories if k["isthing"] == 0]
|
48 |
+
thing_ids = [k["id"] for k in categories if k["isthing"] == 1]
|
49 |
+
id_map = {} # map from category id to id in the output semantic annotation
|
50 |
+
assert len(stuff_ids) <= 254
|
51 |
+
for i, stuff_id in enumerate(stuff_ids):
|
52 |
+
id_map[stuff_id] = i + 1
|
53 |
+
for thing_id in thing_ids:
|
54 |
+
id_map[thing_id] = 0
|
55 |
+
id_map[0] = 255
|
56 |
+
|
57 |
+
with open(panoptic_json) as f:
|
58 |
+
obj = json.load(f)
|
59 |
+
|
60 |
+
pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4))
|
61 |
+
|
62 |
+
def iter_annotations():
|
63 |
+
for anno in obj["annotations"]:
|
64 |
+
file_name = anno["file_name"]
|
65 |
+
segments = anno["segments_info"]
|
66 |
+
input = os.path.join(panoptic_root, file_name)
|
67 |
+
output = os.path.join(sem_seg_root, file_name)
|
68 |
+
yield input, output, segments
|
69 |
+
|
70 |
+
print("Start writing to {} ...".format(sem_seg_root))
|
71 |
+
start = time.time()
|
72 |
+
pool.starmap(
|
73 |
+
functools.partial(_process_panoptic_to_semantic, id_map=id_map),
|
74 |
+
iter_annotations(),
|
75 |
+
chunksize=100,
|
76 |
+
)
|
77 |
+
print("Finished. time: {:.2f}s".format(time.time() - start))
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "coco")
|
82 |
+
for s in ["val2017", "train2017"]:
|
83 |
+
separate_coco_semantic_from_panoptic(
|
84 |
+
os.path.join(dataset_dir, "annotations/panoptic_{}.json".format(s)),
|
85 |
+
os.path.join(dataset_dir, "panoptic_{}".format(s)),
|
86 |
+
os.path.join(dataset_dir, "panoptic_stuff_{}".format(s)),
|
87 |
+
COCO_CATEGORIES,
|
88 |
+
)
|
89 |
+
|
90 |
+
# Prepare val2017_100 for quick testing:
|
91 |
+
|
92 |
+
dest_dir = os.path.join(dataset_dir, "annotations/")
|
93 |
+
URL_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/"
|
94 |
+
download(URL_PREFIX + "annotations/coco/panoptic_val2017_100.json", dest_dir)
|
95 |
+
with open(os.path.join(dest_dir, "panoptic_val2017_100.json")) as f:
|
96 |
+
obj = json.load(f)
|
97 |
+
|
98 |
+
def link_val100(dir_full, dir_100):
|
99 |
+
print("Creating " + dir_100 + " ...")
|
100 |
+
os.makedirs(dir_100, exist_ok=True)
|
101 |
+
for img in obj["images"]:
|
102 |
+
basename = os.path.splitext(img["file_name"])[0]
|
103 |
+
src = os.path.join(dir_full, basename + ".png")
|
104 |
+
dst = os.path.join(dir_100, basename + ".png")
|
105 |
+
src = os.path.relpath(src, start=dir_100)
|
106 |
+
os.symlink(src, dst)
|
107 |
+
|
108 |
+
link_val100(
|
109 |
+
os.path.join(dataset_dir, "panoptic_val2017"),
|
110 |
+
os.path.join(dataset_dir, "panoptic_val2017_100"),
|
111 |
+
)
|
112 |
+
|
113 |
+
link_val100(
|
114 |
+
os.path.join(dataset_dir, "panoptic_stuff_val2017"),
|
115 |
+
os.path.join(dataset_dir, "panoptic_stuff_val2017_100"),
|
116 |
+
)
|
detectron2/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
from .utils.env import setup_environment
|
4 |
+
|
5 |
+
setup_environment()
|
6 |
+
|
7 |
+
|
8 |
+
# This line will be programatically read/write by setup.py.
|
9 |
+
# Leave them at the bottom of this file and don't touch them.
|
10 |
+
__version__ = "0.4"
|
detectron2/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (223 Bytes). View file
|
|
detectron2/checkpoint/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# File:
|
4 |
+
|
5 |
+
|
6 |
+
from . import catalog as _UNUSED # register the handler
|
7 |
+
from .detection_checkpoint import DetectionCheckpointer
|
8 |
+
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
9 |
+
|
10 |
+
__all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"]
|
detectron2/checkpoint/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (391 Bytes). View file
|
|
detectron2/checkpoint/__pycache__/c2_model_loading.cpython-39.pyc
ADDED
Binary file (16.6 kB). View file
|
|
detectron2/checkpoint/__pycache__/catalog.cpython-39.pyc
ADDED
Binary file (4.82 kB). View file
|
|
detectron2/checkpoint/__pycache__/clip_model_loading.cpython-39.pyc
ADDED
Binary file (14.5 kB). View file
|
|
detectron2/checkpoint/__pycache__/detection_checkpoint.cpython-39.pyc
ADDED
Binary file (3.88 kB). View file
|
|
detectron2/checkpoint/c2_model_loading.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
from typing import Dict, List
|
6 |
+
import torch
|
7 |
+
from tabulate import tabulate
|
8 |
+
|
9 |
+
|
10 |
+
def convert_basic_c2_names(original_keys):
|
11 |
+
"""
|
12 |
+
Apply some basic name conversion to names in C2 weights.
|
13 |
+
It only deals with typical backbone models.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
original_keys (list[str]):
|
17 |
+
Returns:
|
18 |
+
list[str]: The same number of strings matching those in original_keys.
|
19 |
+
"""
|
20 |
+
layer_keys = copy.deepcopy(original_keys)
|
21 |
+
layer_keys = [
|
22 |
+
{"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
|
23 |
+
] # some hard-coded mappings
|
24 |
+
|
25 |
+
layer_keys = [k.replace("_", ".") for k in layer_keys]
|
26 |
+
layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
|
27 |
+
layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
|
28 |
+
# Uniform both bn and gn names to "norm"
|
29 |
+
layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
|
30 |
+
layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
|
31 |
+
layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
|
32 |
+
layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
|
33 |
+
layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
|
34 |
+
layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
|
35 |
+
layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
|
36 |
+
layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
|
37 |
+
layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
|
38 |
+
layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
|
39 |
+
|
40 |
+
# stem
|
41 |
+
layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
|
42 |
+
# to avoid mis-matching with "conv1" in other components (e.g. detection head)
|
43 |
+
layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
|
44 |
+
|
45 |
+
# layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
|
46 |
+
# layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
|
47 |
+
# layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
|
48 |
+
# layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
|
49 |
+
# layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
|
50 |
+
|
51 |
+
# blocks
|
52 |
+
layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
|
53 |
+
layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
|
54 |
+
layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
|
55 |
+
layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
|
56 |
+
|
57 |
+
# DensePose substitutions
|
58 |
+
layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
|
59 |
+
layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
|
60 |
+
layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
|
61 |
+
layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
|
62 |
+
layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
|
63 |
+
return layer_keys
|
64 |
+
|
65 |
+
|
66 |
+
def convert_c2_detectron_names(weights):
|
67 |
+
"""
|
68 |
+
Map Caffe2 Detectron weight names to Detectron2 names.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
weights (dict): name -> tensor
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
dict: detectron2 names -> tensor
|
75 |
+
dict: detectron2 names -> C2 names
|
76 |
+
"""
|
77 |
+
logger = logging.getLogger(__name__)
|
78 |
+
logger.info("Renaming Caffe2 weights ......")
|
79 |
+
original_keys = sorted(weights.keys())
|
80 |
+
layer_keys = copy.deepcopy(original_keys)
|
81 |
+
|
82 |
+
layer_keys = convert_basic_c2_names(layer_keys)
|
83 |
+
|
84 |
+
# --------------------------------------------------------------------------
|
85 |
+
# RPN hidden representation conv
|
86 |
+
# --------------------------------------------------------------------------
|
87 |
+
# FPN case
|
88 |
+
# In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
|
89 |
+
# shared for all other levels, hence the appearance of "fpn2"
|
90 |
+
layer_keys = [
|
91 |
+
k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
|
92 |
+
]
|
93 |
+
# Non-FPN case
|
94 |
+
layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
|
95 |
+
|
96 |
+
# --------------------------------------------------------------------------
|
97 |
+
# RPN box transformation conv
|
98 |
+
# --------------------------------------------------------------------------
|
99 |
+
# FPN case (see note above about "fpn2")
|
100 |
+
layer_keys = [
|
101 |
+
k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
|
102 |
+
for k in layer_keys
|
103 |
+
]
|
104 |
+
layer_keys = [
|
105 |
+
k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
|
106 |
+
for k in layer_keys
|
107 |
+
]
|
108 |
+
# Non-FPN case
|
109 |
+
layer_keys = [
|
110 |
+
k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
|
111 |
+
]
|
112 |
+
layer_keys = [
|
113 |
+
k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
|
114 |
+
for k in layer_keys
|
115 |
+
]
|
116 |
+
|
117 |
+
# --------------------------------------------------------------------------
|
118 |
+
# Fast R-CNN box head
|
119 |
+
# --------------------------------------------------------------------------
|
120 |
+
layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
|
121 |
+
layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
|
122 |
+
layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
|
123 |
+
layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
|
124 |
+
# 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
|
125 |
+
layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
|
126 |
+
|
127 |
+
# --------------------------------------------------------------------------
|
128 |
+
# FPN lateral and output convolutions
|
129 |
+
# --------------------------------------------------------------------------
|
130 |
+
def fpn_map(name):
|
131 |
+
"""
|
132 |
+
Look for keys with the following patterns:
|
133 |
+
1) Starts with "fpn.inner."
|
134 |
+
Example: "fpn.inner.res2.2.sum.lateral.weight"
|
135 |
+
Meaning: These are lateral pathway convolutions
|
136 |
+
2) Starts with "fpn.res"
|
137 |
+
Example: "fpn.res2.2.sum.weight"
|
138 |
+
Meaning: These are FPN output convolutions
|
139 |
+
"""
|
140 |
+
splits = name.split(".")
|
141 |
+
norm = ".norm" if "norm" in splits else ""
|
142 |
+
if name.startswith("fpn.inner."):
|
143 |
+
# splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
|
144 |
+
stage = int(splits[2][len("res") :])
|
145 |
+
return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
|
146 |
+
elif name.startswith("fpn.res"):
|
147 |
+
# splits example: ['fpn', 'res2', '2', 'sum', 'weight']
|
148 |
+
stage = int(splits[1][len("res") :])
|
149 |
+
return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
|
150 |
+
return name
|
151 |
+
|
152 |
+
layer_keys = [fpn_map(k) for k in layer_keys]
|
153 |
+
|
154 |
+
# --------------------------------------------------------------------------
|
155 |
+
# Mask R-CNN mask head
|
156 |
+
# --------------------------------------------------------------------------
|
157 |
+
# roi_heads.StandardROIHeads case
|
158 |
+
layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
|
159 |
+
layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
|
160 |
+
layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
|
161 |
+
# roi_heads.Res5ROIHeads case
|
162 |
+
layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
|
163 |
+
|
164 |
+
# --------------------------------------------------------------------------
|
165 |
+
# Keypoint R-CNN head
|
166 |
+
# --------------------------------------------------------------------------
|
167 |
+
# interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
|
168 |
+
layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
|
169 |
+
layer_keys = [
|
170 |
+
k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
|
171 |
+
]
|
172 |
+
layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
|
173 |
+
|
174 |
+
# --------------------------------------------------------------------------
|
175 |
+
# Done with replacements
|
176 |
+
# --------------------------------------------------------------------------
|
177 |
+
assert len(set(layer_keys)) == len(layer_keys)
|
178 |
+
assert len(original_keys) == len(layer_keys)
|
179 |
+
|
180 |
+
new_weights = {}
|
181 |
+
new_keys_to_original_keys = {}
|
182 |
+
for orig, renamed in zip(original_keys, layer_keys):
|
183 |
+
new_keys_to_original_keys[renamed] = orig
|
184 |
+
if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
|
185 |
+
# remove the meaningless prediction weight for background class
|
186 |
+
new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
|
187 |
+
new_weights[renamed] = weights[orig][new_start_idx:]
|
188 |
+
logger.info(
|
189 |
+
"Remove prediction weight for background class in {}. The shape changes from "
|
190 |
+
"{} to {}.".format(
|
191 |
+
renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
|
192 |
+
)
|
193 |
+
)
|
194 |
+
elif renamed.startswith("cls_score."):
|
195 |
+
# move weights of bg class from original index 0 to last index
|
196 |
+
logger.info(
|
197 |
+
"Move classification weights for background class in {} from index 0 to "
|
198 |
+
"index {}.".format(renamed, weights[orig].shape[0] - 1)
|
199 |
+
)
|
200 |
+
new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
|
201 |
+
else:
|
202 |
+
new_weights[renamed] = weights[orig]
|
203 |
+
|
204 |
+
return new_weights, new_keys_to_original_keys
|
205 |
+
|
206 |
+
|
207 |
+
# Note the current matching is not symmetric.
|
208 |
+
# it assumes model_state_dict will have longer names.
|
209 |
+
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
|
210 |
+
"""
|
211 |
+
Match names between the two state-dict, and returns a new chkpt_state_dict with names
|
212 |
+
converted to match model_state_dict with heuristics. The returned dict can be later
|
213 |
+
loaded with fvcore checkpointer.
|
214 |
+
If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
|
215 |
+
model and will be renamed at first.
|
216 |
+
|
217 |
+
Strategy: suppose that the models that we will create will have prefixes appended
|
218 |
+
to each of its keys, for example due to an extra level of nesting that the original
|
219 |
+
pre-trained weights from ImageNet won't contain. For example, model.state_dict()
|
220 |
+
might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
|
221 |
+
res2.conv1.weight. We thus want to match both parameters together.
|
222 |
+
For that, we look for each model weight, look among all loaded keys if there is one
|
223 |
+
that is a suffix of the current weight name, and use it if that's the case.
|
224 |
+
If multiple matches exist, take the one with longest size
|
225 |
+
of the corresponding name. For example, for the same model as before, the pretrained
|
226 |
+
weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
|
227 |
+
we want to match backbone[0].body.conv1.weight to conv1.weight, and
|
228 |
+
backbone[0].body.res2.conv1.weight to res2.conv1.weight.
|
229 |
+
"""
|
230 |
+
model_keys = sorted(model_state_dict.keys())
|
231 |
+
if c2_conversion:
|
232 |
+
ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
|
233 |
+
# original_keys: the name in the original dict (before renaming)
|
234 |
+
else:
|
235 |
+
original_keys = {x: x for x in ckpt_state_dict.keys()}
|
236 |
+
ckpt_keys = sorted(ckpt_state_dict.keys())
|
237 |
+
|
238 |
+
def match(a, b):
|
239 |
+
# Matched ckpt_key should be a complete (starts with '.') suffix.
|
240 |
+
# For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
|
241 |
+
# but matches whatever_conv1 or mesh_head.whatever_conv1.
|
242 |
+
return a == b or a.endswith("." + b)
|
243 |
+
|
244 |
+
# get a matrix of string matches, where each (i, j) entry correspond to the size of the
|
245 |
+
# ckpt_key string, if it matches
|
246 |
+
match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
|
247 |
+
match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
|
248 |
+
# use the matched one with longest size in case of multiple matches
|
249 |
+
max_match_size, idxs = match_matrix.max(1)
|
250 |
+
# remove indices that correspond to no-match
|
251 |
+
idxs[max_match_size == 0] = -1
|
252 |
+
|
253 |
+
logger = logging.getLogger(__name__)
|
254 |
+
# matched_pairs (matched checkpoint key --> matched model key)
|
255 |
+
matched_keys = {}
|
256 |
+
result_state_dict = {}
|
257 |
+
for idx_model, idx_ckpt in enumerate(idxs.tolist()):
|
258 |
+
if idx_ckpt == -1:
|
259 |
+
continue
|
260 |
+
key_model = model_keys[idx_model]
|
261 |
+
key_ckpt = ckpt_keys[idx_ckpt]
|
262 |
+
value_ckpt = ckpt_state_dict[key_ckpt]
|
263 |
+
shape_in_model = model_state_dict[key_model].shape
|
264 |
+
|
265 |
+
if shape_in_model != value_ckpt.shape:
|
266 |
+
logger.warning(
|
267 |
+
"Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
|
268 |
+
key_ckpt, value_ckpt.shape, key_model, shape_in_model
|
269 |
+
)
|
270 |
+
)
|
271 |
+
logger.warning(
|
272 |
+
"{} will not be loaded. Please double check and see if this is desired.".format(
|
273 |
+
key_ckpt
|
274 |
+
)
|
275 |
+
)
|
276 |
+
continue
|
277 |
+
|
278 |
+
assert key_model not in result_state_dict
|
279 |
+
result_state_dict[key_model] = value_ckpt
|
280 |
+
if key_ckpt in matched_keys: # already added to matched_keys
|
281 |
+
logger.error(
|
282 |
+
"Ambiguity found for {} in checkpoint!"
|
283 |
+
"It matches at least two keys in the model ({} and {}).".format(
|
284 |
+
key_ckpt, key_model, matched_keys[key_ckpt]
|
285 |
+
)
|
286 |
+
)
|
287 |
+
raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
|
288 |
+
|
289 |
+
matched_keys[key_ckpt] = key_model
|
290 |
+
|
291 |
+
# logging:
|
292 |
+
matched_model_keys = sorted(matched_keys.values())
|
293 |
+
if len(matched_model_keys) == 0:
|
294 |
+
logger.warning("No weights in checkpoint matched with model.")
|
295 |
+
return ckpt_state_dict
|
296 |
+
common_prefix = _longest_common_prefix(matched_model_keys)
|
297 |
+
rev_matched_keys = {v: k for k, v in matched_keys.items()}
|
298 |
+
original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
|
299 |
+
|
300 |
+
model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
|
301 |
+
table = []
|
302 |
+
memo = set()
|
303 |
+
for key_model in matched_model_keys:
|
304 |
+
if key_model in memo:
|
305 |
+
continue
|
306 |
+
if key_model in model_key_groups:
|
307 |
+
group = model_key_groups[key_model]
|
308 |
+
memo |= set(group)
|
309 |
+
shapes = [tuple(model_state_dict[k].shape) for k in group]
|
310 |
+
table.append(
|
311 |
+
(
|
312 |
+
_longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
|
313 |
+
_group_str([original_keys[k] for k in group]),
|
314 |
+
" ".join([str(x).replace(" ", "") for x in shapes]),
|
315 |
+
)
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
key_checkpoint = original_keys[key_model]
|
319 |
+
shape = str(tuple(model_state_dict[key_model].shape))
|
320 |
+
table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
|
321 |
+
table_str = tabulate(
|
322 |
+
table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"]
|
323 |
+
)
|
324 |
+
logger.info(
|
325 |
+
"Following weights matched with "
|
326 |
+
+ (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
|
327 |
+
+ ":\n"
|
328 |
+
+ table_str
|
329 |
+
)
|
330 |
+
|
331 |
+
unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
|
332 |
+
for k in unmatched_ckpt_keys:
|
333 |
+
result_state_dict[k] = ckpt_state_dict[k]
|
334 |
+
return result_state_dict
|
335 |
+
|
336 |
+
|
337 |
+
def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
|
338 |
+
"""
|
339 |
+
Params in the same submodule are grouped together.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
keys: names of all parameters
|
343 |
+
original_names: mapping from parameter name to their name in the checkpoint
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
dict[name -> all other names in the same group]
|
347 |
+
"""
|
348 |
+
|
349 |
+
def _submodule_name(key):
|
350 |
+
pos = key.rfind(".")
|
351 |
+
if pos < 0:
|
352 |
+
return None
|
353 |
+
prefix = key[: pos + 1]
|
354 |
+
return prefix
|
355 |
+
|
356 |
+
all_submodules = [_submodule_name(k) for k in keys]
|
357 |
+
all_submodules = [x for x in all_submodules if x]
|
358 |
+
all_submodules = sorted(all_submodules, key=len)
|
359 |
+
|
360 |
+
ret = {}
|
361 |
+
for prefix in all_submodules:
|
362 |
+
group = [k for k in keys if k.startswith(prefix)]
|
363 |
+
if len(group) <= 1:
|
364 |
+
continue
|
365 |
+
original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
|
366 |
+
if len(original_name_lcp) == 0:
|
367 |
+
# don't group weights if original names don't share prefix
|
368 |
+
continue
|
369 |
+
|
370 |
+
for k in group:
|
371 |
+
if k in ret:
|
372 |
+
continue
|
373 |
+
ret[k] = group
|
374 |
+
return ret
|
375 |
+
|
376 |
+
|
377 |
+
def _longest_common_prefix(names: List[str]) -> str:
|
378 |
+
"""
|
379 |
+
["abc.zfg", "abc.zef"] -> "abc."
|
380 |
+
"""
|
381 |
+
names = [n.split(".") for n in names]
|
382 |
+
m1, m2 = min(names), max(names)
|
383 |
+
ret = [a for a, b in zip(m1, m2) if a == b]
|
384 |
+
ret = ".".join(ret) + "." if len(ret) else ""
|
385 |
+
return ret
|
386 |
+
|
387 |
+
|
388 |
+
def _longest_common_prefix_str(names: List[str]) -> str:
|
389 |
+
m1, m2 = min(names), max(names)
|
390 |
+
lcp = [a for a, b in zip(m1, m2) if a == b]
|
391 |
+
lcp = "".join(lcp)
|
392 |
+
return lcp
|
393 |
+
|
394 |
+
|
395 |
+
def _group_str(names: List[str]) -> str:
|
396 |
+
"""
|
397 |
+
Turn "common1", "common2", "common3" into "common{1,2,3}"
|
398 |
+
"""
|
399 |
+
lcp = _longest_common_prefix_str(names)
|
400 |
+
rest = [x[len(lcp) :] for x in names]
|
401 |
+
rest = "{" + ",".join(rest) + "}"
|
402 |
+
ret = lcp + rest
|
403 |
+
|
404 |
+
# add some simplification for BN specifically
|
405 |
+
ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
|
406 |
+
ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
|
407 |
+
return ret
|
detectron2/checkpoint/catalog.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from detectron2.utils.file_io import PathHandler, PathManager
|
5 |
+
|
6 |
+
|
7 |
+
class ModelCatalog(object):
|
8 |
+
"""
|
9 |
+
Store mappings from names to third-party models.
|
10 |
+
"""
|
11 |
+
|
12 |
+
S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron"
|
13 |
+
|
14 |
+
# MSRA models have STRIDE_IN_1X1=True. False otherwise.
|
15 |
+
# NOTE: all BN models here have fused BN into an affine layer.
|
16 |
+
# As a result, you should only load them to a model with "FrozenBN".
|
17 |
+
# Loading them to a model with regular BN or SyncBN is wrong.
|
18 |
+
# Even when loaded to FrozenBN, it is still different from affine by an epsilon,
|
19 |
+
# which should be negligible for training.
|
20 |
+
# NOTE: all models here uses PIXEL_STD=[1,1,1]
|
21 |
+
# NOTE: Most of the BN models here are no longer used. We use the
|
22 |
+
# re-converted pre-trained models under detectron2 model zoo instead.
|
23 |
+
C2_IMAGENET_MODELS = {
|
24 |
+
"MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
|
25 |
+
"MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
|
26 |
+
"FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
|
27 |
+
"FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
|
28 |
+
"FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
|
29 |
+
"FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl",
|
30 |
+
"FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl",
|
31 |
+
}
|
32 |
+
|
33 |
+
C2_DETECTRON_PATH_FORMAT = (
|
34 |
+
"{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950
|
35 |
+
)
|
36 |
+
|
37 |
+
C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival"
|
38 |
+
C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival"
|
39 |
+
|
40 |
+
# format: {model_name} -> part of the url
|
41 |
+
C2_DETECTRON_MODELS = {
|
42 |
+
"35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950
|
43 |
+
"35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950
|
44 |
+
"35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950
|
45 |
+
"36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950
|
46 |
+
"35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950
|
47 |
+
"35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950
|
48 |
+
"35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950
|
49 |
+
"36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950
|
50 |
+
"48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950
|
51 |
+
"37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950
|
52 |
+
"35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950
|
53 |
+
"35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950
|
54 |
+
"36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950
|
55 |
+
}
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def get(name):
|
59 |
+
if name.startswith("Caffe2Detectron/COCO"):
|
60 |
+
return ModelCatalog._get_c2_detectron_baseline(name)
|
61 |
+
if name.startswith("ImageNetPretrained/"):
|
62 |
+
return ModelCatalog._get_c2_imagenet_pretrained(name)
|
63 |
+
raise RuntimeError("model not present in the catalog: {}".format(name))
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def _get_c2_imagenet_pretrained(name):
|
67 |
+
prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX
|
68 |
+
name = name[len("ImageNetPretrained/") :]
|
69 |
+
name = ModelCatalog.C2_IMAGENET_MODELS[name]
|
70 |
+
url = "/".join([prefix, name])
|
71 |
+
return url
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def _get_c2_detectron_baseline(name):
|
75 |
+
name = name[len("Caffe2Detectron/COCO/") :]
|
76 |
+
url = ModelCatalog.C2_DETECTRON_MODELS[name]
|
77 |
+
if "keypoint_rcnn" in name:
|
78 |
+
dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS
|
79 |
+
else:
|
80 |
+
dataset = ModelCatalog.C2_DATASET_COCO
|
81 |
+
|
82 |
+
if "35998355/rpn_R-50-C4_1x" in name:
|
83 |
+
# this one model is somehow different from others ..
|
84 |
+
type = "rpn"
|
85 |
+
else:
|
86 |
+
type = "generalized_rcnn"
|
87 |
+
|
88 |
+
# Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`.
|
89 |
+
url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format(
|
90 |
+
prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset
|
91 |
+
)
|
92 |
+
return url
|
93 |
+
|
94 |
+
|
95 |
+
class ModelCatalogHandler(PathHandler):
|
96 |
+
"""
|
97 |
+
Resolve URL like catalog://.
|
98 |
+
"""
|
99 |
+
|
100 |
+
PREFIX = "catalog://"
|
101 |
+
|
102 |
+
def _get_supported_prefixes(self):
|
103 |
+
return [self.PREFIX]
|
104 |
+
|
105 |
+
def _get_local_path(self, path, **kwargs):
|
106 |
+
logger = logging.getLogger(__name__)
|
107 |
+
catalog_path = ModelCatalog.get(path[len(self.PREFIX) :])
|
108 |
+
logger.info("Catalog entry {} points to {}".format(path, catalog_path))
|
109 |
+
return PathManager.get_local_path(catalog_path, **kwargs)
|
110 |
+
|
111 |
+
def _open(self, path, mode="r", **kwargs):
|
112 |
+
return PathManager.open(self._get_local_path(path), mode, **kwargs)
|
113 |
+
|
114 |
+
|
115 |
+
PathManager.register_handler(ModelCatalogHandler())
|
detectron2/checkpoint/clip_model_loading.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
from typing import Dict, List
|
6 |
+
import torch
|
7 |
+
from tabulate import tabulate
|
8 |
+
|
9 |
+
|
10 |
+
def convert_basic_clip_names(original_keys, add_backbone_prefix=False, use_whole_clip=False, use_fpn_arch=False, regionclip=False):
|
11 |
+
"""
|
12 |
+
Apply some basic name conversion to names in CLIP weights.
|
13 |
+
It only deals with typical backbone models.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
original_keys (list[str]):
|
17 |
+
Returns:
|
18 |
+
list[str]: The same number of strings matching those in original_keys.
|
19 |
+
"""
|
20 |
+
layer_keys = copy.deepcopy(original_keys)
|
21 |
+
|
22 |
+
vit = False
|
23 |
+
for l_k in layer_keys:
|
24 |
+
if 'visual.transformer' in l_k:
|
25 |
+
vit = True
|
26 |
+
|
27 |
+
# load pretrained oai clip
|
28 |
+
if not vit: # resnet
|
29 |
+
if add_backbone_prefix: # CLIPRCNN or CLIPFastRCNN
|
30 |
+
if use_whole_clip: # CLIPRCNN
|
31 |
+
layer_keys = [k.replace("visual.", "clip_backbone.visual.") for k in layer_keys]
|
32 |
+
else: # CLIPFastRCNN
|
33 |
+
if use_fpn_arch: # FPN
|
34 |
+
layer_keys = [k.replace("visual.", "backbone.bottom_up.") for k in layer_keys]
|
35 |
+
else: # C4
|
36 |
+
layer_keys = [k.replace("visual.", "backbone.") for k in layer_keys]
|
37 |
+
else: # GeneralizedRCNN or ProposalNetwork
|
38 |
+
#layer_keys = [k.replace("visual.", "backbone.bottom_up.") for k in layer_keys] #
|
39 |
+
layer_keys = [k.replace("visual.", "") for k in layer_keys] #
|
40 |
+
#layer_keys = [k.replace("visual.", "backbone.visual.") for k in layer_keys] #
|
41 |
+
else: # vit
|
42 |
+
pass
|
43 |
+
|
44 |
+
return layer_keys, vit
|
45 |
+
|
46 |
+
|
47 |
+
def convert_clip_names(weights, add_backbone_prefix=False, use_whole_clip=False, use_fpn_arch=False, regionclip=False):
|
48 |
+
"""
|
49 |
+
Map CLIP Detectron weight names to Detectron2 names.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
weights (dict): name -> tensor
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
dict: detectron2 names -> tensor
|
56 |
+
dict: detectron2 names -> C2 names
|
57 |
+
"""
|
58 |
+
logger = logging.getLogger(__name__)
|
59 |
+
logger.info("Renaming CLIP weights ......")
|
60 |
+
original_keys = sorted(weights.keys())
|
61 |
+
layer_keys = copy.deepcopy(original_keys)
|
62 |
+
|
63 |
+
layer_keys, use_vit = convert_basic_clip_names(layer_keys, add_backbone_prefix, use_whole_clip, use_fpn_arch, regionclip)
|
64 |
+
|
65 |
+
# --------------------------------------------------------------------------
|
66 |
+
# RPN hidden representation conv
|
67 |
+
# --------------------------------------------------------------------------
|
68 |
+
# FPN case
|
69 |
+
# In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
|
70 |
+
# shared for all other levels, hence the appearance of "fpn2"
|
71 |
+
layer_keys = [
|
72 |
+
k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
|
73 |
+
]
|
74 |
+
# Non-FPN case
|
75 |
+
layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
|
76 |
+
|
77 |
+
# --------------------------------------------------------------------------
|
78 |
+
# RPN box transformation conv
|
79 |
+
# --------------------------------------------------------------------------
|
80 |
+
# FPN case (see note above about "fpn2")
|
81 |
+
layer_keys = [
|
82 |
+
k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
|
83 |
+
for k in layer_keys
|
84 |
+
]
|
85 |
+
layer_keys = [
|
86 |
+
k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
|
87 |
+
for k in layer_keys
|
88 |
+
]
|
89 |
+
# Non-FPN case
|
90 |
+
layer_keys = [
|
91 |
+
k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
|
92 |
+
]
|
93 |
+
layer_keys = [
|
94 |
+
k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
|
95 |
+
for k in layer_keys
|
96 |
+
]
|
97 |
+
|
98 |
+
# --------------------------------------------------------------------------
|
99 |
+
# Fast R-CNN box head
|
100 |
+
# --------------------------------------------------------------------------
|
101 |
+
layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
|
102 |
+
layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
|
103 |
+
layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
|
104 |
+
layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
|
105 |
+
# 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
|
106 |
+
layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
|
107 |
+
|
108 |
+
# --------------------------------------------------------------------------
|
109 |
+
# FPN lateral and output convolutions
|
110 |
+
# --------------------------------------------------------------------------
|
111 |
+
def fpn_map(name):
|
112 |
+
"""
|
113 |
+
Look for keys with the following patterns:
|
114 |
+
1) Starts with "fpn.inner."
|
115 |
+
Example: "fpn.inner.res2.2.sum.lateral.weight"
|
116 |
+
Meaning: These are lateral pathway convolutions
|
117 |
+
2) Starts with "fpn.res"
|
118 |
+
Example: "fpn.res2.2.sum.weight"
|
119 |
+
Meaning: These are FPN output convolutions
|
120 |
+
"""
|
121 |
+
splits = name.split(".")
|
122 |
+
norm = ".norm" if "norm" in splits else ""
|
123 |
+
if name.startswith("fpn.inner."):
|
124 |
+
# splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
|
125 |
+
stage = int(splits[2][len("res") :])
|
126 |
+
return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
|
127 |
+
elif name.startswith("fpn.res"):
|
128 |
+
# splits example: ['fpn', 'res2', '2', 'sum', 'weight']
|
129 |
+
stage = int(splits[1][len("res") :])
|
130 |
+
return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
|
131 |
+
return name
|
132 |
+
|
133 |
+
layer_keys = [fpn_map(k) for k in layer_keys]
|
134 |
+
|
135 |
+
# --------------------------------------------------------------------------
|
136 |
+
# Mask R-CNN mask head
|
137 |
+
# --------------------------------------------------------------------------
|
138 |
+
# roi_heads.StandardROIHeads case
|
139 |
+
layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
|
140 |
+
layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
|
141 |
+
layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
|
142 |
+
# roi_heads.Res5ROIHeads case
|
143 |
+
layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
|
144 |
+
|
145 |
+
# --------------------------------------------------------------------------
|
146 |
+
# Keypoint R-CNN head
|
147 |
+
# --------------------------------------------------------------------------
|
148 |
+
# interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
|
149 |
+
layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
|
150 |
+
layer_keys = [
|
151 |
+
k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
|
152 |
+
]
|
153 |
+
layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
|
154 |
+
|
155 |
+
# --------------------------------------------------------------------------
|
156 |
+
# Done with replacements
|
157 |
+
# --------------------------------------------------------------------------
|
158 |
+
assert len(set(layer_keys)) == len(layer_keys)
|
159 |
+
assert len(original_keys) == len(layer_keys)
|
160 |
+
|
161 |
+
new_weights = {}
|
162 |
+
new_keys_to_original_keys = {}
|
163 |
+
for orig, renamed in zip(original_keys, layer_keys):
|
164 |
+
new_keys_to_original_keys[renamed] = orig
|
165 |
+
if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
|
166 |
+
# remove the meaningless prediction weight for background class
|
167 |
+
new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
|
168 |
+
new_weights[renamed] = weights[orig][new_start_idx:]
|
169 |
+
logger.info(
|
170 |
+
"Remove prediction weight for background class in {}. The shape changes from "
|
171 |
+
"{} to {}.".format(
|
172 |
+
renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
|
173 |
+
)
|
174 |
+
)
|
175 |
+
elif renamed.startswith("cls_score."):
|
176 |
+
# move weights of bg class from original index 0 to last index
|
177 |
+
logger.info(
|
178 |
+
"Move classification weights for background class in {} from index 0 to "
|
179 |
+
"index {}.".format(renamed, weights[orig].shape[0] - 1)
|
180 |
+
)
|
181 |
+
new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
|
182 |
+
else:
|
183 |
+
new_weights[renamed] = weights[orig]
|
184 |
+
|
185 |
+
return new_weights, new_keys_to_original_keys, use_vit
|
186 |
+
|
187 |
+
|
188 |
+
# Note the current matching is not symmetric.
|
189 |
+
# it assumes model_state_dict will have longer names.
|
190 |
+
def align_and_update_state_dicts_for_CLIP(model_state_dict, ckpt_state_dict, c2_conversion=True, bb_rpn_weights=False, regionclip=False):
|
191 |
+
"""
|
192 |
+
Extended from ./c2_model_loading.py
|
193 |
+
Match names between the two state-dict, and returns a new chkpt_state_dict with names
|
194 |
+
converted to match model_state_dict with heuristics. The returned dict can be later
|
195 |
+
loaded with fvcore checkpointer.
|
196 |
+
If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
|
197 |
+
model and will be renamed at first.
|
198 |
+
|
199 |
+
Strategy: suppose that the models that we will create will have prefixes appended
|
200 |
+
to each of its keys, for example due to an extra level of nesting that the original
|
201 |
+
pre-trained weights from ImageNet won't contain. For example, model.state_dict()
|
202 |
+
might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
|
203 |
+
res2.conv1.weight. We thus want to match both parameters together.
|
204 |
+
For that, we look for each model weight, look among all loaded keys if there is one
|
205 |
+
that is a suffix of the current weight name, and use it if that's the case.
|
206 |
+
If multiple matches exist, take the one with longest size
|
207 |
+
of the corresponding name. For example, for the same model as before, the pretrained
|
208 |
+
weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
|
209 |
+
we want to match backbone[0].body.conv1.weight to conv1.weight, and
|
210 |
+
backbone[0].body.res2.conv1.weight to res2.conv1.weight.
|
211 |
+
"""
|
212 |
+
model_keys = sorted(model_state_dict.keys())
|
213 |
+
use_whole_clip = False # whether use the whole clip (text & visual encoders), typically in CLIPRCNN meta arch
|
214 |
+
add_backbone_prefix = False # convert to 'backbone.' prefix, typically in CLIPFastRCNN meta arch
|
215 |
+
use_fpn_arch = False # if use FPN arch then convert to `bottom_up`, typically in CLIPFastRCNN meta arch with FPN backbone
|
216 |
+
if bb_rpn_weights: # a 2nd pretrained weights to load, for offline backbone & RPN, then convert the ckpt key names and only keep the ones we need
|
217 |
+
new_ckpt_state_dict = {}
|
218 |
+
for original_k in ckpt_state_dict:
|
219 |
+
if 'backbone' in original_k:
|
220 |
+
new_key = original_k.replace('backbone', 'offline_backbone')
|
221 |
+
new_ckpt_state_dict[new_key] = ckpt_state_dict[original_k]
|
222 |
+
if 'proposal_generator' in original_k:
|
223 |
+
new_key = original_k.replace('proposal_generator', 'offline_proposal_generator')
|
224 |
+
new_ckpt_state_dict[new_key] = ckpt_state_dict[original_k]
|
225 |
+
new_ckpt_state_dict['ignore_others'] = torch.tensor([1]) # ignore other model weights (not 'offline_*') in batch_norm.py
|
226 |
+
ckpt_state_dict = new_ckpt_state_dict
|
227 |
+
else: # the 1st pretrained weigths to load
|
228 |
+
for model_key in model_keys: # if use the whole clip, then convert ckpt 'visual.' names to 'clip_backbone.visual.'
|
229 |
+
if 'clip_backbone' in model_key:
|
230 |
+
use_whole_clip = True
|
231 |
+
for model_key in model_keys: # if there are backbone & offline_backbone, then convert the ckpt 'visual.' names to 'backbone.' to avoid ambiguity
|
232 |
+
if 'offline_backbone' in model_key:
|
233 |
+
add_backbone_prefix = True
|
234 |
+
if 'fpn' in model_key:
|
235 |
+
use_fpn_arch = True
|
236 |
+
# original_keys: the name in the original dict (before renaming)
|
237 |
+
ckpt_state_dict, original_keys, use_vit = convert_clip_names(ckpt_state_dict, add_backbone_prefix, use_whole_clip, use_fpn_arch, regionclip)
|
238 |
+
ckpt_keys = sorted(ckpt_state_dict.keys())
|
239 |
+
|
240 |
+
def match(a, b):
|
241 |
+
# Matched ckpt_key should be a complete (starts with '.') suffix.
|
242 |
+
# For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
|
243 |
+
# but matches whatever_conv1 or mesh_head.whatever_conv1.
|
244 |
+
return a == b or a.endswith("." + b)
|
245 |
+
|
246 |
+
# get a matrix of string matches, where each (i, j) entry correspond to the size of the
|
247 |
+
# ckpt_key string, if it matches
|
248 |
+
match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
|
249 |
+
match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
|
250 |
+
# use the matched one with longest size in case of multiple matches
|
251 |
+
max_match_size, idxs = match_matrix.max(1)
|
252 |
+
# remove indices that correspond to no-match
|
253 |
+
idxs[max_match_size == 0] = -1
|
254 |
+
|
255 |
+
logger = logging.getLogger(__name__)
|
256 |
+
# matched_pairs (matched checkpoint key --> matched model key)
|
257 |
+
matched_keys = {}
|
258 |
+
result_state_dict = {}
|
259 |
+
for idx_model, idx_ckpt in enumerate(idxs.tolist()):
|
260 |
+
if idx_ckpt == -1:
|
261 |
+
continue
|
262 |
+
key_model = model_keys[idx_model]
|
263 |
+
key_ckpt = ckpt_keys[idx_ckpt]
|
264 |
+
value_ckpt = ckpt_state_dict[key_ckpt]
|
265 |
+
shape_in_model = model_state_dict[key_model].shape
|
266 |
+
|
267 |
+
if shape_in_model != value_ckpt.shape:
|
268 |
+
logger.warning(
|
269 |
+
"Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
|
270 |
+
key_ckpt, value_ckpt.shape, key_model, shape_in_model
|
271 |
+
)
|
272 |
+
)
|
273 |
+
logger.warning(
|
274 |
+
"{} will not be loaded. Please double check and see if this is desired.".format(
|
275 |
+
key_ckpt
|
276 |
+
)
|
277 |
+
)
|
278 |
+
continue
|
279 |
+
|
280 |
+
assert key_model not in result_state_dict
|
281 |
+
result_state_dict[key_model] = value_ckpt
|
282 |
+
if key_ckpt in matched_keys: # already added to matched_keys
|
283 |
+
logger.error(
|
284 |
+
"Ambiguity found for {} in checkpoint!"
|
285 |
+
"It matches at least two keys in the model ({} and {}).".format(
|
286 |
+
key_ckpt, key_model, matched_keys[key_ckpt]
|
287 |
+
)
|
288 |
+
)
|
289 |
+
raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
|
290 |
+
|
291 |
+
matched_keys[key_ckpt] = key_model
|
292 |
+
|
293 |
+
# logging:
|
294 |
+
matched_model_keys = sorted(matched_keys.values())
|
295 |
+
mmk_list = "The following model parameters are loaded from checkpoints:\n"
|
296 |
+
for mmk in matched_model_keys:
|
297 |
+
mmk_list += mmk + "\n"
|
298 |
+
if len(matched_model_keys) == 0:
|
299 |
+
logger.warning("No weights in checkpoint matched with model.")
|
300 |
+
return ckpt_state_dict
|
301 |
+
common_prefix = _longest_common_prefix(matched_model_keys)
|
302 |
+
rev_matched_keys = {v: k for k, v in matched_keys.items()}
|
303 |
+
original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
|
304 |
+
|
305 |
+
model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
|
306 |
+
table = []
|
307 |
+
memo = set()
|
308 |
+
for key_model in matched_model_keys:
|
309 |
+
if key_model in memo:
|
310 |
+
continue
|
311 |
+
if key_model in model_key_groups:
|
312 |
+
group = model_key_groups[key_model]
|
313 |
+
memo |= set(group)
|
314 |
+
shapes = [tuple(model_state_dict[k].shape) for k in group]
|
315 |
+
table.append(
|
316 |
+
(
|
317 |
+
_longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
|
318 |
+
_group_str([original_keys[k] for k in group]),
|
319 |
+
" ".join([str(x).replace(" ", "") for x in shapes]),
|
320 |
+
)
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
key_checkpoint = original_keys[key_model]
|
324 |
+
shape = str(tuple(model_state_dict[key_model].shape))
|
325 |
+
table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
|
326 |
+
table_str = tabulate(
|
327 |
+
table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"]
|
328 |
+
)
|
329 |
+
if len(table) != 1 and not use_vit: # do this for now; the table function has some bugs when the whole CLIP is loaded
|
330 |
+
logger.info(
|
331 |
+
"Following weights matched with "
|
332 |
+
+ (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
|
333 |
+
+ ":\n"
|
334 |
+
+ table_str
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
logger.info(mmk_list)
|
338 |
+
|
339 |
+
unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
|
340 |
+
for k in unmatched_ckpt_keys:
|
341 |
+
result_state_dict[k] = ckpt_state_dict[k]
|
342 |
+
return result_state_dict
|
343 |
+
|
344 |
+
|
345 |
+
def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
|
346 |
+
"""
|
347 |
+
Params in the same submodule are grouped together.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
keys: names of all parameters
|
351 |
+
original_names: mapping from parameter name to their name in the checkpoint
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
dict[name -> all other names in the same group]
|
355 |
+
"""
|
356 |
+
|
357 |
+
def _submodule_name(key):
|
358 |
+
pos = key.rfind(".")
|
359 |
+
if pos < 0:
|
360 |
+
return None
|
361 |
+
prefix = key[: pos + 1]
|
362 |
+
return prefix
|
363 |
+
|
364 |
+
all_submodules = [_submodule_name(k) for k in keys]
|
365 |
+
all_submodules = [x for x in all_submodules if x]
|
366 |
+
all_submodules = sorted(all_submodules, key=len)
|
367 |
+
|
368 |
+
ret = {}
|
369 |
+
for prefix in all_submodules:
|
370 |
+
group = [k for k in keys if k.startswith(prefix)]
|
371 |
+
if len(group) <= 1:
|
372 |
+
continue
|
373 |
+
original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
|
374 |
+
if len(original_name_lcp) == 0:
|
375 |
+
# don't group weights if original names don't share prefix
|
376 |
+
continue
|
377 |
+
|
378 |
+
for k in group:
|
379 |
+
if k in ret:
|
380 |
+
continue
|
381 |
+
ret[k] = group
|
382 |
+
return ret
|
383 |
+
|
384 |
+
|
385 |
+
def _longest_common_prefix(names: List[str]) -> str:
|
386 |
+
"""
|
387 |
+
["abc.zfg", "abc.zef"] -> "abc."
|
388 |
+
"""
|
389 |
+
names = [n.split(".") for n in names]
|
390 |
+
m1, m2 = min(names), max(names)
|
391 |
+
ret = [a for a, b in zip(m1, m2) if a == b]
|
392 |
+
ret = ".".join(ret) + "." if len(ret) else ""
|
393 |
+
return ret
|
394 |
+
|
395 |
+
|
396 |
+
def _longest_common_prefix_str(names: List[str]) -> str:
|
397 |
+
m1, m2 = min(names), max(names)
|
398 |
+
lcp = [a for a, b in zip(m1, m2) if a == b]
|
399 |
+
lcp = "".join(lcp)
|
400 |
+
return lcp
|
401 |
+
|
402 |
+
|
403 |
+
def _group_str(names: List[str]) -> str:
|
404 |
+
"""
|
405 |
+
Turn "common1", "common2", "common3" into "common{1,2,3}"
|
406 |
+
"""
|
407 |
+
lcp = _longest_common_prefix_str(names)
|
408 |
+
rest = [x[len(lcp) :] for x in names]
|
409 |
+
rest = "{" + ",".join(rest) + "}"
|
410 |
+
ret = lcp + rest
|
411 |
+
|
412 |
+
# add some simplification for BN specifically
|
413 |
+
ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
|
414 |
+
ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
|
415 |
+
return ret
|
detectron2/checkpoint/detection_checkpoint.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import torch
|
6 |
+
from fvcore.common.checkpoint import Checkpointer
|
7 |
+
from torch.nn.parallel import DistributedDataParallel
|
8 |
+
|
9 |
+
import detectron2.utils.comm as comm
|
10 |
+
from detectron2.utils.env import TORCH_VERSION
|
11 |
+
from detectron2.utils.file_io import PathManager
|
12 |
+
|
13 |
+
from .c2_model_loading import align_and_update_state_dicts
|
14 |
+
from .clip_model_loading import align_and_update_state_dicts_for_CLIP
|
15 |
+
|
16 |
+
class DetectionCheckpointer(Checkpointer):
|
17 |
+
"""
|
18 |
+
Same as :class:`Checkpointer`, but is able to:
|
19 |
+
1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models.
|
20 |
+
2. correctly load checkpoints that are only available on the master worker
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, model, save_dir="", *, save_to_disk=None, bb_rpn_weights=False, **checkpointables):
|
24 |
+
is_main_process = comm.is_main_process()
|
25 |
+
super().__init__(
|
26 |
+
model,
|
27 |
+
save_dir,
|
28 |
+
save_to_disk=is_main_process if save_to_disk is None else save_to_disk,
|
29 |
+
**checkpointables,
|
30 |
+
)
|
31 |
+
self.path_manager = PathManager
|
32 |
+
self.bb_rpn_weights = bb_rpn_weights
|
33 |
+
|
34 |
+
def load(self, path, *args, **kwargs):
|
35 |
+
need_sync = False
|
36 |
+
|
37 |
+
if path and isinstance(self.model, DistributedDataParallel):
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
path = self.path_manager.get_local_path(path)
|
40 |
+
has_file = os.path.isfile(path)
|
41 |
+
all_has_file = comm.all_gather(has_file)
|
42 |
+
if not all_has_file[0]:
|
43 |
+
raise OSError(f"File {path} not found on main worker.")
|
44 |
+
if not all(all_has_file):
|
45 |
+
logger.warning(
|
46 |
+
f"Not all workers can read checkpoint {path}. "
|
47 |
+
"Training may fail to fully resume."
|
48 |
+
)
|
49 |
+
# TODO: broadcast the checkpoint file contents from main
|
50 |
+
# worker, and load from it instead.
|
51 |
+
need_sync = True
|
52 |
+
if not has_file:
|
53 |
+
path = None # don't load if not readable
|
54 |
+
ret = super().load(path, *args, **kwargs)
|
55 |
+
|
56 |
+
if need_sync:
|
57 |
+
logger.info("Broadcasting model states from main worker ...")
|
58 |
+
if TORCH_VERSION >= (1, 7):
|
59 |
+
self.model._sync_params_and_buffers()
|
60 |
+
return ret
|
61 |
+
|
62 |
+
def _load_file(self, filename):
|
63 |
+
if filename.endswith(".pkl"):
|
64 |
+
with PathManager.open(filename, "rb") as f:
|
65 |
+
data = pickle.load(f, encoding="latin1")
|
66 |
+
if "model" in data and "__author__" in data:
|
67 |
+
# file is in Detectron2 model zoo format
|
68 |
+
self.logger.info("Reading a file from '{}'".format(data["__author__"]))
|
69 |
+
return data
|
70 |
+
else:
|
71 |
+
# assume file is from Caffe2 / Detectron1 model zoo
|
72 |
+
if "blobs" in data:
|
73 |
+
# Detection models have "blobs", but ImageNet models don't
|
74 |
+
data = data["blobs"]
|
75 |
+
data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
|
76 |
+
return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
|
77 |
+
elif filename.endswith(".pyth"):
|
78 |
+
# assume file is from pycls; no one else seems to use the ".pyth" extension
|
79 |
+
with PathManager.open(filename, "rb") as f:
|
80 |
+
data = torch.load(f)
|
81 |
+
assert (
|
82 |
+
"model_state" in data
|
83 |
+
), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'."
|
84 |
+
model_state = {
|
85 |
+
k: v
|
86 |
+
for k, v in data["model_state"].items()
|
87 |
+
if not k.endswith("num_batches_tracked")
|
88 |
+
}
|
89 |
+
return {"model": model_state, "__author__": "pycls", "matching_heuristics": True}
|
90 |
+
elif "OAI_CLIP" in filename:
|
91 |
+
# assume file is from OpenAI CLIP pre-trained model
|
92 |
+
loaded = super()._load_file(filename) # load native pth checkpoint
|
93 |
+
if "model" not in loaded:
|
94 |
+
loaded = {"model": loaded}
|
95 |
+
return {"model": loaded["model"], "__author__": "OAI_CLIP", "matching_heuristics": True}
|
96 |
+
|
97 |
+
loaded = super()._load_file(filename) # load native pth checkpoint
|
98 |
+
if "model" not in loaded:
|
99 |
+
loaded = {"model": loaded}
|
100 |
+
return loaded
|
101 |
+
|
102 |
+
def _load_model(self, checkpoint):
|
103 |
+
# if checkpoint.get("matching_heuristics", False) or self.bb_rpn_weights:
|
104 |
+
# self._convert_ndarray_to_tensor(checkpoint["model"])
|
105 |
+
# # convert weights by name-matching heuristics
|
106 |
+
# if checkpoint.get("__author__", "NA") == "OAI_CLIP" or self.bb_rpn_weights: # for OAI_CLIP or 2nd ckpt (offline modules)
|
107 |
+
# checkpoint["model"] = align_and_update_state_dicts_for_CLIP(
|
108 |
+
# self.model.state_dict(),
|
109 |
+
# checkpoint["model"],
|
110 |
+
# bb_rpn_weights=self.bb_rpn_weights,
|
111 |
+
# )
|
112 |
+
# else: # default loading
|
113 |
+
# checkpoint["model"] = align_and_update_state_dicts(
|
114 |
+
# self.model.state_dict(),
|
115 |
+
# checkpoint["model"],
|
116 |
+
# c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
|
117 |
+
# )
|
118 |
+
# for non-caffe2 models, use standard ways to load it
|
119 |
+
# if not self.bb_rpn_weights:
|
120 |
+
# checkpoint = {'model': {'backbone.' + key: val for key, val in checkpoint['model'].items()}}
|
121 |
+
incompatible = super()._load_model(checkpoint)
|
122 |
+
del checkpoint # try saving memory
|
123 |
+
|
124 |
+
model_buffers = dict(self.model.named_buffers(recurse=False))
|
125 |
+
for k in ["pixel_mean", "pixel_std"]:
|
126 |
+
# Ignore missing key message about pixel_mean/std.
|
127 |
+
# Though they may be missing in old checkpoints, they will be correctly
|
128 |
+
# initialized from config anyway.
|
129 |
+
if k in model_buffers:
|
130 |
+
try:
|
131 |
+
incompatible.missing_keys.remove(k)
|
132 |
+
except ValueError:
|
133 |
+
pass
|
134 |
+
return incompatible
|
detectron2/config/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .compat import downgrade_config, upgrade_config
|
3 |
+
from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
|
4 |
+
from .instantiate import instantiate
|
5 |
+
from .lazy import LazyCall, LazyConfig
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"CfgNode",
|
9 |
+
"get_cfg",
|
10 |
+
"global_cfg",
|
11 |
+
"set_global_cfg",
|
12 |
+
"downgrade_config",
|
13 |
+
"upgrade_config",
|
14 |
+
"configurable",
|
15 |
+
"instantiate",
|
16 |
+
"LazyCall",
|
17 |
+
"LazyConfig",
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
from detectron2.utils.env import fixup_module_metadata
|
22 |
+
|
23 |
+
fixup_module_metadata(__name__, globals(), __all__)
|
24 |
+
del fixup_module_metadata
|
detectron2/config/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (617 Bytes). View file
|
|
detectron2/config/__pycache__/compat.cpython-39.pyc
ADDED
Binary file (7.7 kB). View file
|
|
detectron2/config/__pycache__/config.cpython-39.pyc
ADDED
Binary file (7.57 kB). View file
|
|
detectron2/config/__pycache__/defaults.cpython-39.pyc
ADDED
Binary file (9.22 kB). View file
|
|
detectron2/config/__pycache__/instantiate.cpython-39.pyc
ADDED
Binary file (2.58 kB). View file
|
|
detectron2/config/__pycache__/lazy.cpython-39.pyc
ADDED
Binary file (11.5 kB). View file
|
|
detectron2/config/compat.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
"""
|
3 |
+
Backward compatibility of configs.
|
4 |
+
|
5 |
+
Instructions to bump version:
|
6 |
+
+ It's not needed to bump version if new keys are added.
|
7 |
+
It's only needed when backward-incompatible changes happen
|
8 |
+
(i.e., some existing keys disappear, or the meaning of a key changes)
|
9 |
+
+ To bump version, do the following:
|
10 |
+
1. Increment _C.VERSION in defaults.py
|
11 |
+
2. Add a converter in this file.
|
12 |
+
|
13 |
+
Each ConverterVX has a function "upgrade" which in-place upgrades config from X-1 to X,
|
14 |
+
and a function "downgrade" which in-place downgrades config from X to X-1
|
15 |
+
|
16 |
+
In each function, VERSION is left unchanged.
|
17 |
+
|
18 |
+
Each converter assumes that its input has the relevant keys
|
19 |
+
(i.e., the input is not a partial config).
|
20 |
+
3. Run the tests (test_config.py) to make sure the upgrade & downgrade
|
21 |
+
functions are consistent.
|
22 |
+
"""
|
23 |
+
|
24 |
+
import logging
|
25 |
+
from typing import List, Optional, Tuple
|
26 |
+
|
27 |
+
from .config import CfgNode as CN
|
28 |
+
from .defaults import _C
|
29 |
+
|
30 |
+
__all__ = ["upgrade_config", "downgrade_config"]
|
31 |
+
|
32 |
+
|
33 |
+
def upgrade_config(cfg: CN, to_version: Optional[int] = None) -> CN:
|
34 |
+
"""
|
35 |
+
Upgrade a config from its current version to a newer version.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
cfg (CfgNode):
|
39 |
+
to_version (int): defaults to the latest version.
|
40 |
+
"""
|
41 |
+
cfg = cfg.clone()
|
42 |
+
if to_version is None:
|
43 |
+
to_version = _C.VERSION
|
44 |
+
|
45 |
+
assert cfg.VERSION <= to_version, "Cannot upgrade from v{} to v{}!".format(
|
46 |
+
cfg.VERSION, to_version
|
47 |
+
)
|
48 |
+
for k in range(cfg.VERSION, to_version):
|
49 |
+
converter = globals()["ConverterV" + str(k + 1)]
|
50 |
+
converter.upgrade(cfg)
|
51 |
+
cfg.VERSION = k + 1
|
52 |
+
return cfg
|
53 |
+
|
54 |
+
|
55 |
+
def downgrade_config(cfg: CN, to_version: int) -> CN:
|
56 |
+
"""
|
57 |
+
Downgrade a config from its current version to an older version.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
cfg (CfgNode):
|
61 |
+
to_version (int):
|
62 |
+
|
63 |
+
Note:
|
64 |
+
A general downgrade of arbitrary configs is not always possible due to the
|
65 |
+
different functionalities in different versions.
|
66 |
+
The purpose of downgrade is only to recover the defaults in old versions,
|
67 |
+
allowing it to load an old partial yaml config.
|
68 |
+
Therefore, the implementation only needs to fill in the default values
|
69 |
+
in the old version when a general downgrade is not possible.
|
70 |
+
"""
|
71 |
+
cfg = cfg.clone()
|
72 |
+
assert cfg.VERSION >= to_version, "Cannot downgrade from v{} to v{}!".format(
|
73 |
+
cfg.VERSION, to_version
|
74 |
+
)
|
75 |
+
for k in range(cfg.VERSION, to_version, -1):
|
76 |
+
converter = globals()["ConverterV" + str(k)]
|
77 |
+
converter.downgrade(cfg)
|
78 |
+
cfg.VERSION = k - 1
|
79 |
+
return cfg
|
80 |
+
|
81 |
+
|
82 |
+
def guess_version(cfg: CN, filename: str) -> int:
|
83 |
+
"""
|
84 |
+
Guess the version of a partial config where the VERSION field is not specified.
|
85 |
+
Returns the version, or the latest if cannot make a guess.
|
86 |
+
|
87 |
+
This makes it easier for users to migrate.
|
88 |
+
"""
|
89 |
+
logger = logging.getLogger(__name__)
|
90 |
+
|
91 |
+
def _has(name: str) -> bool:
|
92 |
+
cur = cfg
|
93 |
+
for n in name.split("."):
|
94 |
+
if n not in cur:
|
95 |
+
return False
|
96 |
+
cur = cur[n]
|
97 |
+
return True
|
98 |
+
|
99 |
+
# Most users' partial configs have "MODEL.WEIGHT", so guess on it
|
100 |
+
ret = None
|
101 |
+
if _has("MODEL.WEIGHT") or _has("TEST.AUG_ON"):
|
102 |
+
ret = 1
|
103 |
+
|
104 |
+
if ret is not None:
|
105 |
+
logger.warning("Config '{}' has no VERSION. Assuming it to be v{}.".format(filename, ret))
|
106 |
+
else:
|
107 |
+
ret = _C.VERSION
|
108 |
+
logger.warning(
|
109 |
+
"Config '{}' has no VERSION. Assuming it to be compatible with latest v{}.".format(
|
110 |
+
filename, ret
|
111 |
+
)
|
112 |
+
)
|
113 |
+
return ret
|
114 |
+
|
115 |
+
|
116 |
+
def _rename(cfg: CN, old: str, new: str) -> None:
|
117 |
+
old_keys = old.split(".")
|
118 |
+
new_keys = new.split(".")
|
119 |
+
|
120 |
+
def _set(key_seq: List[str], val: str) -> None:
|
121 |
+
cur = cfg
|
122 |
+
for k in key_seq[:-1]:
|
123 |
+
if k not in cur:
|
124 |
+
cur[k] = CN()
|
125 |
+
cur = cur[k]
|
126 |
+
cur[key_seq[-1]] = val
|
127 |
+
|
128 |
+
def _get(key_seq: List[str]) -> CN:
|
129 |
+
cur = cfg
|
130 |
+
for k in key_seq:
|
131 |
+
cur = cur[k]
|
132 |
+
return cur
|
133 |
+
|
134 |
+
def _del(key_seq: List[str]) -> None:
|
135 |
+
cur = cfg
|
136 |
+
for k in key_seq[:-1]:
|
137 |
+
cur = cur[k]
|
138 |
+
del cur[key_seq[-1]]
|
139 |
+
if len(cur) == 0 and len(key_seq) > 1:
|
140 |
+
_del(key_seq[:-1])
|
141 |
+
|
142 |
+
_set(new_keys, _get(old_keys))
|
143 |
+
_del(old_keys)
|
144 |
+
|
145 |
+
|
146 |
+
class _RenameConverter:
|
147 |
+
"""
|
148 |
+
A converter that handles simple rename.
|
149 |
+
"""
|
150 |
+
|
151 |
+
RENAME: List[Tuple[str, str]] = [] # list of tuples of (old name, new name)
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
def upgrade(cls, cfg: CN) -> None:
|
155 |
+
for old, new in cls.RENAME:
|
156 |
+
_rename(cfg, old, new)
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def downgrade(cls, cfg: CN) -> None:
|
160 |
+
for old, new in cls.RENAME[::-1]:
|
161 |
+
_rename(cfg, new, old)
|
162 |
+
|
163 |
+
|
164 |
+
class ConverterV1(_RenameConverter):
|
165 |
+
RENAME = [("MODEL.RPN_HEAD.NAME", "MODEL.RPN.HEAD_NAME")]
|
166 |
+
|
167 |
+
|
168 |
+
class ConverterV2(_RenameConverter):
|
169 |
+
"""
|
170 |
+
A large bulk of rename, before public release.
|
171 |
+
"""
|
172 |
+
|
173 |
+
RENAME = [
|
174 |
+
("MODEL.WEIGHT", "MODEL.WEIGHTS"),
|
175 |
+
("MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT"),
|
176 |
+
("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"),
|
177 |
+
("MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT"),
|
178 |
+
("MODEL.PANOPTIC_FPN.COMBINE_ON", "MODEL.PANOPTIC_FPN.COMBINE.ENABLED"),
|
179 |
+
(
|
180 |
+
"MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD",
|
181 |
+
"MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH",
|
182 |
+
),
|
183 |
+
(
|
184 |
+
"MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT",
|
185 |
+
"MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT",
|
186 |
+
),
|
187 |
+
(
|
188 |
+
"MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD",
|
189 |
+
"MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH",
|
190 |
+
),
|
191 |
+
("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"),
|
192 |
+
("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"),
|
193 |
+
("MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", "MODEL.RETINANET.SCORE_THRESH_TEST"),
|
194 |
+
("MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", "MODEL.RETINANET.TOPK_CANDIDATES_TEST"),
|
195 |
+
("MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", "MODEL.RETINANET.NMS_THRESH_TEST"),
|
196 |
+
("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"),
|
197 |
+
("TEST.AUG_ON", "TEST.AUG.ENABLED"),
|
198 |
+
("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"),
|
199 |
+
("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"),
|
200 |
+
("TEST.AUG_FLIP", "TEST.AUG.FLIP"),
|
201 |
+
]
|
202 |
+
|
203 |
+
@classmethod
|
204 |
+
def upgrade(cls, cfg: CN) -> None:
|
205 |
+
super().upgrade(cfg)
|
206 |
+
|
207 |
+
if cfg.MODEL.META_ARCHITECTURE == "RetinaNet":
|
208 |
+
_rename(
|
209 |
+
cfg, "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS"
|
210 |
+
)
|
211 |
+
_rename(cfg, "MODEL.RETINANET.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
|
212 |
+
del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"]
|
213 |
+
del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"]
|
214 |
+
else:
|
215 |
+
_rename(cfg, "MODEL.RPN.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS")
|
216 |
+
_rename(cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
|
217 |
+
del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"]
|
218 |
+
del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"]
|
219 |
+
del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"]
|
220 |
+
|
221 |
+
@classmethod
|
222 |
+
def downgrade(cls, cfg: CN) -> None:
|
223 |
+
super().downgrade(cfg)
|
224 |
+
|
225 |
+
_rename(cfg, "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", "MODEL.RPN.ANCHOR_ASPECT_RATIOS")
|
226 |
+
_rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES")
|
227 |
+
cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS
|
228 |
+
cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES
|
229 |
+
cfg.MODEL.RETINANET.ANCHOR_STRIDES = [] # this is not used anywhere in any version
|
detectron2/config/config.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import functools
|
5 |
+
import inspect
|
6 |
+
import logging
|
7 |
+
from fvcore.common.config import CfgNode as _CfgNode
|
8 |
+
|
9 |
+
from detectron2.utils.file_io import PathManager
|
10 |
+
|
11 |
+
|
12 |
+
class CfgNode(_CfgNode):
|
13 |
+
"""
|
14 |
+
The same as `fvcore.common.config.CfgNode`, but different in:
|
15 |
+
|
16 |
+
1. Use unsafe yaml loading by default.
|
17 |
+
Note that this may lead to arbitrary code execution: you must not
|
18 |
+
load a config file from untrusted sources before manually inspecting
|
19 |
+
the content of the file.
|
20 |
+
2. Support config versioning.
|
21 |
+
When attempting to merge an old config, it will convert the old config automatically.
|
22 |
+
"""
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def _open_cfg(cls, filename):
|
26 |
+
return PathManager.open(filename, "r")
|
27 |
+
|
28 |
+
# Note that the default value of allow_unsafe is changed to True
|
29 |
+
def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
|
30 |
+
assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!"
|
31 |
+
loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
|
32 |
+
loaded_cfg = type(self)(loaded_cfg)
|
33 |
+
|
34 |
+
# defaults.py needs to import CfgNode
|
35 |
+
from .defaults import _C
|
36 |
+
|
37 |
+
latest_ver = _C.VERSION
|
38 |
+
assert (
|
39 |
+
latest_ver == self.VERSION
|
40 |
+
), "CfgNode.merge_from_file is only allowed on a config object of latest version!"
|
41 |
+
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
|
44 |
+
loaded_ver = loaded_cfg.get("VERSION", None)
|
45 |
+
if loaded_ver is None:
|
46 |
+
from .compat import guess_version
|
47 |
+
|
48 |
+
loaded_ver = guess_version(loaded_cfg, cfg_filename)
|
49 |
+
assert loaded_ver <= self.VERSION, "Cannot merge a v{} config into a v{} config.".format(
|
50 |
+
loaded_ver, self.VERSION
|
51 |
+
)
|
52 |
+
|
53 |
+
if loaded_ver == self.VERSION:
|
54 |
+
self.merge_from_other_cfg(loaded_cfg)
|
55 |
+
else:
|
56 |
+
# compat.py needs to import CfgNode
|
57 |
+
from .compat import upgrade_config, downgrade_config
|
58 |
+
|
59 |
+
logger.warning(
|
60 |
+
"Loading an old v{} config file '{}' by automatically upgrading to v{}. "
|
61 |
+
"See docs/CHANGELOG.md for instructions to update your files.".format(
|
62 |
+
loaded_ver, cfg_filename, self.VERSION
|
63 |
+
)
|
64 |
+
)
|
65 |
+
# To convert, first obtain a full config at an old version
|
66 |
+
old_self = downgrade_config(self, to_version=loaded_ver)
|
67 |
+
old_self.merge_from_other_cfg(loaded_cfg)
|
68 |
+
new_config = upgrade_config(old_self)
|
69 |
+
self.clear()
|
70 |
+
self.update(new_config)
|
71 |
+
|
72 |
+
def dump(self, *args, **kwargs):
|
73 |
+
"""
|
74 |
+
Returns:
|
75 |
+
str: a yaml string representation of the config
|
76 |
+
"""
|
77 |
+
# to make it show up in docs
|
78 |
+
return super().dump(*args, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
global_cfg = CfgNode()
|
82 |
+
|
83 |
+
|
84 |
+
def get_cfg() -> CfgNode:
|
85 |
+
"""
|
86 |
+
Get a copy of the default config.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
a detectron2 CfgNode instance.
|
90 |
+
"""
|
91 |
+
from .defaults import _C
|
92 |
+
|
93 |
+
return _C.clone()
|
94 |
+
|
95 |
+
|
96 |
+
def set_global_cfg(cfg: CfgNode) -> None:
|
97 |
+
"""
|
98 |
+
Let the global config point to the given cfg.
|
99 |
+
|
100 |
+
Assume that the given "cfg" has the key "KEY", after calling
|
101 |
+
`set_global_cfg(cfg)`, the key can be accessed by:
|
102 |
+
::
|
103 |
+
from detectron2.config import global_cfg
|
104 |
+
print(global_cfg.KEY)
|
105 |
+
|
106 |
+
By using a hacky global config, you can access these configs anywhere,
|
107 |
+
without having to pass the config object or the values deep into the code.
|
108 |
+
This is a hacky feature introduced for quick prototyping / research exploration.
|
109 |
+
"""
|
110 |
+
global global_cfg
|
111 |
+
global_cfg.clear()
|
112 |
+
global_cfg.update(cfg)
|
113 |
+
|
114 |
+
|
115 |
+
def configurable(init_func=None, *, from_config=None):
|
116 |
+
"""
|
117 |
+
Decorate a function or a class's __init__ method so that it can be called
|
118 |
+
with a :class:`CfgNode` object using a :func:`from_config` function that translates
|
119 |
+
:class:`CfgNode` to arguments.
|
120 |
+
|
121 |
+
Examples:
|
122 |
+
::
|
123 |
+
# Usage 1: Decorator on __init__:
|
124 |
+
class A:
|
125 |
+
@configurable
|
126 |
+
def __init__(self, a, b=2, c=3):
|
127 |
+
pass
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def from_config(cls, cfg): # 'cfg' must be the first argument
|
131 |
+
# Returns kwargs to be passed to __init__
|
132 |
+
return {"a": cfg.A, "b": cfg.B}
|
133 |
+
|
134 |
+
a1 = A(a=1, b=2) # regular construction
|
135 |
+
a2 = A(cfg) # construct with a cfg
|
136 |
+
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
|
137 |
+
|
138 |
+
# Usage 2: Decorator on any function. Needs an extra from_config argument:
|
139 |
+
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
|
140 |
+
def a_func(a, b=2, c=3):
|
141 |
+
pass
|
142 |
+
|
143 |
+
a1 = a_func(a=1, b=2) # regular call
|
144 |
+
a2 = a_func(cfg) # call with a cfg
|
145 |
+
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
|
146 |
+
|
147 |
+
Args:
|
148 |
+
init_func (callable): a class's ``__init__`` method in usage 1. The
|
149 |
+
class must have a ``from_config`` classmethod which takes `cfg` as
|
150 |
+
the first argument.
|
151 |
+
from_config (callable): the from_config function in usage 2. It must take `cfg`
|
152 |
+
as its first argument.
|
153 |
+
"""
|
154 |
+
|
155 |
+
if init_func is not None:
|
156 |
+
assert (
|
157 |
+
inspect.isfunction(init_func)
|
158 |
+
and from_config is None
|
159 |
+
and init_func.__name__ == "__init__"
|
160 |
+
), "Incorrect use of @configurable. Check API documentation for examples."
|
161 |
+
|
162 |
+
@functools.wraps(init_func)
|
163 |
+
def wrapped(self, *args, **kwargs):
|
164 |
+
try:
|
165 |
+
from_config_func = type(self).from_config
|
166 |
+
except AttributeError as e:
|
167 |
+
raise AttributeError(
|
168 |
+
"Class with @configurable must have a 'from_config' classmethod."
|
169 |
+
) from e
|
170 |
+
if not inspect.ismethod(from_config_func):
|
171 |
+
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
|
172 |
+
|
173 |
+
if _called_with_cfg(*args, **kwargs):
|
174 |
+
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
|
175 |
+
init_func(self, **explicit_args)
|
176 |
+
else:
|
177 |
+
init_func(self, *args, **kwargs)
|
178 |
+
|
179 |
+
return wrapped
|
180 |
+
|
181 |
+
else:
|
182 |
+
if from_config is None:
|
183 |
+
return configurable # @configurable() is made equivalent to @configurable
|
184 |
+
assert inspect.isfunction(
|
185 |
+
from_config
|
186 |
+
), "from_config argument of configurable must be a function!"
|
187 |
+
|
188 |
+
def wrapper(orig_func):
|
189 |
+
@functools.wraps(orig_func)
|
190 |
+
def wrapped(*args, **kwargs):
|
191 |
+
if _called_with_cfg(*args, **kwargs):
|
192 |
+
explicit_args = _get_args_from_config(from_config, *args, **kwargs)
|
193 |
+
return orig_func(**explicit_args)
|
194 |
+
else:
|
195 |
+
return orig_func(*args, **kwargs)
|
196 |
+
|
197 |
+
return wrapped
|
198 |
+
|
199 |
+
return wrapper
|
200 |
+
|
201 |
+
|
202 |
+
def _get_args_from_config(from_config_func, *args, **kwargs):
|
203 |
+
"""
|
204 |
+
Use `from_config` to obtain explicit arguments.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
dict: arguments to be used for cls.__init__
|
208 |
+
"""
|
209 |
+
signature = inspect.signature(from_config_func)
|
210 |
+
if list(signature.parameters.keys())[0] != "cfg":
|
211 |
+
if inspect.isfunction(from_config_func):
|
212 |
+
name = from_config_func.__name__
|
213 |
+
else:
|
214 |
+
name = f"{from_config_func.__self__}.from_config"
|
215 |
+
raise TypeError(f"{name} must take 'cfg' as the first argument!")
|
216 |
+
support_var_arg = any(
|
217 |
+
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
|
218 |
+
for param in signature.parameters.values()
|
219 |
+
)
|
220 |
+
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
|
221 |
+
ret = from_config_func(*args, **kwargs)
|
222 |
+
else:
|
223 |
+
# forward supported arguments to from_config
|
224 |
+
supported_arg_names = set(signature.parameters.keys())
|
225 |
+
extra_kwargs = {}
|
226 |
+
for name in list(kwargs.keys()):
|
227 |
+
if name not in supported_arg_names:
|
228 |
+
extra_kwargs[name] = kwargs.pop(name)
|
229 |
+
ret = from_config_func(*args, **kwargs)
|
230 |
+
# forward the other arguments to __init__
|
231 |
+
ret.update(extra_kwargs)
|
232 |
+
return ret
|
233 |
+
|
234 |
+
|
235 |
+
def _called_with_cfg(*args, **kwargs):
|
236 |
+
"""
|
237 |
+
Returns:
|
238 |
+
bool: whether the arguments contain CfgNode and should be considered
|
239 |
+
forwarded to from_config.
|
240 |
+
"""
|
241 |
+
from omegaconf import DictConfig
|
242 |
+
|
243 |
+
if len(args) and isinstance(args[0], (_CfgNode, DictConfig)):
|
244 |
+
return True
|
245 |
+
if isinstance(kwargs.pop("cfg", None), (_CfgNode, DictConfig)):
|
246 |
+
return True
|
247 |
+
# `from_config`'s first argument is forced to be "cfg".
|
248 |
+
# So the above check covers all cases.
|
249 |
+
return False
|
detectron2/config/defaults.py
ADDED
@@ -0,0 +1,786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .config import CfgNode as CN
|
3 |
+
|
4 |
+
# -----------------------------------------------------------------------------
|
5 |
+
# Convention about Training / Test specific parameters
|
6 |
+
# -----------------------------------------------------------------------------
|
7 |
+
# Whenever an argument can be either used for training or for testing, the
|
8 |
+
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
|
9 |
+
# or _TEST for a test-specific parameter.
|
10 |
+
# For example, the number of images during training will be
|
11 |
+
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
|
12 |
+
# IMAGES_PER_BATCH_TEST
|
13 |
+
|
14 |
+
# -----------------------------------------------------------------------------
|
15 |
+
# Config definition
|
16 |
+
# -----------------------------------------------------------------------------
|
17 |
+
|
18 |
+
_C = CN()
|
19 |
+
|
20 |
+
# The version number, to upgrade from old configs to new ones if any
|
21 |
+
# changes happen. It's recommended to keep a VERSION in your config file.
|
22 |
+
_C.VERSION = 2
|
23 |
+
|
24 |
+
_C.MODEL = CN()
|
25 |
+
_C.MODEL.LOAD_PROPOSALS = False
|
26 |
+
_C.MODEL.MASK_ON = False
|
27 |
+
_C.MODEL.KEYPOINT_ON = False
|
28 |
+
_C.MODEL.DEVICE = "cuda"
|
29 |
+
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
|
30 |
+
|
31 |
+
# Path (a file path, or URL like detectron2://.., https://..) to a checkpoint file
|
32 |
+
# to be loaded to the model. You can find available models in the model zoo.
|
33 |
+
_C.MODEL.WEIGHTS = ""
|
34 |
+
|
35 |
+
# Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR).
|
36 |
+
# To train on images of different number of channels, just set different mean & std.
|
37 |
+
# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
|
38 |
+
_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
|
39 |
+
# When using pre-trained models in Detectron1 or any MSRA models,
|
40 |
+
# std has been absorbed into its conv1 weights, so the std needs to be set 1.
|
41 |
+
# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
|
42 |
+
_C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0]
|
43 |
+
|
44 |
+
|
45 |
+
# -----------------------------------------------------------------------------
|
46 |
+
# INPUT
|
47 |
+
# -----------------------------------------------------------------------------
|
48 |
+
_C.INPUT = CN()
|
49 |
+
# Size of the smallest side of the image during training
|
50 |
+
_C.INPUT.MIN_SIZE_TRAIN = (800,)
|
51 |
+
# Sample size of smallest side by choice or random selection from range give by
|
52 |
+
# INPUT.MIN_SIZE_TRAIN
|
53 |
+
_C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
|
54 |
+
# Maximum size of the side of the image during training
|
55 |
+
_C.INPUT.MAX_SIZE_TRAIN = 1333
|
56 |
+
# Size of the smallest side of the image during testing. Set to zero to disable resize in testing.
|
57 |
+
_C.INPUT.MIN_SIZE_TEST = 800
|
58 |
+
# Maximum size of the side of the image during testing
|
59 |
+
_C.INPUT.MAX_SIZE_TEST = 1333
|
60 |
+
# Mode for flipping images used in data augmentation during training
|
61 |
+
# choose one of ["horizontal, "vertical", "none"]
|
62 |
+
_C.INPUT.RANDOM_FLIP = "horizontal"
|
63 |
+
|
64 |
+
# `True` if cropping is used for data augmentation during training
|
65 |
+
_C.INPUT.CROP = CN({"ENABLED": False})
|
66 |
+
# Cropping type. See documentation of `detectron2.data.transforms.RandomCrop` for explanation.
|
67 |
+
_C.INPUT.CROP.TYPE = "relative_range"
|
68 |
+
# Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of
|
69 |
+
# pixels if CROP.TYPE is "absolute"
|
70 |
+
_C.INPUT.CROP.SIZE = [0.9, 0.9]
|
71 |
+
|
72 |
+
|
73 |
+
# Whether the model needs RGB, YUV, HSV etc.
|
74 |
+
# Should be one of the modes defined here, as we use PIL to read the image:
|
75 |
+
# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
|
76 |
+
# with BGR being the one exception. One can set image format to BGR, we will
|
77 |
+
# internally use RGB for conversion and flip the channels over
|
78 |
+
_C.INPUT.FORMAT = "BGR"
|
79 |
+
# The ground truth mask format that the model will use.
|
80 |
+
# Mask R-CNN supports either "polygon" or "bitmask" as ground truth.
|
81 |
+
_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask"
|
82 |
+
|
83 |
+
################### Text Tokenizer from MSR-CLIP ##################
|
84 |
+
_C.INPUT.TEXT_TOKENIZER = "openai_bpe" # "bert-base-cased"
|
85 |
+
|
86 |
+
################## Data Augmentation from MSR-CLIP ##################
|
87 |
+
_C.AUG = CN()
|
88 |
+
_C.AUG.SCALE = (0.08, 1.0)
|
89 |
+
_C.AUG.RATIO = (3.0/4.0, 4.0/3.0)
|
90 |
+
_C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0]
|
91 |
+
_C.AUG.GRAY_SCALE = 0.0
|
92 |
+
_C.AUG.GAUSSIAN_BLUR = 0.0
|
93 |
+
_C.AUG.DROPBLOCK_LAYERS = [3, 4]
|
94 |
+
_C.AUG.DROPBLOCK_KEEP_PROB = 1.0
|
95 |
+
_C.AUG.DROPBLOCK_BLOCK_SIZE = 7
|
96 |
+
_C.AUG.MIXUP_PROB = 0.0
|
97 |
+
_C.AUG.MIXUP = 0.0
|
98 |
+
_C.AUG.MIXCUT = 0.0
|
99 |
+
_C.AUG.MIXCUT_MINMAX = []
|
100 |
+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
101 |
+
_C.AUG.MIXUP_MODE = 'batch'
|
102 |
+
_C.AUG.MIXCUT_AND_MIXUP = False
|
103 |
+
_C.AUG.INTERPOLATION = 3
|
104 |
+
_C.AUG.USE_TIMM = False
|
105 |
+
_C.AUG.TIMM_AUG = CN(new_allowed=True)
|
106 |
+
_C.AUG.TIMM_AUG.USE_LOADER = False
|
107 |
+
_C.AUG.TIMM_AUG.USE_TRANSFORM = False
|
108 |
+
|
109 |
+
_C.AUG.TRAIN = CN()
|
110 |
+
_C.AUG.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
|
111 |
+
_C.AUG.TRAIN.MAX_SIZE = None # the maximum size for longer edge after resizing
|
112 |
+
_C.AUG.TEST = CN()
|
113 |
+
_C.AUG.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
|
114 |
+
_C.AUG.TEST.MAX_SIZE = None # the maximum size for longer edge after resizing
|
115 |
+
_C.AUG.TEST.CENTER_CROP = False
|
116 |
+
_C.AUG.TEST.INTERPOLATION = 3
|
117 |
+
|
118 |
+
|
119 |
+
# -----------------------------------------------------------------------------
|
120 |
+
# Dataset
|
121 |
+
# -----------------------------------------------------------------------------
|
122 |
+
_C.DATASETS = CN()
|
123 |
+
# List of the dataset names for training. Must be registered in DatasetCatalog
|
124 |
+
# Samples from these datasets will be merged and used as one dataset.
|
125 |
+
_C.DATASETS.TRAIN = ()
|
126 |
+
# List of the pre-computed proposal files for training, which must be consistent
|
127 |
+
# with datasets listed in DATASETS.TRAIN.
|
128 |
+
_C.DATASETS.PROPOSAL_FILES_TRAIN = ()
|
129 |
+
# Number of top scoring precomputed proposals to keep for training
|
130 |
+
_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000
|
131 |
+
# List of the dataset names for testing. Must be registered in DatasetCatalog
|
132 |
+
_C.DATASETS.TEST = ()
|
133 |
+
# List of the pre-computed proposal files for test, which must be consistent
|
134 |
+
# with datasets listed in DATASETS.TEST.
|
135 |
+
_C.DATASETS.PROPOSAL_FILES_TEST = ()
|
136 |
+
# Number of top scoring precomputed proposals to keep for test
|
137 |
+
_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000
|
138 |
+
################## Data Loading from MSR-CLIP ##################
|
139 |
+
# List of dataset class names for training
|
140 |
+
_C.DATASETS.FACTORY_TRAIN = ()
|
141 |
+
# List of dataset folder for training
|
142 |
+
_C.DATASETS.PATH_TRAIN = ()
|
143 |
+
# List of the dataset names for auxilary training, as present in paths_catalog.py
|
144 |
+
_C.DATASETS.AUX = ()
|
145 |
+
# List of dataset class names for auxilary training
|
146 |
+
_C.DATASETS.FACTORY_AUX = ()
|
147 |
+
# List of dataset folder for auxilary training
|
148 |
+
_C.DATASETS.PATH_AUX = ()
|
149 |
+
# List of dataset class names for testing
|
150 |
+
_C.DATASETS.FACTORY_TEST = ()
|
151 |
+
# List of dataset folder for testing
|
152 |
+
_C.DATASETS.PATH_TEST = ()
|
153 |
+
# Labelmap file to convert to tsv or for demo purpose
|
154 |
+
_C.DATASETS.LABELMAP_FILE = ''
|
155 |
+
_C.DATASETS.ATTR_LABELMAP_FILE = ''
|
156 |
+
_C.DATASETS.FILTERED_CLASSIFICATION_DATASETS = ''
|
157 |
+
# hierarchy file for test time score aggregation (developed on OpenImages)
|
158 |
+
_C.DATASETS.HIERARCHY_FILE = ''
|
159 |
+
# List of box extra fields for training/testing
|
160 |
+
# If given, will not infer from the other cfgs.
|
161 |
+
_C.DATASETS.BOX_EXTRA_FIELDS = ()
|
162 |
+
|
163 |
+
_C.DATASETS.NUM_CLASSES = 0
|
164 |
+
_C.DATASETS.ROOT = ''
|
165 |
+
_C.DATASETS.TRAIN_SET = 'train'
|
166 |
+
_C.DATASETS.VAL_SET = ''
|
167 |
+
_C.DATASETS.TEST_SET = 'val'
|
168 |
+
|
169 |
+
# The maximum total input sequence length after WordPiece tokenization
|
170 |
+
# Sequences longer than this will be truncated, and sequences shorter than this will be padded.
|
171 |
+
_C.DATASETS.MAX_SEQ_LENGTH = 35
|
172 |
+
|
173 |
+
# -----------------------------------------------------------------------------
|
174 |
+
# DataLoader
|
175 |
+
# -----------------------------------------------------------------------------
|
176 |
+
_C.DATALOADER = CN()
|
177 |
+
# Number of data loading threads
|
178 |
+
_C.DATALOADER.NUM_WORKERS = 4
|
179 |
+
# If True, each batch should contain only images for which the aspect ratio
|
180 |
+
# is compatible. This groups portrait images together, and landscape images
|
181 |
+
# are not batched with portrait images.
|
182 |
+
_C.DATALOADER.ASPECT_RATIO_GROUPING = True
|
183 |
+
# Options: TrainingSampler, RepeatFactorTrainingSampler
|
184 |
+
_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
|
185 |
+
# Repeat threshold for RepeatFactorTrainingSampler
|
186 |
+
_C.DATALOADER.REPEAT_THRESHOLD = 0.0
|
187 |
+
# Tf True, when working on datasets that have instance annotations, the
|
188 |
+
# training dataloader will filter out images without associated annotations
|
189 |
+
_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True
|
190 |
+
|
191 |
+
# ---------------------------------------------------------------------------- #
|
192 |
+
# CLIP options
|
193 |
+
# ---------------------------------------------------------------------------- #
|
194 |
+
_C.MODEL.CLIP = CN()
|
195 |
+
|
196 |
+
_C.MODEL.CLIP.CROP_REGION_TYPE = "" # options: "GT", "RPN"
|
197 |
+
_C.MODEL.CLIP.BB_RPN_WEIGHTS = None # the weights of pretrained MaskRCNN
|
198 |
+
_C.MODEL.CLIP.IMS_PER_BATCH_TEST = 8 # the #images during inference per batch
|
199 |
+
|
200 |
+
_C.MODEL.CLIP.USE_TEXT_EMB_CLASSIFIER = False # if True, use the CLIP text embedding as the classifier's weights
|
201 |
+
_C.MODEL.CLIP.TEXT_EMB_PATH = None # "/mnt/output_storage/trained_models/lvis_cls_emb/lvis_1203_cls_emb.pth"
|
202 |
+
_C.MODEL.CLIP.OFFLINE_RPN_CONFIG = None # option: all configs of pretrained RPN
|
203 |
+
_C.MODEL.CLIP.NO_BOX_DELTA = False # if True, during inference, no box delta will be applied to region proposals
|
204 |
+
|
205 |
+
_C.MODEL.CLIP.BG_CLS_LOSS_WEIGHT = None # if not None, it is the loss weight for bg regions
|
206 |
+
_C.MODEL.CLIP.ONLY_SAMPLE_FG_PROPOSALS = False # if True, during training, ignore all bg proposals and only sample fg proposals
|
207 |
+
_C.MODEL.CLIP.MULTIPLY_RPN_SCORE = False # if True, during inference, multiply RPN scores with classification scores
|
208 |
+
|
209 |
+
_C.MODEL.CLIP.OPENSET_TEST_NUM_CLASSES = None # if an integer, it is #all_cls in test
|
210 |
+
_C.MODEL.CLIP.OPENSET_TEST_TEXT_EMB_PATH = None # if not None, enables the openset/zero-shot training, the category embeddings during test
|
211 |
+
|
212 |
+
_C.MODEL.CLIP.CLSS_TEMP = None # if None, dot product wo normalization & temperature; if float, normalization plus temperature
|
213 |
+
_C.MODEL.CLIP.RUN_CVPR_OVR = False # if True, train CVPR OVR model with their text embeddings
|
214 |
+
_C.MODEL.CLIP.FOCAL_SCALED_LOSS = None # if not None (float value for gamma), apply focal loss scaling idea to standard cross-entropy loss
|
215 |
+
|
216 |
+
_C.MODEL.CLIP.OFFLINE_RPN_NMS_THRESH = None # the threshold of NMS in offline RPN
|
217 |
+
_C.MODEL.CLIP.PRETRAIN_IMG_TXT_LEVEL = True # if True, pretrain model using image-text level matching
|
218 |
+
_C.MODEL.CLIP.PRETRAIN_ONLY_EOT = False # if True, use end-of-token emb to match region features, in image-text level matching
|
219 |
+
_C.MODEL.CLIP.PRETRAIN_RPN_REGIONS = None # if not None, the number of RPN regions per image during pretraining
|
220 |
+
_C.MODEL.CLIP.PRETRAIN_SAMPLE_REGIONS = None # if not None, the number of regions per image during pretraining after sampling, to avoid overfitting
|
221 |
+
_C.MODEL.CLIP.GATHER_GPUS = False # if True, gather tensors across GPUS to increase batch size
|
222 |
+
_C.MODEL.CLIP.GRID_REGIONS = False # if True, use grid boxes to extract grid features, instead of object proposals
|
223 |
+
_C.MODEL.CLIP.CONCEPT_POOL_EMB = None # if not None, it provides the file path of embs of concept pool and thus enables region-concept matching
|
224 |
+
_C.MODEL.CLIP.CONCEPT_THRES = None # if not None, the threshold to filter out the regions with low matching score with concept embs, dependent on temp (default: 0.01)
|
225 |
+
|
226 |
+
_C.MODEL.CLIP.OFFLINE_RPN_LSJ_PRETRAINED = False # if True, use large-scale jittering (LSJ) pretrained RPN
|
227 |
+
_C.MODEL.CLIP.TEACHER_RESNETS_DEPTH = 50 # the type of visual encoder of teacher model, sucha as ResNet 50, 101, 200 (a flag for 50x4)
|
228 |
+
_C.MODEL.CLIP.TEACHER_CONCEPT_POOL_EMB = None # if not None, it uses the same concept embedding as student model; otherwise, uses a seperate embedding of teacher model
|
229 |
+
_C.MODEL.CLIP.TEACHER_POOLER_RESOLUTION = 14 # RoIpooling resolution of teacher model
|
230 |
+
|
231 |
+
_C.MODEL.CLIP.TEXT_EMB_DIM = 1024 # the dimension of precomputed class embeddings
|
232 |
+
|
233 |
+
# ---------------------------------------------------------------------------- #
|
234 |
+
# Backbone options
|
235 |
+
# ---------------------------------------------------------------------------- #
|
236 |
+
_C.MODEL.BACKBONE = CN()
|
237 |
+
|
238 |
+
_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
|
239 |
+
# Freeze the first several stages so they are not trained.
|
240 |
+
# There are 5 stages in ResNet. The first is a convolution, and the following
|
241 |
+
# stages are each group of residual blocks.
|
242 |
+
_C.MODEL.BACKBONE.FREEZE_AT = 2
|
243 |
+
|
244 |
+
_C.MODEL.TEXT_BACKBONE = CN()
|
245 |
+
_C.MODEL.TEXT_BACKBONE.NAME = "build_clip_swin_text_backbone"
|
246 |
+
|
247 |
+
|
248 |
+
# ---------------------------------------------------------------------------- #
|
249 |
+
# FPN options
|
250 |
+
# ---------------------------------------------------------------------------- #
|
251 |
+
_C.MODEL.FPN = CN()
|
252 |
+
# Names of the input feature maps to be used by FPN
|
253 |
+
# They must have contiguous power of 2 strides
|
254 |
+
# e.g., ["res2", "res3", "res4", "res5"]
|
255 |
+
_C.MODEL.FPN.IN_FEATURES = []
|
256 |
+
_C.MODEL.FPN.OUT_CHANNELS = 256
|
257 |
+
|
258 |
+
# Options: "" (no norm), "GN"
|
259 |
+
_C.MODEL.FPN.NORM = ""
|
260 |
+
|
261 |
+
# Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg"
|
262 |
+
_C.MODEL.FPN.FUSE_TYPE = "sum"
|
263 |
+
|
264 |
+
|
265 |
+
# ---------------------------------------------------------------------------- #
|
266 |
+
# Proposal generator options
|
267 |
+
# ---------------------------------------------------------------------------- #
|
268 |
+
_C.MODEL.PROPOSAL_GENERATOR = CN()
|
269 |
+
# Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals"
|
270 |
+
_C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN"
|
271 |
+
# Proposal height and width both need to be greater than MIN_SIZE
|
272 |
+
# (a the scale used during training or inference)
|
273 |
+
_C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0
|
274 |
+
|
275 |
+
|
276 |
+
# ---------------------------------------------------------------------------- #
|
277 |
+
# Anchor generator options
|
278 |
+
# ---------------------------------------------------------------------------- #
|
279 |
+
_C.MODEL.ANCHOR_GENERATOR = CN()
|
280 |
+
# The generator can be any name in the ANCHOR_GENERATOR registry
|
281 |
+
_C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator"
|
282 |
+
# Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input.
|
283 |
+
# Format: list[list[float]]. SIZES[i] specifies the list of sizes to use for
|
284 |
+
# IN_FEATURES[i]; len(SIZES) must be equal to len(IN_FEATURES) or 1.
|
285 |
+
# When len(SIZES) == 1, SIZES[0] is used for all IN_FEATURES.
|
286 |
+
_C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]]
|
287 |
+
# Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect
|
288 |
+
# ratios are generated by an anchor generator.
|
289 |
+
# Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W)
|
290 |
+
# to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true,
|
291 |
+
# or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used
|
292 |
+
# for all IN_FEATURES.
|
293 |
+
_C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]]
|
294 |
+
# Anchor angles.
|
295 |
+
# list[list[float]], the angle in degrees, for each input feature map.
|
296 |
+
# ANGLES[i] specifies the list of angles for IN_FEATURES[i].
|
297 |
+
_C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]]
|
298 |
+
# Relative offset between the center of the first anchor and the top-left corner of the image
|
299 |
+
# Value has to be in [0, 1). Recommend to use 0.5, which means half stride.
|
300 |
+
# The value is not expected to affect model accuracy.
|
301 |
+
_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0
|
302 |
+
|
303 |
+
# ---------------------------------------------------------------------------- #
|
304 |
+
# RPN options
|
305 |
+
# ---------------------------------------------------------------------------- #
|
306 |
+
_C.MODEL.RPN = CN()
|
307 |
+
_C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY
|
308 |
+
|
309 |
+
# Names of the input feature maps to be used by RPN
|
310 |
+
# e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN
|
311 |
+
_C.MODEL.RPN.IN_FEATURES = ["res4"]
|
312 |
+
# Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels
|
313 |
+
# Set to -1 or a large value, e.g. 100000, to disable pruning anchors
|
314 |
+
_C.MODEL.RPN.BOUNDARY_THRESH = -1
|
315 |
+
# IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD]
|
316 |
+
# Minimum overlap required between an anchor and ground-truth box for the
|
317 |
+
# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
|
318 |
+
# ==> positive RPN example: 1)
|
319 |
+
# Maximum overlap allowed between an anchor and ground-truth box for the
|
320 |
+
# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
|
321 |
+
# ==> negative RPN example: 0)
|
322 |
+
# Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD)
|
323 |
+
# are ignored (-1)
|
324 |
+
_C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7]
|
325 |
+
_C.MODEL.RPN.IOU_LABELS = [0, -1, 1]
|
326 |
+
# Number of regions per image used to train RPN
|
327 |
+
_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
|
328 |
+
# Target fraction of foreground (positive) examples per RPN minibatch
|
329 |
+
_C.MODEL.RPN.POSITIVE_FRACTION = 0.5
|
330 |
+
# Options are: "smooth_l1", "giou"
|
331 |
+
_C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1"
|
332 |
+
_C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0
|
333 |
+
# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets
|
334 |
+
_C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
|
335 |
+
# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
|
336 |
+
_C.MODEL.RPN.SMOOTH_L1_BETA = 0.0
|
337 |
+
_C.MODEL.RPN.LOSS_WEIGHT = 1.0
|
338 |
+
# Number of top scoring RPN proposals to keep before applying NMS
|
339 |
+
# When FPN is used, this is *per FPN level* (not total)
|
340 |
+
_C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000
|
341 |
+
_C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000
|
342 |
+
# Number of top scoring RPN proposals to keep after applying NMS
|
343 |
+
# When FPN is used, this limit is applied per level and then again to the union
|
344 |
+
# of proposals from all levels
|
345 |
+
# NOTE: When FPN is used, the meaning of this config is different from Detectron1.
|
346 |
+
# It means per-batch topk in Detectron1, but per-image topk here.
|
347 |
+
# See the "find_top_rpn_proposals" function for details.
|
348 |
+
_C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000
|
349 |
+
_C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000
|
350 |
+
# NMS threshold used on RPN proposals
|
351 |
+
_C.MODEL.RPN.NMS_THRESH = 0.7
|
352 |
+
# Set this to -1 to use the same number of output channels as input channels.
|
353 |
+
_C.MODEL.RPN.CONV_DIMS = [-1]
|
354 |
+
|
355 |
+
# ---------------------------------------------------------------------------- #
|
356 |
+
# ROI HEADS options
|
357 |
+
# ---------------------------------------------------------------------------- #
|
358 |
+
_C.MODEL.ROI_HEADS = CN()
|
359 |
+
_C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads"
|
360 |
+
# Number of foreground classes
|
361 |
+
_C.MODEL.ROI_HEADS.NUM_CLASSES = 80
|
362 |
+
# Names of the input feature maps to be used by ROI heads
|
363 |
+
# Currently all heads (box, mask, ...) use the same input feature map list
|
364 |
+
# e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN
|
365 |
+
_C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"]
|
366 |
+
# IOU overlap ratios [IOU_THRESHOLD]
|
367 |
+
# Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD)
|
368 |
+
# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD)
|
369 |
+
_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5]
|
370 |
+
_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1]
|
371 |
+
# RoI minibatch size *per image* (number of regions of interest [ROIs])
|
372 |
+
# Total number of RoIs per training minibatch =
|
373 |
+
# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH
|
374 |
+
# E.g., a common configuration is: 512 * 16 = 8192
|
375 |
+
_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
|
376 |
+
# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
|
377 |
+
_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
|
378 |
+
|
379 |
+
# Only used on test mode
|
380 |
+
|
381 |
+
# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
|
382 |
+
# balance obtaining high recall with not having too many low precision
|
383 |
+
# detections that will slow down inference post processing steps (like NMS)
|
384 |
+
# A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down
|
385 |
+
# inference.
|
386 |
+
_C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05
|
387 |
+
# Overlap threshold used for non-maximum suppression (suppress boxes with
|
388 |
+
# IoU >= this threshold)
|
389 |
+
_C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
|
390 |
+
# If True, augment proposals with ground-truth boxes before sampling proposals to
|
391 |
+
# train ROI heads.
|
392 |
+
_C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True
|
393 |
+
|
394 |
+
# Use soft NMS instead of standard NMS if set to True
|
395 |
+
_C.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
|
396 |
+
# See soft NMS paper for definition of these options
|
397 |
+
_C.MODEL.ROI_HEADS.SOFT_NMS_METHOD = "gaussian" # "linear"
|
398 |
+
_C.MODEL.ROI_HEADS.SOFT_NMS_SIGMA = 0.5
|
399 |
+
# For the linear_threshold we use NMS_THRESH_TEST
|
400 |
+
_C.MODEL.ROI_HEADS.SOFT_NMS_PRUNE = 0.001
|
401 |
+
|
402 |
+
# ---------------------------------------------------------------------------- #
|
403 |
+
# Box Head
|
404 |
+
# ---------------------------------------------------------------------------- #
|
405 |
+
_C.MODEL.ROI_BOX_HEAD = CN()
|
406 |
+
# C4 don't use head name option
|
407 |
+
# Options for non-C4 models: FastRCNNConvFCHead,
|
408 |
+
_C.MODEL.ROI_BOX_HEAD.NAME = ""
|
409 |
+
# Options are: "smooth_l1", "giou"
|
410 |
+
_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1"
|
411 |
+
# The final scaling coefficient on the box regression loss, used to balance the magnitude of its
|
412 |
+
# gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`.
|
413 |
+
_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT = 1.0
|
414 |
+
# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
|
415 |
+
# These are empirically chosen to approximately lead to unit variance targets
|
416 |
+
_C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
|
417 |
+
# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
|
418 |
+
_C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0
|
419 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
|
420 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
|
421 |
+
# Type of pooling operation applied to the incoming feature map for each RoI
|
422 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2"
|
423 |
+
|
424 |
+
_C.MODEL.ROI_BOX_HEAD.NUM_FC = 0
|
425 |
+
# Hidden layer dimension for FC layers in the RoI box head
|
426 |
+
_C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024
|
427 |
+
_C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0
|
428 |
+
# Channel dimension for Conv layers in the RoI box head
|
429 |
+
_C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256
|
430 |
+
# Normalization method for the convolution layers.
|
431 |
+
# Options: "" (no norm), "GN", "SyncBN".
|
432 |
+
_C.MODEL.ROI_BOX_HEAD.NORM = ""
|
433 |
+
# Whether to use class agnostic for bbox regression
|
434 |
+
_C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False
|
435 |
+
# If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes.
|
436 |
+
_C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False
|
437 |
+
|
438 |
+
# ---------------------------------------------------------------------------- #
|
439 |
+
# Cascaded Box Head
|
440 |
+
# ---------------------------------------------------------------------------- #
|
441 |
+
_C.MODEL.ROI_BOX_CASCADE_HEAD = CN()
|
442 |
+
# The number of cascade stages is implicitly defined by the length of the following two configs.
|
443 |
+
_C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = (
|
444 |
+
(10.0, 10.0, 5.0, 5.0),
|
445 |
+
(20.0, 20.0, 10.0, 10.0),
|
446 |
+
(30.0, 30.0, 15.0, 15.0),
|
447 |
+
)
|
448 |
+
_C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7)
|
449 |
+
|
450 |
+
|
451 |
+
# ---------------------------------------------------------------------------- #
|
452 |
+
# Mask Head
|
453 |
+
# ---------------------------------------------------------------------------- #
|
454 |
+
_C.MODEL.ROI_MASK_HEAD = CN()
|
455 |
+
_C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead"
|
456 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
|
457 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
|
458 |
+
_C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head
|
459 |
+
_C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256
|
460 |
+
# Normalization method for the convolution layers.
|
461 |
+
# Options: "" (no norm), "GN", "SyncBN".
|
462 |
+
_C.MODEL.ROI_MASK_HEAD.NORM = ""
|
463 |
+
# Whether to use class agnostic for mask prediction
|
464 |
+
_C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False
|
465 |
+
# Type of pooling operation applied to the incoming feature map for each RoI
|
466 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2"
|
467 |
+
|
468 |
+
|
469 |
+
# ---------------------------------------------------------------------------- #
|
470 |
+
# Keypoint Head
|
471 |
+
# ---------------------------------------------------------------------------- #
|
472 |
+
_C.MODEL.ROI_KEYPOINT_HEAD = CN()
|
473 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead"
|
474 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14
|
475 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0
|
476 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8))
|
477 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO.
|
478 |
+
|
479 |
+
# Images with too few (or no) keypoints are excluded from training.
|
480 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1
|
481 |
+
# Normalize by the total number of visible keypoints in the minibatch if True.
|
482 |
+
# Otherwise, normalize by the total number of keypoints that could ever exist
|
483 |
+
# in the minibatch.
|
484 |
+
# The keypoint softmax loss is only calculated on visible keypoints.
|
485 |
+
# Since the number of visible keypoints can vary significantly between
|
486 |
+
# minibatches, this has the effect of up-weighting the importance of
|
487 |
+
# minibatches with few visible keypoints. (Imagine the extreme case of
|
488 |
+
# only one visible keypoint versus N: in the case of N, each one
|
489 |
+
# contributes 1/N to the gradient compared to the single keypoint
|
490 |
+
# determining the gradient direction). Instead, we can normalize the
|
491 |
+
# loss by the total number of keypoints, if it were the case that all
|
492 |
+
# keypoints were visible in a full minibatch. (Returning to the example,
|
493 |
+
# this means that the one visible keypoint contributes as much as each
|
494 |
+
# of the N keypoints.)
|
495 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True
|
496 |
+
# Multi-task loss weight to use for keypoints
|
497 |
+
# Recommended values:
|
498 |
+
# - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True
|
499 |
+
# - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False
|
500 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0
|
501 |
+
# Type of pooling operation applied to the incoming feature map for each RoI
|
502 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2"
|
503 |
+
|
504 |
+
# ---------------------------------------------------------------------------- #
|
505 |
+
# Semantic Segmentation Head
|
506 |
+
# ---------------------------------------------------------------------------- #
|
507 |
+
_C.MODEL.SEM_SEG_HEAD = CN()
|
508 |
+
_C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead"
|
509 |
+
_C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"]
|
510 |
+
# Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for
|
511 |
+
# the correposnding pixel.
|
512 |
+
_C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255
|
513 |
+
# Number of classes in the semantic segmentation head
|
514 |
+
_C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54
|
515 |
+
# Number of channels in the 3x3 convs inside semantic-FPN heads.
|
516 |
+
_C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128
|
517 |
+
# Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride.
|
518 |
+
_C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
|
519 |
+
# Normalization method for the convolution layers. Options: "" (no norm), "GN".
|
520 |
+
_C.MODEL.SEM_SEG_HEAD.NORM = "GN"
|
521 |
+
_C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0
|
522 |
+
|
523 |
+
_C.MODEL.PANOPTIC_FPN = CN()
|
524 |
+
# Scaling of all losses from instance detection / segmentation head.
|
525 |
+
_C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0
|
526 |
+
|
527 |
+
# options when combining instance & semantic segmentation outputs
|
528 |
+
_C.MODEL.PANOPTIC_FPN.COMBINE = CN({"ENABLED": True}) # "COMBINE.ENABLED" is deprecated & not used
|
529 |
+
_C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5
|
530 |
+
_C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096
|
531 |
+
_C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
|
532 |
+
|
533 |
+
|
534 |
+
# ---------------------------------------------------------------------------- #
|
535 |
+
# RetinaNet Head
|
536 |
+
# ---------------------------------------------------------------------------- #
|
537 |
+
_C.MODEL.RETINANET = CN()
|
538 |
+
|
539 |
+
# This is the number of foreground classes.
|
540 |
+
_C.MODEL.RETINANET.NUM_CLASSES = 80
|
541 |
+
|
542 |
+
_C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]
|
543 |
+
|
544 |
+
# Convolutions to use in the cls and bbox tower
|
545 |
+
# NOTE: this doesn't include the last conv for logits
|
546 |
+
_C.MODEL.RETINANET.NUM_CONVS = 4
|
547 |
+
|
548 |
+
# IoU overlap ratio [bg, fg] for labeling anchors.
|
549 |
+
# Anchors with < bg are labeled negative (0)
|
550 |
+
# Anchors with >= bg and < fg are ignored (-1)
|
551 |
+
# Anchors with >= fg are labeled positive (1)
|
552 |
+
_C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5]
|
553 |
+
_C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1]
|
554 |
+
|
555 |
+
# Prior prob for rare case (i.e. foreground) at the beginning of training.
|
556 |
+
# This is used to set the bias for the logits layer of the classifier subnet.
|
557 |
+
# This improves training stability in the case of heavy class imbalance.
|
558 |
+
_C.MODEL.RETINANET.PRIOR_PROB = 0.01
|
559 |
+
|
560 |
+
# Inference cls score threshold, only anchors with score > INFERENCE_TH are
|
561 |
+
# considered for inference (to improve speed)
|
562 |
+
_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05
|
563 |
+
# Select topk candidates before NMS
|
564 |
+
_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000
|
565 |
+
_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5
|
566 |
+
|
567 |
+
# Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets
|
568 |
+
_C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
|
569 |
+
|
570 |
+
# Loss parameters
|
571 |
+
_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0
|
572 |
+
_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25
|
573 |
+
_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1
|
574 |
+
# Options are: "smooth_l1", "giou"
|
575 |
+
_C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1"
|
576 |
+
|
577 |
+
# One of BN, SyncBN, FrozenBN, GN
|
578 |
+
# Only supports GN until unshared norm is implemented
|
579 |
+
_C.MODEL.RETINANET.NORM = ""
|
580 |
+
|
581 |
+
|
582 |
+
# ---------------------------------------------------------------------------- #
|
583 |
+
# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
|
584 |
+
# Note that parts of a resnet may be used for both the backbone and the head
|
585 |
+
# These options apply to both
|
586 |
+
# ---------------------------------------------------------------------------- #
|
587 |
+
_C.MODEL.RESNETS = CN()
|
588 |
+
|
589 |
+
_C.MODEL.RESNETS.DEPTH = 50
|
590 |
+
_C.MODEL.RESNETS.OUT_FEATURES = ["res4"] # res4 for C4 backbone, res2..5 for FPN backbone
|
591 |
+
|
592 |
+
# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
|
593 |
+
_C.MODEL.RESNETS.NUM_GROUPS = 1
|
594 |
+
|
595 |
+
# Options: FrozenBN, GN, "SyncBN", "BN"
|
596 |
+
_C.MODEL.RESNETS.NORM = "FrozenBN"
|
597 |
+
|
598 |
+
# Baseline width of each group.
|
599 |
+
# Scaling this parameters will scale the width of all bottleneck layers.
|
600 |
+
_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
|
601 |
+
|
602 |
+
# Place the stride 2 conv on the 1x1 filter
|
603 |
+
# Use True only for the original MSRA ResNet; use False for C2 and Torch models
|
604 |
+
_C.MODEL.RESNETS.STRIDE_IN_1X1 = True
|
605 |
+
|
606 |
+
# Apply dilation in stage "res5"
|
607 |
+
_C.MODEL.RESNETS.RES5_DILATION = 1
|
608 |
+
|
609 |
+
# Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet
|
610 |
+
# For R18 and R34, this needs to be set to 64
|
611 |
+
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
|
612 |
+
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
|
613 |
+
|
614 |
+
# Apply Deformable Convolution in stages
|
615 |
+
# Specify if apply deform_conv on Res2, Res3, Res4, Res5
|
616 |
+
_C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False]
|
617 |
+
# Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168);
|
618 |
+
# Use False for DeformableV1.
|
619 |
+
_C.MODEL.RESNETS.DEFORM_MODULATED = False
|
620 |
+
# Number of groups in deformable conv.
|
621 |
+
_C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1
|
622 |
+
|
623 |
+
|
624 |
+
# ---------------------------------------------------------------------------- #
|
625 |
+
# Swin options
|
626 |
+
# Note that parts of a resnet may be used for both the backbone and the head
|
627 |
+
# These options apply to both
|
628 |
+
# ---------------------------------------------------------------------------- #
|
629 |
+
_C.MODEL.SPEC = CN()
|
630 |
+
_C.MODEL.SPEC.EMBED_DIM = 512
|
631 |
+
|
632 |
+
_C.MODEL.SPEC.VISION = CN()
|
633 |
+
_C.MODEL.SPEC.VISION.PATCH_SIZE = 4
|
634 |
+
_C.MODEL.SPEC.VISION.IN_CHANS = 3
|
635 |
+
_C.MODEL.SPEC.VISION.EMBED_DIM = 96
|
636 |
+
_C.MODEL.SPEC.VISION.DEPTHS = [2, 2, 6, 2]
|
637 |
+
_C.MODEL.SPEC.VISION.NUM_HEADS = [3, 6, 12, 24]
|
638 |
+
_C.MODEL.SPEC.VISION.WINDOW_SIZE = 7
|
639 |
+
_C.MODEL.SPEC.VISION.MLP_RATIO = 4.
|
640 |
+
_C.MODEL.SPEC.VISION.DROP_RATE = .0
|
641 |
+
_C.MODEL.SPEC.VISION.ATTN_DROP_RATE = .0
|
642 |
+
_C.MODEL.SPEC.VISION.DROP_PATH_RATE = .0
|
643 |
+
_C.MODEL.SPEC.VISION.QKV_BIAS = True
|
644 |
+
_C.MODEL.SPEC.VISION.QK_SCALE = False
|
645 |
+
_C.MODEL.SPEC.VISION.APE = False
|
646 |
+
_C.MODEL.SPEC.VISION.PATCH_NORM = True
|
647 |
+
_C.MODEL.SPEC.VISION.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"]
|
648 |
+
|
649 |
+
_C.MODEL.SPEC.TEXT = CN()
|
650 |
+
_C.MODEL.SPEC.TEXT.NAME = 'transformer'
|
651 |
+
_C.MODEL.SPEC.TEXT.LOAD_PRETRAINED = False
|
652 |
+
_C.MODEL.SPEC.TEXT.PRETRAINED = ''
|
653 |
+
_C.MODEL.SPEC.TEXT.TOKENIZER = 'clip'
|
654 |
+
_C.MODEL.SPEC.TEXT.CONTEXT_LENGTH = 77
|
655 |
+
_C.MODEL.SPEC.TEXT.WIDTH = 512
|
656 |
+
_C.MODEL.SPEC.TEXT.HEADS = 8
|
657 |
+
_C.MODEL.SPEC.TEXT.LAYERS = 12
|
658 |
+
_C.MODEL.SPEC.TEXT.AUTOGRESSIVE = True
|
659 |
+
|
660 |
+
# ---------------------------------------------------------------------------- #
|
661 |
+
# Solver
|
662 |
+
# ---------------------------------------------------------------------------- #
|
663 |
+
_C.SOLVER = CN()
|
664 |
+
|
665 |
+
# See detectron2/solver/build.py for LR scheduler options
|
666 |
+
_C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
|
667 |
+
|
668 |
+
_C.SOLVER.MAX_ITER = 40000
|
669 |
+
|
670 |
+
_C.SOLVER.BASE_LR = 0.001
|
671 |
+
|
672 |
+
_C.SOLVER.MOMENTUM = 0.9
|
673 |
+
|
674 |
+
_C.SOLVER.NESTEROV = False
|
675 |
+
|
676 |
+
_C.SOLVER.WEIGHT_DECAY = 0.0001
|
677 |
+
# The weight decay that's applied to parameters of normalization layers
|
678 |
+
# (typically the affine transformation)
|
679 |
+
_C.SOLVER.WEIGHT_DECAY_NORM = 0.0
|
680 |
+
|
681 |
+
_C.SOLVER.GAMMA = 0.1
|
682 |
+
# The iteration number to decrease learning rate by GAMMA.
|
683 |
+
_C.SOLVER.STEPS = (30000,)
|
684 |
+
|
685 |
+
_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
|
686 |
+
_C.SOLVER.WARMUP_ITERS = 1000
|
687 |
+
_C.SOLVER.WARMUP_METHOD = "linear"
|
688 |
+
|
689 |
+
# Save a checkpoint after every this number of iterations
|
690 |
+
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
691 |
+
|
692 |
+
# Number of images per batch across all machines. This is also the number
|
693 |
+
# of training images per step (i.e. per iteration). If we use 16 GPUs
|
694 |
+
# and IMS_PER_BATCH = 32, each GPU will see 2 images per batch.
|
695 |
+
# May be adjusted automatically if REFERENCE_WORLD_SIZE is set.
|
696 |
+
_C.SOLVER.IMS_PER_BATCH = 16
|
697 |
+
|
698 |
+
# The reference number of workers (GPUs) this config is meant to train with.
|
699 |
+
# It takes no effect when set to 0.
|
700 |
+
# With a non-zero value, it will be used by DefaultTrainer to compute a desired
|
701 |
+
# per-worker batch size, and then scale the other related configs (total batch size,
|
702 |
+
# learning rate, etc) to match the per-worker batch size.
|
703 |
+
# See documentation of `DefaultTrainer.auto_scale_workers` for details:
|
704 |
+
_C.SOLVER.REFERENCE_WORLD_SIZE = 0
|
705 |
+
|
706 |
+
# Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for
|
707 |
+
# biases. This is not useful (at least for recent models). You should avoid
|
708 |
+
# changing these and they exist only to reproduce Detectron v1 training if
|
709 |
+
# desired.
|
710 |
+
_C.SOLVER.BIAS_LR_FACTOR = 1.0
|
711 |
+
_C.SOLVER.WEIGHT_DECAY_BIAS = _C.SOLVER.WEIGHT_DECAY
|
712 |
+
|
713 |
+
# Gradient clipping
|
714 |
+
_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False})
|
715 |
+
# Type of gradient clipping, currently 2 values are supported:
|
716 |
+
# - "value": the absolute values of elements of each gradients are clipped
|
717 |
+
# - "norm": the norm of the gradient for each parameter is clipped thus
|
718 |
+
# affecting all elements in the parameter
|
719 |
+
_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value"
|
720 |
+
# Maximum absolute value used for clipping gradients
|
721 |
+
_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
|
722 |
+
# Floating point number p for L-p norm to be used with the "norm"
|
723 |
+
# gradient clipping type; for L-inf, please specify .inf
|
724 |
+
_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
|
725 |
+
|
726 |
+
# Enable automatic mixed precision for training
|
727 |
+
# Note that this does not change model's inference behavior.
|
728 |
+
# To use AMP in inference, run inference under autocast()
|
729 |
+
_C.SOLVER.AMP = CN({"ENABLED": False})
|
730 |
+
|
731 |
+
# ---------------------------------------------------------------------------- #
|
732 |
+
# Specific test options
|
733 |
+
# ---------------------------------------------------------------------------- #
|
734 |
+
_C.TEST = CN()
|
735 |
+
# For end-to-end tests to verify the expected accuracy.
|
736 |
+
# Each item is [task, metric, value, tolerance]
|
737 |
+
# e.g.: [['bbox', 'AP', 38.5, 0.2]]
|
738 |
+
_C.TEST.EXPECTED_RESULTS = []
|
739 |
+
# The period (in terms of steps) to evaluate the model during training.
|
740 |
+
# Set to 0 to disable.
|
741 |
+
_C.TEST.EVAL_PERIOD = 0
|
742 |
+
# The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval
|
743 |
+
# When empty, it will use the defaults in COCO.
|
744 |
+
# Otherwise it should be a list[float] with the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
|
745 |
+
_C.TEST.KEYPOINT_OKS_SIGMAS = []
|
746 |
+
# Maximum number of detections to return per image during inference (100 is
|
747 |
+
# based on the limit established for the COCO dataset).
|
748 |
+
_C.TEST.DETECTIONS_PER_IMAGE = 100
|
749 |
+
|
750 |
+
_C.TEST.AUG = CN({"ENABLED": False})
|
751 |
+
_C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200)
|
752 |
+
_C.TEST.AUG.MAX_SIZE = 4000
|
753 |
+
_C.TEST.AUG.FLIP = True
|
754 |
+
|
755 |
+
_C.TEST.PRECISE_BN = CN({"ENABLED": False})
|
756 |
+
_C.TEST.PRECISE_BN.NUM_ITER = 200
|
757 |
+
|
758 |
+
# ---------------------------------------------------------------------------- #
|
759 |
+
# Misc options
|
760 |
+
# ---------------------------------------------------------------------------- #
|
761 |
+
# Directory where output files are written
|
762 |
+
_C.OUTPUT_DIR = "./output"
|
763 |
+
# Set seed to negative to fully randomize everything.
|
764 |
+
# Set seed to positive to use a fixed seed. Note that a fixed seed increases
|
765 |
+
# reproducibility but does not guarantee fully deterministic behavior.
|
766 |
+
# Disabling all parallelism further increases reproducibility.
|
767 |
+
_C.SEED = -1
|
768 |
+
# Benchmark different cudnn algorithms.
|
769 |
+
# If input images have very different sizes, this option will have large overhead
|
770 |
+
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
|
771 |
+
# If input images have the same or similar sizes, benchmark is often helpful.
|
772 |
+
_C.CUDNN_BENCHMARK = False
|
773 |
+
# The period (in terms of steps) for minibatch visualization at train time.
|
774 |
+
# Set to 0 to disable.
|
775 |
+
_C.VIS_PERIOD = 0
|
776 |
+
|
777 |
+
# global config is for quick hack purposes.
|
778 |
+
# You can set them in command line or config files,
|
779 |
+
# and access it with:
|
780 |
+
#
|
781 |
+
# from detectron2.config import global_cfg
|
782 |
+
# print(global_cfg.HACK)
|
783 |
+
#
|
784 |
+
# Do not commit any configs into it.
|
785 |
+
_C.GLOBAL = CN()
|
786 |
+
_C.GLOBAL.HACK = 1.0
|
detectron2/config/instantiate.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import dataclasses
|
3 |
+
import logging
|
4 |
+
from collections import abc
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
from detectron2.utils.registry import _convert_target_to_string, locate
|
8 |
+
|
9 |
+
__all__ = ["dump_dataclass", "instantiate"]
|
10 |
+
|
11 |
+
|
12 |
+
def dump_dataclass(obj: Any):
|
13 |
+
"""
|
14 |
+
Dump a dataclass recursively into a dict that can be later instantiated.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
obj: a dataclass object
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
dict
|
21 |
+
"""
|
22 |
+
assert dataclasses.is_dataclass(obj) and not isinstance(
|
23 |
+
obj, type
|
24 |
+
), "dump_dataclass() requires an instance of a dataclass."
|
25 |
+
ret = {"_target_": _convert_target_to_string(type(obj))}
|
26 |
+
for f in dataclasses.fields(obj):
|
27 |
+
v = getattr(obj, f.name)
|
28 |
+
if dataclasses.is_dataclass(v):
|
29 |
+
v = dump_dataclass(v)
|
30 |
+
if isinstance(v, (list, tuple)):
|
31 |
+
v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
|
32 |
+
ret[f.name] = v
|
33 |
+
return ret
|
34 |
+
|
35 |
+
|
36 |
+
def instantiate(cfg):
|
37 |
+
"""
|
38 |
+
Recursively instantiate objects defined in dictionaries by
|
39 |
+
"_target_" and arguments.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
cfg: a dict-like object with "_target_" that defines the caller, and
|
43 |
+
other keys that define the arguments
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
object instantiated by cfg
|
47 |
+
"""
|
48 |
+
from omegaconf import ListConfig
|
49 |
+
|
50 |
+
if isinstance(cfg, ListConfig):
|
51 |
+
lst = [instantiate(x) for x in cfg]
|
52 |
+
return ListConfig(lst, flags={"allow_objects": True})
|
53 |
+
if isinstance(cfg, list):
|
54 |
+
# Specialize for list, because many classes take
|
55 |
+
# list[objects] as arguments, such as ResNet, DatasetMapper
|
56 |
+
return [instantiate(x) for x in cfg]
|
57 |
+
|
58 |
+
if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
|
59 |
+
# conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
|
60 |
+
# but faster: https://github.com/facebookresearch/hydra/issues/1200
|
61 |
+
cfg = {k: instantiate(v) for k, v in cfg.items()}
|
62 |
+
cls = cfg.pop("_target_")
|
63 |
+
cls = instantiate(cls)
|
64 |
+
|
65 |
+
if isinstance(cls, str):
|
66 |
+
cls_name = cls
|
67 |
+
cls = locate(cls_name)
|
68 |
+
assert cls is not None, cls_name
|
69 |
+
else:
|
70 |
+
try:
|
71 |
+
cls_name = cls.__module__ + "." + cls.__qualname__
|
72 |
+
except Exception:
|
73 |
+
# target could be anything, so the above could fail
|
74 |
+
cls_name = str(cls)
|
75 |
+
assert callable(cls), f"_target_ {cls} does not define a callable object"
|
76 |
+
try:
|
77 |
+
return cls(**cfg)
|
78 |
+
except TypeError:
|
79 |
+
logger = logging.getLogger(__name__)
|
80 |
+
logger.error(f"Error when instantiating {cls_name}!")
|
81 |
+
raise
|
82 |
+
return cfg # return as-is if don't know what to do
|
detectron2/config/lazy.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import ast
|
3 |
+
import builtins
|
4 |
+
import importlib
|
5 |
+
import inspect
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import uuid
|
9 |
+
from collections import abc
|
10 |
+
from contextlib import contextmanager
|
11 |
+
from copy import deepcopy
|
12 |
+
from typing import List, Tuple, Union
|
13 |
+
import cloudpickle
|
14 |
+
import yaml
|
15 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
16 |
+
|
17 |
+
from detectron2.utils.file_io import PathManager
|
18 |
+
from detectron2.utils.registry import _convert_target_to_string
|
19 |
+
|
20 |
+
__all__ = ["LazyCall", "LazyConfig"]
|
21 |
+
|
22 |
+
|
23 |
+
class LazyCall:
|
24 |
+
"""
|
25 |
+
Wrap a callable so that when it's called, the call will not be executed,
|
26 |
+
but returns a dict that describes the call.
|
27 |
+
|
28 |
+
LazyCall object has to be called with only keyword arguments. Positional
|
29 |
+
arguments are not yet supported.
|
30 |
+
|
31 |
+
Examples:
|
32 |
+
::
|
33 |
+
from detectron2.config import instantiate, LazyCall
|
34 |
+
|
35 |
+
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
|
36 |
+
layer_cfg.out_channels = 64 # can edit it afterwards
|
37 |
+
layer = instantiate(layer_cfg)
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, target):
|
41 |
+
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
|
42 |
+
raise TypeError(
|
43 |
+
"target of LazyCall must be a callable or defines a callable! Got {target}"
|
44 |
+
)
|
45 |
+
self._target = target
|
46 |
+
|
47 |
+
def __call__(self, **kwargs):
|
48 |
+
kwargs["_target_"] = self._target
|
49 |
+
return DictConfig(content=kwargs, flags={"allow_objects": True})
|
50 |
+
|
51 |
+
|
52 |
+
def _visit_dict_config(cfg, func):
|
53 |
+
"""
|
54 |
+
Apply func recursively to all DictConfig in cfg.
|
55 |
+
"""
|
56 |
+
if isinstance(cfg, DictConfig):
|
57 |
+
func(cfg)
|
58 |
+
for v in cfg.values():
|
59 |
+
_visit_dict_config(v, func)
|
60 |
+
elif isinstance(cfg, ListConfig):
|
61 |
+
for v in cfg:
|
62 |
+
_visit_dict_config(v, func)
|
63 |
+
|
64 |
+
|
65 |
+
def _validate_py_syntax(filename):
|
66 |
+
# see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
|
67 |
+
with PathManager.open(filename, "r") as f:
|
68 |
+
content = f.read()
|
69 |
+
try:
|
70 |
+
ast.parse(content)
|
71 |
+
except SyntaxError as e:
|
72 |
+
raise SyntaxError(f"Config file {filename} has syntax error!") from e
|
73 |
+
|
74 |
+
|
75 |
+
def _cast_to_config(obj):
|
76 |
+
# if given a dict, return DictConfig instead
|
77 |
+
if isinstance(obj, dict):
|
78 |
+
return DictConfig(obj, flags={"allow_objects": True})
|
79 |
+
return obj
|
80 |
+
|
81 |
+
|
82 |
+
_CFG_PACKAGE_NAME = "detectron2._cfg_loader"
|
83 |
+
"""
|
84 |
+
A namespace to put all imported config into.
|
85 |
+
"""
|
86 |
+
|
87 |
+
|
88 |
+
def _random_package_name(filename):
|
89 |
+
# generate a random package name when loading config files
|
90 |
+
return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
|
91 |
+
|
92 |
+
|
93 |
+
@contextmanager
|
94 |
+
def _patch_import():
|
95 |
+
"""
|
96 |
+
Enhance relative import statements in config files, so that they:
|
97 |
+
1. locate files purely based on relative location, regardless of packages.
|
98 |
+
e.g. you can import file without having __init__
|
99 |
+
2. do not cache modules globally; modifications of module states has no side effect
|
100 |
+
3. support other storage system through PathManager
|
101 |
+
4. imported dict are turned into omegaconf.DictConfig automatically
|
102 |
+
"""
|
103 |
+
old_import = builtins.__import__
|
104 |
+
|
105 |
+
def find_relative_file(original_file, relative_import_path, level):
|
106 |
+
cur_file = os.path.dirname(original_file)
|
107 |
+
for _ in range(level - 1):
|
108 |
+
cur_file = os.path.dirname(cur_file)
|
109 |
+
cur_name = relative_import_path.lstrip(".")
|
110 |
+
for part in cur_name.split("."):
|
111 |
+
cur_file = os.path.join(cur_file, part)
|
112 |
+
# NOTE: directory import is not handled. Because then it's unclear
|
113 |
+
# if such import should produce python module or DictConfig. This can
|
114 |
+
# be discussed further if needed.
|
115 |
+
if not cur_file.endswith(".py"):
|
116 |
+
cur_file += ".py"
|
117 |
+
if not PathManager.isfile(cur_file):
|
118 |
+
raise ImportError(
|
119 |
+
f"Cannot import name {relative_import_path} from "
|
120 |
+
f"{original_file}: {cur_file} has to exist."
|
121 |
+
)
|
122 |
+
return cur_file
|
123 |
+
|
124 |
+
def new_import(name, globals=None, locals=None, fromlist=(), level=0):
|
125 |
+
if (
|
126 |
+
# Only deal with relative imports inside config files
|
127 |
+
level != 0
|
128 |
+
and globals is not None
|
129 |
+
and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
|
130 |
+
):
|
131 |
+
cur_file = find_relative_file(globals["__file__"], name, level)
|
132 |
+
_validate_py_syntax(cur_file)
|
133 |
+
spec = importlib.machinery.ModuleSpec(
|
134 |
+
_random_package_name(cur_file), None, origin=cur_file
|
135 |
+
)
|
136 |
+
module = importlib.util.module_from_spec(spec)
|
137 |
+
module.__file__ = cur_file
|
138 |
+
with PathManager.open(cur_file) as f:
|
139 |
+
content = f.read()
|
140 |
+
exec(compile(content, cur_file, "exec"), module.__dict__)
|
141 |
+
for name in fromlist: # turn imported dict into DictConfig automatically
|
142 |
+
val = _cast_to_config(module.__dict__[name])
|
143 |
+
module.__dict__[name] = val
|
144 |
+
return module
|
145 |
+
return old_import(name, globals, locals, fromlist=fromlist, level=level)
|
146 |
+
|
147 |
+
builtins.__import__ = new_import
|
148 |
+
yield new_import
|
149 |
+
builtins.__import__ = old_import
|
150 |
+
|
151 |
+
|
152 |
+
class LazyConfig:
|
153 |
+
"""
|
154 |
+
Provid methods to save, load, and overrides an omegaconf config object
|
155 |
+
which may contain definition of lazily-constructed objects.
|
156 |
+
"""
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
|
160 |
+
"""
|
161 |
+
Similar to :meth:`load()`, but load path relative to the caller's
|
162 |
+
source file.
|
163 |
+
|
164 |
+
This has the same functionality as a relative import, except that this method
|
165 |
+
accepts filename as a string, so more characters are allowed in the filename.
|
166 |
+
"""
|
167 |
+
caller_frame = inspect.stack()[1]
|
168 |
+
caller_fname = caller_frame[0].f_code.co_filename
|
169 |
+
assert caller_fname != "<string>", "load_rel Unable to find caller"
|
170 |
+
caller_dir = os.path.dirname(caller_fname)
|
171 |
+
filename = os.path.join(caller_dir, filename)
|
172 |
+
return LazyConfig.load(filename, keys)
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
|
176 |
+
"""
|
177 |
+
Load a config file.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
filename: absolute path or relative path w.r.t. the current working directory
|
181 |
+
keys: keys to load and return. If not given, return all keys
|
182 |
+
(whose values are config objects) in a dict.
|
183 |
+
"""
|
184 |
+
has_keys = keys is not None
|
185 |
+
filename = filename.replace("/./", "/") # redundant
|
186 |
+
if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
|
187 |
+
raise ValueError(f"Config file {filename} has to be a python or yaml file.")
|
188 |
+
if filename.endswith(".py"):
|
189 |
+
_validate_py_syntax(filename)
|
190 |
+
|
191 |
+
with _patch_import():
|
192 |
+
# Record the filename
|
193 |
+
module_namespace = {
|
194 |
+
"__file__": filename,
|
195 |
+
"__package__": _random_package_name(filename),
|
196 |
+
}
|
197 |
+
with PathManager.open(filename) as f:
|
198 |
+
content = f.read()
|
199 |
+
# Compile first with filename to:
|
200 |
+
# 1. make filename appears in stacktrace
|
201 |
+
# 2. make load_rel able to find its parent's (possibly remote) location
|
202 |
+
exec(compile(content, filename, "exec"), module_namespace)
|
203 |
+
|
204 |
+
ret = module_namespace
|
205 |
+
else:
|
206 |
+
with PathManager.open(filename) as f:
|
207 |
+
obj = yaml.unsafe_load(f)
|
208 |
+
ret = OmegaConf.create(obj, flags={"allow_objects": True})
|
209 |
+
|
210 |
+
if has_keys:
|
211 |
+
if isinstance(keys, str):
|
212 |
+
return _cast_to_config(ret[keys])
|
213 |
+
else:
|
214 |
+
return tuple(_cast_to_config(ret[a]) for a in keys)
|
215 |
+
else:
|
216 |
+
if filename.endswith(".py"):
|
217 |
+
# when not specified, only load those that are config objects
|
218 |
+
ret = DictConfig(
|
219 |
+
{
|
220 |
+
name: _cast_to_config(value)
|
221 |
+
for name, value in ret.items()
|
222 |
+
if isinstance(value, (DictConfig, ListConfig, dict))
|
223 |
+
and not name.startswith("_")
|
224 |
+
},
|
225 |
+
flags={"allow_objects": True},
|
226 |
+
)
|
227 |
+
return ret
|
228 |
+
|
229 |
+
@staticmethod
|
230 |
+
def save(cfg, filename: str):
|
231 |
+
"""
|
232 |
+
Args:
|
233 |
+
cfg: an omegaconf config object
|
234 |
+
filename: yaml file name to save the config file
|
235 |
+
"""
|
236 |
+
logger = logging.getLogger(__name__)
|
237 |
+
try:
|
238 |
+
cfg = deepcopy(cfg)
|
239 |
+
except Exception:
|
240 |
+
pass
|
241 |
+
else:
|
242 |
+
# if it's deep-copyable, then...
|
243 |
+
def _replace_type_by_name(x):
|
244 |
+
if "_target_" in x and callable(x._target_):
|
245 |
+
try:
|
246 |
+
x._target_ = _convert_target_to_string(x._target_)
|
247 |
+
except AttributeError:
|
248 |
+
pass
|
249 |
+
|
250 |
+
# not necessary, but makes yaml looks nicer
|
251 |
+
_visit_dict_config(cfg, _replace_type_by_name)
|
252 |
+
|
253 |
+
try:
|
254 |
+
with PathManager.open(filename, "w") as f:
|
255 |
+
dict = OmegaConf.to_container(cfg, resolve=False)
|
256 |
+
dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999)
|
257 |
+
f.write(dumped)
|
258 |
+
except Exception:
|
259 |
+
logger.exception("Unable to serialize the config to yaml. Error:")
|
260 |
+
new_filename = filename + ".pkl"
|
261 |
+
try:
|
262 |
+
# retry by pickle
|
263 |
+
with PathManager.open(new_filename, "wb") as f:
|
264 |
+
cloudpickle.dump(cfg, f)
|
265 |
+
logger.warning(f"Config saved using cloudpickle at {new_filename} ...")
|
266 |
+
except Exception:
|
267 |
+
pass
|
268 |
+
|
269 |
+
@staticmethod
|
270 |
+
def apply_overrides(cfg, overrides: List[str]):
|
271 |
+
"""
|
272 |
+
In-place override contents of cfg.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
cfg: an omegaconf config object
|
276 |
+
overrides: list of strings in the format of "a=b" to override configs.
|
277 |
+
See https://hydra.cc/docs/next/advanced/override_grammar/basic/
|
278 |
+
for syntax.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
the cfg object
|
282 |
+
"""
|
283 |
+
|
284 |
+
def safe_update(cfg, key, value):
|
285 |
+
parts = key.split(".")
|
286 |
+
for idx in range(1, len(parts)):
|
287 |
+
prefix = ".".join(parts[:idx])
|
288 |
+
v = OmegaConf.select(cfg, prefix, default=None)
|
289 |
+
if v is None:
|
290 |
+
break
|
291 |
+
if not OmegaConf.is_config(v):
|
292 |
+
raise KeyError(
|
293 |
+
f"Trying to update key {key}, but {prefix} "
|
294 |
+
f"is not a config, but has type {type(v)}."
|
295 |
+
)
|
296 |
+
OmegaConf.update(cfg, key, value, merge=True)
|
297 |
+
|
298 |
+
from hydra.core.override_parser.overrides_parser import OverridesParser
|
299 |
+
|
300 |
+
parser = OverridesParser.create()
|
301 |
+
overrides = parser.parse_overrides(overrides)
|
302 |
+
for o in overrides:
|
303 |
+
key = o.key_or_group
|
304 |
+
value = o.value()
|
305 |
+
if o.is_delete():
|
306 |
+
# TODO support this
|
307 |
+
raise NotImplementedError("deletion is not yet a supported override")
|
308 |
+
safe_update(cfg, key, value)
|
309 |
+
return cfg
|
310 |
+
|
311 |
+
@staticmethod
|
312 |
+
def to_py(cfg, prefix: str = "cfg."):
|
313 |
+
"""
|
314 |
+
Convert a config object into its equivalent Python code.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
cfg: an omegaconf config object
|
318 |
+
prefix: root name for the resulting code (default: "cfg.")
|
319 |
+
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
str of formatted Python code
|
323 |
+
"""
|
324 |
+
import black
|
325 |
+
|
326 |
+
cfg = OmegaConf.to_container(cfg, resolve=True)
|
327 |
+
|
328 |
+
def _to_str(obj, prefix=None, inside_call=False):
|
329 |
+
if prefix is None:
|
330 |
+
prefix = []
|
331 |
+
if isinstance(obj, abc.Mapping) and "_target_" in obj:
|
332 |
+
# Dict representing a function call
|
333 |
+
target = _convert_target_to_string(obj.pop("_target_"))
|
334 |
+
args = []
|
335 |
+
for k, v in sorted(obj.items()):
|
336 |
+
args.append(f"{k}={_to_str(v, inside_call=True)}")
|
337 |
+
args = ", ".join(args)
|
338 |
+
call = f"{target}({args})"
|
339 |
+
return "".join(prefix) + call
|
340 |
+
elif isinstance(obj, abc.Mapping) and not inside_call:
|
341 |
+
# Dict that is not inside a call is a list of top-level config objects that we
|
342 |
+
# render as one object per line with dot separated prefixes
|
343 |
+
key_list = []
|
344 |
+
for k, v in sorted(obj.items()):
|
345 |
+
if isinstance(v, abc.Mapping) and "_target_" not in v:
|
346 |
+
key_list.append(_to_str(v, prefix=prefix + [k + "."]))
|
347 |
+
else:
|
348 |
+
key = "".join(prefix) + k
|
349 |
+
key_list.append(f"{key}={_to_str(v)}")
|
350 |
+
return "\n".join(key_list)
|
351 |
+
elif isinstance(obj, abc.Mapping):
|
352 |
+
# Dict that is inside a call is rendered as a regular dict
|
353 |
+
return (
|
354 |
+
"{"
|
355 |
+
+ ",".join(
|
356 |
+
f"{repr(k)}: {_to_str(v, inside_call=inside_call)}"
|
357 |
+
for k, v in sorted(obj.items())
|
358 |
+
)
|
359 |
+
+ "}"
|
360 |
+
)
|
361 |
+
elif isinstance(obj, list):
|
362 |
+
return "[" + ",".join(_to_str(x, inside_call=inside_call) for x in obj) + "]"
|
363 |
+
else:
|
364 |
+
return repr(obj)
|
365 |
+
|
366 |
+
py_str = _to_str(cfg, prefix=[prefix])
|
367 |
+
try:
|
368 |
+
return black.format_str(py_str, mode=black.Mode())
|
369 |
+
except black.InvalidInput:
|
370 |
+
return py_str
|
detectron2/data/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import transforms # isort:skip
|
3 |
+
|
4 |
+
from .build import (
|
5 |
+
build_batch_data_loader,
|
6 |
+
build_detection_test_loader,
|
7 |
+
build_detection_train_loader,
|
8 |
+
get_detection_dataset_dicts,
|
9 |
+
load_proposals_into_dataset,
|
10 |
+
print_instances_class_histogram,
|
11 |
+
)
|
12 |
+
from .catalog import DatasetCatalog, MetadataCatalog, Metadata
|
13 |
+
from .common import DatasetFromList, MapDataset
|
14 |
+
from .dataset_mapper import DatasetMapper
|
15 |
+
|
16 |
+
# ensure the builtin datasets are registered
|
17 |
+
from . import datasets, samplers # isort:skip
|
18 |
+
|
19 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
detectron2/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (859 Bytes). View file
|
|
detectron2/data/__pycache__/build.cpython-39.pyc
ADDED
Binary file (16.9 kB). View file
|
|
detectron2/data/__pycache__/catalog.cpython-39.pyc
ADDED
Binary file (7.6 kB). View file
|
|
detectron2/data/__pycache__/clip_build.cpython-39.pyc
ADDED
Binary file (4.32 kB). View file
|
|
detectron2/data/__pycache__/common.cpython-39.pyc
ADDED
Binary file (6.84 kB). View file
|
|
detectron2/data/__pycache__/dataset_mapper.cpython-39.pyc
ADDED
Binary file (5.89 kB). View file
|
|
detectron2/data/__pycache__/detection_utils.cpython-39.pyc
ADDED
Binary file (18.2 kB). View file
|
|
detectron2/data/build.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import itertools
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import operator
|
6 |
+
import pickle
|
7 |
+
import torch.utils.data
|
8 |
+
from tabulate import tabulate
|
9 |
+
from termcolor import colored
|
10 |
+
|
11 |
+
from detectron2.config import configurable
|
12 |
+
from detectron2.structures import BoxMode
|
13 |
+
from detectron2.utils.comm import get_world_size
|
14 |
+
from detectron2.utils.env import seed_all_rng
|
15 |
+
from detectron2.utils.file_io import PathManager
|
16 |
+
from detectron2.utils.logger import _log_api_usage, log_first_n
|
17 |
+
|
18 |
+
from .catalog import DatasetCatalog, MetadataCatalog
|
19 |
+
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset
|
20 |
+
from .dataset_mapper import DatasetMapper
|
21 |
+
from .detection_utils import check_metadata_consistency
|
22 |
+
from .samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler
|
23 |
+
|
24 |
+
from .clip_build import make_clip_dataset
|
25 |
+
|
26 |
+
"""
|
27 |
+
This file contains the default logic to build a dataloader for training or testing.
|
28 |
+
"""
|
29 |
+
|
30 |
+
__all__ = [
|
31 |
+
"build_batch_data_loader",
|
32 |
+
"build_detection_train_loader",
|
33 |
+
"build_detection_test_loader",
|
34 |
+
"get_detection_dataset_dicts",
|
35 |
+
"load_proposals_into_dataset",
|
36 |
+
"print_instances_class_histogram",
|
37 |
+
]
|
38 |
+
|
39 |
+
|
40 |
+
def filter_images_with_only_crowd_annotations(dataset_dicts):
|
41 |
+
"""
|
42 |
+
Filter out images with none annotations or only crowd annotations
|
43 |
+
(i.e., images without non-crowd annotations).
|
44 |
+
A common training-time preprocessing on COCO dataset.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
list[dict]: the same format, but filtered.
|
51 |
+
"""
|
52 |
+
num_before = len(dataset_dicts)
|
53 |
+
|
54 |
+
def valid(anns):
|
55 |
+
for ann in anns:
|
56 |
+
if ann.get("iscrowd", 0) == 0:
|
57 |
+
return True
|
58 |
+
return False
|
59 |
+
|
60 |
+
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
|
61 |
+
num_after = len(dataset_dicts)
|
62 |
+
logger = logging.getLogger(__name__)
|
63 |
+
logger.info(
|
64 |
+
"Removed {} images with no usable annotations. {} images left.".format(
|
65 |
+
num_before - num_after, num_after
|
66 |
+
)
|
67 |
+
)
|
68 |
+
return dataset_dicts
|
69 |
+
|
70 |
+
|
71 |
+
def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
|
72 |
+
"""
|
73 |
+
Filter out images with too few number of keypoints.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
list[dict]: the same format as dataset_dicts, but filtered.
|
80 |
+
"""
|
81 |
+
num_before = len(dataset_dicts)
|
82 |
+
|
83 |
+
def visible_keypoints_in_image(dic):
|
84 |
+
# Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
|
85 |
+
annotations = dic["annotations"]
|
86 |
+
return sum(
|
87 |
+
(np.array(ann["keypoints"][2::3]) > 0).sum()
|
88 |
+
for ann in annotations
|
89 |
+
if "keypoints" in ann
|
90 |
+
)
|
91 |
+
|
92 |
+
dataset_dicts = [
|
93 |
+
x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
|
94 |
+
]
|
95 |
+
num_after = len(dataset_dicts)
|
96 |
+
logger = logging.getLogger(__name__)
|
97 |
+
logger.info(
|
98 |
+
"Removed {} images with fewer than {} keypoints.".format(
|
99 |
+
num_before - num_after, min_keypoints_per_image
|
100 |
+
)
|
101 |
+
)
|
102 |
+
return dataset_dicts
|
103 |
+
|
104 |
+
|
105 |
+
def load_proposals_into_dataset(dataset_dicts, proposal_file):
|
106 |
+
"""
|
107 |
+
Load precomputed object proposals into the dataset.
|
108 |
+
|
109 |
+
The proposal file should be a pickled dict with the following keys:
|
110 |
+
|
111 |
+
- "ids": list[int] or list[str], the image ids
|
112 |
+
- "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
|
113 |
+
- "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
|
114 |
+
corresponding to the boxes.
|
115 |
+
- "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
119 |
+
proposal_file (str): file path of pre-computed proposals, in pkl format.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
list[dict]: the same format as dataset_dicts, but added proposal field.
|
123 |
+
"""
|
124 |
+
logger = logging.getLogger(__name__)
|
125 |
+
logger.info("Loading proposals from: {}".format(proposal_file))
|
126 |
+
|
127 |
+
with PathManager.open(proposal_file, "rb") as f:
|
128 |
+
proposals = pickle.load(f, encoding="latin1")
|
129 |
+
|
130 |
+
# Rename the key names in D1 proposal files
|
131 |
+
rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
|
132 |
+
for key in rename_keys:
|
133 |
+
if key in proposals:
|
134 |
+
proposals[rename_keys[key]] = proposals.pop(key)
|
135 |
+
|
136 |
+
# Fetch the indexes of all proposals that are in the dataset
|
137 |
+
# Convert image_id to str since they could be int.
|
138 |
+
img_ids = set({str(record["image_id"]) for record in dataset_dicts})
|
139 |
+
id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
|
140 |
+
|
141 |
+
# Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
|
142 |
+
bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
|
143 |
+
|
144 |
+
for record in dataset_dicts:
|
145 |
+
# Get the index of the proposal
|
146 |
+
i = id_to_index[str(record["image_id"])]
|
147 |
+
|
148 |
+
boxes = proposals["boxes"][i]
|
149 |
+
objectness_logits = proposals["objectness_logits"][i]
|
150 |
+
# Sort the proposals in descending order of the scores
|
151 |
+
inds = objectness_logits.argsort()[::-1]
|
152 |
+
record["proposal_boxes"] = boxes[inds]
|
153 |
+
record["proposal_objectness_logits"] = objectness_logits[inds]
|
154 |
+
record["proposal_bbox_mode"] = bbox_mode
|
155 |
+
|
156 |
+
return dataset_dicts
|
157 |
+
|
158 |
+
|
159 |
+
def print_instances_class_histogram(dataset_dicts, class_names):
|
160 |
+
"""
|
161 |
+
Args:
|
162 |
+
dataset_dicts (list[dict]): list of dataset dicts.
|
163 |
+
class_names (list[str]): list of class names (zero-indexed).
|
164 |
+
"""
|
165 |
+
num_classes = len(class_names)
|
166 |
+
hist_bins = np.arange(num_classes + 1)
|
167 |
+
histogram = np.zeros((num_classes,), dtype=np.int)
|
168 |
+
for entry in dataset_dicts:
|
169 |
+
annos = entry["annotations"]
|
170 |
+
classes = np.asarray(
|
171 |
+
[x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=np.int
|
172 |
+
)
|
173 |
+
if len(classes):
|
174 |
+
assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
|
175 |
+
assert (
|
176 |
+
classes.max() < num_classes
|
177 |
+
), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
|
178 |
+
histogram += np.histogram(classes, bins=hist_bins)[0]
|
179 |
+
|
180 |
+
N_COLS = min(6, len(class_names) * 2)
|
181 |
+
|
182 |
+
def short_name(x):
|
183 |
+
# make long class names shorter. useful for lvis
|
184 |
+
if len(x) > 13:
|
185 |
+
return x[:11] + ".."
|
186 |
+
return x
|
187 |
+
|
188 |
+
data = list(
|
189 |
+
itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
|
190 |
+
)
|
191 |
+
total_num_instances = sum(data[1::2])
|
192 |
+
data.extend([None] * (N_COLS - (len(data) % N_COLS)))
|
193 |
+
if num_classes > 1:
|
194 |
+
data.extend(["total", total_num_instances])
|
195 |
+
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
|
196 |
+
table = tabulate(
|
197 |
+
data,
|
198 |
+
headers=["category", "#instances"] * (N_COLS // 2),
|
199 |
+
tablefmt="pipe",
|
200 |
+
numalign="left",
|
201 |
+
stralign="center",
|
202 |
+
)
|
203 |
+
log_first_n(
|
204 |
+
logging.INFO,
|
205 |
+
"Distribution of instances among all {} categories:\n".format(num_classes)
|
206 |
+
+ colored(table, "cyan"),
|
207 |
+
key="message",
|
208 |
+
)
|
209 |
+
|
210 |
+
|
211 |
+
def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, proposal_files=None):
|
212 |
+
"""
|
213 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
names (str or list[str]): a dataset name or a list of dataset names
|
217 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
218 |
+
min_keypoints (int): filter out images with fewer keypoints than
|
219 |
+
`min_keypoints`. Set to 0 to do nothing.
|
220 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
221 |
+
that match each dataset in `names`.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
225 |
+
"""
|
226 |
+
if isinstance(names, str):
|
227 |
+
names = [names]
|
228 |
+
assert len(names), names
|
229 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
|
230 |
+
for dataset_name, dicts in zip(names, dataset_dicts):
|
231 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
232 |
+
|
233 |
+
if proposal_files is not None:
|
234 |
+
assert len(names) == len(proposal_files)
|
235 |
+
# load precomputed proposals from proposal files
|
236 |
+
dataset_dicts = [
|
237 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
238 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
239 |
+
]
|
240 |
+
|
241 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
242 |
+
|
243 |
+
has_instances = "annotations" in dataset_dicts[0]
|
244 |
+
if filter_empty and has_instances:
|
245 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
246 |
+
if min_keypoints > 0 and has_instances:
|
247 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
248 |
+
|
249 |
+
if has_instances:
|
250 |
+
try:
|
251 |
+
class_names = MetadataCatalog.get(names[0]).thing_classes
|
252 |
+
check_metadata_consistency("thing_classes", names)
|
253 |
+
print_instances_class_histogram(dataset_dicts, class_names)
|
254 |
+
except AttributeError: # class names are not available for this dataset
|
255 |
+
pass
|
256 |
+
|
257 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
|
258 |
+
return dataset_dicts
|
259 |
+
|
260 |
+
|
261 |
+
def build_batch_data_loader(
|
262 |
+
dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0
|
263 |
+
):
|
264 |
+
"""
|
265 |
+
Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
|
266 |
+
1. support aspect ratio grouping options
|
267 |
+
2. use no "batch collation", because this is common for detection training
|
268 |
+
|
269 |
+
Args:
|
270 |
+
dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed.
|
271 |
+
sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices
|
272 |
+
total_batch_size, aspect_ratio_grouping, num_workers): see
|
273 |
+
:func:`build_detection_train_loader`.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
iterable[list]. Length of each list is the batch size of the current
|
277 |
+
GPU. Each element in the list comes from the dataset.
|
278 |
+
"""
|
279 |
+
world_size = get_world_size()
|
280 |
+
assert (
|
281 |
+
total_batch_size > 0 and total_batch_size % world_size == 0
|
282 |
+
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
|
283 |
+
total_batch_size, world_size
|
284 |
+
)
|
285 |
+
|
286 |
+
batch_size = total_batch_size // world_size
|
287 |
+
if aspect_ratio_grouping:
|
288 |
+
data_loader = torch.utils.data.DataLoader(
|
289 |
+
dataset,
|
290 |
+
sampler=sampler,
|
291 |
+
num_workers=num_workers,
|
292 |
+
batch_sampler=None,
|
293 |
+
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
294 |
+
worker_init_fn=worker_init_reset_seed,
|
295 |
+
) # yield individual mapped dict
|
296 |
+
return AspectRatioGroupedDataset(data_loader, batch_size)
|
297 |
+
else:
|
298 |
+
batch_sampler = torch.utils.data.sampler.BatchSampler(
|
299 |
+
sampler, batch_size, drop_last=True
|
300 |
+
) # drop_last so the batch always have the same size
|
301 |
+
return torch.utils.data.DataLoader(
|
302 |
+
dataset,
|
303 |
+
num_workers=num_workers,
|
304 |
+
batch_sampler=batch_sampler,
|
305 |
+
collate_fn=trivial_batch_collator,
|
306 |
+
worker_init_fn=worker_init_reset_seed,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
311 |
+
if 'yfcc100m' in cfg.DATASETS.TRAIN: # dataset, transform/aug., sampler for image-text pairs training
|
312 |
+
logger = logging.getLogger(__name__)
|
313 |
+
logger.info("Creating dataset {}".format(cfg.DATASETS.TRAIN))
|
314 |
+
datasets, precomputed_tokens, dataset_classes = make_clip_dataset(
|
315 |
+
cfg, is_train=True,
|
316 |
+
transforms=None, # for training, we use our own defined transforms
|
317 |
+
)
|
318 |
+
dataset = datasets[0] # during training, a single (possibly concatenated) dataset was returned
|
319 |
+
if sampler is None:
|
320 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
321 |
+
logger = logging.getLogger(__name__)
|
322 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
323 |
+
if sampler_name == "TrainingSampler":
|
324 |
+
sampler = TrainingSampler(len(dataset))
|
325 |
+
elif sampler_name == "RepeatFactorTrainingSampler":
|
326 |
+
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
327 |
+
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
|
328 |
+
)
|
329 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
330 |
+
else:
|
331 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
332 |
+
return {
|
333 |
+
"dataset": dataset,
|
334 |
+
"sampler": sampler,
|
335 |
+
"mapper": None,
|
336 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
337 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
338 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
339 |
+
}
|
340 |
+
# the following is the default code in Detectron2
|
341 |
+
if dataset is None:
|
342 |
+
dataset = get_detection_dataset_dicts(
|
343 |
+
cfg.DATASETS.TRAIN,
|
344 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
345 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
346 |
+
if cfg.MODEL.KEYPOINT_ON
|
347 |
+
else 0,
|
348 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
349 |
+
)
|
350 |
+
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
|
351 |
+
|
352 |
+
if mapper is None:
|
353 |
+
mapper = DatasetMapper(cfg, True)
|
354 |
+
|
355 |
+
if sampler is None:
|
356 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
357 |
+
logger = logging.getLogger(__name__)
|
358 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
359 |
+
if sampler_name == "TrainingSampler":
|
360 |
+
sampler = TrainingSampler(len(dataset))
|
361 |
+
elif sampler_name == "RepeatFactorTrainingSampler":
|
362 |
+
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
363 |
+
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
|
364 |
+
)
|
365 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
366 |
+
else:
|
367 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
368 |
+
|
369 |
+
return {
|
370 |
+
"dataset": dataset,
|
371 |
+
"sampler": sampler,
|
372 |
+
"mapper": mapper,
|
373 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
374 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
375 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
376 |
+
}
|
377 |
+
|
378 |
+
|
379 |
+
# TODO can allow dataset as an iterable or IterableDataset to make this function more general
|
380 |
+
@configurable(from_config=_train_loader_from_config)
|
381 |
+
def build_detection_train_loader(
|
382 |
+
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
|
383 |
+
):
|
384 |
+
"""
|
385 |
+
Build a dataloader for object detection with some default features.
|
386 |
+
This interface is experimental.
|
387 |
+
|
388 |
+
Args:
|
389 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
390 |
+
or a map-style pytorch dataset. They can be obtained by using
|
391 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
392 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
393 |
+
returns the format to be consumed by the model.
|
394 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
395 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
|
396 |
+
indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
|
397 |
+
which coordinates an infinite random shuffle sequence across all workers.
|
398 |
+
total_batch_size (int): total batch size across all workers. Batching
|
399 |
+
simply puts data into a list.
|
400 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
401 |
+
aspect ratio for efficiency. When enabled, it requires each
|
402 |
+
element in dataset be a dict with keys "width" and "height".
|
403 |
+
num_workers (int): number of parallel data loading workers
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
torch.utils.data.DataLoader:
|
407 |
+
a dataloader. Each output from it is a ``list[mapped_element]`` of length
|
408 |
+
``total_batch_size / num_workers``, where ``mapped_element`` is produced
|
409 |
+
by the ``mapper``.
|
410 |
+
"""
|
411 |
+
if isinstance(dataset, list):
|
412 |
+
dataset = DatasetFromList(dataset, copy=False)
|
413 |
+
if mapper is not None:
|
414 |
+
dataset = MapDataset(dataset, mapper)
|
415 |
+
if sampler is None:
|
416 |
+
sampler = TrainingSampler(len(dataset))
|
417 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
418 |
+
return build_batch_data_loader(
|
419 |
+
dataset,
|
420 |
+
sampler,
|
421 |
+
total_batch_size,
|
422 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
423 |
+
num_workers=num_workers,
|
424 |
+
)
|
425 |
+
|
426 |
+
|
427 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
428 |
+
"""
|
429 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
430 |
+
standard practice is to evaluate each test set individually (not combining them).
|
431 |
+
"""
|
432 |
+
if 'yfcc100m' in cfg.DATASETS.TEST: # dataset, no {transform/aug., sampler for image-text pairs training}
|
433 |
+
logger = logging.getLogger(__name__)
|
434 |
+
logger.info("Creating dataset {}".format(cfg.DATASETS.TEST))
|
435 |
+
datasets, precomputed_tokens, dataset_classes = make_clip_dataset(
|
436 |
+
cfg, is_train=False,
|
437 |
+
transforms=None, # for training, we use our own defined transforms
|
438 |
+
)
|
439 |
+
dataset = datasets[0] # during training, a single (possibly concatenated) dataset was returned
|
440 |
+
return {
|
441 |
+
"dataset": dataset,
|
442 |
+
"mapper": None,
|
443 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
444 |
+
}
|
445 |
+
|
446 |
+
# the following is the default code in Detectron2
|
447 |
+
dataset = get_detection_dataset_dicts(
|
448 |
+
[dataset_name],
|
449 |
+
filter_empty=False,
|
450 |
+
proposal_files=[
|
451 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
|
452 |
+
]
|
453 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
454 |
+
else None,
|
455 |
+
)
|
456 |
+
if mapper is None:
|
457 |
+
mapper = DatasetMapper(cfg, False)
|
458 |
+
if cfg.MODEL.META_ARCHITECTURE == 'CLIPRCNN': # speed up when using CLIP in inference
|
459 |
+
return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS,\
|
460 |
+
"clip_batch_size": cfg.MODEL.CLIP.IMS_PER_BATCH_TEST}
|
461 |
+
return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS}
|
462 |
+
|
463 |
+
|
464 |
+
@configurable(from_config=_test_loader_from_config)
|
465 |
+
def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0, clip_batch_size=None):
|
466 |
+
"""
|
467 |
+
Similar to `build_detection_train_loader`, but uses a batch size of 1,
|
468 |
+
and :class:`InferenceSampler`. This sampler coordinates all workers to
|
469 |
+
produce the exact set of all samples.
|
470 |
+
This interface is experimental.
|
471 |
+
|
472 |
+
Args:
|
473 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
474 |
+
or a map-style pytorch dataset. They can be obtained by using
|
475 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
476 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
477 |
+
and returns the format to be consumed by the model.
|
478 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
479 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
|
480 |
+
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
|
481 |
+
which splits the dataset across all workers.
|
482 |
+
num_workers (int): number of parallel data loading workers
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
486 |
+
dataset, with test-time transformation and batching.
|
487 |
+
|
488 |
+
Examples:
|
489 |
+
::
|
490 |
+
data_loader = build_detection_test_loader(
|
491 |
+
DatasetRegistry.get("my_test"),
|
492 |
+
mapper=DatasetMapper(...))
|
493 |
+
|
494 |
+
# or, instantiate with a CfgNode:
|
495 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
496 |
+
"""
|
497 |
+
if isinstance(dataset, list):
|
498 |
+
dataset = DatasetFromList(dataset, copy=False)
|
499 |
+
if mapper is not None:
|
500 |
+
dataset = MapDataset(dataset, mapper)
|
501 |
+
if sampler is None:
|
502 |
+
sampler = InferenceSampler(len(dataset))
|
503 |
+
|
504 |
+
if clip_batch_size: # multiple images per gpu
|
505 |
+
world_size = get_world_size()
|
506 |
+
batch_size = clip_batch_size // world_size
|
507 |
+
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last=False)
|
508 |
+
data_loader = torch.utils.data.DataLoader(
|
509 |
+
dataset,
|
510 |
+
num_workers=num_workers,
|
511 |
+
batch_sampler=batch_sampler,
|
512 |
+
collate_fn=trivial_batch_collator,
|
513 |
+
)
|
514 |
+
return data_loader
|
515 |
+
# Always use 1 image per worker during inference since this is the
|
516 |
+
# standard when reporting inference time in papers.
|
517 |
+
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False)
|
518 |
+
data_loader = torch.utils.data.DataLoader(
|
519 |
+
dataset,
|
520 |
+
num_workers=num_workers,
|
521 |
+
batch_sampler=batch_sampler,
|
522 |
+
collate_fn=trivial_batch_collator,
|
523 |
+
)
|
524 |
+
return data_loader
|
525 |
+
|
526 |
+
|
527 |
+
def trivial_batch_collator(batch):
|
528 |
+
"""
|
529 |
+
A batch collator that does nothing.
|
530 |
+
"""
|
531 |
+
return batch
|
532 |
+
|
533 |
+
|
534 |
+
def worker_init_reset_seed(worker_id):
|
535 |
+
initial_seed = torch.initial_seed() % 2 ** 31
|
536 |
+
seed_all_rng(initial_seed + worker_id)
|
detectron2/data/catalog.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
import types
|
5 |
+
from collections import UserDict
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from detectron2.utils.logger import log_first_n
|
9 |
+
|
10 |
+
__all__ = ["DatasetCatalog", "MetadataCatalog", "Metadata"]
|
11 |
+
|
12 |
+
|
13 |
+
class _DatasetCatalog(UserDict):
|
14 |
+
"""
|
15 |
+
A global dictionary that stores information about the datasets and how to obtain them.
|
16 |
+
|
17 |
+
It contains a mapping from strings
|
18 |
+
(which are names that identify a dataset, e.g. "coco_2014_train")
|
19 |
+
to a function which parses the dataset and returns the samples in the
|
20 |
+
format of `list[dict]`.
|
21 |
+
|
22 |
+
The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details)
|
23 |
+
if used with the data loader functionalities in `data/build.py,data/detection_transform.py`.
|
24 |
+
|
25 |
+
The purpose of having this catalog is to make it easy to choose
|
26 |
+
different datasets, by just using the strings in the config.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def register(self, name, func):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
33 |
+
func (callable): a callable which takes no arguments and returns a list of dicts.
|
34 |
+
It must return the same results if called multiple times.
|
35 |
+
"""
|
36 |
+
assert callable(func), "You must register a function with `DatasetCatalog.register`!"
|
37 |
+
assert name not in self, "Dataset '{}' is already registered!".format(name)
|
38 |
+
self[name] = func
|
39 |
+
|
40 |
+
def get(self, name):
|
41 |
+
"""
|
42 |
+
Call the registered function and return its results.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
list[dict]: dataset annotations.
|
49 |
+
"""
|
50 |
+
try:
|
51 |
+
f = self[name]
|
52 |
+
except KeyError as e:
|
53 |
+
raise KeyError(
|
54 |
+
"Dataset '{}' is not registered! Available datasets are: {}".format(
|
55 |
+
name, ", ".join(list(self.keys()))
|
56 |
+
)
|
57 |
+
) from e
|
58 |
+
return f()
|
59 |
+
|
60 |
+
def list(self) -> List[str]:
|
61 |
+
"""
|
62 |
+
List all registered datasets.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
list[str]
|
66 |
+
"""
|
67 |
+
return list(self.keys())
|
68 |
+
|
69 |
+
def remove(self, name):
|
70 |
+
"""
|
71 |
+
Alias of ``pop``.
|
72 |
+
"""
|
73 |
+
self.pop(name)
|
74 |
+
|
75 |
+
def __str__(self):
|
76 |
+
return "DatasetCatalog(registered datasets: {})".format(", ".join(self.keys()))
|
77 |
+
|
78 |
+
__repr__ = __str__
|
79 |
+
|
80 |
+
|
81 |
+
DatasetCatalog = _DatasetCatalog()
|
82 |
+
DatasetCatalog.__doc__ = (
|
83 |
+
_DatasetCatalog.__doc__
|
84 |
+
+ """
|
85 |
+
.. automethod:: detectron2.data.catalog.DatasetCatalog.register
|
86 |
+
.. automethod:: detectron2.data.catalog.DatasetCatalog.get
|
87 |
+
"""
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
class Metadata(types.SimpleNamespace):
|
92 |
+
"""
|
93 |
+
A class that supports simple attribute setter/getter.
|
94 |
+
It is intended for storing metadata of a dataset and make it accessible globally.
|
95 |
+
|
96 |
+
Examples:
|
97 |
+
::
|
98 |
+
# somewhere when you load the data:
|
99 |
+
MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"]
|
100 |
+
|
101 |
+
# somewhere when you print statistics or visualize:
|
102 |
+
classes = MetadataCatalog.get("mydataset").thing_classes
|
103 |
+
"""
|
104 |
+
|
105 |
+
# the name of the dataset
|
106 |
+
# set default to N/A so that `self.name` in the errors will not trigger getattr again
|
107 |
+
name: str = "N/A"
|
108 |
+
|
109 |
+
_RENAMED = {
|
110 |
+
"class_names": "thing_classes",
|
111 |
+
"dataset_id_to_contiguous_id": "thing_dataset_id_to_contiguous_id",
|
112 |
+
"stuff_class_names": "stuff_classes",
|
113 |
+
}
|
114 |
+
|
115 |
+
def __getattr__(self, key):
|
116 |
+
if key in self._RENAMED:
|
117 |
+
log_first_n(
|
118 |
+
logging.WARNING,
|
119 |
+
"Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
|
120 |
+
n=10,
|
121 |
+
)
|
122 |
+
return getattr(self, self._RENAMED[key])
|
123 |
+
|
124 |
+
# "name" exists in every metadata
|
125 |
+
if len(self.__dict__) > 1:
|
126 |
+
raise AttributeError(
|
127 |
+
"Attribute '{}' does not exist in the metadata of dataset '{}'. Available "
|
128 |
+
"keys are {}.".format(key, self.name, str(self.__dict__.keys()))
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
raise AttributeError(
|
132 |
+
f"Attribute '{key}' does not exist in the metadata of dataset '{self.name}': "
|
133 |
+
"metadata is empty."
|
134 |
+
)
|
135 |
+
|
136 |
+
def __setattr__(self, key, val):
|
137 |
+
if key in self._RENAMED:
|
138 |
+
log_first_n(
|
139 |
+
logging.WARNING,
|
140 |
+
"Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
|
141 |
+
n=10,
|
142 |
+
)
|
143 |
+
setattr(self, self._RENAMED[key], val)
|
144 |
+
|
145 |
+
# Ensure that metadata of the same name stays consistent
|
146 |
+
try:
|
147 |
+
oldval = getattr(self, key)
|
148 |
+
assert oldval == val, (
|
149 |
+
"Attribute '{}' in the metadata of '{}' cannot be set "
|
150 |
+
"to a different value!\n{} != {}".format(key, self.name, oldval, val)
|
151 |
+
)
|
152 |
+
except AttributeError:
|
153 |
+
super().__setattr__(key, val)
|
154 |
+
|
155 |
+
def as_dict(self):
|
156 |
+
"""
|
157 |
+
Returns all the metadata as a dict.
|
158 |
+
Note that modifications to the returned dict will not reflect on the Metadata object.
|
159 |
+
"""
|
160 |
+
return copy.copy(self.__dict__)
|
161 |
+
|
162 |
+
def set(self, **kwargs):
|
163 |
+
"""
|
164 |
+
Set multiple metadata with kwargs.
|
165 |
+
"""
|
166 |
+
for k, v in kwargs.items():
|
167 |
+
setattr(self, k, v)
|
168 |
+
return self
|
169 |
+
|
170 |
+
def get(self, key, default=None):
|
171 |
+
"""
|
172 |
+
Access an attribute and return its value if exists.
|
173 |
+
Otherwise return default.
|
174 |
+
"""
|
175 |
+
try:
|
176 |
+
return getattr(self, key)
|
177 |
+
except AttributeError:
|
178 |
+
return default
|
179 |
+
|
180 |
+
|
181 |
+
class _MetadataCatalog(UserDict):
|
182 |
+
"""
|
183 |
+
MetadataCatalog is a global dictionary that provides access to
|
184 |
+
:class:`Metadata` of a given dataset.
|
185 |
+
|
186 |
+
The metadata associated with a certain name is a singleton: once created, the
|
187 |
+
metadata will stay alive and will be returned by future calls to ``get(name)``.
|
188 |
+
|
189 |
+
It's like global variables, so don't abuse it.
|
190 |
+
It's meant for storing knowledge that's constant and shared across the execution
|
191 |
+
of the program, e.g.: the class names in COCO.
|
192 |
+
"""
|
193 |
+
|
194 |
+
def get(self, name):
|
195 |
+
"""
|
196 |
+
Args:
|
197 |
+
name (str): name of a dataset (e.g. coco_2014_train).
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
Metadata: The :class:`Metadata` instance associated with this name,
|
201 |
+
or create an empty one if none is available.
|
202 |
+
"""
|
203 |
+
assert len(name)
|
204 |
+
r = super().get(name, None)
|
205 |
+
if r is None:
|
206 |
+
r = self[name] = Metadata(name=name)
|
207 |
+
return r
|
208 |
+
|
209 |
+
def list(self):
|
210 |
+
"""
|
211 |
+
List all registered metadata.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
list[str]: keys (names of datasets) of all registered metadata
|
215 |
+
"""
|
216 |
+
return list(self.keys())
|
217 |
+
|
218 |
+
def remove(self, name):
|
219 |
+
"""
|
220 |
+
Alias of ``pop``.
|
221 |
+
"""
|
222 |
+
self.pop(name)
|
223 |
+
|
224 |
+
def __str__(self):
|
225 |
+
return "MetadataCatalog(registered metadata: {})".format(", ".join(self.keys()))
|
226 |
+
|
227 |
+
__repr__ = __str__
|
228 |
+
|
229 |
+
|
230 |
+
MetadataCatalog = _MetadataCatalog()
|
231 |
+
MetadataCatalog.__doc__ = (
|
232 |
+
_MetadataCatalog.__doc__
|
233 |
+
+ """
|
234 |
+
.. automethod:: detectron2.data.catalog.MetadataCatalog.get
|
235 |
+
"""
|
236 |
+
)
|
detectron2/data/clip_build.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
import bisect
|
3 |
+
import copy
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import torch.utils.data
|
8 |
+
import torch.distributed
|
9 |
+
from torch.utils.data.dataset import ConcatDataset
|
10 |
+
|
11 |
+
from .catalog import DatasetCatalog
|
12 |
+
from .clip_datasets.clip_img_txt_pair_tsv import CLIPImgTxtPairTSVDataset
|
13 |
+
|
14 |
+
from .transforms.build import build_clip_transforms
|
15 |
+
|
16 |
+
def config_tsv_dataset_args(cfg, dataset_file, factory_name=None, is_train=True):
|
17 |
+
############### code removecd as tsv_dataset_name = factory_name = "CLIPImgTxtPairTSVDataset" ##############
|
18 |
+
if factory_name is not None:
|
19 |
+
tsv_dataset_name = factory_name
|
20 |
+
|
21 |
+
if tsv_dataset_name in ["CLIPImgTxtPairTSVDataset"]:
|
22 |
+
# no need for extra arguments
|
23 |
+
args = {}
|
24 |
+
args['args'] = cfg
|
25 |
+
args['seq_len'] = cfg.DATASETS.MAX_SEQ_LENGTH # cfg.max_seq_length
|
26 |
+
|
27 |
+
return args, tsv_dataset_name
|
28 |
+
|
29 |
+
|
30 |
+
def build_dataset(cfg, transforms, dataset_catalog, is_train=True, is_aux=False):
|
31 |
+
"""
|
32 |
+
Arguments:
|
33 |
+
cfg: config file.
|
34 |
+
transforms (callable): transforms to apply to each (image, target) sample
|
35 |
+
dataset_catalog (DatasetCatalog): contains the information on how to construct a dataset.
|
36 |
+
is_train (bool): whether to setup the dataset for training or testing
|
37 |
+
"""
|
38 |
+
|
39 |
+
dataset_list = (cfg.DATASETS.TRAIN if not is_aux else cfg.DATASETS.AUX) if is_train else cfg.DATASETS.TEST
|
40 |
+
factory_list = (cfg.DATASETS.FACTORY_TRAIN if not is_aux else cfg.DATASETS.FACTORY_AUX) if is_train else cfg.DATASETS.FACTORY_TEST
|
41 |
+
path_list = (cfg.DATASETS.PATH_TRAIN if not is_aux else cfg.DATASETS.PATH_AUX) if is_train else cfg.DATASETS.PATH_TEST
|
42 |
+
|
43 |
+
if not isinstance(dataset_list, (list, tuple)):
|
44 |
+
raise RuntimeError(
|
45 |
+
"dataset_list should be a list of strings, got {}".format(dataset_list))
|
46 |
+
if not isinstance(factory_list, (list, tuple)):
|
47 |
+
raise RuntimeError(
|
48 |
+
"factory_list should be a list of strings, got {}".format(factory_list))
|
49 |
+
datasets = []
|
50 |
+
target_offset = 0
|
51 |
+
for i, dataset_name in enumerate(dataset_list):
|
52 |
+
factory_name = factory_list[i] if i < len(factory_list) else None
|
53 |
+
|
54 |
+
if factory_name == "CLIPImgTxtPairTSVDataset":
|
55 |
+
dataset_names_merged = dataset_name.split('+')
|
56 |
+
path_lists_merged = path_list[i].split('+')
|
57 |
+
|
58 |
+
assert len(dataset_names_merged) == len(path_lists_merged), "number of datasets must match that of dataset paths"
|
59 |
+
|
60 |
+
image_tsv_list = []
|
61 |
+
text_tsv_list = []
|
62 |
+
dataset_name_list = []
|
63 |
+
map_files = []
|
64 |
+
max_num_tsv = 20 # maximum tsv files to load within a given folder
|
65 |
+
|
66 |
+
for dname, dpath in zip(dataset_names_merged, path_lists_merged):
|
67 |
+
args, tsv_dataset_name = config_tsv_dataset_args(
|
68 |
+
cfg, dataset_name, factory_name, is_train
|
69 |
+
)
|
70 |
+
factory = CLIPImgTxtPairTSVDataset if tsv_dataset_name in ["CLIPImgTxtPairTSVDataset"] else None
|
71 |
+
prev_len = len(image_tsv_list)
|
72 |
+
|
73 |
+
isFile = os.path.isfile(dpath)
|
74 |
+
if isFile:
|
75 |
+
dpath_listed_files = [os.path.basename(dpath)]
|
76 |
+
dpath = os.path.dirname(dpath)
|
77 |
+
else:
|
78 |
+
dpath_listed_files = sorted(os.listdir(dpath))
|
79 |
+
|
80 |
+
for filename in dpath_listed_files:
|
81 |
+
if ("images" in filename or "image" in filename or "img" in filename) and filename.endswith(".tsv"):
|
82 |
+
image_tsv_list.append(os.path.join(dpath, filename))
|
83 |
+
if "images" in filename: # "images" - "text"
|
84 |
+
text_tsv_list.append(os.path.join(dpath, filename.replace("images", "text")))
|
85 |
+
elif "image" in filename: # "image"-"text"
|
86 |
+
text_tsv_list.append(os.path.join(dpath, filename.replace("image", "text")))
|
87 |
+
elif "img" in filename: # "img"-"caption"
|
88 |
+
text_tsv_list.append(os.path.join(dpath, filename.replace("img", "caption")))
|
89 |
+
if len(image_tsv_list) - prev_len == max_num_tsv:
|
90 |
+
break
|
91 |
+
dataset_name_list += [dname] * (len(image_tsv_list) - prev_len)
|
92 |
+
|
93 |
+
if dname == "imagenet22k":
|
94 |
+
map_files += [os.path.join(dpath, 'darknet_data_imagenet.labels.list')] * (len(image_tsv_list) - prev_len)
|
95 |
+
else:
|
96 |
+
map_files += [None] * (len(image_tsv_list) - prev_len)
|
97 |
+
|
98 |
+
assert len(image_tsv_list) == len(text_tsv_list), \
|
99 |
+
"the number image tsv files must be equal to that of text tsv files, otherwise check your data!"
|
100 |
+
|
101 |
+
args["image_tsv_file"] = image_tsv_list
|
102 |
+
args["text_tsv_file"] = text_tsv_list
|
103 |
+
args["dataset_name"] = dataset_name_list
|
104 |
+
args["map_file"] = map_files
|
105 |
+
args["filtered_datasets"] = cfg.DATASETS.FILTERED_CLASSIFICATION_DATASETS
|
106 |
+
assert len(image_tsv_list) == len(text_tsv_list) == len(dataset_name_list) == len(map_files)
|
107 |
+
|
108 |
+
print("number of image tsv files: ", len(image_tsv_list))
|
109 |
+
print("number of text tsv fies: ", len(text_tsv_list))
|
110 |
+
|
111 |
+
args["is_train"] = is_train
|
112 |
+
args["transforms"] = transforms
|
113 |
+
args["target_offset"] = target_offset
|
114 |
+
if "bpe" in cfg.INPUT.TEXT_TOKENIZER:
|
115 |
+
from detectron2.data.datasets.clip_prompt_utils import SimpleTokenizer as _Tokenizer
|
116 |
+
tokenizer = _Tokenizer()
|
117 |
+
args["tokenizer_type"] = "bpe"
|
118 |
+
args["tokenizer"] = tokenizer
|
119 |
+
# make dataset from factory
|
120 |
+
dataset = factory(**args)
|
121 |
+
datasets.append(dataset)
|
122 |
+
|
123 |
+
precomputed_tokens = {}
|
124 |
+
dataset_classes = {}
|
125 |
+
for dataset in datasets:
|
126 |
+
if hasattr(dataset, "input_ids_all_classes"):
|
127 |
+
precomputed_tokens["imagenet"] = \
|
128 |
+
[dataset.input_ids_all_classes, dataset.input_mask_all_classes, dataset.segment_ids_all_classes]
|
129 |
+
if hasattr(dataset, "classnames"):
|
130 |
+
if isinstance(dataset.classnames, dict):
|
131 |
+
dataset_classes.update(dataset.classnames)
|
132 |
+
else:
|
133 |
+
dataset_classes[dataset.dataset_name] = dataset.classnames
|
134 |
+
|
135 |
+
# for testing, return a list of datasets
|
136 |
+
if not is_train:
|
137 |
+
return datasets, precomputed_tokens, dataset_classes
|
138 |
+
|
139 |
+
if len(datasets) == 0:
|
140 |
+
return None, None, None
|
141 |
+
|
142 |
+
# for training, concatenate all datasets into a single one
|
143 |
+
dataset = datasets[0]
|
144 |
+
if len(datasets) > 1:
|
145 |
+
dataset = ConcatDataset(datasets)
|
146 |
+
return [dataset], precomputed_tokens, dataset_classes
|
147 |
+
|
148 |
+
|
149 |
+
def make_clip_dataset(cfg, is_train=True, is_aux=False, transforms=None):
|
150 |
+
if transforms is None:
|
151 |
+
transforms = build_clip_transforms(cfg, is_train)
|
152 |
+
print("data transforms: ")
|
153 |
+
print(transforms)
|
154 |
+
datasets, precomputed_tokens, dataset_classes = build_dataset(cfg, transforms, DatasetCatalog, is_train, is_aux)
|
155 |
+
|
156 |
+
if not datasets:
|
157 |
+
return None, None, None
|
158 |
+
return datasets, precomputed_tokens, dataset_classes
|