LongVU_Llama3_2_3B / vision_sampler.py
jadechoghari's picture
add initial files
15fca6e verified
import math
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class CrossAttention(nn.Module):
def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = self.hidden_dim // self.num_heads
if (self.head_dim * self.num_heads) != self.hidden_dim:
raise ValueError(
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Sequential(
nn.LayerNorm(q_dim),
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
)
self.k_proj = nn.Sequential(
nn.LayerNorm(kv_dim),
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
)
self.v_proj = nn.Sequential(
nn.LayerNorm(kv_dim),
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, q_dim, bias=attention_bias
)
def forward(self, vision_latents, queries, attention_mask):
bsz, q_len, _ = queries.size()
bsz, v_len, _ = vision_latents.size()
query_states = self.q_proj(queries)
key_states = self.k_proj(vision_latents)
value_states = self.v_proj(vision_latents)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, v_len, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, v_len, self.num_heads, self.head_dim
).transpose(1, 2)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, v_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
attn_output = self.o_proj(attn_output)
return attn_output
class AggregationBlock(nn.Module):
def __init__(
self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = self.hidden_dim // self.num_heads
if (self.head_dim * self.num_heads) != self.hidden_dim:
raise ValueError(
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.attention = attention
if attention:
self.attention_layer = CrossAttention(
q_dim, kv_dim, hidden_dim, num_heads, attention_bias
)
else:
self.attention_layer = MLP(kv_dim, q_dim, q_dim)
def forward(self, vision_latents, queries, attention_mask):
if self.attention:
queries = self.attention_layer(vision_latents, queries, attention_mask)
else:
queries = self.attention_layer(vision_latents)
return queries
class MultiKVCrossAttention(nn.Module):
def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = self.hidden_dim // self.num_heads
if (self.head_dim * self.num_heads) != self.hidden_dim:
raise ValueError(
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Sequential(
nn.LayerNorm(q_dim),
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
)
self.num_of_kvs = len(kv_dim_list)
for i, kv_dim in enumerate(kv_dim_list):
setattr(
self,
"k_proj_{}".format(i),
nn.Sequential(
nn.LayerNorm(kv_dim),
nn.Linear(
kv_dim, self.num_heads * self.head_dim, bias=attention_bias
),
),
)
setattr(
self,
"v_proj_{}".format(i),
nn.Sequential(
nn.LayerNorm(kv_dim),
nn.Linear(
kv_dim, self.num_heads * self.head_dim, bias=attention_bias
),
),
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, q_dim, bias=attention_bias
)
def forward(
self,
queries,
*vision_latents_attention_mask_list,
):
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
bsz, q_len, _ = queries.size()
query_states = self.q_proj(queries)
key_states = torch.cat(
[
getattr(self, "k_proj_{}".format(i))(vision_latents_list[i])
for i in range(self.num_of_kvs)
],
dim=1,
)
value_states = torch.cat(
[
getattr(self, "v_proj_{}".format(i))(vision_latents_list[i])
for i in range(self.num_of_kvs)
],
dim=1,
)
v_len = key_states.shape[1]
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, v_len, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, v_len, self.num_heads, self.head_dim
).transpose(1, 2)
# if kv_weight is not None:
# kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attention_mask = torch.cat(attention_mask_list, dim=-1)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, v_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
)
# attn_output = spda(
# query_states,
# key_states,
# value_states,
# attn_mask=attention_mask,
# additional_score=kv_weight
# )
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
attn_output = self.o_proj(attn_output)
return attn_output
class MLP(nn.Module):
def __init__(self, d_in, d_hidden, d_out):
super().__init__()
self.linear_1 = nn.Linear(d_in, d_hidden, bias=False)
self.act = nn.GELU()
self.linear_2 = nn.Linear(d_hidden, d_out, bias=False)
def forward(self, x):
return self.linear_2(self.act(self.linear_1(x)))
class VisionCrossAttentionLayer(nn.Module):
def __init__(
self,
q_dim,
context_dim,
kv_dim_list,
kv_size_list,
hidden_dim=1024,
layer_idx=0,
):
super().__init__()
num_heads = 16
self.num_of_kvs = len(kv_dim_list)
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
# if self.num_of_kvs > 1:
# self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs)
# self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs)))
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
self.norm = nn.LayerNorm(hidden_dim)
self.cross_attn = MultiKVCrossAttention(
hidden_dim, kv_dim_list, hidden_dim, num_heads
)
self.kv_size_list = kv_size_list
for i, kv_size in enumerate(kv_size_list):
if kv_size > 1:
setattr(
self,
"pos_embed_{}".format(i),
nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
)
# self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False)
def forward(
self,
queries,
context_feature,
*vision_latents_attention_mask_list,
) -> torch.FloatTensor:
residual = queries
# queries = self.proj_in(queries)
context_feature = self.proj_context(context_feature)
# queries = queries + context_feature
queries = torch.cat([queries, context_feature], -1)
# if self.num_of_kvs > 1:
# kv_weight = self.weight_mlp(queries) # B * 1 * num_tower
# kv_weight = kv_weight + self.tower_weight.view(1, 1, -1)
# kv_weight = kv_weight.softmax(-1)
# kv_number_list = [size**2 for size in self.kv_size_list]
# kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1)
# else:
# kv_weight = None
queries = self.proj_in(queries)
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
attention_mask_list_reshaped = []
if attention_mask_list is not None:
for attention_mask in attention_mask_list:
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
attention_mask_list_reshaped.append(attention_mask)
vision_latents_pos_list = []
for i, vision_latents in enumerate(vision_latents_list):
if vision_latents.shape[1] > 1:
vision_latents_pos_list.append(
vision_latents
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
vision_latents.dtype
)
)
else:
vision_latents_pos_list.append(vision_latents)
# Cross Attention
attention_output = self.cross_attn(
queries, *vision_latents_pos_list, *attention_mask_list_reshaped
)
# attention_output = (attention_output * combination_weight).sum(2)
queries = queries + attention_output
queries = self.norm(queries)
queries = self.proj_out(queries)
queries = queries + residual
return queries
class VisionAggregationLayer(nn.Module):
def __init__(
self,
q_dim,
context_dim,
kv_dim_list,
kv_size_list,
hidden_dim=1024,
layer_idx=0,
):
super().__init__()
num_heads = 16
self.num_of_kvs = len(kv_dim_list)
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
self.norm = nn.LayerNorm(hidden_dim)
if self.num_of_kvs > 1:
self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs)
for i, kv_size in enumerate(kv_size_list):
if kv_size > 1:
setattr(
self,
"pos_embed_{}".format(i),
nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
)
setattr(
self,
"aggregate_{}".format(i),
AggregationBlock(
True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
),
)
else:
setattr(
self,
"aggregate_{}".format(i),
AggregationBlock(
False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
),
)
def forward(
self,
queries,
context_feature,
*vision_latents_attention_mask_list,
) -> torch.FloatTensor:
residual = queries
# queries = self.proj_in(queries)
context_feature = self.proj_context(context_feature)
# queries = queries + context_feature
queries = torch.cat([queries, context_feature], -1)
if self.num_of_kvs > 1:
combination_weight = self.weight_mlp(queries).softmax(
-1
) # B * 1 * num_tower
combination_weight = combination_weight.unsqueeze(-1)
else:
combination_weight = 1
queries = self.proj_in(queries)
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
attention_mask_list_reshaped = []
if attention_mask_list is not None:
for attention_mask in attention_mask_list:
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
attention_mask_list_reshaped.append(attention_mask)
vision_latents_pos_list = []
for i, vision_latents in enumerate(vision_latents_list):
if vision_latents.shape[1] > 1:
vision_latents_pos_list.append(
vision_latents
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
vision_latents.dtype
)
)
else:
vision_latents_pos_list.append(vision_latents)
aggregated_vision_latents_list = []
for i, (vision_latents, attention_mask) in enumerate(
zip(vision_latents_pos_list, attention_mask_list_reshaped)
):
aggregated_vision_latents_list.append(
getattr(self, "aggregate_{}".format(i))(
vision_latents, queries, attention_mask
)
)
aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2)
queries = queries + (aggregated_vision_latents * combination_weight).sum(2)
queries = self.norm(queries)
queries = self.proj_out(queries)
queries = queries + residual
return queries
class VisionTokenSampler(nn.Module):
def __init__(
self,
q_dim,
context_dim,
kv_dim_list,
kv_size_list,
vision_hidden_size,
num_of_layers=1,
layer_type="joint",
):
super().__init__()
assert layer_type in ["joint", "sep"]
if layer_type == "joint":
self.layers = nn.ModuleList(
[
VisionCrossAttentionLayer(
q_dim,
context_dim,
kv_dim_list,
kv_size_list,
vision_hidden_size,
idx,
)
for idx in range(num_of_layers)
]
)
else:
self.layers = nn.ModuleList(
[
VisionAggregationLayer(
q_dim,
context_dim,
kv_dim_list,
kv_size_list,
vision_hidden_size,
idx,
)
for idx in range(num_of_layers)
]
)
def forward(self, queries, context_feature, *vision_latents_attention_mask_list):
for layer in self.layers:
queries = layer(
queries, context_feature, *vision_latents_attention_mask_list
)
return queries