Upload tutorial.py
Browse files- tutorial.py +33 -3
tutorial.py
CHANGED
@@ -74,11 +74,13 @@ scenario_names = np.array([
|
|
74 |
scenario_idxs = np.array([0, 1, 2, 3, 4, 5])[3]
|
75 |
selected_scenario_names = scenario_names[scenario_idxs]
|
76 |
|
|
|
|
|
77 |
preprocessed_chs = tokenizer(
|
78 |
selected_scenario_names=selected_scenario_names,
|
79 |
manual_data=None,
|
80 |
gen_raw=True,
|
81 |
-
snr_db=
|
82 |
)
|
83 |
|
84 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
@@ -87,7 +89,7 @@ model = lwm.from_pretrained(device=device)
|
|
87 |
#%%
|
88 |
from inference import lwm_inference, create_raw_dataset
|
89 |
input_types = ['cls_emb', 'channel_emb', 'raw']
|
90 |
-
selected_input_type = input_types[
|
91 |
|
92 |
if selected_input_type in ['cls_emb', 'channel_emb']:
|
93 |
dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
|
@@ -139,7 +141,7 @@ for task in tasks:
|
|
139 |
#%% TRAINING
|
140 |
#%% TRAINING PARAMETERS
|
141 |
task = ['LoS/NLoS Classification', 'Beam Prediction'][0] # Select the task
|
142 |
-
n_trials =
|
143 |
num_classes = 2 if task == 'LoS/NLoS Classification' else n_beams # Set number of classes based on the task
|
144 |
input_types = ['raw', 'cls_emb'] # Types of input data
|
145 |
split_ratios = np.array([.005, .0075, .01, .015, .02, .03,
|
@@ -174,6 +176,20 @@ for input_type_idx, input_type in enumerate(input_types):
|
|
174 |
print(f"\ninput type: {input_type}, \nnumber of training samples: {int(split_ratio*len(dataset))}, \ntrial: {trial}\n")
|
175 |
|
176 |
torch.manual_seed(trial) # Set seed for reproducibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
train_loader, test_loader = get_data_loaders(
|
178 |
dataset,
|
179 |
labels,
|
@@ -240,6 +256,20 @@ for input_type_idx, input_type in enumerate(input_types):
|
|
240 |
for trial in range(n_trials):
|
241 |
|
242 |
torch.manual_seed(trial) # Set seed for reproducibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
train_loader, test_loader = get_data_loaders(
|
244 |
dataset,
|
245 |
labels,
|
|
|
74 |
scenario_idxs = np.array([0, 1, 2, 3, 4, 5])[3]
|
75 |
selected_scenario_names = scenario_names[scenario_idxs]
|
76 |
|
77 |
+
snr_db = None
|
78 |
+
|
79 |
preprocessed_chs = tokenizer(
|
80 |
selected_scenario_names=selected_scenario_names,
|
81 |
manual_data=None,
|
82 |
gen_raw=True,
|
83 |
+
snr_db=snr_db
|
84 |
)
|
85 |
|
86 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
89 |
#%%
|
90 |
from inference import lwm_inference, create_raw_dataset
|
91 |
input_types = ['cls_emb', 'channel_emb', 'raw']
|
92 |
+
selected_input_type = input_types[1]
|
93 |
|
94 |
if selected_input_type in ['cls_emb', 'channel_emb']:
|
95 |
dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
|
|
|
141 |
#%% TRAINING
|
142 |
#%% TRAINING PARAMETERS
|
143 |
task = ['LoS/NLoS Classification', 'Beam Prediction'][0] # Select the task
|
144 |
+
n_trials = 1 # Number of trials for each configuration
|
145 |
num_classes = 2 if task == 'LoS/NLoS Classification' else n_beams # Set number of classes based on the task
|
146 |
input_types = ['raw', 'cls_emb'] # Types of input data
|
147 |
split_ratios = np.array([.005, .0075, .01, .015, .02, .03,
|
|
|
176 |
print(f"\ninput type: {input_type}, \nnumber of training samples: {int(split_ratio*len(dataset))}, \ntrial: {trial}\n")
|
177 |
|
178 |
torch.manual_seed(trial) # Set seed for reproducibility
|
179 |
+
|
180 |
+
if snr_db is not None:
|
181 |
+
preprocessed_chs = tokenizer(
|
182 |
+
selected_scenario_names=selected_scenario_names,
|
183 |
+
manual_data=None,
|
184 |
+
gen_raw=True,
|
185 |
+
snr_db=snr_db
|
186 |
+
)
|
187 |
+
if input_type in ['cls_emb', 'channel_emb']:
|
188 |
+
dataset = lwm_inference(preprocessed_chs, input_type, model, device)
|
189 |
+
else:
|
190 |
+
dataset = create_raw_dataset(preprocessed_chs, device)
|
191 |
+
dataset = dataset.view(dataset.size(0), -1)
|
192 |
+
|
193 |
train_loader, test_loader = get_data_loaders(
|
194 |
dataset,
|
195 |
labels,
|
|
|
256 |
for trial in range(n_trials):
|
257 |
|
258 |
torch.manual_seed(trial) # Set seed for reproducibility
|
259 |
+
|
260 |
+
if snr_db is not None:
|
261 |
+
preprocessed_chs = tokenizer(
|
262 |
+
selected_scenario_names=selected_scenario_names,
|
263 |
+
manual_data=None,
|
264 |
+
gen_raw=True,
|
265 |
+
snr_db=snr_db
|
266 |
+
)
|
267 |
+
if input_type in ['cls_emb', 'channel_emb']:
|
268 |
+
dataset = lwm_inference(preprocessed_chs, input_type, model, device)
|
269 |
+
else:
|
270 |
+
dataset = create_raw_dataset(preprocessed_chs, device)
|
271 |
+
dataset = dataset.view(dataset.size(0), -1)
|
272 |
+
|
273 |
train_loader, test_loader = get_data_loaders(
|
274 |
dataset,
|
275 |
labels,
|