Sadjad Alikhani
commited on
Commit
•
4d31f33
1
Parent(s):
c79de44
Update input_preprocess.py
Browse files- input_preprocess.py +2 -22
input_preprocess.py
CHANGED
@@ -16,19 +16,6 @@ import pickle
|
|
16 |
import DeepMIMOv3
|
17 |
import torch
|
18 |
|
19 |
-
def set_random_seed(seed=42):
|
20 |
-
torch.manual_seed(seed)
|
21 |
-
np.random.seed(seed)
|
22 |
-
#random.seed(seed)
|
23 |
-
if torch.cuda.is_available():
|
24 |
-
torch.cuda.manual_seed_all(seed)
|
25 |
-
# Ensures deterministic behavior
|
26 |
-
torch.backends.cudnn.deterministic = True
|
27 |
-
torch.backends.cudnn.benchmark = False
|
28 |
-
|
29 |
-
# Apply random seed
|
30 |
-
set_random_seed()
|
31 |
-
|
32 |
#%% Scenarios List
|
33 |
def scenarios_list():
|
34 |
"""Returns an array of available scenarios."""
|
@@ -208,7 +195,6 @@ def get_parameters(scenario):
|
|
208 |
|
209 |
return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
|
210 |
|
211 |
-
|
212 |
#%% Sample Generation
|
213 |
def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
|
214 |
"""
|
@@ -226,7 +212,6 @@ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_ra
|
|
226 |
Returns:
|
227 |
sample (list): Generated sample for the user.
|
228 |
"""
|
229 |
-
set_random_seed()
|
230 |
|
231 |
tokens = patch[user_idx]
|
232 |
input_ids = np.vstack((word2id['[CLS]'], tokens))
|
@@ -246,8 +231,7 @@ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_ra
|
|
246 |
input_ids[pos] = np.random.rand(patch_size)
|
247 |
elif rnd_num < 0.9:
|
248 |
input_ids[pos] = word2id['[MASK]']
|
249 |
-
|
250 |
-
# print(f'masked_pos: {masked_pos}')
|
251 |
return [input_ids, masked_tokens, masked_pos]
|
252 |
|
253 |
|
@@ -323,8 +307,7 @@ def load_var(path):
|
|
323 |
|
324 |
return var
|
325 |
|
326 |
-
#%%
|
327 |
-
|
328 |
def label_gen(task, data, scenario, n_beams=64):
|
329 |
|
330 |
idxs = np.where(data['user']['LoS'] != -1)[0]
|
@@ -364,13 +347,10 @@ def label_gen(task, data, scenario, n_beams=64):
|
|
364 |
return label.astype(int)
|
365 |
|
366 |
def steering_vec(array, phi=0, theta=0, kd=np.pi):
|
367 |
-
# phi = azimuth
|
368 |
-
# theta = elevation
|
369 |
idxs = DeepMIMOv3.ant_indices(array)
|
370 |
resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
|
371 |
return resp / np.linalg.norm(resp)
|
372 |
|
373 |
-
|
374 |
def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
|
375 |
labels = []
|
376 |
for scenario_idx in scenario_idxs:
|
|
|
16 |
import DeepMIMOv3
|
17 |
import torch
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
#%% Scenarios List
|
20 |
def scenarios_list():
|
21 |
"""Returns an array of available scenarios."""
|
|
|
195 |
|
196 |
return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
|
197 |
|
|
|
198 |
#%% Sample Generation
|
199 |
def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
|
200 |
"""
|
|
|
212 |
Returns:
|
213 |
sample (list): Generated sample for the user.
|
214 |
"""
|
|
|
215 |
|
216 |
tokens = patch[user_idx]
|
217 |
input_ids = np.vstack((word2id['[CLS]'], tokens))
|
|
|
231 |
input_ids[pos] = np.random.rand(patch_size)
|
232 |
elif rnd_num < 0.9:
|
233 |
input_ids[pos] = word2id['[MASK]']
|
234 |
+
|
|
|
235 |
return [input_ids, masked_tokens, masked_pos]
|
236 |
|
237 |
|
|
|
307 |
|
308 |
return var
|
309 |
|
310 |
+
#%% Label Generation
|
|
|
311 |
def label_gen(task, data, scenario, n_beams=64):
|
312 |
|
313 |
idxs = np.where(data['user']['LoS'] != -1)[0]
|
|
|
347 |
return label.astype(int)
|
348 |
|
349 |
def steering_vec(array, phi=0, theta=0, kd=np.pi):
|
|
|
|
|
350 |
idxs = DeepMIMOv3.ant_indices(array)
|
351 |
resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
|
352 |
return resp / np.linalg.norm(resp)
|
353 |
|
|
|
354 |
def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
|
355 |
labels = []
|
356 |
for scenario_idx in scenario_idxs:
|