Blackroot's picture
Upload 18 files
6aced58 verified
raw
history blame
1.21 kB
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mlp import LlamaMLP
from .config import LlamaConfig
from .rms_norm import LlamaRMSNorm
from .decoder import LlamaDecoderLayer
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=None)
self.layers = nn.ModuleList([LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = self.norm(hidden_states)
return hidden_states