from cnn import CNN,Model from utils import loader_test import torch device = torch.device('cuda') model = CNN().to(device) model.load_state_dict(torch.load('net_params.pth')) #model = Model().to(device) #model.load_state_dict(torch.load('cls_params.pth')) def test(): model.eval() correct = 0 total = 0 for i ,(input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test): print(i) input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) token_type_ids = token_type_ids.to(device) labels = labels.to(device) with torch.no_grad(): out = model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids) out = out.argmax(dim=1) correct += (out == labels).sum().item() total += len(labels) print('correct: ',correct,'total: ',total) print('accuracy:',correct/total) test()