File size: 5,076 Bytes
e1aaaac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from transformers import AutoModelForCausalLM, AutoTokenizer
import open_clip
from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
from .utils import extend_instance
def create_model_and_transforms(
clip_vision_encoder_path: str,
clip_vision_encoder_pretrained: str,
lang_encoder_path: str,
tokenizer_path: str,
cross_attn_every_n_layers: int = 1,
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
freeze_lm_embeddings: bool = False,
**flamingo_kwargs,
):
"""
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
Appends special tokens to the tokenizer and freezes backbones.
Args:
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
lang_encoder_path (str): path to pretrained language encoder
tokenizer_path (str): path to pretrained tokenizer
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Tokenizer: A tokenizer for the language model
"""
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
)
# set the vision encoder to output the visual features
vision_encoder.visual.output_tokens = True
text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=use_local_files,
trust_remote_code=True,
)
# add Flamingo special tokens to the tokenizer
text_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endofchunk|>", "<image>"]}
)
if text_tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
lang_encoder = AutoModelForCausalLM.from_pretrained(
lang_encoder_path,
local_files_only=use_local_files,
trust_remote_code=True,
)
# hacks for MPT-1B, which doesn't have a get_input_embeddings method
if "mpt-1b-redpajama-200b" in lang_encoder_path:
class EmbeddingFnMixin:
def get_input_embeddings(self):
return self.transformer.wte
def set_input_embeddings(self, new_embeddings):
self.transformer.wte = new_embeddings
extend_instance(lang_encoder, EmbeddingFnMixin)
# convert LM to FlamingoLM
extend_instance(lang_encoder, FlamingoLMMixin)
if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))
model = Flamingo(
vision_encoder,
lang_encoder,
text_tokenizer.encode("<|endofchunk|>")[-1],
text_tokenizer.encode("<image>")[-1],
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
"width"
],
cross_attn_every_n_layers=cross_attn_every_n_layers,
**flamingo_kwargs,
)
# Freeze all parameters
model.requires_grad_(False)
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
model.perceiver.requires_grad_(True)
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
if not freeze_lm_embeddings:
model.lang_encoder.get_input_embeddings().requires_grad_(True)
# TODO: investigate also training the output embeddings when untied
print(
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
)
return model, image_processor, text_tokenizer
def _infer_decoder_layers_attr_name(model):
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
if k.lower() in model.__class__.__name__.lower():
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
raise ValueError(
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
)
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
"opt": "model.decoder.layers",
"gptj": "transformer.h",
"gpt-j": "transformer.h",
"pythia": "gpt_neox.layers",
"llama": "model.layers",
"gptneoxforcausallm": "gpt_neox.layers",
"mpt": "transformer.blocks",
"mosaicgpt": "transformer.blocks",
}
|