LayTextLLM-All / build_mlp.py
LayTextLLM's picture
Upload LlamaForCausalLM
7d75019 verified
import torch
import torch.nn as nn
import math
import re
def build_layout_projector():
projector_type = 'mlp2x_gelu'
mm_hidden_size = 4
hidden_size = 4096
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(mm_hidden_size, hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(hidden_size, hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {'mm_projector_type': 'identity'}
class PLoRA(nn.Linear):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
lora_r=8,
lora_alpha=16,
lora_dropout=0.05,
lora_len=0,
**kwargs) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_len = lora_len
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
self.lora_scaling = self.lora_alpha / self.lora_r
self.Plora_A = nn.Linear(
in_features, self.lora_r, bias=False, device=device, dtype=dtype)
self.Plora_B = nn.Linear(
self.lora_r, out_features, bias=False, device=device, dtype=dtype)
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x, im_mask=None):
res = super().forward(x)
if im_mask is not None:
if torch.sum(im_mask) > 0:
part_x = x[im_mask]
res[im_mask] += self.Plora_B(
self.Plora_A(
self.lora_dropout(part_x))) * self.lora_scaling
else:
part_x = x[:, :1]
res[:, :1] += self.Plora_B(
self.Plora_A(self.lora_dropout(part_x))) * 0
return res