import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM from RBFLayer import RBFLayer def l_norm(x, p=2): return torch.norm(x, p=p, dim=-1) # Gaussian RBF def rbf_gaussian(x): return (-x.pow(2)).exp() class CustomRBFFeedForward(nn.Module): def __init__(self, in_features, out_features, num_kernels): super(CustomRBFFeedForward, self).__init__() # RBFLayer from the given implementation self.rbf_layer = RBFLayer( in_features_dim=in_features, # Input size (e.g., 896) num_kernels=num_kernels, # Number of kernels in the RBF layer (can be tuned) out_features_dim=out_features, # Output size (e.g., 4864) radial_function=rbf_gaussian, # Use the Gaussian RBF norm_function=l_norm # Use Euclidean norm ) def forward(self, x): # Apply the RBF layer to the input x return self.rbf_layer(x)