lwm / model.py
Sadjad Alikhani
upload required files
ebfb25d verified
raw
history blame
874 Bytes
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 13 19:16:12 2024
@author: salikha4
"""
from transformers import PreTrainedModel, PretrainedConfig
from lwm_model import LWM
class WirelessConfig(PretrainedConfig):
model_type = "lwm"
def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12, **kwargs):
super().__init__(**kwargs)
self.element_length = element_length
self.d_model = d_model
self.max_len = max_len
self.n_layers = n_layers
class WirelessChannelModel(PreTrainedModel):
config_class = WirelessConfig
def __init__(self, config):
super().__init__(config)
self.lwm = LWM(config.element_length, config.d_model, config.max_len, config.n_layers)
def forward(self, input_ids, masked_pos):
return self.lwm(input_ids, masked_pos)