DIY
Collection
Here I keep my homemade architectures and private projects (clos-rise). • 3 items • Updated
TPLT - Three-Phase Lightweight Transformer, архитектура на основе трансфоремера предпологающая структурирование всей архитектуры до трех основнхы модулей.
Для инференса модели используем следующий код :
import json
import torch
import torch.nn as nn
from safetensors.torch import load_file
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
with open("chars.txt", "r", encoding="utf-8") as f:
vocab_chars = json.load(f)
with open("model_config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
class Tokenizer:
def __init__(self, chars):
self.chars = chars
self.stoi = {c: i for i, c in enumerate(chars)}
self.itos = {i: c for c, i in self.stoi.items()}
self.pad_token = "<pad>"
self.sos_token = "<s>"
self.eos_token = "</s>"
def encode(self, text, max_len):
ids = [self.stoi.get(c, self.stoi[self.pad_token]) for c in text]
ids = ids[:max_len]
ids += [self.stoi[self.pad_token]] * (max_len - len(ids))
return ids
def decode(self, ids):
return "".join(
self.itos[i]
for i in ids
if i in self.itos and self.itos[i] not in {self.pad_token, self.sos_token, self.eos_token}
)
tokenizer = Tokenizer(vocab_chars)
class TinyTPLT(nn.Module):
def __init__(self, vocab_size, dim, num_heads, num_layers, max_len):
super().__init__()
self.embed = nn.Embedding(vocab_size, dim)
self.pos_embed = nn.Parameter(torch.zeros(1, max_len, dim))
enc_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=num_heads,
dim_feedforward=dim * 4,
batch_first=True,
dropout=0.1,
activation="gelu",
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
dec_layer = nn.TransformerDecoderLayer(
d_model=dim,
nhead=num_heads,
dim_feedforward=dim * 4,
batch_first=True,
dropout=0.1,
activation="gelu",
)
self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers)
self.fc_out = nn.Linear(dim, vocab_size)
self.max_len = max_len
def forward(self, src, tgt, tgt_mask=None, pad_idx=None):
src_emb = self.embed(src) + self.pos_embed[:, :src.size(1), :]
tgt_emb = self.embed(tgt) + self.pos_embed[:, :tgt.size(1), :]
if pad_idx is None:
pad_idx = 0
src_key_padding_mask = src == pad_idx
tgt_key_padding_mask = tgt == pad_idx
memory = self.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
out = self.decoder(
tgt_emb,
memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
)
return self.fc_out(out)
model = TinyTPLT(
vocab_size=cfg["vocab_size"],
dim=cfg["dim"],
num_heads=cfg["num_heads"],
num_layers=cfg["num_layers"],
max_len=cfg["max_len"],
).to(DEVICE)
state_dict = load_file("tplt.safetensors")
model.load_state_dict(state_dict, strict=True)
model.eval()
pad_idx = tokenizer.stoi["<pad>"]
sos_idx = tokenizer.stoi["<s>"]
eos_idx = tokenizer.stoi["</s>"]
max_len = cfg["max_len"]
@torch.no_grad()
def translate(word):
src_ids = tokenizer.encode(word, max_len=max_len)
src = torch.tensor([src_ids], dtype=torch.long, device=DEVICE)
tgt_seq = torch.tensor([[sos_idx]], dtype=torch.long, device=DEVICE)
out_ids = []
for _ in range(max_len - 1):
cur_len = tgt_seq.size(1)
causal_mask = torch.triu(torch.full((cur_len, cur_len), float("-inf"), device=DEVICE), diagonal=1)
logits = model(src, tgt_seq, tgt_mask=causal_mask, pad_idx=pad_idx)
next_logits = logits[:, -1, :]
next_logits[0, pad_idx] = -float("inf")
next_logits[0, sos_idx] = -float("inf")
next_id = next_logits.argmax(dim=-1).item()
if next_id == eos_idx:
break
out_ids.append(next_id)
tgt_seq = torch.cat([tgt_seq, torch.tensor([[next_id]], dtype=torch.long, device=DEVICE)], dim=1)
return tokenizer.decode(out_ids)
for w in ["Voda", "Ignis", 'Minarik', "Solur", 'Il', "Lumin"]:
print(f"{w:10} -> '{translate(w)}'")
Модель имеет 66тыс. параметров , созданная для демонстрации работоспособности архитектуры и протестированна на специально составленном языке для проверки работоспособности NLP моделей машинного обучения - Spectrum, обеспечивающий легкое и быстрое запоминание информации на уровне, достаточном для нормальной базовой речи.