Neu256 commited on
Commit
a82b5a1
·
1 Parent(s): 0efbf87

Upload 4 files

Browse files
Files changed (4) hide show
  1. base_model.pth +3 -0
  2. model.py +199 -0
  3. requirements.txt +4 -0
  4. utils.py +121 -0
base_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:065d18e5699492bb121d4011df95c85bf5505eed62225aa843db7c558258b9e2
3
+ size 201382377
model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from utils import DEVICE
5
+
6
+
7
+ class AttentionHead(nn.Module):
8
+ """
9
+ One head of the self-attention layer
10
+ """
11
+
12
+ def __init__(self, head_size, num_embed, block_size, dropout):
13
+ super().__init__()
14
+ self.key = nn.Linear(num_embed, head_size, bias=False)
15
+ self.query = nn.Linear(num_embed, head_size, bias=False)
16
+ self.value = nn.Linear(num_embed, head_size, bias=False)
17
+ # tril is a lower triangular matrix. it is not a parameter
18
+ # of the model, so we assign it to the module using register_buffer
19
+ self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
20
+
21
+ # let's also add dropout
22
+ self.dropout = nn.Dropout(dropout)
23
+
24
+ def forward(self, x):
25
+ B, T, C = x.shape
26
+ k = self.key(x)
27
+ q = self.query(x)
28
+ # compute attention scores
29
+ # (B, T, C) @ (B, C, T) -> (B, T, T)
30
+ wei = q @ k.transpose(-2, -1) * C**-0.5
31
+ # Tril matrix (lower triagular matrix) is used to mask
32
+ # future positions (setting them to -inf) so that the
33
+ # decoder "learns" to predict next words
34
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B,T,T)
35
+ wei = F.softmax(wei, dim=-1) # (B,T,T)
36
+ wei = self.dropout(wei)
37
+ # weighted aggregation of the values
38
+ v = self.value(x)
39
+ out = wei @ v # (B,T,T) @ (B,T,C) ---> (B,T,C)
40
+ return out
41
+
42
+
43
+ class MultiHeadAttention(nn.Module):
44
+ """
45
+ Multiple Heads of self-attention in parallel
46
+ """
47
+
48
+ def __init__(self, num_heads, head_size, num_embed, block_size, dropout):
49
+ super().__init__()
50
+ self.heads = nn.ModuleList(
51
+ [
52
+ AttentionHead(
53
+ head_size=head_size,
54
+ num_embed=num_embed,
55
+ block_size=block_size,
56
+ dropout=dropout,
57
+ )
58
+ for _ in range(num_heads)
59
+ ]
60
+ )
61
+ self.proj = nn.Linear(num_embed, num_embed)
62
+ self.dropout = nn.Dropout(dropout)
63
+
64
+ def forward(self, x):
65
+ # output of the self-attention
66
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
67
+ # apply the linear projection layer
68
+ out = self.dropout(self.proj(out))
69
+ return out
70
+
71
+
72
+ class FeedForward(nn.Module):
73
+ """
74
+ A simple linear layer followed by ReLu
75
+ """
76
+
77
+ def __init__(self, num_embed, dropout):
78
+ super().__init__()
79
+ self.net = nn.Sequential(
80
+ # in the Attention is All You Need paper
81
+ # authors are using the size of the ffwd layer 2048
82
+ # and the output of the model is 512
83
+ # so we apply the same factor of 4
84
+ nn.Linear(num_embed, 4 * num_embed),
85
+ nn.ReLU(),
86
+ # apply the linear projection layer
87
+ nn.Linear(4 * num_embed, num_embed),
88
+ nn.Dropout(dropout),
89
+ )
90
+
91
+ def forward(self, x):
92
+ return self.net(x)
93
+
94
+
95
+ class TransformerBlock(nn.Module):
96
+ """
97
+ This calss will group together MultiHead Attention and
98
+ FeedForward NN, so that we can copy it in Transformer
99
+ """
100
+
101
+ def __init__(self, num_heads, block_size, num_embed, dropout):
102
+ super().__init__()
103
+ head_size = num_embed // num_heads
104
+ self.sa = MultiHeadAttention(
105
+ num_heads=num_heads,
106
+ head_size=head_size,
107
+ num_embed=num_embed,
108
+ block_size=block_size,
109
+ dropout=dropout,
110
+ )
111
+ self.ffwd = FeedForward(num_embed=num_embed, dropout=dropout)
112
+ # add the layer normalization
113
+ self.ln1 = nn.LayerNorm(num_embed)
114
+ self.ln2 = nn.LayerNorm(num_embed)
115
+
116
+ def forward(self, x):
117
+ # "x +" is the skip (or residual) connection
118
+ # it helps with optimization
119
+ # also we apply layer normalization before self-attention
120
+ # and feed-forward (a reshufle from original paper)
121
+ x = x + self.sa(self.ln1(x))
122
+ x = x + self.ffwd(self.ln2(x))
123
+ return x
124
+
125
+
126
+ class Transformer(nn.Module):
127
+ def __init__(self, **kwargs):
128
+ super().__init__()
129
+ # a simple lookup table that stores embeddings of a fixed dictionary and size
130
+ # each token directly reads off the logits for the next token from a lookup table
131
+ # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
132
+ self.vocab_size = kwargs.get("vocab_size", 100)
133
+ self.num_embed = kwargs.get("num_embed", 32)
134
+ self.block_size = kwargs.get("block_size", 8)
135
+ self.num_heads = kwargs.get("num_heads", 4)
136
+ self.num_layers = kwargs.get("num_layers", 4)
137
+ self.dropout = kwargs.get("dropout", 0.2)
138
+ # each token reads the logits for the next token from a lookup table
139
+ self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed)
140
+ # each position from 0 to block_size-1 will get its embedding
141
+ self.position_embedding_table = nn.Embedding(self.block_size, self.num_embed)
142
+ self.blocks = nn.Sequential(
143
+ *[
144
+ TransformerBlock(
145
+ num_heads=self.num_heads,
146
+ block_size=self.block_size,
147
+ num_embed=self.num_embed,
148
+ dropout=self.dropout,
149
+ )
150
+ for _ in range(self.num_layers)
151
+ ]
152
+ )
153
+ # we add the layer norm before the Linear layer
154
+ self.ln_f = nn.LayerNorm(self.num_embed)
155
+ self.lm_head = nn.Linear(self.num_embed, self.vocab_size)
156
+
157
+ def forward(self, idx, targets=None):
158
+ B, T = idx.shape
159
+ # idx and targets are (B,T) tensor of integers
160
+ # the token_emb is (B, T, C), C = NUM_EMBED
161
+ token_emb = self.token_embedding_table(idx)
162
+ # (T, C)
163
+ posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
164
+
165
+ x = token_emb + posit_emb
166
+ # apply one head of self-attention
167
+ x = self.blocks(x)
168
+ # (B, T, vocab_size)
169
+ logits = self.lm_head(x)
170
+ # compute the loss
171
+ if targets != None:
172
+ # cross_entropy accepts inputs in a (batch_size, num_classes)
173
+ # so we need to reformat our logits dimensions to
174
+ # (batch_size * time, dim_vocabulary), time = block_size
175
+ B, T, C = logits.shape
176
+ logits = torch.reshape(logits, (B * T, C))
177
+ targets = torch.reshape(targets, (B * T,))
178
+ loss = F.cross_entropy(logits, targets)
179
+ else:
180
+ loss = None
181
+ return logits, loss
182
+
183
+ def generate(self, idx: torch.Tensor, max_new_tokens: int, block_size: int):
184
+ # idx is (B, T) array of indices in the current context
185
+ for _ in range(max_new_tokens):
186
+ # crop the context too the last block_size tokens
187
+ # because tokens don't communicate between blocks
188
+ idx_crop = idx[:, -block_size:]
189
+ # get the predictions
190
+ logits, loss = self.forward(idx_crop)
191
+ # focus only on the last time step
192
+ logits = logits[:, -1, :] # becomes (B, C)
193
+ # apply softmax to get probabilities
194
+ probs = F.softmax(logits, dim=-1) # (B, C)
195
+ # sample from the distribution with probabilities probs
196
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
197
+ # append sampled index to the running sequence
198
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
199
+ return idx
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # python>=3.9 # This is a recommended python version. It cannot be installed by pip. (???)
2
+ torch>=1.13.1
3
+ transformers>=4.25.1
4
+ numpy
utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from datetime import datetime
4
+
5
+ # hyperparameters
6
+ BATCH_SIZE = 32 # how many independent sequences will we process in parallel?
7
+ BLOCK_SIZE = 64 # what is the maximum context length for predictions?
8
+ MAX_ITER = 500 # number of training iterations
9
+ EVAL_INTER = 1
10
+ LEARNING_RATE = 3e-4
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ NUM_HEAD = 6
13
+ NUM_EMBED = NUM_HEAD * 128
14
+ NUM_LAYER = 6
15
+ DROPOUT = 0.2
16
+
17
+ def encode(text_seq: str, tokenizer: any) -> torch.Tensor:
18
+ """
19
+ Function to encode input text using a pre-trained tokenizer and vectorized lookups
20
+ """
21
+ # tokenize the input text
22
+ tokens = tokenizer.tokenize(text_seq)
23
+ # convert the tokens to their corresponding ids
24
+ token_indices = tokenizer.convert_tokens_to_ids(tokens)
25
+ token_indices = torch.tensor(token_indices, dtype=torch.long)
26
+ return token_indices
27
+
28
+
29
+ def decode(enc_sec: torch.Tensor, tokenizer: any) -> str:
30
+ """
31
+ Function to decode a sequence of token indices back to a string
32
+ """
33
+ # convert the indices to a list
34
+ enc_sec = enc_sec.tolist()
35
+ # decode the indices to a string
36
+ text = tokenizer.decode(enc_sec)
37
+ return text
38
+
39
+
40
+ def get_batch(data: list[str], block_size: int, batch_size: int):
41
+ """
42
+ This is a simple function to create batches of data.
43
+ GPUs allow for parallel processing we can feed multiple chunks at once
44
+ so that's why we would need batches - how many independant sequences
45
+ will we process in parallel.
46
+
47
+ Parameters:
48
+ data: list[str]: data to take batch from
49
+ block_size (int): size of the text that is proccessed at once
50
+ batch_size (int): number of sequences to process in parallel
51
+
52
+ Returns:
53
+ x, y: a tuple with token sequence and token target
54
+ """
55
+ ix = torch.randint(len(data) - block_size, (batch_size,))
56
+ # we stack batch_size rows of sentences
57
+ # so x and y are the matrices with rows_num=batch_size
58
+ # and col_num=block_size
59
+ x = torch.stack([data[i : i + block_size] for i in ix])
60
+ # y is x shifted one position right - because we predict
61
+ # word in y having all the previous words as context
62
+ y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
63
+ x, y = x.to(DEVICE), y.to(DEVICE)
64
+ return x, y
65
+
66
+
67
+ @torch.no_grad()
68
+ def estimate_loss(
69
+ data: list[str],
70
+ model: torch.nn.Module,
71
+ block_size: int,
72
+ batch_size: int,
73
+ eval_iters: int = 10,
74
+ ):
75
+ out = {}
76
+ model.eval()
77
+ losses = torch.zeros(eval_iters)
78
+ for k in range(eval_iters):
79
+ X, Y = get_batch(data=data, block_size=block_size, batch_size=batch_size)
80
+ logits, loss = model.forward(X, Y)
81
+ losses[k] = loss.item()
82
+ out = losses.mean()
83
+ model.train()
84
+ return out
85
+
86
+
87
+ def load_model_from_checkpoint(
88
+ model_class: torch.nn.Module,
89
+ path_to_checkpoint: str = "checkpoints/state_dict_model.pt",
90
+ **kwargs: dict,
91
+ ) -> torch.nn.Module:
92
+ try:
93
+ state_dict = torch.load(path_to_checkpoint)
94
+ print("Successfully loaded model from the checkpoint")
95
+ except Exception as e:
96
+ print(f"Error loading the model from the checkpoint. {e}")
97
+
98
+ model = model_class(**kwargs)
99
+ # load the state_dict into the model
100
+ model.load_state_dict(state_dict)
101
+ return model
102
+
103
+
104
+ def save_model_to_chekpoint(
105
+ model: torch.nn.Module, path_to_checkpoint: str = "checkpoints", epoch: int = 0
106
+ ):
107
+ # check if path exists, otherwise create it
108
+ if not os.path.exists(path_to_checkpoint):
109
+ os.makedirs(path_to_checkpoint)
110
+
111
+ # datetime object containing current date and time
112
+ now = datetime.now()
113
+ # dd/mm/YY H:M:S
114
+ dt_string = now.strftime("%d.%m.%Y_%H:%M:%S")
115
+ checkpoint_name = "checkpoint_epoch-" + str(epoch) + "_" + dt_string + ".pt"
116
+ full_path = os.path.join(path_to_checkpoint, checkpoint_name)
117
+ try:
118
+ torch.save(model.state_dict(), full_path)
119
+ print("Successfully saved the model to {}".format(full_path))
120
+ except Exception as e:
121
+ print(f"Error saving the model to checkpoint. {e}")