Sadjad Alikhani commited on
Commit
a21070d
·
verified ·
1 Parent(s): 571a0d4

Upload input_preprocess.py

Browse files
Files changed (1) hide show
  1. input_preprocess.py +295 -0
input_preprocess.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 16:13:29 2024
4
+
5
+ This script generates preprocessed data from wireless communication scenarios,
6
+ including token generation, patch creation, and data sampling for machine learning models.
7
+
8
+ @author: salikha4
9
+ """
10
+
11
+ import numpy as np
12
+ import os
13
+ from tqdm import tqdm
14
+ import time
15
+ import pickle
16
+ import DeepMIMOv3
17
+
18
+ #%% Scenarios List
19
+ def scenarios_list():
20
+ """Returns an array of available scenarios."""
21
+ return np.array([
22
+ 'city_18_denver', 'city_15_indianapolis', 'city_19_oklahoma',
23
+ 'city_12_fortworth', 'city_11_santaclara', 'city_7_sandiego'
24
+ ])
25
+
26
+ #%% Token Generation
27
+ def tokenizer(scenario_names, gen_raw=True):
28
+ """
29
+ Generates tokens by preparing and preprocessing the dataset.
30
+
31
+ Args:
32
+ scenario_idxs (list): Indices of the scenarios.
33
+ patch_gen (bool): Whether to generate patches. Defaults to True.
34
+ patch_size (int): Size of each patch. Defaults to 16.
35
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data. Defaults to False.
36
+ gen_raw (bool): Whether to generate raw data. Defaults to False.
37
+ save_data (bool): Whether to save the preprocessed data. Defaults to False.
38
+
39
+ Returns:
40
+ preprocessed_data, sequence_length, element_length: Preprocessed data and related dimensions.
41
+ """
42
+
43
+ # Patch generation or loading
44
+ deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in scenario_names]
45
+ n_scenarios = len(scenario_names)
46
+
47
+ patches = [patch_maker(deepmimo_data[scenario_idx]) for scenario_idx in range(n_scenarios)]
48
+ patches = np.vstack(patches)
49
+
50
+ # Define dimensions
51
+ patch_size = patches.shape[2]
52
+ n_patches = patches.shape[1]
53
+ n_masks_half = int(0.15 * n_patches / 2)
54
+ sequence_length = n_patches + 1
55
+ element_length = patch_size
56
+
57
+ word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
58
+
59
+ # Generate preprocessed channels
60
+ preprocessed_data = []
61
+ for user_idx in tqdm(range(len(patches)), desc="Processing items"):
62
+ sample = make_sample(user_idx, patches, word2id, n_patches, n_masks_half, patch_size, gen_raw=gen_raw)
63
+ preprocessed_data.append(sample)
64
+
65
+ return preprocessed_data
66
+
67
+ #%% Patch Creation
68
+ def patch_maker(data, patch_size=16, norm_factor=1e6):
69
+ """
70
+ Creates patches from the dataset based on the scenario.
71
+
72
+ Args:-
73
+ patch_size (int): Size of each patch.
74
+ scenario (str): Selected scenario for data generation.
75
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
76
+ norm_factor (int): Normalization factor for channels.
77
+
78
+ Returns:
79
+ patch (numpy array): Generated patches.
80
+ """
81
+ idxs = np.where(data['user']['LoS'] != -1)[0]
82
+
83
+ # Reshaping and normalizing channels
84
+ original_ch = data['user']['channel'][idxs]
85
+ flat_channels = original_ch.reshape((original_ch.shape[0], -1)).astype(np.csingle)
86
+ flat_channels_complex = np.hstack((flat_channels.real, flat_channels.imag)) * norm_factor
87
+
88
+ # Create patches
89
+ n_patches = flat_channels_complex.shape[1] // patch_size
90
+ patch = np.zeros((len(idxs), n_patches, patch_size))
91
+ for idx in range(n_patches):
92
+ patch[:, idx, :] = flat_channels_complex[:, idx * patch_size:(idx + 1) * patch_size]
93
+
94
+ return patch
95
+
96
+
97
+ #%% Data Generation for Scenario Areas
98
+ def DeepMIMO_data_gen(scenario):
99
+ """
100
+ Generates or loads data for a given scenario.
101
+
102
+ Args:
103
+ scenario (str): Scenario name.
104
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
105
+ save_data (bool): Whether to save generated data.
106
+
107
+ Returns:
108
+ data (dict): Loaded or generated data.
109
+ """
110
+
111
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
112
+
113
+ deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
114
+ uniform_idxs = uniform_sampling(deepMIMO_dataset, [1, 1], len(parameters['user_rows']),
115
+ users_per_row=row_column_users[scenario]['n_per_row'])
116
+ data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
117
+
118
+ return data
119
+
120
+ #%%%
121
+ def get_parameters(scenario):
122
+
123
+ n_ant_bs = 32 #32
124
+ n_ant_ue = 1
125
+ n_subcarriers = 32 #32
126
+ scs = 30e3
127
+
128
+ row_column_users = {
129
+ 'city_18_denver': {
130
+ 'n_rows': 85,
131
+ 'n_per_row': 82
132
+ },
133
+ 'city_15_indianapolis': {
134
+ 'n_rows': 80,
135
+ 'n_per_row': 79
136
+ },
137
+ 'city_19_oklahoma': {
138
+ 'n_rows': 82,
139
+ 'n_per_row': 75
140
+ },
141
+ 'city_12_fortworth': {
142
+ 'n_rows': 86,
143
+ 'n_per_row': 72
144
+ },
145
+ 'city_11_santaclara': {
146
+ 'n_rows': 47,
147
+ 'n_per_row': 114
148
+ },
149
+ 'city_7_sandiego': {
150
+ 'n_rows': 71,
151
+ 'n_per_row': 83
152
+ }}
153
+
154
+ parameters = DeepMIMOv3.default_params()
155
+ parameters['dataset_folder'] = './scenarios'
156
+ parameters['scenario'] = scenario
157
+
158
+ if scenario == 'O1_3p5':
159
+ parameters['active_BS'] = np.array([4])
160
+ elif scenario in ['city_18_denver', 'city_15_indianapolis']:
161
+ parameters['active_BS'] = np.array([3])
162
+ else:
163
+ parameters['active_BS'] = np.array([1])
164
+
165
+ if scenario == 'Boston5G_3p5':
166
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
167
+ row_column_users[scenario]['n_rows'][1])
168
+ else:
169
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
170
+ parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
171
+ parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
172
+ parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
173
+ parameters['enable_BS2BS'] = False
174
+ parameters['OFDM']['subcarriers'] = n_subcarriers
175
+ parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
176
+
177
+ parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
178
+ parameters['num_paths'] = 20
179
+
180
+ return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
181
+
182
+
183
+ #%% Sample Generation
184
+ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
185
+ """
186
+ Generates a sample for each user, including masking and tokenizing.
187
+
188
+ Args:
189
+ user_idx (int): Index of the user.
190
+ patch (numpy array): Patches data.
191
+ word2id (dict): Dictionary for special tokens.
192
+ n_patches (int): Number of patches.
193
+ n_masks (int): Number of masks.
194
+ patch_size (int): Size of each patch.
195
+ gen_raw (bool): Whether to generate raw tokens.
196
+
197
+ Returns:
198
+ sample (list): Generated sample for the user.
199
+ """
200
+
201
+ tokens = patch[user_idx]
202
+ input_ids = np.vstack((word2id['[CLS]'], tokens))
203
+
204
+ real_tokens_size = int(n_patches / 2)
205
+ masks_pos_real = np.random.choice(range(0, real_tokens_size), size=n_masks, replace=False)
206
+ masks_pos_imag = masks_pos_real + real_tokens_size
207
+ masked_pos = np.hstack((masks_pos_real, masks_pos_imag)) + 1
208
+
209
+ masked_tokens = []
210
+ for pos in masked_pos:
211
+ original_masked_tokens = input_ids[pos].copy()
212
+ masked_tokens.append(original_masked_tokens)
213
+ if not gen_raw:
214
+ rnd_num = np.random.rand()
215
+ if rnd_num < 0.1:
216
+ input_ids[pos] = np.random.rand(patch_size)
217
+ elif rnd_num < 0.9:
218
+ input_ids[pos] = word2id['[MASK]']
219
+
220
+ return [input_ids, masked_tokens, masked_pos]
221
+
222
+
223
+ #%% Sampling and Data Selection
224
+ def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
225
+ """
226
+ Performs uniform sampling on the dataset.
227
+
228
+ Args:
229
+ dataset (dict): DeepMIMO dataset.
230
+ sampling_div (list): Step sizes along [x, y] dimensions.
231
+ n_rows (int): Number of rows for user selection.
232
+ users_per_row (int): Number of users per row.
233
+
234
+ Returns:
235
+ uniform_idxs (numpy array): Indices of the selected samples.
236
+ """
237
+ cols = np.arange(users_per_row, step=sampling_div[0])
238
+ rows = np.arange(n_rows, step=sampling_div[1])
239
+ uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
240
+
241
+ return uniform_idxs
242
+
243
+ def select_by_idx(dataset, idxs):
244
+ """
245
+ Selects a subset of the dataset based on the provided indices.
246
+
247
+ Args:
248
+ dataset (dict): Dataset to trim.
249
+ idxs (numpy array): Indices of users to select.
250
+
251
+ Returns:
252
+ dataset_t (list): Trimmed dataset based on selected indices.
253
+ """
254
+ dataset_t = [] # Trimmed dataset
255
+ for bs_idx in range(len(dataset)):
256
+ dataset_t.append({})
257
+ for key in dataset[bs_idx].keys():
258
+ dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
259
+ dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
260
+
261
+ return dataset_t
262
+
263
+ #%% Save and Load Utilities
264
+ def save_var(var, path):
265
+ """
266
+ Saves a variable to a pickle file.
267
+
268
+ Args:
269
+ var (object): Variable to be saved.
270
+ path (str): Path to save the file.
271
+
272
+ Returns:
273
+ None
274
+ """
275
+ path_full = path if path.endswith('.p') else (path + '.pickle')
276
+ with open(path_full, 'wb') as handle:
277
+ pickle.dump(var, handle)
278
+
279
+ def load_var(path):
280
+ """
281
+ Loads a variable from a pickle file.
282
+
283
+ Args:
284
+ path (str): Path of the file to load.
285
+
286
+ Returns:
287
+ var (object): Loaded variable.
288
+ """
289
+ path_full = path if path.endswith('.p') else (path + '.pickle')
290
+ with open(path_full, 'rb') as handle:
291
+ var = pickle.load(handle)
292
+
293
+ return var
294
+
295
+ #%%