Sadjad Alikhani
commited on
Update inference.py
Browse files- inference.py +5 -2
inference.py
CHANGED
@@ -79,11 +79,14 @@ def evaluate(model, dataloader):
|
|
79 |
criterionMCM = nn.MSELoss()
|
80 |
|
81 |
with torch.no_grad():
|
82 |
-
for batch in dataloader:
|
83 |
input_ids = batch[0]
|
84 |
masked_tokens = batch[1]
|
85 |
masked_pos = batch[2]
|
86 |
-
|
|
|
|
|
|
|
87 |
logits_lm, output = model(input_ids, masked_pos)
|
88 |
|
89 |
output_batch_preproc = output
|
|
|
79 |
criterionMCM = nn.MSELoss()
|
80 |
|
81 |
with torch.no_grad():
|
82 |
+
for idx, batch in enumerate(dataloader):
|
83 |
input_ids = batch[0]
|
84 |
masked_tokens = batch[1]
|
85 |
masked_pos = batch[2]
|
86 |
+
|
87 |
+
if idx == 0:
|
88 |
+
print(input_ids[0])
|
89 |
+
|
90 |
logits_lm, output = model(input_ids, masked_pos)
|
91 |
|
92 |
output_batch_preproc = output
|