Sadjad Alikhani commited on
Commit
68f053b
·
verified ·
1 Parent(s): 1d6fec0

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +11 -17
inference.py CHANGED
@@ -33,24 +33,18 @@ if torch.cuda.is_available():
33
 
34
  def lwm_inference(preprocessed_chs, input_type, lwm_model):
35
 
36
- if input_type in ['cls_emb', 'channel_emb']:
37
- dataset = prepare_for_LWM(preprocessed_chs, device)
38
- elif input_type == 'raw':
39
- dataset = create_raw_dataset(preprocessed_chs, device)
40
 
41
- if input_type in ['cls_emb','channel_emb']:
42
-
43
- # Process data through LWM
44
- lwm_loss, embedding_data = evaluate(lwm_model, dataset)
45
-
46
- print(f'LWM loss: {lwm_loss:.4f}')
47
-
48
- if input_type == 'cls_emb':
49
- embedding_data = embedding_data[:, 0]
50
- elif input_type == 'channel_emb':
51
- embedding_data = embedding_data[:, 1:]
52
-
53
- dataset = embedding_data.float()
54
 
55
  return dataset
56
 
 
33
 
34
  def lwm_inference(preprocessed_chs, input_type, lwm_model):
35
 
36
+ dataset = prepare_for_LWM(preprocessed_chs, device)
 
 
 
37
 
38
+ # Process data through LWM
39
+ lwm_loss, embedding_data = evaluate(lwm_model, dataset)
40
+ print(f'LWM loss: {lwm_loss:.4f}')
41
+
42
+ if input_type == 'cls_emb':
43
+ embedding_data = embedding_data[:, 0]
44
+ elif input_type == 'channel_emb':
45
+ embedding_data = embedding_data[:, 1:]
46
+
47
+ dataset = embedding_data.float()
 
 
 
48
 
49
  return dataset
50