Sadjad Alikhani commited on
Commit
d0a7b2e
·
verified ·
1 Parent(s): b7a626b

Upload 3 files

Browse files
Files changed (3) hide show
  1. inference.py +165 -0
  2. input_preprocess.py +294 -0
  3. lwm_model.py +154 -0
inference.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 18:27:17 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import os
9
+ import csv
10
+ import json
11
+ import shutil
12
+ import random
13
+ import argparse
14
+ from datetime import datetime
15
+ import pandas as pd
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
21
+ from torch.optim import Adam
22
+ import numpy as np
23
+ #from lwm_model import LWM, load_model
24
+ import warnings
25
+ warnings.filterwarnings('ignore')
26
+ from input_preprocess import *
27
+
28
+ # Device configuration
29
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
30
+ if torch.cuda.is_available():
31
+ torch.cuda.empty_cache()
32
+
33
+ # Folders
34
+ # MODELS_FOLDER = 'models/'
35
+
36
+ def dataset_gen(preprocessed_chs, input_type, lwm_model):
37
+
38
+ if input_type in ['cls_emb', 'channel_emb']:
39
+ dataset = prepare_for_LWM(preprocessed_chs, device)
40
+ elif input_type == 'raw':
41
+ dataset = create_raw_dataset(preprocessed_chs, device)
42
+
43
+ if input_type in ['cls_emb','channel_emb']:
44
+
45
+ # Process data through LWM
46
+ lwm_loss, embedding_data = evaluate(lwm_model, dataset)
47
+
48
+ print(f'LWM loss: {lwm_loss:.4f}')
49
+
50
+ if input_type == 'cls_emb':
51
+ embedding_data = embedding_data[:, 0]
52
+ elif input_type == 'channel_emb':
53
+ embedding_data = embedding_data[:, 1:]
54
+
55
+ dataset = embedding_data.float()
56
+
57
+ return dataset
58
+
59
+
60
+ def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
61
+
62
+ input_ids, masked_tokens, masked_pos = zip(*data)
63
+
64
+ input_ids_tensor = torch.tensor(input_ids, device=device).float()
65
+ masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
66
+ masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
67
+
68
+ dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
69
+
70
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
71
+
72
+
73
+ def create_raw_dataset(data, device):
74
+ """Create a dataset for raw channel data."""
75
+ input_ids, _, _ = zip(*data)
76
+ input_data = torch.tensor(input_ids, device=device)[:, 1:]
77
+ return input_data.float()
78
+
79
+
80
+ def label_gen(task, data, scenario, n_beams=64):
81
+
82
+ idxs = np.where(data['user']['LoS'] != -1)[0]
83
+
84
+ if task == 'LoS/NLoS Classification':
85
+ label = data['user']['LoS'][idxs]
86
+ elif task == 'Beam Prediction':
87
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
88
+ n_users = len(data['user']['channel'])
89
+ n_subbands = 1
90
+ fov = 120
91
+
92
+ # Setup Beamformers
93
+ beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
94
+
95
+ F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
96
+ phi=azi*np.pi/180,
97
+ kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
98
+ for azi in beam_angles])
99
+
100
+ full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
101
+ for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
102
+ if data['user']['LoS'][ue_idx] == -1:
103
+ full_dbm[:,:,ue_idx] = np.nan
104
+ else:
105
+ chs = F1 @ data['user']['channel'][ue_idx]
106
+ full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
107
+ full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
108
+
109
+ best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
110
+ best_beams = best_beams.astype(float)
111
+ best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
112
+ max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
113
+
114
+ label = best_beams[idxs]
115
+
116
+ return label.astype(int)
117
+
118
+
119
+ def steering_vec(array, phi=0, theta=0, kd=np.pi):
120
+ # phi = azimuth
121
+ # theta = elevation
122
+ idxs = DeepMIMOv3.ant_indices(array)
123
+ resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
124
+ return resp / np.linalg.norm(resp)
125
+
126
+
127
+ def evaluate(model, dataloader):
128
+
129
+ model.eval()
130
+ running_loss = 0.0
131
+ outputs = []
132
+ criterionMCM = nn.MSELoss()
133
+
134
+ with torch.no_grad():
135
+ for batch in dataloader:
136
+ input_ids = batch[0]
137
+ masked_tokens = batch[1]
138
+ masked_pos = batch[2]
139
+
140
+ logits_lm, output = model(input_ids, masked_pos)
141
+
142
+ output_batch_preproc = output
143
+ outputs.append(output_batch_preproc)
144
+
145
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
146
+ loss = loss_lm/torch.var(masked_tokens)
147
+ running_loss += loss.item()
148
+
149
+ average_loss = running_loss / len(dataloader)
150
+ output_total = torch.cat(outputs, dim=0)
151
+
152
+ return average_loss, output_total
153
+
154
+
155
+ def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
156
+ labels = []
157
+ for scenario_idx in scenario_idxs:
158
+ scenario_name = scenarios_list()[scenario_idx]
159
+ # data = DeepMIMO_data_gen(scenario_name)
160
+ data = deepmimo_data[scenario_idx]
161
+ labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
162
+
163
+ preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
164
+
165
+ return preprocessed_chs
input_preprocess.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 16:13:29 2024
4
+
5
+ This script generates preprocessed data from wireless communication scenarios,
6
+ including token generation, patch creation, and data sampling for machine learning models.
7
+
8
+ @author: salikha4
9
+ """
10
+
11
+ import numpy as np
12
+ import os
13
+ from tqdm import tqdm
14
+ import time
15
+ import pickle
16
+
17
+ #%% Scenarios List
18
+ def scenarios_list():
19
+ """Returns an array of available scenarios."""
20
+ return np.array([
21
+ 'city_18_denver', 'city_15_indianapolis', 'city_19_oklahoma',
22
+ 'city_12_fortworth', 'city_11_santaclara', 'city_7_sandiego'
23
+ ])
24
+
25
+ #%% Token Generation
26
+ def tokenizer(scenario_idxs, gen_raw=True):
27
+ """
28
+ Generates tokens by preparing and preprocessing the dataset.
29
+
30
+ Args:
31
+ scenario_idxs (list): Indices of the scenarios.
32
+ patch_gen (bool): Whether to generate patches. Defaults to True.
33
+ patch_size (int): Size of each patch. Defaults to 16.
34
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data. Defaults to False.
35
+ gen_raw (bool): Whether to generate raw data. Defaults to False.
36
+ save_data (bool): Whether to save the preprocessed data. Defaults to False.
37
+
38
+ Returns:
39
+ preprocessed_data, sequence_length, element_length: Preprocessed data and related dimensions.
40
+ """
41
+
42
+ # Patch generation or loading
43
+ deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in scenario_names]
44
+ n_scenarios = len(scenario_names)
45
+
46
+ patches = [patch_maker(deepmimo_data[scenario_idx]) for scenario_idx in range(n_scenarios)]
47
+ patches = np.vstack(patches)
48
+
49
+ # Define dimensions
50
+ patch_size = patches.shape[2]
51
+ n_patches = patches.shape[1]
52
+ n_masks_half = int(0.15 * n_patches / 2)
53
+ sequence_length = n_patches + 1
54
+ element_length = patch_size
55
+
56
+ word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
57
+
58
+ # Generate preprocessed channels
59
+ preprocessed_data = []
60
+ for user_idx in tqdm(range(len(patches)), desc="Processing items"):
61
+ sample = make_sample(user_idx, patches, word2id, n_patches, n_masks_half, patch_size, gen_raw=gen_raw)
62
+ preprocessed_data.append(sample)
63
+
64
+ return preprocessed_data
65
+
66
+ #%% Patch Creation
67
+ def patch_maker(data, patch_size=16, norm_factor=1e6):
68
+ """
69
+ Creates patches from the dataset based on the scenario.
70
+
71
+ Args:-
72
+ patch_size (int): Size of each patch.
73
+ scenario (str): Selected scenario for data generation.
74
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
75
+ norm_factor (int): Normalization factor for channels.
76
+
77
+ Returns:
78
+ patch (numpy array): Generated patches.
79
+ """
80
+ idxs = np.where(data['user']['LoS'] != -1)[0]
81
+
82
+ # Reshaping and normalizing channels
83
+ original_ch = data['user']['channel'][idxs]
84
+ flat_channels = original_ch.reshape((original_ch.shape[0], -1)).astype(np.csingle)
85
+ flat_channels_complex = np.hstack((flat_channels.real, flat_channels.imag)) * norm_factor
86
+
87
+ # Create patches
88
+ n_patches = flat_channels_complex.shape[1] // patch_size
89
+ patch = np.zeros((len(idxs), n_patches, patch_size))
90
+ for idx in range(n_patches):
91
+ patch[:, idx, :] = flat_channels_complex[:, idx * patch_size:(idx + 1) * patch_size]
92
+
93
+ return patch
94
+
95
+
96
+ #%% Data Generation for Scenario Areas
97
+ def DeepMIMO_data_gen(scenario):
98
+ """
99
+ Generates or loads data for a given scenario.
100
+
101
+ Args:
102
+ scenario (str): Scenario name.
103
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
104
+ save_data (bool): Whether to save generated data.
105
+
106
+ Returns:
107
+ data (dict): Loaded or generated data.
108
+ """
109
+
110
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
111
+
112
+ deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
113
+ uniform_idxs = uniform_sampling(deepMIMO_dataset, [1, 1], len(parameters['user_rows']),
114
+ users_per_row=row_column_users[scenario]['n_per_row'])
115
+ data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
116
+
117
+ return data
118
+
119
+ #%%%
120
+ def get_parameters(scenario):
121
+
122
+ n_ant_bs = 32 #32
123
+ n_ant_ue = 1
124
+ n_subcarriers = 32 #32
125
+ scs = 30e3
126
+
127
+ row_column_users = {
128
+ 'city_18_denver': {
129
+ 'n_rows': 85,
130
+ 'n_per_row': 82
131
+ },
132
+ 'city_15_indianapolis': {
133
+ 'n_rows': 80,
134
+ 'n_per_row': 79
135
+ },
136
+ 'city_19_oklahoma': {
137
+ 'n_rows': 82,
138
+ 'n_per_row': 75
139
+ },
140
+ 'city_12_fortworth': {
141
+ 'n_rows': 86,
142
+ 'n_per_row': 72
143
+ },
144
+ 'city_11_santaclara': {
145
+ 'n_rows': 47,
146
+ 'n_per_row': 114
147
+ },
148
+ 'city_7_sandiego': {
149
+ 'n_rows': 71,
150
+ 'n_per_row': 83
151
+ }}
152
+
153
+ parameters = DeepMIMOv3.default_params()
154
+ parameters['dataset_folder'] = './scenarios'
155
+ parameters['scenario'] = scenario
156
+
157
+ if scenario == 'O1_3p5':
158
+ parameters['active_BS'] = np.array([4])
159
+ elif scenario in ['city_18_denver', 'city_15_indianapolis']:
160
+ parameters['active_BS'] = np.array([3])
161
+ else:
162
+ parameters['active_BS'] = np.array([1])
163
+
164
+ if scenario == 'Boston5G_3p5':
165
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
166
+ row_column_users[scenario]['n_rows'][1])
167
+ else:
168
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
169
+ parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
170
+ parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
171
+ parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
172
+ parameters['enable_BS2BS'] = False
173
+ parameters['OFDM']['subcarriers'] = n_subcarriers
174
+ parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
175
+
176
+ parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
177
+ parameters['num_paths'] = 20
178
+
179
+ return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
180
+
181
+
182
+ #%% Sample Generation
183
+ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
184
+ """
185
+ Generates a sample for each user, including masking and tokenizing.
186
+
187
+ Args:
188
+ user_idx (int): Index of the user.
189
+ patch (numpy array): Patches data.
190
+ word2id (dict): Dictionary for special tokens.
191
+ n_patches (int): Number of patches.
192
+ n_masks (int): Number of masks.
193
+ patch_size (int): Size of each patch.
194
+ gen_raw (bool): Whether to generate raw tokens.
195
+
196
+ Returns:
197
+ sample (list): Generated sample for the user.
198
+ """
199
+
200
+ tokens = patch[user_idx]
201
+ input_ids = np.vstack((word2id['[CLS]'], tokens))
202
+
203
+ real_tokens_size = int(n_patches / 2)
204
+ masks_pos_real = np.random.choice(range(0, real_tokens_size), size=n_masks, replace=False)
205
+ masks_pos_imag = masks_pos_real + real_tokens_size
206
+ masked_pos = np.hstack((masks_pos_real, masks_pos_imag)) + 1
207
+
208
+ masked_tokens = []
209
+ for pos in masked_pos:
210
+ original_masked_tokens = input_ids[pos].copy()
211
+ masked_tokens.append(original_masked_tokens)
212
+ if not gen_raw:
213
+ rnd_num = np.random.rand()
214
+ if rnd_num < 0.1:
215
+ input_ids[pos] = np.random.rand(patch_size)
216
+ elif rnd_num < 0.9:
217
+ input_ids[pos] = word2id['[MASK]']
218
+
219
+ return [input_ids, masked_tokens, masked_pos]
220
+
221
+
222
+ #%% Sampling and Data Selection
223
+ def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
224
+ """
225
+ Performs uniform sampling on the dataset.
226
+
227
+ Args:
228
+ dataset (dict): DeepMIMO dataset.
229
+ sampling_div (list): Step sizes along [x, y] dimensions.
230
+ n_rows (int): Number of rows for user selection.
231
+ users_per_row (int): Number of users per row.
232
+
233
+ Returns:
234
+ uniform_idxs (numpy array): Indices of the selected samples.
235
+ """
236
+ cols = np.arange(users_per_row, step=sampling_div[0])
237
+ rows = np.arange(n_rows, step=sampling_div[1])
238
+ uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
239
+
240
+ return uniform_idxs
241
+
242
+ def select_by_idx(dataset, idxs):
243
+ """
244
+ Selects a subset of the dataset based on the provided indices.
245
+
246
+ Args:
247
+ dataset (dict): Dataset to trim.
248
+ idxs (numpy array): Indices of users to select.
249
+
250
+ Returns:
251
+ dataset_t (list): Trimmed dataset based on selected indices.
252
+ """
253
+ dataset_t = [] # Trimmed dataset
254
+ for bs_idx in range(len(dataset)):
255
+ dataset_t.append({})
256
+ for key in dataset[bs_idx].keys():
257
+ dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
258
+ dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
259
+
260
+ return dataset_t
261
+
262
+ #%% Save and Load Utilities
263
+ def save_var(var, path):
264
+ """
265
+ Saves a variable to a pickle file.
266
+
267
+ Args:
268
+ var (object): Variable to be saved.
269
+ path (str): Path to save the file.
270
+
271
+ Returns:
272
+ None
273
+ """
274
+ path_full = path if path.endswith('.p') else (path + '.pickle')
275
+ with open(path_full, 'wb') as handle:
276
+ pickle.dump(var, handle)
277
+
278
+ def load_var(path):
279
+ """
280
+ Loads a variable from a pickle file.
281
+
282
+ Args:
283
+ path (str): Path of the file to load.
284
+
285
+ Returns:
286
+ var (object): Loaded variable.
287
+ """
288
+ path_full = path if path.endswith('.p') else (path + '.pickle')
289
+ with open(path_full, 'rb') as handle:
290
+ var = pickle.load(handle)
291
+
292
+ return var
293
+
294
+ #%%
lwm_model.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 19:55:23 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ from inference import *
14
+ from load_data import load_DeepMIMO_data
15
+ from input_preprocess import *
16
+ from huggingface_hub import hf_hub_download
17
+
18
+
19
+ ELEMENT_LENGTH = 16
20
+ D_MODEL = 64
21
+ MAX_LEN = 129
22
+ N_LAYERS = 12
23
+ N_HEADS = 12
24
+ D_FF = D_MODEL * 4
25
+ D_K = D_MODEL // N_HEADS
26
+ D_V = D_MODEL // N_HEADS
27
+ DROPOUT = 0.1
28
+
29
+ class LayerNormalization(nn.Module):
30
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
31
+ super().__init__()
32
+ self.eps = eps
33
+ self.alpha = nn.Parameter(torch.ones(d_model))
34
+ self.bias = nn.Parameter(torch.zeros(d_model))
35
+
36
+ def forward(self, x):
37
+ mean = x.mean(dim=-1, keepdim=True)
38
+ std = x.std(dim=-1, keepdim=True)
39
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
40
+
41
+ class Embedding(nn.Module):
42
+ def __init__(self, element_length, d_model, max_len):
43
+ super().__init__()
44
+ self.element_length = element_length
45
+ self.d_model = d_model
46
+ self.proj = nn.Linear(element_length, d_model)
47
+ self.pos_embed = nn.Embedding(max_len, d_model)
48
+ self.norm = LayerNormalization(d_model)
49
+
50
+ def forward(self, x):
51
+ seq_len = x.size(1)
52
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
53
+ pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
54
+ tok_emb = self.proj(x.float())
55
+ embedding = tok_emb + self.pos_embed(pos)
56
+ return self.norm(embedding)
57
+
58
+ class ScaledDotProductAttention(nn.Module):
59
+ def __init__(self):
60
+ super().__init__()
61
+
62
+ def forward(self, Q, K, V):
63
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(D_K)
64
+ attn = F.softmax(scores, dim=-1)
65
+ context = torch.matmul(attn, V)
66
+ return context, attn
67
+
68
+ class MultiHeadAttention(nn.Module):
69
+ def __init__(self):
70
+ super().__init__()
71
+ self.W_Q = nn.Linear(D_MODEL, D_K * N_HEADS)
72
+ self.W_K = nn.Linear(D_MODEL, D_K * N_HEADS)
73
+ self.W_V = nn.Linear(D_MODEL, D_V * N_HEADS)
74
+ self.linear = nn.Linear(N_HEADS * D_V, D_MODEL)
75
+ self.norm = LayerNormalization(D_MODEL)
76
+ self.dropout = nn.Dropout(DROPOUT)
77
+
78
+ def forward(self, Q, K, V):
79
+ residual, batch_size = Q, Q.size(0)
80
+ q_s = self.W_Q(Q).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
81
+ k_s = self.W_K(K).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
82
+ v_s = self.W_V(V).view(batch_size, -1, N_HEADS, D_V).transpose(1, 2)
83
+
84
+ context, attn = ScaledDotProductAttention()(q_s, k_s, v_s)
85
+ output = context.transpose(1, 2).contiguous().view(batch_size, -1, N_HEADS * D_V)
86
+ output = self.linear(output)
87
+ return residual + self.dropout(output), attn #residual + self.dropout(output), attn
88
+
89
+ class PoswiseFeedForwardNet(nn.Module):
90
+ def __init__(self):
91
+ super().__init__()
92
+ self.fc1 = nn.Linear(D_MODEL, D_FF)
93
+ self.fc2 = nn.Linear(D_FF, D_MODEL)
94
+ self.dropout = nn.Dropout(DROPOUT)
95
+ self.norm = LayerNormalization(D_MODEL)
96
+
97
+ def forward(self, x):
98
+ output = self.fc2(self.dropout(F.relu(self.fc1(x))))
99
+ return x + self.dropout(output) #x + self.dropout(output)
100
+
101
+ class EncoderLayer(nn.Module):
102
+ def __init__(self):
103
+ super().__init__()
104
+ self.enc_self_attn = MultiHeadAttention()
105
+ self.pos_ffn = PoswiseFeedForwardNet()
106
+ self.norm = LayerNormalization(D_MODEL)
107
+
108
+ def forward(self, enc_inputs):
109
+ attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
110
+ attn_outputs = self.norm(attn_outputs)
111
+ enc_outputs = self.pos_ffn(attn_outputs)
112
+ return enc_outputs, attn
113
+
114
+ class LWM(torch.nn.Module):
115
+ def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
116
+ super().__init__()
117
+ self.embedding = Embedding(element_length, d_model, max_len)
118
+ self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
119
+ self.linear = nn.Linear(d_model, d_model)
120
+ self.norm = LayerNormalization(d_model)
121
+
122
+ embed_weight = self.embedding.proj.weight
123
+ d_model, n_dim = embed_weight.size()
124
+ self.decoder = nn.Linear(d_model, n_dim, bias=False)
125
+ self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
126
+ self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
127
+
128
+ @classmethod
129
+ def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
130
+ # Define model
131
+ model = cls().to(device)
132
+
133
+ # Download model weights using Hugging Face Hub
134
+ ckpt_path = hf_hub_download(repo_id="sadjadalikhani/LWM", filename=ckpt_name, use_auth_token=use_auth_token)
135
+
136
+ # Load the model weights
137
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
138
+ print(f"Model loaded successfully from {ckpt_path} to {device}")
139
+
140
+ return model
141
+
142
+ def forward(self, input_ids, masked_pos):
143
+ # Forward pass
144
+ output = self.embedding(input_ids)
145
+ for layer in self.layers:
146
+ output, _ = layer(output)
147
+
148
+ masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
149
+ h_masked = torch.gather(output, 1, masked_pos)
150
+ h_masked = self.norm(F.relu(self.linear(h_masked)))
151
+ logits_lm = self.decoder(h_masked) + self.decoder_bias
152
+
153
+ return logits_lm, output
154
+