File size: 5,028 Bytes
0c0cc86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import List
import warnings

import torch
from torch import nn, Tensor
from torchvision import transforms

from torchtune.models.llama3 import lora_llama3_8b, llama3_8b
from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear
from torchtune.modules import TransformerDecoder

with warnings.catch_warnings():
    warnings.simplefilter("ignore", UserWarning)
    from imagebind.models import imagebind_model
    from models.imagebind_wrapper import get_imagebind_v2, V2_PATH
    from models.imagebind_wrapper import ImageBind

IMAGEBIND_DIM = 1024
CLIP_DIM = 768


class MMEmbedding(nn.Embedding):
    def __init__(self, e, perception_tokens=1, use_clip=False):
        super().__init__(
            num_embeddings=e.num_embeddings,
            embedding_dim=e.embedding_dim,
            padding_idx=e.padding_idx,
            max_norm=e.max_norm,
            norm_type=e.norm_type,
            scale_grad_by_freq=e.scale_grad_by_freq,
            sparse=e.sparse,
        )
        self._perception_tokens = perception_tokens
        self._context = []
        self._use_clip = use_clip

        dim_in = IMAGEBIND_DIM + (CLIP_DIM if use_clip else 0)
        dim_out = e.embedding_dim * perception_tokens

        self.proj_to_llama = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.GELU(),
            nn.LayerNorm(dim_out),
            nn.Linear(dim_out, dim_out),
        )

    def set_context(self, context):
        self._context = context

    def forward(self, input: Tensor) -> Tensor:
        r = super().forward(input)
        # self._context is first indexed by batch idx
        for b, context_dict in enumerate(self._context):
            # then by sequence idx
            for s, embed in context_dict.items():
                # and then must be transformed from imagebind dim -> llama3 dim
                if self._use_clip:
                    llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"], embed["clip_embed"]]))
                else:
                    llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"]]))
                r[b, s:s+self._perception_tokens] = llama_embed.view(self._perception_tokens, -1)
        return r


class MMLinear(nn.Linear):
    def __init__(self, o):
        super().__init__(
            in_features=o.in_features,
            out_features=o.out_features,
            bias=(o.bias != None)
        )
        self._context = []

        dim_out = CLIP_DIM
        dim_in = o.in_features
        self.proj_from_llama = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.GELU(),
            nn.LayerNorm(dim_out),
            nn.Linear(dim_out, dim_out),
        )

    def set_context(self, context):
        self._context = context

    def forward(self, input_bsd: Tensor) -> Tensor:
        # self._context has the indexes of image llama tokens: process these with proj_from_llama
        self._clip_projections = []
        # # self._context is first indexed by batch idx
        # for b, context_dict in enumerate(self._context):
        #     # then by sequence idx
        #     for s, embed in context_dict.items():
        #         # and then must be transformed from llama3 dim -> clip dim
        #         self._clip_projections.append((
        #             self.proj_from_llama(input_bsd[b, s]),
        #             (embed["clip_embed"] if "clip_embed" in embed else None) # terrible
        #         ))
        r = super().forward(input_bsd)
        return r



def lora_mmllama3_8b(
    lora_attn_modules: List[LORA_ATTN_MODULES],
    apply_lora_to_mlp: bool = False,
    apply_lora_to_output: bool = False,
    lora_rank: int = 8,
    lora_alpha: float = 16,
    quantize_base: bool = False,
    perception_tokens: int = 2,
    use_clip: bool = False
) -> TransformerDecoder:
    llama3 = lora_llama3_8b(
        lora_attn_modules,
        apply_lora_to_mlp,
        apply_lora_to_output,
        lora_rank,
        lora_alpha,
        quantize_base,
    )
    llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
    llama3.output = MMLinear(llama3.output)
    return llama3


def mmllama3_8b(
    perception_tokens: int = 2,
    use_clip: bool = False
) -> TransformerDecoder:
    llama3 = llama3_8b()
    llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
    llama3.output = MMLinear(llama3.output)
    return llama3


def imagebind_huge(use_v2: bool=True):
    if use_v2:
        imagebind = ImageBind(v2=True)
    else:
        imagebind = imagebind_model.imagebind_huge(pretrained=True)
    imagebind.transform_from_pil = transforms.Compose([
        transforms.Resize(
            224, interpolation=transforms.InterpolationMode.BICUBIC
        ),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ])
    return imagebind