Sadjad Alikhani
commited on
Upload inference.py
Browse files- inference.py +11 -17
inference.py
CHANGED
@@ -33,24 +33,18 @@ if torch.cuda.is_available():
|
|
33 |
|
34 |
def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
35 |
|
36 |
-
|
37 |
-
dataset = prepare_for_LWM(preprocessed_chs, device)
|
38 |
-
elif input_type == 'raw':
|
39 |
-
dataset = create_raw_dataset(preprocessed_chs, device)
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
embedding_data = embedding_data[:, 1:]
|
52 |
-
|
53 |
-
dataset = embedding_data.float()
|
54 |
|
55 |
return dataset
|
56 |
|
|
|
33 |
|
34 |
def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
35 |
|
36 |
+
dataset = prepare_for_LWM(preprocessed_chs, device)
|
|
|
|
|
|
|
37 |
|
38 |
+
# Process data through LWM
|
39 |
+
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
40 |
+
print(f'LWM loss: {lwm_loss:.4f}')
|
41 |
+
|
42 |
+
if input_type == 'cls_emb':
|
43 |
+
embedding_data = embedding_data[:, 0]
|
44 |
+
elif input_type == 'channel_emb':
|
45 |
+
embedding_data = embedding_data[:, 1:]
|
46 |
+
|
47 |
+
dataset = embedding_data.float()
|
|
|
|
|
|
|
48 |
|
49 |
return dataset
|
50 |
|