File size: 2,656 Bytes
eb186c4
 
70b0dde
 
 
8544327
 
 
 
 
 
 
0c6aaea
8544327
a4e806e
8544327
0a31f26
8544327
 
 
0d63905
 
 
8544327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
---
license: openrail
language:
- ru
library_name: transformers
tags:
  - pytorch
  - causal-lm
---

## CharGPT-96M

Крошечная языковая модель с **посимвольной** токенизацией для всевозможных экспериментов, когда задача решается плохо из-за BPE токенизации на слова и их части.

К примеру, если вы хотите делать детектор орфографии, или модельку для фонетическую транскрипцию и т.д., данная модель с посимвольной токенизацией может оказаться предпочтительнее, чем обычные GPT.

Размер модели - **96 миллионов** параметров.

### Особенности предварительной тренировки

Я делал эту модель для экспериментов с русской поэзией в рамках проекта ["Литературная студия"](https://github.com/Koziev/verslibre).
Поэтому корпус претрейна содержал значительное количество текстов поэтического формата.
Это может повлиять на ваши downstream задачи.

Объем корпуса претрейна - около 30B токенов.

### Использование

С библиотекой transformerts модель можно использовать штатным способом как обычную GPT:

```
import os
import torch
import transformers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name_or_path = 'inkoziev/chargpt-96M'
model = transformers.GPT2LMHeadModel.from_pretrained(model_name_or_path)
model.to(device)
model.eval()

tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name_or_path)
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})

prompt = '<s>У Лукоморья дуб зеленый\n'
encoded_prompt = tokenizer.encode(prompt, return_tensors='pt')

output_sequences = model.generate(
    input_ids=encoded_prompt.to(device),
    max_length=400,
    temperature=1.0,
    top_k=0,
    top_p=0.8,
    repetition_penalty=1.0,
    do_sample=True,
    num_return_sequences=5,
    pad_token_id=0,
)

for o in output_sequences:
    text = tokenizer.decode(o)
    if text.startswith('<s>'):
        text = text.replace('<s>', '')
    text = text[:text.index('</s>')].strip()
    print(text)
    print('-'*80)
```