Spaces:
Runtime error
Runtime error
File size: 1,397 Bytes
85efb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# pyre-unsafe
import copy
from .dino_encoder import DinoVisionTower
from .siglip_encoder import SiglipVisionTower
def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
vision_tower_aux_name_list = getattr(
vision_tower_cfg,
"mm_vision_tower_aux_list",
getattr(vision_tower_cfg, "vision_tower_aux_list", None),
)
vision_tower_aux_token_len_list = getattr(
vision_tower_cfg,
"mm_vision_tower_aux_token_len_list",
getattr(vision_tower_cfg, "vision_tower_aux_token_len_list", None),
)
vision_tower_aux_list = []
for vision_tower_aux_name, vision_tower_aux_token_len in zip(
vision_tower_aux_name_list, vision_tower_aux_token_len_list
):
config = copy.deepcopy(vision_tower_cfg)
vision_tower_aux_name += "-interp{}".format(vision_tower_aux_token_len)
if "siglip" in vision_tower_aux_name.lower():
vision_tower_aux_list.append(
SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs)
)
# SSL-based Vision Towers
elif "dinov2" in vision_tower_aux_name.lower():
vision_tower_aux_list.append(
DinoVisionTower(vision_tower_aux_name, args=config, **kwargs)
)
else:
raise ValueError(f"Unknown vision tower: {vision_tower_aux_name}")
return vision_tower_aux_list
|