saeedbenadeeb's picture
Lora Model Uploaded
5fc7eb1
raw
history blame
994 Bytes
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
def __init__(self, input_dim, output_dim, rank, alpha):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = nn.Parameter(torch.randn(input_dim, rank) * std_dev) # Low-rank matrix A
self.B = nn.Parameter(torch.zeros(rank, output_dim)) # Low-rank matrix B
self.alpha = alpha # Scaling factor
def forward(self, x):
# Apply low-rank adaptation: x + alpha * (x @ A @ B)
return self.alpha * (x @ self.A @ self.B)
class LinearWithLoRA(nn.Module):
def __init__(self, linear_layer, rank, alpha):
super().__init__()
self.linear = linear_layer # Original linear layer
self.lora = LoRALayer(linear_layer.in_features, linear_layer.out_features, rank, alpha) # LoRA layer
def forward(self, x):
# Combine original linear layer output with LoRA adaptation
return self.linear(x) + self.lora(x)