|
|
|
"""
|
|
Created on Fri Sep 13 19:16:12 2024
|
|
|
|
@author: salikha4
|
|
"""
|
|
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
from lwm_model import LWM
|
|
|
|
class WirelessConfig(PretrainedConfig):
|
|
model_type = "lwm"
|
|
|
|
def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.element_length = element_length
|
|
self.d_model = d_model
|
|
self.max_len = max_len
|
|
self.n_layers = n_layers
|
|
|
|
class WirelessChannelModel(PreTrainedModel):
|
|
config_class = WirelessConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.lwm = LWM(config.element_length, config.d_model, config.max_len, config.n_layers)
|
|
|
|
def forward(self, input_ids, masked_pos):
|
|
return self.lwm(input_ids, masked_pos)
|
|
|