TPLT - Three-Phase Lightweight Transformer, архитектура на основе трансфоремера предпологающая структурирование всей архитектуры до трех основнхы модулей.

Для инференса модели используем следующий код :

import json
import torch
import torch.nn as nn
from safetensors.torch import load_file

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

with open("chars.txt", "r", encoding="utf-8") as f:
    vocab_chars = json.load(f)

with open("model_config.json", "r", encoding="utf-8") as f:
    cfg = json.load(f)

class Tokenizer:
    def __init__(self, chars):
        self.chars = chars
        self.stoi = {c: i for i, c in enumerate(chars)}
        self.itos = {i: c for c, i in self.stoi.items()}
        self.pad_token = "<pad>"
        self.sos_token = "<s>"
        self.eos_token = "</s>"

    def encode(self, text, max_len):
        ids = [self.stoi.get(c, self.stoi[self.pad_token]) for c in text]
        ids = ids[:max_len]
        ids += [self.stoi[self.pad_token]] * (max_len - len(ids))
        return ids

    def decode(self, ids):
        return "".join(
            self.itos[i]
            for i in ids
            if i in self.itos and self.itos[i] not in {self.pad_token, self.sos_token, self.eos_token}
        )

tokenizer = Tokenizer(vocab_chars)

class TinyTPLT(nn.Module):
    def __init__(self, vocab_size, dim, num_heads, num_layers, max_len):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_len, dim))

        enc_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=num_heads,
            dim_feedforward=dim * 4,
            batch_first=True,
            dropout=0.1,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        dec_layer = nn.TransformerDecoderLayer(
            d_model=dim,
            nhead=num_heads,
            dim_feedforward=dim * 4,
            batch_first=True,
            dropout=0.1,
            activation="gelu",
        )
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers)

        self.fc_out = nn.Linear(dim, vocab_size)
        self.max_len = max_len

    def forward(self, src, tgt, tgt_mask=None, pad_idx=None):
        src_emb = self.embed(src) + self.pos_embed[:, :src.size(1), :]
        tgt_emb = self.embed(tgt) + self.pos_embed[:, :tgt.size(1), :]
        if pad_idx is None:
            pad_idx = 0
        src_key_padding_mask = src == pad_idx
        tgt_key_padding_mask = tgt == pad_idx
        memory = self.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
        out = self.decoder(
            tgt_emb,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
        )
        return self.fc_out(out)

model = TinyTPLT(
    vocab_size=cfg["vocab_size"],
    dim=cfg["dim"],
    num_heads=cfg["num_heads"],
    num_layers=cfg["num_layers"],
    max_len=cfg["max_len"],
).to(DEVICE)

state_dict = load_file("tplt.safetensors")
model.load_state_dict(state_dict, strict=True)
model.eval()

pad_idx = tokenizer.stoi["<pad>"]
sos_idx = tokenizer.stoi["<s>"]
eos_idx = tokenizer.stoi["</s>"]
max_len = cfg["max_len"]

@torch.no_grad()
def translate(word):
    src_ids = tokenizer.encode(word, max_len=max_len)
    src = torch.tensor([src_ids], dtype=torch.long, device=DEVICE)

    tgt_seq = torch.tensor([[sos_idx]], dtype=torch.long, device=DEVICE)
    out_ids = []

    for _ in range(max_len - 1):
        cur_len = tgt_seq.size(1)
        causal_mask = torch.triu(torch.full((cur_len, cur_len), float("-inf"), device=DEVICE), diagonal=1)
        logits = model(src, tgt_seq, tgt_mask=causal_mask, pad_idx=pad_idx)
        next_logits = logits[:, -1, :]
        next_logits[0, pad_idx] = -float("inf")
        next_logits[0, sos_idx] = -float("inf")
        next_id = next_logits.argmax(dim=-1).item()
        if next_id == eos_idx:
            break
        out_ids.append(next_id)
        tgt_seq = torch.cat([tgt_seq, torch.tensor([[next_id]], dtype=torch.long, device=DEVICE)], dim=1)

    return tokenizer.decode(out_ids)

for w in ["Voda", "Ignis", 'Minarik', "Solur", 'Il', "Lumin"]:
    print(f"{w:10} -> '{translate(w)}'")

Модель имеет 66тыс. параметров , созданная для демонстрации работоспособности архитектуры и протестированна на специально составленном языке для проверки работоспособности NLP моделей машинного обучения - Spectrum, обеспечивающий легкое и быстрое запоминание информации на уровне, достаточном для нормальной базовой речи.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including Darkester/TPLT-66k