Update engine.py
Browse files
engine.py
CHANGED
@@ -108,7 +108,9 @@ def predict_fn(data_loader, model, device, extract_features=False):
|
|
108 |
mask=mask,
|
109 |
token_type_ids=token_type_ids
|
110 |
).cpu().detach().numpy().tolist())
|
111 |
-
|
|
|
|
|
112 |
fin_outputs.extend(torch.argmax(
|
113 |
outputs, dim=1).cpu().detach().numpy().tolist())
|
114 |
|
|
|
108 |
mask=mask,
|
109 |
token_type_ids=token_type_ids
|
110 |
).cpu().detach().numpy().tolist())
|
111 |
+
print("1",torch.argmax(outputs, dim=1))
|
112 |
+
print("2",torch.argmax(outputs, dim=1).cpu())
|
113 |
+
print("3",torch.argmax(outputs, dim=1).cpu().numpy())
|
114 |
fin_outputs.extend(torch.argmax(
|
115 |
outputs, dim=1).cpu().detach().numpy().tolist())
|
116 |
|