wi-lab commited on
Commit
50b1857
·
verified ·
1 Parent(s): 21b55b0

Upload tutorial.py

Browse files
Files changed (1) hide show
  1. tutorial.py +33 -3
tutorial.py CHANGED
@@ -74,11 +74,13 @@ scenario_names = np.array([
74
  scenario_idxs = np.array([0, 1, 2, 3, 4, 5])[3]
75
  selected_scenario_names = scenario_names[scenario_idxs]
76
 
 
 
77
  preprocessed_chs = tokenizer(
78
  selected_scenario_names=selected_scenario_names,
79
  manual_data=None,
80
  gen_raw=True,
81
- snr_db=None
82
  )
83
 
84
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -87,7 +89,7 @@ model = lwm.from_pretrained(device=device)
87
  #%%
88
  from inference import lwm_inference, create_raw_dataset
89
  input_types = ['cls_emb', 'channel_emb', 'raw']
90
- selected_input_type = input_types[2]
91
 
92
  if selected_input_type in ['cls_emb', 'channel_emb']:
93
  dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
@@ -139,7 +141,7 @@ for task in tasks:
139
  #%% TRAINING
140
  #%% TRAINING PARAMETERS
141
  task = ['LoS/NLoS Classification', 'Beam Prediction'][0] # Select the task
142
- n_trials = 10 # Number of trials for each configuration
143
  num_classes = 2 if task == 'LoS/NLoS Classification' else n_beams # Set number of classes based on the task
144
  input_types = ['raw', 'cls_emb'] # Types of input data
145
  split_ratios = np.array([.005, .0075, .01, .015, .02, .03,
@@ -174,6 +176,20 @@ for input_type_idx, input_type in enumerate(input_types):
174
  print(f"\ninput type: {input_type}, \nnumber of training samples: {int(split_ratio*len(dataset))}, \ntrial: {trial}\n")
175
 
176
  torch.manual_seed(trial) # Set seed for reproducibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  train_loader, test_loader = get_data_loaders(
178
  dataset,
179
  labels,
@@ -240,6 +256,20 @@ for input_type_idx, input_type in enumerate(input_types):
240
  for trial in range(n_trials):
241
 
242
  torch.manual_seed(trial) # Set seed for reproducibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  train_loader, test_loader = get_data_loaders(
244
  dataset,
245
  labels,
 
74
  scenario_idxs = np.array([0, 1, 2, 3, 4, 5])[3]
75
  selected_scenario_names = scenario_names[scenario_idxs]
76
 
77
+ snr_db = None
78
+
79
  preprocessed_chs = tokenizer(
80
  selected_scenario_names=selected_scenario_names,
81
  manual_data=None,
82
  gen_raw=True,
83
+ snr_db=snr_db
84
  )
85
 
86
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
89
  #%%
90
  from inference import lwm_inference, create_raw_dataset
91
  input_types = ['cls_emb', 'channel_emb', 'raw']
92
+ selected_input_type = input_types[1]
93
 
94
  if selected_input_type in ['cls_emb', 'channel_emb']:
95
  dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
 
141
  #%% TRAINING
142
  #%% TRAINING PARAMETERS
143
  task = ['LoS/NLoS Classification', 'Beam Prediction'][0] # Select the task
144
+ n_trials = 1 # Number of trials for each configuration
145
  num_classes = 2 if task == 'LoS/NLoS Classification' else n_beams # Set number of classes based on the task
146
  input_types = ['raw', 'cls_emb'] # Types of input data
147
  split_ratios = np.array([.005, .0075, .01, .015, .02, .03,
 
176
  print(f"\ninput type: {input_type}, \nnumber of training samples: {int(split_ratio*len(dataset))}, \ntrial: {trial}\n")
177
 
178
  torch.manual_seed(trial) # Set seed for reproducibility
179
+
180
+ if snr_db is not None:
181
+ preprocessed_chs = tokenizer(
182
+ selected_scenario_names=selected_scenario_names,
183
+ manual_data=None,
184
+ gen_raw=True,
185
+ snr_db=snr_db
186
+ )
187
+ if input_type in ['cls_emb', 'channel_emb']:
188
+ dataset = lwm_inference(preprocessed_chs, input_type, model, device)
189
+ else:
190
+ dataset = create_raw_dataset(preprocessed_chs, device)
191
+ dataset = dataset.view(dataset.size(0), -1)
192
+
193
  train_loader, test_loader = get_data_loaders(
194
  dataset,
195
  labels,
 
256
  for trial in range(n_trials):
257
 
258
  torch.manual_seed(trial) # Set seed for reproducibility
259
+
260
+ if snr_db is not None:
261
+ preprocessed_chs = tokenizer(
262
+ selected_scenario_names=selected_scenario_names,
263
+ manual_data=None,
264
+ gen_raw=True,
265
+ snr_db=snr_db
266
+ )
267
+ if input_type in ['cls_emb', 'channel_emb']:
268
+ dataset = lwm_inference(preprocessed_chs, input_type, model, device)
269
+ else:
270
+ dataset = create_raw_dataset(preprocessed_chs, device)
271
+ dataset = dataset.view(dataset.size(0), -1)
272
+
273
  train_loader, test_loader = get_data_loaders(
274
  dataset,
275
  labels,