Spaces:
Sleeping
Sleeping
File size: 1,028 Bytes
cfd9f7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torch.nn as nn
from model import (
SwitchTransformer,
SwitchTransformerLayer,
MultiHeadAttention,
SwitchFeedForward,
FeedForward,
)
from transformers import AutoTokenizer
device = 'cpu'
ff = FeedForward(768, 768*4)
attn = MultiHeadAttention(8, 768, 0.2)
st_ff = SwitchFeedForward(
capacity_factor=1.25,
drop_tokens=False,
n_experts=4,
expert=ff,
d_model=768,
is_scale_prob=True,
)
st_layer = SwitchTransformerLayer(
d_model=768,
attn=attn,
feed_forward=st_ff,
dropout_prob=0.2
)
model = SwitchTransformer(
layer=st_layer,
n_layers=4,
n_experts=4,
device=device,
load_balancing_loss_ceof=0.05,
).to(device)
model.load_state_dict(torch.load("switch_transformer.pt", map_location=torch.device('cpu')))
tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz")
|