Sadjad Alikhani commited on
Commit
dd4577b
·
verified ·
1 Parent(s): 89297d3

Upload 5 files

Browse files
Files changed (5) hide show
  1. inference.py +171 -0
  2. input_preprocess.py +296 -0
  3. load_data.py +39 -0
  4. lwm_model.py +173 -0
  5. model.py +160 -0
inference.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 18:27:17 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import os
9
+ import csv
10
+ import json
11
+ import shutil
12
+ import random
13
+ import argparse
14
+ from datetime import datetime
15
+ import pandas as pd
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
21
+ from torch.optim import Adam
22
+ import numpy as np
23
+ from lwm_model import LWM, load_model
24
+ import warnings
25
+ warnings.filterwarnings('ignore')
26
+ from input_preprocess import *
27
+
28
+ # Device configuration
29
+ device_idx_ds = 3
30
+ device = torch.device(f'cuda:{device_idx_ds}' if torch.cuda.is_available() else "cpu")
31
+ if torch.cuda.is_available():
32
+ torch.cuda.empty_cache()
33
+
34
+ # Folders
35
+ # MODELS_FOLDER = 'models/'
36
+
37
+ def dataset_gen(preprocessed_chs, input_type, scenario_idxs, lwm_model):
38
+
39
+ if input_type in ['cls_emb', 'channel_emb']:
40
+ dataset = prepare_for_LWM(preprocessed_chs, device)
41
+ elif input_type == 'raw':
42
+ dataset = create_raw_dataset(preprocessed_chs, device)
43
+
44
+ if input_type in ['cls_emb','channel_emb']:
45
+ # model = LWM().to(device)
46
+ # ckpt_name = 'model_weights.pth'
47
+ # ckpt_path = os.path.join(MODELS_FOLDER, ckpt_name)
48
+ # lwm_model = load_model(model, ckpt_path, device)
49
+ # print(f"Model loaded successfully on {device}")
50
+
51
+ # Process data through LWM
52
+ lwm_loss, embedding_data = evaluate(lwm_model, dataset)
53
+
54
+ print(f'LWM loss: {lwm_loss:.4f}')
55
+
56
+ if input_type == 'cls_emb':
57
+ embedding_data = embedding_data[:, 0]
58
+ elif input_type == 'channel_emb':
59
+ embedding_data = embedding_data[:, 1:]
60
+
61
+ dataset = embedding_data.float()
62
+
63
+ return dataset
64
+
65
+
66
+ def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
67
+
68
+ input_ids, masked_tokens, masked_pos = zip(*data)
69
+
70
+ input_ids_tensor = torch.tensor(input_ids, device=device).float()
71
+ masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
72
+ masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
73
+
74
+ dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
75
+
76
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
77
+
78
+
79
+ def create_raw_dataset(data, device):
80
+ """Create a dataset for raw channel data."""
81
+ input_ids, _, _ = zip(*data)
82
+ input_data = torch.tensor(input_ids, device=device)[:, 1:]
83
+ return input_data.float()
84
+
85
+
86
+ def label_gen(task, data, scenario, n_beams=64):
87
+
88
+ idxs = np.where(data['user']['LoS'] != -1)[0]
89
+
90
+ if task == 'LoS/NLoS Classification':
91
+ label = data['user']['LoS'][idxs]
92
+ elif task == 'Beam Prediction':
93
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
94
+ n_users = len(data['user']['channel'])
95
+ n_subbands = 1
96
+ fov = 120
97
+
98
+ # Setup Beamformers
99
+ beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
100
+
101
+ F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
102
+ phi=azi*np.pi/180,
103
+ kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
104
+ for azi in beam_angles])
105
+
106
+ full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
107
+ for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
108
+ if data['user']['LoS'][ue_idx] == -1:
109
+ full_dbm[:,:,ue_idx] = np.nan
110
+ else:
111
+ chs = F1 @ data['user']['channel'][ue_idx]
112
+ full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
113
+ full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
114
+
115
+ best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
116
+ best_beams = best_beams.astype(float)
117
+ best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
118
+ max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
119
+
120
+ label = best_beams[idxs]
121
+
122
+ return label.astype(int)
123
+
124
+
125
+ def steering_vec(array, phi=0, theta=0, kd=np.pi):
126
+ # phi = azimuth
127
+ # theta = elevation
128
+ idxs = DeepMIMOv3.ant_indices(array)
129
+ resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
130
+ return resp / np.linalg.norm(resp)
131
+
132
+
133
+ def evaluate(model, dataloader):
134
+
135
+ model.eval()
136
+ running_loss = 0.0
137
+ outputs = []
138
+ criterionMCM = nn.MSELoss()
139
+
140
+ with torch.no_grad():
141
+ for batch in dataloader:
142
+ input_ids = batch[0]
143
+ masked_tokens = batch[1]
144
+ masked_pos = batch[2]
145
+
146
+ logits_lm, output = model(input_ids, masked_pos)
147
+
148
+ output_batch_preproc = output
149
+ outputs.append(output_batch_preproc)
150
+
151
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
152
+ loss = loss_lm/torch.var(masked_tokens)
153
+ running_loss += loss.item()
154
+
155
+ average_loss = running_loss / len(dataloader)
156
+ output_total = torch.cat(outputs, dim=0)
157
+
158
+ return average_loss, output_total
159
+
160
+
161
+ def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
162
+ labels = []
163
+ for scenario_idx in scenario_idxs:
164
+ scenario_name = scenarios_list()[scenario_idx]
165
+ # data = DeepMIMO_data_gen(scenario_name)
166
+ data = deepmimo_data[scenario_idx]
167
+ labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
168
+
169
+ preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
170
+
171
+ return preprocessed_chs
input_preprocess.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 16:13:29 2024
4
+
5
+ This script generates preprocessed data from wireless communication scenarios,
6
+ including token generation, patch creation, and data sampling for machine learning models.
7
+
8
+ @author: salikha4
9
+ """
10
+
11
+ import numpy as np
12
+ import os
13
+ from tqdm import tqdm
14
+ import time
15
+ import pickle
16
+ import DeepMIMOv3
17
+
18
+ vars_folder = 'variables/'
19
+ os.makedirs(vars_folder, exist_ok=True)
20
+
21
+ #%% Scenarios List
22
+ def scenarios_list():
23
+ """Returns an array of available scenarios."""
24
+ return np.array([
25
+ 'city_18_denver', 'city_15_indianapolis', 'city_19_oklahoma',
26
+ 'city_12_fortworth', 'city_11_santaclara', 'city_7_sandiego'
27
+ ])
28
+
29
+ #%% Token Generation
30
+ def tokenizer(deepmimo_data, gen_raw=True):
31
+ """
32
+ Generates tokens by preparing and preprocessing the dataset.
33
+
34
+ Args:
35
+ scenario_idxs (list): Indices of the scenarios.
36
+ patch_gen (bool): Whether to generate patches. Defaults to True.
37
+ patch_size (int): Size of each patch. Defaults to 16.
38
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data. Defaults to False.
39
+ gen_raw (bool): Whether to generate raw data. Defaults to False.
40
+ save_data (bool): Whether to save the preprocessed data. Defaults to False.
41
+
42
+ Returns:
43
+ preprocessed_data, sequence_length, element_length: Preprocessed data and related dimensions.
44
+ """
45
+
46
+ # Patch generation or loading
47
+ n_scenarios = len(deepmimo_data)
48
+ patches = [patch_maker(deepmimo_data[scenario_idx]) for scenario_idx in range(n_scenarios)]
49
+ patches = np.vstack(patches)
50
+
51
+ # Define dimensions
52
+ patch_size = patches.shape[2]
53
+ n_patches = patches.shape[1]
54
+ n_masks_half = int(0.15 * n_patches / 2)
55
+ sequence_length = n_patches + 1
56
+ element_length = patch_size
57
+
58
+ word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
59
+
60
+ # Generate preprocessed channels
61
+ preprocessed_data = []
62
+ for user_idx in tqdm(range(len(patches)), desc="Processing items"):
63
+ sample = make_sample(user_idx, patches, word2id, n_patches, n_masks_half, patch_size, gen_raw=gen_raw)
64
+ preprocessed_data.append(sample)
65
+
66
+ return preprocessed_data
67
+
68
+ #%% Patch Creation
69
+ def patch_maker(data, patch_size=16, norm_factor=1e6):
70
+ """
71
+ Creates patches from the dataset based on the scenario.
72
+
73
+ Args:-
74
+ patch_size (int): Size of each patch.
75
+ scenario (str): Selected scenario for data generation.
76
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
77
+ norm_factor (int): Normalization factor for channels.
78
+
79
+ Returns:
80
+ patch (numpy array): Generated patches.
81
+ """
82
+ idxs = np.where(data['user']['LoS'] != -1)[0]
83
+
84
+ # Reshaping and normalizing channels
85
+ original_ch = data['user']['channel'][idxs]
86
+ flat_channels = original_ch.reshape((original_ch.shape[0], -1)).astype(np.csingle)
87
+ flat_channels_complex = np.hstack((flat_channels.real, flat_channels.imag)) * norm_factor
88
+
89
+ # Create patches
90
+ n_patches = flat_channels_complex.shape[1] // patch_size
91
+ patch = np.zeros((len(idxs), n_patches, patch_size))
92
+ for idx in range(n_patches):
93
+ patch[:, idx, :] = flat_channels_complex[:, idx * patch_size:(idx + 1) * patch_size]
94
+
95
+ return patch
96
+
97
+
98
+ #%% Data Generation for Scenario Areas
99
+ def DeepMIMO_data_gen(scenario):
100
+ """
101
+ Generates or loads data for a given scenario.
102
+
103
+ Args:
104
+ scenario (str): Scenario name.
105
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
106
+ save_data (bool): Whether to save generated data.
107
+
108
+ Returns:
109
+ data (dict): Loaded or generated data.
110
+ """
111
+
112
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
113
+
114
+ deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
115
+ uniform_idxs = uniform_sampling(deepMIMO_dataset, [1, 1], len(parameters['user_rows']),
116
+ users_per_row=row_column_users[scenario]['n_per_row'])
117
+ data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
118
+
119
+ return data
120
+
121
+ #%%%
122
+ def get_parameters(scenario):
123
+
124
+ n_ant_bs = 32 #32
125
+ n_ant_ue = 1
126
+ n_subcarriers = 32 #32
127
+ scs = 30e3
128
+
129
+ row_column_users = {
130
+ 'city_18_denver': {
131
+ 'n_rows': 85,
132
+ 'n_per_row': 82
133
+ },
134
+ 'city_15_indianapolis': {
135
+ 'n_rows': 80,
136
+ 'n_per_row': 79
137
+ },
138
+ 'city_19_oklahoma': {
139
+ 'n_rows': 82,
140
+ 'n_per_row': 75
141
+ },
142
+ 'city_12_fortworth': {
143
+ 'n_rows': 86,
144
+ 'n_per_row': 72
145
+ },
146
+ 'city_11_santaclara': {
147
+ 'n_rows': 47,
148
+ 'n_per_row': 114
149
+ },
150
+ 'city_7_sandiego': {
151
+ 'n_rows': 71,
152
+ 'n_per_row': 83
153
+ }}
154
+
155
+ parameters = DeepMIMOv3.default_params()
156
+ parameters['dataset_folder'] = './scenarios'
157
+ parameters['scenario'] = scenario
158
+
159
+ if scenario == 'O1_3p5':
160
+ parameters['active_BS'] = np.array([4])
161
+ elif scenario in ['city_18_denver', 'city_15_indianapolis']:
162
+ parameters['active_BS'] = np.array([3])
163
+ else:
164
+ parameters['active_BS'] = np.array([1])
165
+
166
+ if scenario == 'Boston5G_3p5':
167
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
168
+ row_column_users[scenario]['n_rows'][1])
169
+ else:
170
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
171
+ parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
172
+ parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
173
+ parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
174
+ parameters['enable_BS2BS'] = False
175
+ parameters['OFDM']['subcarriers'] = n_subcarriers
176
+ parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
177
+
178
+ parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
179
+ parameters['num_paths'] = 20
180
+
181
+ return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
182
+
183
+
184
+ #%% Sample Generation
185
+ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
186
+ """
187
+ Generates a sample for each user, including masking and tokenizing.
188
+
189
+ Args:
190
+ user_idx (int): Index of the user.
191
+ patch (numpy array): Patches data.
192
+ word2id (dict): Dictionary for special tokens.
193
+ n_patches (int): Number of patches.
194
+ n_masks (int): Number of masks.
195
+ patch_size (int): Size of each patch.
196
+ gen_raw (bool): Whether to generate raw tokens.
197
+
198
+ Returns:
199
+ sample (list): Generated sample for the user.
200
+ """
201
+
202
+ tokens = patch[user_idx]
203
+ input_ids = np.vstack((word2id['[CLS]'], tokens))
204
+
205
+ real_tokens_size = int(n_patches / 2)
206
+ masks_pos_real = np.random.choice(range(0, real_tokens_size), size=n_masks, replace=False)
207
+ masks_pos_imag = masks_pos_real + real_tokens_size
208
+ masked_pos = np.hstack((masks_pos_real, masks_pos_imag)) + 1
209
+
210
+ masked_tokens = []
211
+ for pos in masked_pos:
212
+ original_masked_tokens = input_ids[pos].copy()
213
+ masked_tokens.append(original_masked_tokens)
214
+ if not gen_raw:
215
+ rnd_num = np.random.rand()
216
+ if rnd_num < 0.1:
217
+ input_ids[pos] = np.random.rand(patch_size)
218
+ elif rnd_num < 0.9:
219
+ input_ids[pos] = word2id['[MASK]']
220
+
221
+ return [input_ids, masked_tokens, masked_pos]
222
+
223
+
224
+ #%% Sampling and Data Selection
225
+ def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
226
+ """
227
+ Performs uniform sampling on the dataset.
228
+
229
+ Args:
230
+ dataset (dict): DeepMIMO dataset.
231
+ sampling_div (list): Step sizes along [x, y] dimensions.
232
+ n_rows (int): Number of rows for user selection.
233
+ users_per_row (int): Number of users per row.
234
+
235
+ Returns:
236
+ uniform_idxs (numpy array): Indices of the selected samples.
237
+ """
238
+ cols = np.arange(users_per_row, step=sampling_div[0])
239
+ rows = np.arange(n_rows, step=sampling_div[1])
240
+ uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
241
+
242
+ return uniform_idxs
243
+
244
+ def select_by_idx(dataset, idxs):
245
+ """
246
+ Selects a subset of the dataset based on the provided indices.
247
+
248
+ Args:
249
+ dataset (dict): Dataset to trim.
250
+ idxs (numpy array): Indices of users to select.
251
+
252
+ Returns:
253
+ dataset_t (list): Trimmed dataset based on selected indices.
254
+ """
255
+ dataset_t = [] # Trimmed dataset
256
+ for bs_idx in range(len(dataset)):
257
+ dataset_t.append({})
258
+ for key in dataset[bs_idx].keys():
259
+ dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
260
+ dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
261
+
262
+ return dataset_t
263
+
264
+ #%% Save and Load Utilities
265
+ def save_var(var, path):
266
+ """
267
+ Saves a variable to a pickle file.
268
+
269
+ Args:
270
+ var (object): Variable to be saved.
271
+ path (str): Path to save the file.
272
+
273
+ Returns:
274
+ None
275
+ """
276
+ path_full = path if path.endswith('.p') else (path + '.pickle')
277
+ with open(path_full, 'wb') as handle:
278
+ pickle.dump(var, handle)
279
+
280
+ def load_var(path):
281
+ """
282
+ Loads a variable from a pickle file.
283
+
284
+ Args:
285
+ path (str): Path of the file to load.
286
+
287
+ Returns:
288
+ var (object): Loaded variable.
289
+ """
290
+ path_full = path if path.endswith('.p') else (path + '.pickle')
291
+ with open(path_full, 'rb') as handle:
292
+ var = pickle.load(handle)
293
+
294
+ return var
295
+
296
+ #%%
load_data.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 17:19:21 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import requests
9
+ import zipfile
10
+ import os
11
+ import pickle
12
+
13
+ def load_DeepMIMO_data(zip_file_name='dataset.zip', p_file_name='deepmimo_data.p', extract_path='./data'):
14
+ url = "https://huggingface.co/datasets/sadjadalikhani/lwm/resolve/main/dataset.zip"
15
+ # Step 1: Download the ZIP file from Hugging Face
16
+ print(f"Downloading ZIP file from {url}...")
17
+ response = requests.get(url)
18
+ with open(zip_file_name, 'wb') as f:
19
+ f.write(response.content)
20
+ print(f"Downloaded ZIP file to {zip_file_name}.")
21
+
22
+ # Step 2: Unzip the file
23
+ print(f"Extracting ZIP file to {extract_path}...")
24
+ with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
25
+ zip_ref.extractall(extract_path)
26
+ print(f"Extracted ZIP file contents to {extract_path}.")
27
+
28
+ # Step 3: Load the .p file
29
+ p_file_path = os.path.join(extract_path, p_file_name)
30
+ print(f"Loading .p file from {p_file_path}...")
31
+ with open(p_file_path, 'rb') as f:
32
+ data = pickle.load(f)
33
+
34
+ print("Data successfully loaded from the .p file.")
35
+ return data
36
+
37
+
38
+
39
+
lwm_model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LWM (Large Wireless Model) Implementation and Loading
3
+
4
+ @author: salikha4
5
+
6
+ This module defines a Large Wireless Model (LWM) using PyTorch, including custom layers
7
+ for embedding, self-attention, and feed-forward networks. It also provides functionality
8
+ to load a pre-trained model from a checkpoint.
9
+
10
+ Dependencies:
11
+ - torch
12
+ - numpy
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+
20
+ ELEMENT_LENGTH = 16
21
+ D_MODEL = 64
22
+ MAX_LEN = 129
23
+ N_LAYERS = 12
24
+ N_HEADS = 12
25
+ D_FF = D_MODEL * 4
26
+ D_K = D_MODEL // N_HEADS
27
+ D_V = D_MODEL // N_HEADS
28
+ DROPOUT = 0.1
29
+
30
+ class LayerNormalization(nn.Module):
31
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
32
+ super().__init__()
33
+ self.eps = eps
34
+ self.alpha = nn.Parameter(torch.ones(d_model))
35
+ self.bias = nn.Parameter(torch.zeros(d_model))
36
+
37
+ def forward(self, x):
38
+ mean = x.mean(dim=-1, keepdim=True)
39
+ std = x.std(dim=-1, keepdim=True)
40
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
41
+
42
+ class Embedding(nn.Module):
43
+ def __init__(self, element_length, d_model, max_len):
44
+ super().__init__()
45
+ self.element_length = element_length
46
+ self.d_model = d_model
47
+ self.proj = nn.Linear(element_length, d_model)
48
+ self.pos_embed = nn.Embedding(max_len, d_model)
49
+ self.norm = LayerNormalization(d_model)
50
+
51
+ def forward(self, x):
52
+ seq_len = x.size(1)
53
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
54
+ pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
55
+ tok_emb = self.proj(x.float())
56
+ embedding = tok_emb + self.pos_embed(pos)
57
+ return self.norm(embedding)
58
+
59
+ class ScaledDotProductAttention(nn.Module):
60
+ def __init__(self):
61
+ super().__init__()
62
+
63
+ def forward(self, Q, K, V):
64
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(D_K)
65
+ attn = F.softmax(scores, dim=-1)
66
+ context = torch.matmul(attn, V)
67
+ return context, attn
68
+
69
+ class MultiHeadAttention(nn.Module):
70
+ def __init__(self):
71
+ super().__init__()
72
+ self.W_Q = nn.Linear(D_MODEL, D_K * N_HEADS)
73
+ self.W_K = nn.Linear(D_MODEL, D_K * N_HEADS)
74
+ self.W_V = nn.Linear(D_MODEL, D_V * N_HEADS)
75
+ self.linear = nn.Linear(N_HEADS * D_V, D_MODEL)
76
+ self.norm = LayerNormalization(D_MODEL)
77
+ self.dropout = nn.Dropout(DROPOUT)
78
+
79
+ def forward(self, Q, K, V):
80
+ residual, batch_size = Q, Q.size(0)
81
+ q_s = self.W_Q(Q).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
82
+ k_s = self.W_K(K).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
83
+ v_s = self.W_V(V).view(batch_size, -1, N_HEADS, D_V).transpose(1, 2)
84
+
85
+ context, attn = ScaledDotProductAttention()(q_s, k_s, v_s)
86
+ output = context.transpose(1, 2).contiguous().view(batch_size, -1, N_HEADS * D_V)
87
+ output = self.linear(output)
88
+ return residual + self.dropout(output), attn #residual + self.dropout(output), attn
89
+
90
+ class PoswiseFeedForwardNet(nn.Module):
91
+ def __init__(self):
92
+ super().__init__()
93
+ self.fc1 = nn.Linear(D_MODEL, D_FF)
94
+ self.fc2 = nn.Linear(D_FF, D_MODEL)
95
+ self.dropout = nn.Dropout(DROPOUT)
96
+ self.norm = LayerNormalization(D_MODEL)
97
+
98
+ def forward(self, x):
99
+ output = self.fc2(self.dropout(F.relu(self.fc1(x))))
100
+ return x + self.dropout(output) #x + self.dropout(output)
101
+
102
+ class EncoderLayer(nn.Module):
103
+ def __init__(self):
104
+ super().__init__()
105
+ self.enc_self_attn = MultiHeadAttention()
106
+ self.pos_ffn = PoswiseFeedForwardNet()
107
+ self.norm = LayerNormalization(D_MODEL)
108
+
109
+ def forward(self, enc_inputs):
110
+ attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
111
+ attn_outputs = self.norm(attn_outputs)
112
+ enc_outputs = self.pos_ffn(attn_outputs)
113
+ return enc_outputs, attn
114
+
115
+ class LWM(nn.Module):
116
+ def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
117
+ super().__init__()
118
+
119
+ self.embedding = Embedding(element_length, d_model, max_len)
120
+ self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
121
+ self.linear = nn.Linear(d_model, d_model)
122
+ self.norm = LayerNormalization(d_model)
123
+
124
+ embed_weight = self.embedding.proj.weight
125
+ d_model, n_dim = embed_weight.size()
126
+ self.decoder = nn.Linear(d_model, n_dim, bias=False)
127
+ self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
128
+ self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
129
+
130
+ def forward(self, input_ids, masked_pos):
131
+ output = self.embedding(input_ids)
132
+
133
+ for layer in self.layers:
134
+ output, _ = layer(output)
135
+
136
+ masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
137
+ h_masked = torch.gather(output, 1, masked_pos)
138
+ h_masked = self.norm(F.relu(self.linear(h_masked)))
139
+ logits_lm = self.decoder(h_masked) + self.decoder_bias
140
+
141
+ return logits_lm, output
142
+
143
+ def load_model(model, model_path, device=None):
144
+ """
145
+ Load a pre-trained LWM model from a checkpoint.
146
+
147
+ Args:
148
+ model_path (str): Path to the checkpoint file.
149
+ device (torch.device, optional): Device to load the model onto.
150
+
151
+ Returns:
152
+ LWM: Loaded model instance.
153
+ """
154
+ if device is None:
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+
157
+ #model = LWM(ELEMENT_LENGTH, D_MODEL, MAX_LEN, N_LAYERS)
158
+ state_dict = torch.load(model_path, map_location=device)
159
+ model.load_state_dict(state_dict)
160
+ model.to(device)
161
+ return model
162
+
163
+ # Usage example
164
+ if __name__ == "__main__":
165
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166
+ model_name = 'model_weights.pth'
167
+ model_path = f'models/{model_name}'
168
+
169
+ model = LWM()
170
+
171
+ model = load_model(model, model_path, device)
172
+ print(f"Model loaded successfully on {device}")
173
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 19:55:23 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+
14
+ from inference import *
15
+ from load_data import load_DeepMIMO_data
16
+ from input_preprocess import *
17
+ from lwm_model import LWM, load_model
18
+
19
+
20
+ ELEMENT_LENGTH = 16
21
+ D_MODEL = 64
22
+ MAX_LEN = 129
23
+ N_LAYERS = 12
24
+ N_HEADS = 12
25
+ D_FF = D_MODEL * 4
26
+ D_K = D_MODEL // N_HEADS
27
+ D_V = D_MODEL // N_HEADS
28
+ DROPOUT = 0.1
29
+
30
+ class LayerNormalization(nn.Module):
31
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
32
+ super().__init__()
33
+ self.eps = eps
34
+ self.alpha = nn.Parameter(torch.ones(d_model))
35
+ self.bias = nn.Parameter(torch.zeros(d_model))
36
+
37
+ def forward(self, x):
38
+ mean = x.mean(dim=-1, keepdim=True)
39
+ std = x.std(dim=-1, keepdim=True)
40
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
41
+
42
+ class Embedding(nn.Module):
43
+ def __init__(self, element_length, d_model, max_len):
44
+ super().__init__()
45
+ self.element_length = element_length
46
+ self.d_model = d_model
47
+ self.proj = nn.Linear(element_length, d_model)
48
+ self.pos_embed = nn.Embedding(max_len, d_model)
49
+ self.norm = LayerNormalization(d_model)
50
+
51
+ def forward(self, x):
52
+ seq_len = x.size(1)
53
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
54
+ pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
55
+ tok_emb = self.proj(x.float())
56
+ embedding = tok_emb + self.pos_embed(pos)
57
+ return self.norm(embedding)
58
+
59
+ class ScaledDotProductAttention(nn.Module):
60
+ def __init__(self):
61
+ super().__init__()
62
+
63
+ def forward(self, Q, K, V):
64
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(D_K)
65
+ attn = F.softmax(scores, dim=-1)
66
+ context = torch.matmul(attn, V)
67
+ return context, attn
68
+
69
+ class MultiHeadAttention(nn.Module):
70
+ def __init__(self):
71
+ super().__init__()
72
+ self.W_Q = nn.Linear(D_MODEL, D_K * N_HEADS)
73
+ self.W_K = nn.Linear(D_MODEL, D_K * N_HEADS)
74
+ self.W_V = nn.Linear(D_MODEL, D_V * N_HEADS)
75
+ self.linear = nn.Linear(N_HEADS * D_V, D_MODEL)
76
+ self.norm = LayerNormalization(D_MODEL)
77
+ self.dropout = nn.Dropout(DROPOUT)
78
+
79
+ def forward(self, Q, K, V):
80
+ residual, batch_size = Q, Q.size(0)
81
+ q_s = self.W_Q(Q).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
82
+ k_s = self.W_K(K).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
83
+ v_s = self.W_V(V).view(batch_size, -1, N_HEADS, D_V).transpose(1, 2)
84
+
85
+ context, attn = ScaledDotProductAttention()(q_s, k_s, v_s)
86
+ output = context.transpose(1, 2).contiguous().view(batch_size, -1, N_HEADS * D_V)
87
+ output = self.linear(output)
88
+ return residual + self.dropout(output), attn #residual + self.dropout(output), attn
89
+
90
+ class PoswiseFeedForwardNet(nn.Module):
91
+ def __init__(self):
92
+ super().__init__()
93
+ self.fc1 = nn.Linear(D_MODEL, D_FF)
94
+ self.fc2 = nn.Linear(D_FF, D_MODEL)
95
+ self.dropout = nn.Dropout(DROPOUT)
96
+ self.norm = LayerNormalization(D_MODEL)
97
+
98
+ def forward(self, x):
99
+ output = self.fc2(self.dropout(F.relu(self.fc1(x))))
100
+ return x + self.dropout(output) #x + self.dropout(output)
101
+
102
+ class EncoderLayer(nn.Module):
103
+ def __init__(self):
104
+ super().__init__()
105
+ self.enc_self_attn = MultiHeadAttention()
106
+ self.pos_ffn = PoswiseFeedForwardNet()
107
+ self.norm = LayerNormalization(D_MODEL)
108
+
109
+ def forward(self, enc_inputs):
110
+ attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
111
+ attn_outputs = self.norm(attn_outputs)
112
+ enc_outputs = self.pos_ffn(attn_outputs)
113
+ return enc_outputs, attn
114
+
115
+ class LWM(torch.nn.Module):
116
+ def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
117
+ super().__init__()
118
+
119
+ self.embedding = Embedding(element_length, d_model, max_len)
120
+ self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
121
+ self.linear = nn.Linear(d_model, d_model)
122
+ self.norm = LayerNormalization(d_model)
123
+
124
+ embed_weight = self.embedding.proj.weight
125
+ d_model, n_dim = embed_weight.size()
126
+ self.decoder = nn.Linear(d_model, n_dim, bias=False)
127
+ self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
128
+ self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
129
+
130
+ def forward(self, input_ids, masked_pos):
131
+ output = self.embedding(input_ids)
132
+
133
+ for layer in self.layers:
134
+ output, _ = layer(output)
135
+
136
+ masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
137
+ h_masked = torch.gather(output, 1, masked_pos)
138
+ h_masked = self.norm(F.relu(self.linear(h_masked)))
139
+ logits_lm = self.decoder(h_masked) + self.decoder_bias
140
+
141
+ return logits_lm, output
142
+
143
+ @classmethod
144
+ def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
145
+ # Define model
146
+ model = cls().to(device)
147
+
148
+ # Download the model weights (from a remote or local repository)
149
+ ckpt_path = f'https://huggingface.co/sadjadalikhani/lwm/resolve/main/{ckpt_name}'
150
+
151
+ # Load the model weights
152
+ model.load_state_dict(torch.hub.load_state_dict_from_url(ckpt_path, map_location=device))
153
+ print(f"Model loaded successfully from {ckpt_path} to {device}")
154
+
155
+ return model
156
+
157
+ def forward(self, x):
158
+ # Define the forward pass for the model
159
+ pass
160
+