|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch INFLM model.""" |
|
|
|
import torch |
|
from torch import nn |
|
from transformers.models.llama.modeling_llama import ( |
|
LlamaDecoderLayer, |
|
LlamaModel, |
|
LlamaForCausalLM |
|
) |
|
from .configuration_inflm import INFLMConfig |
|
|
|
_CONFIG_FOR_DOC = "INFLMConfig" |
|
|
|
|
|
class INFLMDecoderLayer(LlamaDecoderLayer): |
|
def __init__(self, config: INFLMConfig, layer_idx: int): |
|
super().__init__(config, layer_idx) |
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
|
|
class INFLMModel(LlamaModel): |
|
config_class = INFLMConfig |
|
_no_split_modules = ["INFLMDecoderLayer"] |
|
|
|
def __init__(self, config: INFLMConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList([INFLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) |
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
self.post_init() |
|
|
|
|
|
class INFLMForCausalLM(LlamaForCausalLM): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config: INFLMConfig): |
|
super().__init__(config) |
|
self.model = INFLMModel(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|