Sadjad Alikhani commited on
Commit
fedc36d
1 Parent(s): b4f2449

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +2 -2
inference.py CHANGED
@@ -25,7 +25,7 @@ warnings.filterwarnings('ignore')
25
 
26
  def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
27
 
28
- dataset = prepare_for_LWM(preprocessed_chs, device)
29
  # Process data through LWM
30
  lwm_loss, embedding_data = evaluate(lwm_model, dataset)
31
  print(f'LWM loss: {lwm_loss:.4f}')
@@ -38,7 +38,7 @@ def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
38
  dataset = embedding_data.float()
39
  return dataset
40
 
41
- def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
42
 
43
  input_ids, masked_tokens, masked_pos = zip(*data)
44
 
 
25
 
26
  def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
27
 
28
+ dataset = prepare_for_lwm(preprocessed_chs, device)
29
  # Process data through LWM
30
  lwm_loss, embedding_data = evaluate(lwm_model, dataset)
31
  print(f'LWM loss: {lwm_loss:.4f}')
 
38
  dataset = embedding_data.float()
39
  return dataset
40
 
41
+ def prepare_for_lwm(data, device, batch_size=64, shuffle=False):
42
 
43
  input_ids, masked_tokens, masked_pos = zip(*data)
44