File size: 2,150 Bytes
b61b9f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# --------------------------------------------------------
# Ristretto
# Copyright (c) 2025 LiAutoAD
# Licensed under The MIT License
# --------------------------------------------------------

import numpy as np
from torch import nn
import torch.nn.functional as F



class FFN(nn.Module):
    def __init__(self, dim, out_dim, mlp_ratio=3):
        super().__init__()
        self.layernorm = nn.LayerNorm(dim)
        self.f1 = nn.Linear(dim, mlp_ratio * dim)
        self.f2 = nn.Linear(dim, mlp_ratio * dim)
        self.g = nn.Linear(mlp_ratio * dim, out_dim)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.layernorm(x)
        input = x
        x1, x2 = self.f1(x), self.f2(x)
        x = self.act(x1) * x2
        x = self.g(x)
        return x


class TokenAdaptiveProjector(nn.Module):
    def __init__(self, vit_hidden_size, llm_hidden_size, num_image_token):
        super().__init__()
        self.num_image_token = num_image_token
        self.mlp = FFN(vit_hidden_size, llm_hidden_size)

    def find_resize_hw(self, H, W, num_image_token):
        target_h = target_w = int(num_image_token ** 0.5)
        resize_h = int(np.ceil(H / target_h)) * target_h
        resize_w = int(np.ceil(W / target_w)) * target_h
        return resize_h, resize_w, target_h, target_w

    def forward(self, x, num_image_token=None):
        bs, L, C = x.shape

        if num_image_token is None:
            num_image_token = self.num_image_token

        H = W = int(L ** 0.5)
        assert L == H * W, "L should equal H * W"

        resize_h, resize_w, target_h, target_w = self.find_resize_hw(
            H, W, num_image_token
        )

        x = x.view(bs, H, W, C).permute(0, 3, 1, 2)  # [bs, C, H, W]
        if resize_h != H or resize_w != W:
            x = F.interpolate(
                x, size=(resize_h, resize_w), mode="bilinear", align_corners=True
            )
            _, _, H, W = x.shape

        n = target_h
        patch_h = patch_w = H // n

        x = (
            F.avg_pool2d(x, (patch_h, patch_w)).permute(0, 2, 3, 1).reshape(bs, -1, C)
        )
        x = self.mlp(x)
        return x