|
from typing import Optional, Tuple |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
|
|
from .config import LlamaConfig |
|
from .model import LlamaModel |
|
|
|
class LlamaForCausalLM(nn.Module): |
|
def __init__(self, config: LlamaConfig): |
|
super().__init__() |
|
self.model = LlamaModel(config) |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
if config.tie_word_embeddings: |
|
self.lm_head.weight = self.model.embed_tokens.weight |
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
"""Initialize weights for all layers.""" |
|
|
|
if hasattr(self.model, 'embed_tokens'): |
|
nn.init.normal_(self.model.embed_tokens.weight, mean=0.0, std=0.041666666666666664) |
|
|
|
|
|
for module in self.modules(): |
|
if isinstance(module, nn.Linear): |
|
|
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
|
|
nn.init.zeros_(module.bias) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
hidden_states = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
) |
|
|
|
return hidden_states, self.lm_head.weight |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
input_ids: torch.LongTensor, |
|
max_new_tokens: int = 30, |
|
temperature: float = 0.0, |
|
) -> torch.LongTensor: |
|
self.eval() |
|
bsz, seq_len = input_ids.shape |
|
|
|
position_ids = repeat( |
|
torch.arange(seq_len, device=input_ids.device), |
|
'l -> b l', |
|
b=bsz |
|
) |
|
|
|
for _ in range(max_new_tokens): |
|
hidden_states, classifier_weights = self.forward(input_ids, position_ids=position_ids) |
|
|
|
|
|
next_token_logits = hidden_states[:, -1] @ classifier_weights.T |
|
|
|
if temperature == 0: |
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
else: |
|
scaled_logits = next_token_logits / temperature |
|
probs = torch.softmax(scaled_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=1) |
|
new_position_ids = position_ids[:, -1:] + 1 |
|
position_ids = torch.cat([position_ids, new_position_ids], dim=1) |
|
|
|
return input_ids |