from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from .mlp import LlamaMLP from .config import LlamaConfig from .rms_norm import LlamaRMSNorm from .decoder import LlamaDecoderLayer class LlamaModel(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=None) self.layers = nn.ModuleList([LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = self.norm(hidden_states) return hidden_states