from transformers import BertForMaskedLM import torch.nn as nn from RBFLayer import RBFLayer # Assuming RBFLayer is your custom RBF implementation class CustomBertForMaskedLM(BertForMaskedLM): def __init__(self, config): super().__init__(config) # Replace the feedforward MLP layers with RBF layers in BERT's encoder for i, layer in enumerate(self.bert.encoder.layer): in_features = 768 intermediate_features = 3072 # Replace the intermediate dense layer (768 -> 3072) with RBF layer.intermediate.dense = RBFLayer( in_features_dim=in_features, num_kernels=2, # Number of kernels in the RBF layer out_features_dim=intermediate_features, radial_function=gaussian_rbf, norm_function=euclidean_norm ) # Replace the output dense layer (3072 -> 768) with RBF layer.output.dense = RBFLayer( in_features_dim=intermediate_features, num_kernels=2, out_features_dim=in_features, radial_function=gaussian_rbf, norm_function=euclidean_norm ) # Define radial basis and norm functions def gaussian_rbf(x): return torch.exp(-x**2) def euclidean_norm(x): return torch.norm(x, p=2, dim=-1)