Sadjad Alikhani commited on
Commit
5d9edcb
·
verified ·
1 Parent(s): 573514a

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -2
inference.py CHANGED
@@ -79,11 +79,14 @@ def evaluate(model, dataloader):
79
  criterionMCM = nn.MSELoss()
80
 
81
  with torch.no_grad():
82
- for batch in dataloader:
83
  input_ids = batch[0]
84
  masked_tokens = batch[1]
85
  masked_pos = batch[2]
86
-
 
 
 
87
  logits_lm, output = model(input_ids, masked_pos)
88
 
89
  output_batch_preproc = output
 
79
  criterionMCM = nn.MSELoss()
80
 
81
  with torch.no_grad():
82
+ for idx, batch in enumerate(dataloader):
83
  input_ids = batch[0]
84
  masked_tokens = batch[1]
85
  masked_pos = batch[2]
86
+
87
+ if idx == 0:
88
+ print(input_ids[0])
89
+
90
  logits_lm, output = model(input_ids, masked_pos)
91
 
92
  output_batch_preproc = output