File size: 3,276 Bytes
abd6e8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import numpy as np
import matplotlib.pyplot as plt


# Hyper arguments
INPUT_DIM = 64
OUT_DIM = 5
H1_DIM = 16
H2_DIM = 10

ALPHA = 0.005
NUM_EPOCHS = 100

# Global
img_map = []
loss_arr = []


def relu(t):
    return np.maximum(t, 0)


def softmax(t):
    out = np.exp(t)
    return out / np.sum(out)


def sparse_cross_entropy(z, y):
    return -np.log(z[0, y])


def to_full(y, num_classes):
    y_full = np.zeros((1, num_classes))
    y_full[0, y] = 1
    return y_full


def relu_deriv(t):
    return (t >= 0).astype(float)


def predict(x):
    t1 = x @ W1 + b1
    h1 = relu(t1)
    t2 = h1 @ W2 + b2
    h2 = relu(t2)
    t3 = h2 @ W3 + b3
    z = softmax(t3)
    return z


def calc_accuracy():
    correct = 0
    for y, x in dataset:
        z = predict(x)
        y_pred = np.argmax(z)
        if y_pred == y:
            correct += 1
    acc = correct / len(dataset)
    return acc


def create_map():
    for x in image:
        z = predict(x)
        y_pred = np.argmax(z)
        img_map.append(y_pred)


if __name__ == "__main__":
    # Load infomation
    dataset = np.load('./../dataset.npy', allow_pickle=True)

    image = np.load('./../image.npy')

    # Random synapses
    W1 = np.random.rand(INPUT_DIM, H1_DIM)
    b1 = np.random.rand(1, H1_DIM)
    W2 = np.random.rand(H1_DIM, H2_DIM)
    b2 = np.random.rand(1, H2_DIM)
    W3 = np.random.rand(H2_DIM, OUT_DIM)
    b3 = np.random.rand(1, OUT_DIM)

    W1 = (W1 - 0.5) * 2 * np.sqrt(1/INPUT_DIM)
    b1 = (b1 - 0.5) * 2 * np.sqrt(1/INPUT_DIM)
    W2 = (W2 - 0.5) * 2 * np.sqrt(1/H1_DIM)
    b2 = (b2 - 0.5) * 2 * np.sqrt(1/H1_DIM)
    W3 = (W3 - 0.5) * 2 * np.sqrt(1/H2_DIM)
    b3 = (b3 - 0.5) * 2 * np.sqrt(1/H2_DIM)

    loss = 0

    # Backpropagation
    for ep in range(NUM_EPOCHS):
        print(ep)
        np.random.shuffle(dataset)
        for i in range(len(dataset)):

            x = dataset[i][1]
            y = dataset[i][0]

            # Forward
            t1 = x @ W1 + b1
            h1 = relu(t1)
            t2 = h1 @ W2 + b2
            h2 = relu(t2)
            t3 = h2 @ W3 + b3
            z = softmax(t3)
            E = sparse_cross_entropy(z, y)

            # Backward
            y_full = to_full(y, OUT_DIM)
            dE_dt3 = z - y_full
            dE_dW3 = h2.T @ dE_dt3
            dE_db3 = np.sum(dE_dt3, axis=0, keepdims=True)
            dE_dh2 = dE_dt3 @ W3.T
            dE_dt2 = dE_dh2 * relu_deriv(t2)
            dE_dW2 = h1.T @ dE_dt2
            dE_db2 = np.sum(dE_dt2, axis=0, keepdims=True)
            dE_dh1 = dE_dt2 @ W2.T
            dE_dt1 = dE_dh1 * relu_deriv(t1)
            dE_dW1 = x.T @ dE_dt1
            dE_db1 = np.sum(dE_dt1, axis=0, keepdims=True)

            # Update
            W1 = W1 - ALPHA * dE_dW1
            b1 = b1 - ALPHA * dE_db1
            W2 = W2 - ALPHA * dE_dW2
            b2 = b2 - ALPHA * dE_db2
            W3 = W3 - ALPHA * dE_dW3
            b3 = b3 - ALPHA * dE_db3

            loss += E

        loss_arr.append(loss)
        loss = 0

    # Accuracy
    accuracy = calc_accuracy()
    print("Accuracy:", accuracy)

    # Map
    create_map()
    np.save('map', np.asarray(img_map))

    # Plot
    plt.plot(loss_arr)
    plt.show()

    # Save synapses
    np.savez('./../synapses', W1, b1, W2, b2, W3, b3)