Update input_preprocess.py
Browse files- input_preprocess.py +1 -1
input_preprocess.py
CHANGED
|
@@ -371,4 +371,4 @@ def create_labels(task, scenario_names, n_beams=64):
|
|
| 371 |
for scenario_name in scenario_names:
|
| 372 |
data = DeepMIMO_data_gen(scenario_name)
|
| 373 |
labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
|
| 374 |
-
return labels
|
|
|
|
| 371 |
for scenario_name in scenario_names:
|
| 372 |
data = DeepMIMO_data_gen(scenario_name)
|
| 373 |
labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
|
| 374 |
+
return torch.tensor(labels)
|