Spaces:
Sleeping
Sleeping
# ------------------------------------------------------------------------ | |
# Copyright (c) 2023-present, BAAI. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------ | |
"""Easy model builder.""" | |
from functools import partial | |
import pickle | |
import torch | |
from tokenize_anything.modeling import ConceptProjector | |
from tokenize_anything.modeling import ImageDecoder | |
from tokenize_anything.modeling import ImageEncoderViT | |
from tokenize_anything.modeling import ImageTokenizer | |
from tokenize_anything.modeling import PromptEncoder | |
from tokenize_anything.modeling import TextDecoder | |
from tokenize_anything.modeling import TextTokenizer | |
def get_device(device_index): | |
"""Create an available device object.""" | |
if torch.cuda.is_available(): | |
return torch.device("cuda", device_index) | |
return torch.device("cpu") | |
def load_weights(module, weights_file, strict=True): | |
"""Load a weights file.""" | |
if not weights_file: | |
return | |
if weights_file.endswith(".pkl"): | |
with open(weights_file, "rb") as f: | |
state_dict = pickle.load(f) | |
for k, v in state_dict.items(): | |
state_dict[k] = torch.as_tensor(v) | |
else: | |
state_dict = torch.load(weights_file, map_location="cpu") | |
module.load_state_dict(state_dict, strict=strict) | |
def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size): | |
"""Build an image encoder with ViT.""" | |
return ImageEncoderViT( | |
depth=depth, | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=4, | |
patch_size=16, | |
window_size=16, | |
image_size=image_size, | |
out_dim=out_dim, | |
) | |
def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", **kwargs): | |
"""Build an image tokenizer.""" | |
image_size = kwargs.get("image_size", 1024) | |
image_embed_dim = kwargs.get("image_embed_dim", 256) | |
sem_embed_dim = kwargs.get("sem_embed_dim", 1024) | |
text_embed_dim = kwargs.get("text_embed_dim", 512) | |
text_decoder_depth = kwargs.get("text_decoder_depth", 12) | |
text_seq_len = kwargs.get("text_seq_len", 40) | |
text_tokenizer = TextTokenizer() | |
model = ImageTokenizer( | |
image_encoder=image_encoder(out_dim=image_embed_dim, image_size=image_size), | |
prompt_encoder=PromptEncoder(embed_dim=image_embed_dim, image_size=image_size), | |
image_decoder=ImageDecoder( | |
embed_dim=image_embed_dim, | |
num_heads=image_embed_dim // 32, | |
sem_embed_dim=sem_embed_dim, | |
depth=2, | |
num_mask_tokens=4, | |
), | |
text_tokenizer=text_tokenizer, | |
concept_projector=ConceptProjector(), | |
text_decoder=TextDecoder( | |
depth=text_decoder_depth, | |
embed_dim=text_embed_dim, | |
num_heads=text_embed_dim // 64, | |
prompt_embed_dim=image_embed_dim, | |
max_seq_len=text_seq_len, | |
vocab_size=text_tokenizer.n_words, | |
mlp_ratio=4, | |
), | |
) | |
load_weights(model, checkpoint) | |
model = model.to(device=get_device(device)) | |
model = model.eval() if not kwargs.get("training", False) else model | |
model = model.half() if dtype == "float16" else model | |
model = model.bfloat16() if dtype == "bfloat16" else model | |
model = model.float() if dtype == "float32" else model | |
return model | |
vit_b_encoder = partial(vit_encoder, depth=12, embed_dim=768, num_heads=12) | |
vit_l_encoder = partial(vit_encoder, depth=24, embed_dim=1024, num_heads=16) | |
vit_h_encoder = partial(vit_encoder, depth=32, embed_dim=1280, num_heads=16) | |
model_registry = { | |
"tap_vit_b": partial(image_tokenizer, image_encoder=vit_b_encoder), | |
"tap_vit_l": partial(image_tokenizer, image_encoder=vit_l_encoder), | |
"tap_vit_h": partial(image_tokenizer, image_encoder=vit_h_encoder), | |
} | |