File size: 3,124 Bytes
6aced58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from .config import LlamaConfig
from .model import LlamaModel
class LlamaForCausalLM(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Weight tying uses the head weights as the classifier for the token embeddings for both in and out.
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self._init_weights()
def _init_weights(self):
"""Initialize weights for all layers."""
# Initialize embeddings
if hasattr(self.model, 'embed_tokens'):
nn.init.normal_(self.model.embed_tokens.weight, mean=0.0, std=0.041666666666666664)
# Initialize linear layers
for module in self.modules():
if isinstance(module, nn.Linear):
# Xavier/Glorot initialization for weights
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
# Zero initialization for biases
nn.init.zeros_(module.bias)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
return hidden_states, self.lm_head.weight
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
max_new_tokens: int = 30,
temperature: float = 0.0,
) -> torch.LongTensor:
self.eval()
bsz, seq_len = input_ids.shape
position_ids = repeat(
torch.arange(seq_len, device=input_ids.device),
'l -> b l',
b=bsz
)
for _ in range(max_new_tokens):
hidden_states, classifier_weights = self.forward(input_ids, position_ids=position_ids)
# Get logits by computing hidden_states @ classifier_weights.T
next_token_logits = hidden_states[:, -1] @ classifier_weights.T
if temperature == 0:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
else:
scaled_logits = next_token_logits / temperature
probs = torch.softmax(scaled_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
new_position_ids = position_ids[:, -1:] + 1
position_ids = torch.cat([position_ids, new_position_ids], dim=1)
return input_ids |