|
from typing import Optional, Tuple |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .mlp import LlamaMLP |
|
from .config import LlamaConfig |
|
from .rms_norm import LlamaRMSNorm |
|
from .decoder import LlamaDecoderLayer |
|
|
|
class LlamaModel(nn.Module): |
|
def __init__(self, config: LlamaConfig): |
|
super().__init__() |
|
self.config = config |
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=None) |
|
self.layers = nn.ModuleList([LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) |
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
hidden_states = self.embed_tokens(input_ids) |
|
|
|
for decoder_layer in self.layers: |
|
hidden_states = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
return hidden_states |