translation-en-fr / model.py
ngia's picture
deploy on hugging face spaces for inference
d91ea77
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
import os
from utils import get_file_FROM_HF
from safetensors.torch import load_file
@dataclass
class TransformerConfig:
src_vocab_size: int = 32000
tgt_vocab_size: int = 32000
max_seq_length: int = 64
d_model: int = 512
num_heads: int = 8
num_encoder_layers: int = 6
num_decoder_layers: int = 6
dropout_p: float = 0.1
dff: int = 2048
device: str = 'cpu'
# Source Embedding block
class SourceEmbedding(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.src_embedding = nn.Embedding(num_embeddings=config.src_vocab_size, embedding_dim=config.d_model)
def forward(self, x):
x = self.src_embedding(x)
return x
# Target Embedding block
class TargetEmbedding(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.tgt_embedding = nn.Embedding(num_embeddings=config.tgt_vocab_size, embedding_dim=config.d_model)
def forward(self, x):
x = self.tgt_embedding(x)
return x
# Position Encoding (PE)
class PositionEncoding(nn.Module):
def __init__(self, config: TransformerConfig, require_grad=False):
super().__init__()
self.PE = torch.zeros(config.max_seq_length, config.d_model)
pos = torch.arange(0, config.max_seq_length).reshape(-1, 1)
i = torch.arange(0, config.d_model, step=2)
denominator = torch.pow(10000, (2*i) / config.d_model)
self.PE[:, 0::2] = torch.sin(pos/denominator)
self.PE[:, 1::2] = torch.cos(pos/denominator)
self.PE = nn.Parameter(self.PE, requires_grad=require_grad)
def forward(self, x):
max_seq_length = x.shape[1]
return x + self.PE[:max_seq_length]
# Muti Head Attention block for (Multi Head Attention, Masked Multi Head Attention and Cross Multi Heads Attention)
class MultiheadAttention(nn.Module):
def __init__(self, config:TransformerConfig):
super().__init__()
self.config = config
# check if the d_model is divided by num_heads to get the head dim
assert config.d_model % self.config.num_heads == 0, "The d_model is not divided by the num of heads"
self.head_dim = self.config.d_model // self.config.num_heads
self.q_proj = nn.Linear(in_features=self.config.d_model, out_features=self.config.d_model)
self.k_proj = nn.Linear(in_features=self.config.d_model, out_features=self.config.d_model)
self.v_proj = nn.Linear(in_features=self.config.d_model, out_features=self.config.d_model)
self.out_proj = nn.Linear(in_features=self.config.d_model, out_features=self.config.d_model)
def forward(self, src, tgt=None, attention_mask=None, causal=False):
batch, src_seq_length, d_model = src.shape
if tgt is None:
q = self.q_proj(src).reshape(batch, src_seq_length, self.config.num_heads, self.head_dim).transpose(1,2).contiguous()
k = self.k_proj(src).reshape(batch, src_seq_length, self.config.num_heads, self.head_dim).transpose(1,2).contiguous()
v = self.v_proj(src).reshape(batch, src_seq_length, self.config.num_heads, self.head_dim).transpose(1,2).contiguous()
#MASKED MULTI HEAD ATTENTION
if attention_mask is not None:
attention_mask = attention_mask.bool()
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1,1,src_seq_length,1).to(self.config.device)
if causal and attention_mask is not None:
# compute new mask (pad mask + causal mask)
causal_mask = ~torch.triu(torch.ones((src_seq_length, src_seq_length), dtype=torch.bool), diagonal=1)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).to(self.config.device)
combined_mask = causal_mask.int() * attention_mask.int()
attention_mask = combined_mask.bool().to(self.config.device)
# torch.set_printoptions(threshold=torch.inf)
attention_out = F.scaled_dot_product_attention(q,k,v,
attn_mask=attention_mask,
dropout_p=self.config.dropout_p if self.training else 0.0,
is_causal=False)
# CROSS ATTENTION
else:
tgt_seq_length = tgt.shape[1]
q = self.q_proj(tgt).reshape(batch, tgt_seq_length, self.config.num_heads, self.head_dim).transpose(1,2).contiguous()
k = self.k_proj(src).reshape(batch, src_seq_length, self.config.num_heads, self.head_dim).transpose(1,2).contiguous()
v = self.v_proj(src).reshape(batch, src_seq_length, self.config.num_heads, self.head_dim).transpose(1,2).contiguous()
if attention_mask is not None:
attention_mask = attention_mask.bool()
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1,1,tgt_seq_length,1)
attention_out = F.scaled_dot_product_attention(q,k,v,
attn_mask=attention_mask,
dropout_p=self.config.dropout_p if self.training else 0.0,
is_causal=False)
attention_out = attention_out.transpose(1,2).flatten(2)
attention_out = self.out_proj(attention_out)
return attention_out
# Position Wise Feed Forward Network (MLP)
class FeedForward(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.hidden_layer = nn.Linear(in_features=config.d_model, out_features=config.dff) #eg: 512 -> 2048
self.hidden_dropout = nn.Dropout(p=config.dropout_p)
self.output_layer = nn.Linear(in_features=config.dff, out_features=config.d_model) #eg : 2048 - > 512
self.output_dropout = nn.Dropout(p=config.dropout_p)
def forward(self, x):
x = self.hidden_layer(x)
x = F.gelu(x)
x = self.hidden_dropout(x)
x = self.output_layer(x)
x = self.output_dropout(x)
return x
# Encoder block
class EncoderBlock(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.multi_head_attention = MultiheadAttention(config=config)
self.feed_forward = FeedForward(config=config)
self.layer_norm_1 = nn.LayerNorm(config.d_model)
self.layer_norm_2 = nn.LayerNorm(config.d_model)
self.dropout = nn.Dropout(config.dropout_p)
def forward(self, x, attention_mask=None):
x = x + self.dropout(self.multi_head_attention(src=x, attention_mask=attention_mask))
x = self.layer_norm_1(x)
x = x + self.feed_forward(x)
x = self.layer_norm_2(x)
return x
# Decoder block
class DecoderBlock(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.masked_multi_head_attention = MultiheadAttention(config=config)
self.dropout_masked = nn.Dropout(config.dropout_p)
self.cross_multi_head_attention = MultiheadAttention(config=config)
self.dropout_cross = nn.Dropout(config.dropout_p)
self.feed_forward = FeedForward(config=config)
self.layer_norm_1 = nn.LayerNorm(config.d_model)
self.layer_norm_2 = nn.LayerNorm(config.d_model)
self.layer_norm_3 = nn.LayerNorm(config.d_model)
def forward(self, src,tgt, src_attention_mask=None, tgt_attention_mask=None):
tgt = tgt + self.dropout_masked(self.masked_multi_head_attention(tgt, attention_mask=tgt_attention_mask, causal=True))
tgt = self.layer_norm_1(tgt)
tgt = tgt + self.dropout_cross(self.cross_multi_head_attention(src, tgt, attention_mask=src_attention_mask))
tgt = self.layer_norm_2(tgt)
tgt = tgt + self.feed_forward(tgt)
return tgt
# Transformer (put it all together)
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.src_embedding = SourceEmbedding(config=config)
self.tgt_embedding = TargetEmbedding(config=config)
self.position_encoding = PositionEncoding(config=config)
self.encoder = nn.ModuleList(
[EncoderBlock(config=config) for _ in range(config.num_encoder_layers)]
)
self.decoder = nn.ModuleList(
[DecoderBlock(config=config) for _ in range(config.num_decoder_layers)]
)
self.output = nn.Linear(config.d_model, config.tgt_vocab_size)
## Init weights
self.apply(_init_weights_)
def forward(self, src_ids, tgt_ids, src_attention_mask=None, tgt_attention_mask=None):
# embed token ids
src_embed = self.src_embedding(src_ids)
tgt_embed = self.tgt_embedding(tgt_ids)
# add position encoding
src_embed = self.position_encoding(src_embed)
tgt_embed = self.position_encoding(tgt_embed)
for layer in self.encoder:
src_embed = layer(src_embed, src_attention_mask)
for layer in self.decoder:
tgt_embed = layer(src_embed, tgt_embed, src_attention_mask, tgt_attention_mask)
pred = self.output(tgt_embed)
return pred
@torch.no_grad()
def inference(self, src_ids, tgt_start_id, tgt_end_id, max_seq_length):
tgt_ids = torch.tensor([tgt_start_id], device=src_ids.device).reshape(1,1)
#Encode the source
src_embed = self.src_embedding(src_ids)
src_embed = self.position_encoding(src_embed)
for layer in self.encoder:
src_embed = layer(src_embed)
#Generate Target
for i in range(max_seq_length):
tgt_embed = self.tgt_embedding(tgt_ids)
tgt_embed = self.position_encoding(tgt_embed)
for layer in self.decoder:
tgt_embed = layer(src_embed, tgt_embed)
tgt_embed = tgt_embed[:, -1]
pred = self.output(tgt_embed)
pred = pred.argmax(axis=-1).unsqueeze(0)
tgt_ids = torch.cat([tgt_ids, pred], axis=-1)
if torch.all(pred == tgt_end_id):
break
return tgt_ids.squeeze().cpu().tolist()
def load_weights_from_checkpoints(self, path_to_checkpoints):
if not os.path.exists(path_to_checkpoints):
print("------------------- LOADING MODEL CHECKPOINTS FROM HUGGING FACE --------------------------")
folder = os.path.dirname(path_to_checkpoints)
os.makedirs(folder, exist_ok=True)
path_to_checkpoints = get_file_FROM_HF(repo_id="ngia/ml-translation-en-fr", file_path="final_checkpoint/model.safetensors", local_dir=folder)
chekpoints = load_file(filename=path_to_checkpoints)
self.load_state_dict(chekpoints)
return self
def _init_weights_(module):
"""
Simple weight intialization taken directly from the huggingface
`modeling_roberta.py` implementation!
"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if __name__ == "__main__":
config = TransformerConfig()
model = Transformer(config=config)
english = torch.randint(low=0, high=1000, size=(1,3))
res = model.inference(src_ids=english, tgt_start_id=1, tgt_end_id=2, max_seq_length=config.max_seq_length)
print(res)