File size: 3,525 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import copy
from collections import deque, defaultdict
from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
@dataclass
class PretrainedCfg:
"""
"""
# weight source locations
url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
file: Optional[str] = None # local / shared filesystem path
state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit
tag: Optional[str] = None # pretrained tag of source
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config
input_size: Tuple[int, int, int] = (3, 224, 224)
test_input_size: Optional[Tuple[int, int, int]] = None
min_input_size: Optional[Tuple[int, int, int]] = None
fixed_input_size: bool = False
interpolation: str = 'bicubic'
crop_pct: float = 0.875
test_crop_pct: Optional[float] = None
crop_mode: str = 'center'
mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
std: Tuple[float, ...] = (0.229, 0.224, 0.225)
# head / classifier config and meta-data
num_classes: int = 1000
label_offset: Optional[int] = None
label_names: Optional[Tuple[str]] = None
label_descriptions: Optional[Dict[str, str]] = None
# model attributes that vary with above or required for pretrained adaptation
pool_size: Optional[Tuple[int, ...]] = None
test_pool_size: Optional[Tuple[int, ...]] = None
first_conv: Optional[str] = None
classifier: Optional[str] = None
license: Optional[str] = None
description: Optional[str] = None
origin_url: Optional[str] = None
paper_name: Optional[str] = None
paper_ids: Optional[Union[str, Tuple[str]]] = None
notes: Optional[Tuple[str]] = None
@property
def has_weights(self):
return self.url or self.file or self.hf_hub_id
def to_dict(self, remove_source=False, remove_null=True):
return filter_pretrained_cfg(
asdict(self),
remove_source=remove_source,
remove_null=remove_null
)
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
filtered_cfg = {}
keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
for k, v in cfg.items():
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
continue
if remove_null and v is None and k not in keep_null:
continue
filtered_cfg[k] = v
return filtered_cfg
@dataclass
class DefaultCfg:
tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
is_pretrained: bool = False # at least one of the configs has a pretrained source set
@property
def default(self):
return self.cfgs[self.tags[0]]
@property
def default_with_tag(self):
tag = self.tags[0]
return tag, self.cfgs[tag]
|