File size: 3,859 Bytes
71bd54f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch

def define_optimizer(model, lr, alpha):
    # Define optimizer
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, alpha=alpha)
    optimizer.zero_grad()
    return optimizer

def tuple_of_tensors_to_tensor(tuple_of_tensors):
    return  torch.stack(list(tuple_of_tensors), dim=0)

def predict(model, inputs, notes, device):
    outputs = model.forward(inputs, notes)
    predicted = torch.sigmoid(outputs)
    predicted = (predicted>0.5).float() 
    return outputs, predicted

def display_train(epoch, num_epochs, i, model, correct, total, loss, train_loader, valid_loader, device):
    print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Train Loss: {loss.item():.4f}')
    train_accuracy = correct/total
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Accuracy: {train_accuracy:.4f}')
    valid_loss, valid_accuracy = eval_valid(model, valid_loader, epoch, num_epochs, device)
    return train_accuracy, valid_accuracy, valid_loss

def eval_valid(model, valid_loader, epoch, num_epochs, device):
    # Compute model train accuracy on test after all samples have been seen using test samples
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        running_loss = 0
        for inputs, labels, notes in valid_loader:
            # Get images and labels from test loader
            inputs = inputs.transpose(1,2).float().to(device)
            labels = labels.float().to(device)
            notes = notes.to(device)

            # Forward pass and predict class using max
            # outputs = model(inputs)
            outputs, predicted = predict(model, inputs, notes, device) #torch.max(outputs.data, 1)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, labels)
            running_loss += loss.item()*len(labels)

            # Check if predicted class matches label and count numbler of correct predictions
            total += labels.size(0)
            #TODO: change acc criteria
            # correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() # (predicted == labels).sum().item()
            values, indices = torch.max(outputs,dim=1)
            correct += sum(1 for s, i in enumerate(indices)
                             if labels[s][i] == 1)
            
    # Compute final accuracy and display
    valid_accuracy = correct/total
    validation_loss = running_loss/total
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {valid_accuracy:.4f}, Validation Loss: {validation_loss:.4f}')
    return validation_loss, valid_accuracy


def eval_test(model, test_loader, device):
    # Compute model test accuracy on test after training
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, labels, notes in test_loader:
            # Get images and labels from test loader
            inputs = inputs.transpose(1,2).float().to(device)
            labels = labels.float().to(device)
            notes = notes.to(device)

            # Forward pass and predict class using max
            # outputs = model(inputs)
            outputs, predicted = predict(model, inputs, notes, device)#torch.max(outputs.data, 1)

            # Check if predicted class matches label and count numbler of correct predictions
            total += labels.size(0)
            #TODO: change acc criteria
            # correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() # (predicted == labels).sum().item()
            values, indices = torch.max(outputs,dim=1)
            correct += sum(1 for s, i in enumerate(indices)
                             if labels[s][i] == 1)

    # Compute final accuracy and display
    test_accuracy = correct/total
    print(f'Ended Training, Test Accuracy: {test_accuracy:.4f}')
    return test_accuracy