File size: 957 Bytes
046d995 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from RBFLayer import RBFLayer
def l_norm(x, p=2):
return torch.norm(x, p=p, dim=-1)
# Gaussian RBF
def rbf_gaussian(x):
return (-x.pow(2)).exp()
class CustomRBFFeedForward(nn.Module):
def __init__(self, in_features, out_features, num_kernels):
super(CustomRBFFeedForward, self).__init__()
# RBFLayer from the given implementation
self.rbf_layer = RBFLayer(
in_features_dim=in_features, # Input size (e.g., 896)
num_kernels=num_kernels, # Number of kernels in the RBF layer (can be tuned)
out_features_dim=out_features, # Output size (e.g., 4864)
radial_function=rbf_gaussian, # Use the Gaussian RBF
norm_function=l_norm # Use Euclidean norm
)
def forward(self, x):
# Apply the RBF layer to the input x
return self.rbf_layer(x)
|