File size: 10,734 Bytes
afdb7c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89705f3
afdb7c4
89705f3
afdb7c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d7f089
afdb7c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759b981
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# Importing libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from .configuration_gpt import GPTConfig


class GPT(nn.Module):
    """

    The GPT language model:

    - Embeddings (token + positional)

    - Stack of Transformer blocks

    - Final LayerNorm + Linear head for output logits

    """

    def __init__(

        self,

        block_size: int = 1024,

        vocab_size: int = 50304,

        n_layer: int = 12,

        n_head: int = 12,

        n_embd: int = 768,

    ):
        super().__init__()

        # Store model hyperparameters
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd

        # Transformer components stored in a module dictionary
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(self.vocab_size, self.n_embd),  # Token embedding
                wpe=nn.Embedding(self.block_size, self.n_embd),  # Positional embedding
                h=nn.ModuleList(
                    [self.Block(self.n_embd, self.n_head) for _ in range(self.n_layer)]
                ),  # Transformer blocks
                ln_f=nn.LayerNorm(self.n_embd),  # Final layer normalization
            )
        )

        # Linear head for output logits
        self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)

        # Tie weights between token embedding and output projection
        self.transformer.wte.weight = self.lm_head.weight

    def forward(self, x):
        B, T = x.shape  # Batch size and sequence length
        assert T <= self.block_size, "Cannot forward sequence longer than block size"

        # Token and positional embeddings
        tok_emb = self.transformer.wte(x)
        pos_emb = self.transformer.wpe(torch.arange(T, device=x.device))
        x = tok_emb + pos_emb.unsqueeze(0)

        # Forward pass through transformer blocks
        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)  # Final layer norm
        logits = self.lm_head(x)  # Compute logits
        return logits

    class CausalSelfAttention(nn.Module):
        """

        Multi-head self-attention with causal masking.

        """

        def __init__(self, n_embd, n_head):
            super().__init__()
            assert (
                n_embd % n_head == 0
            ), "Embedding dimension must be divisible by number of heads"
            self.n_head = n_head
            self.n_embd = n_embd

            # Linear layers for query, key, and value
            self.c_attn = nn.Linear(n_embd, 3 * n_embd)
            self.c_proj = nn.Linear(n_embd, n_embd)

        def forward(self, x):
            B, T, C = x.size()
            qkv = self.c_attn(x)
            q, k, v = qkv.split(self.n_embd, dim=2)

            # Reshape and transpose for multi-head attention
            k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
            q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
            v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

            # Apply scaled dot-product attention with causal masking
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

            # Reshape and apply output projection
            y = y.transpose(1, 2).contiguous().view(B, T, C)
            y = self.c_proj(y)
            return y

    class MLP(nn.Module):
        """

        Feed-forward network block used in Transformer architectures.

        """

        def __init__(self, n_embd):
            super().__init__()
            self.c_fc = nn.Linear(n_embd, 4 * n_embd)
            self.gelu = nn.GELU(approximate="tanh")
            self.c_proj = nn.Linear(4 * n_embd, n_embd)

        def forward(self, x):
            return self.c_proj(self.gelu(self.c_fc(x)))

    class Block(nn.Module):
        """

        A single Transformer block.

        """

        def __init__(self, n_embd, n_head):
            super().__init__()
            self.ln_1 = nn.LayerNorm(n_embd)
            self.attn = GPT.CausalSelfAttention(n_embd, n_head)
            self.ln_2 = nn.LayerNorm(n_embd)
            self.mlp = GPT.MLP(n_embd)

        def forward(self, x):
            x = x + self.attn(self.ln_1(x))
            x = x + self.mlp(self.ln_2(x))
            return x


class GPTModelForTextGeneration(PreTrainedModel):
    """

    A wrapper class for GPT-based text generation.

    This integrates a Transformer model within the Hugging Face `PreTrainedModel` framework.

    """

    config_class = GPTConfig

    def __init__(self, config):
        super().__init__(config)

        # Instantiate the GPT model with the provided configuration
        self.model = GPT(
            block_size=config.block_size,
            vocab_size=config.vocab_size,
            n_layer=config.n_layer,
            n_head=config.n_head,
            n_embd=config.n_embd,
        )

    def forward(self, input_ids: torch.Tensor):
        # Check input_ids type and shape
        assert isinstance(input_ids, torch.Tensor), "input_ids must be a PyTorch tensor"

        tokens = input_ids.clone()  # Avoid modifying input_ids directly
        tokens = tokens.unsqueeze(0) if tokens.dim() == 1 else tokens

        assert (
            tokens.ndim == 2 and tokens.shape[0] == 1
        ), "input_ids must have 2 dimensions: (1, sequence_length)"

        # Check token values
        assert torch.all(
            (tokens >= 0) & (tokens <= self.model.vocab_size)
        ), "input_ids contain invalid token values"

        # Forward pass through the model
        logits = self.model.forward(tokens)

        return {"logits": logits}

    @torch.no_grad()
    def generate(

        self,

        input_ids: torch.Tensor,

        max_length: int = 50,

        do_sample: bool = True,

        top_k: int = 50,

        top_p: float = 0.95,

        temperature: float = 0.9,

        device: str = "cpu",

    ):
        """

        Generates text using autoregressive sampling with top-k, top-p, and temperature.

        """

        # Validate device type
        if device.startswith("cuda"):
            assert torch.cuda.is_available(), "CUDA is not available, please use 'cpu'"
            if device != "cuda":  # Check for specific CUDA device (cuda:n)
                try:
                    device_index = int(device.split(":")[1])  # Extract device number
                    assert (
                        0 <= device_index < torch.cuda.device_count()
                    ), f"Invalid CUDA device index: {device_index}"
                except (IndexError, ValueError):
                    raise ValueError(
                        "Invalid device format. Use 'cpu', 'cuda', or 'cuda:N' where N is an integer."
                    )
        elif device != "cpu":
            raise ValueError("Invalid device. Use 'cpu', 'cuda', or 'cuda:N'.")

        # Move input tensor and model to the specified device
        input_ids = input_ids.to(device)
        self.model.to(device)

        # Check input_ids type and shape
        assert isinstance(input_ids, torch.Tensor), "input_ids must be a PyTorch tensor"
        tokens = input_ids.clone()  # Avoid modifying input_ids directly
        tokens = tokens.unsqueeze(0) if tokens.dim() == 1 else tokens

        assert (
            tokens.ndim == 2 and tokens.shape[0] == 1
        ), "input_ids must have 2 dimensions: (1, sequence_length)"

        # Check token values
        assert torch.all(
            (tokens >= 0) & (tokens < self.model.vocab_size)
        ), "input_ids contain invalid token values"

        # Check max_length
        assert (
            isinstance(max_length, int) and max_length >= 1
        ), "max_length must be a positive integer"
        assert (
            max_length <= self.model.block_size
        ), f"max_length must be in range [1, {self.model.block_size}]"

        # Check top_k
        assert isinstance(top_k, int) and top_k >= 1, "top_k must be a positive integer"

        # Check top_p
        assert (
            isinstance(top_p, (int, float)) and 0.0 <= top_p <= 1.0
        ), "top_p must be in range [0, 1]"

        # Check temperature
        assert (
            isinstance(temperature, (int, float)) and 0.0 <= temperature <= 1.0
        ), "temperature must be in range [0, 1]"

        # Move tokens to the correct device
        tokens = tokens.to(device)

        # Autoregressive token generation loop
        while tokens.size(1) < max_length:
            logits = self.forward(tokens)["logits"][:, -1, :]
            logits = logits / max(0.01, temperature)

            if do_sample:
                top_k = min(top_k, logits.size(-1))  # Safety check

                # Remove all tokens with a probability less than the last token of the top-k
                indices_to_remove = (
                    logits < torch.topk(logits, top_k, dim=1)[0][..., -1, None]
                )
                logits[indices_to_remove] = float("-inf")

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)

                cumulative_probs = torch.cumsum(
                    F.softmax(sorted_logits, dim=-1), dim=-1
                )
                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                # Shift the indices to the right to keep also the first token above the threshold
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                # Replace logits to be removed with -inf in the sorted_logits
                sorted_logits[sorted_indices_to_remove] = float("-inf")
                # Then reverse the sorting process by mapping back sorted_logits to their original position
                logits = torch.gather(sorted_logits, 1, sorted_indices.argsort(-1))

                # Convert sorted indices back to original vocab indices
                next_tokens = torch.multinomial(F.softmax(logits, -1), 1)
            else:
                next_tokens = torch.argmax(logits, dim=-1, keepdim=True)

            tokens = torch.cat((tokens, next_tokens), dim=1)

        return tokens.flatten()