|
--- |
|
license: openrail |
|
language: |
|
- ru |
|
library_name: transformers |
|
tags: |
|
- pytorch |
|
- causal-lm |
|
--- |
|
|
|
## CharGPT-96M |
|
|
|
Крошечная языковая модель с **посимвольной** токенизацией для всевозможных экспериментов, когда задача решается плохо из-за BPE токенизации на слова и их части. |
|
|
|
К примеру, если вы хотите делать детектор орфографии, или модельку для фонетическую транскрипцию и т.д., данная модель с посимвольной токенизацией может оказаться предпочтительнее, чем обычные GPT. |
|
|
|
Размер модели - **96 миллионов** параметров. |
|
|
|
### Особенности предварительной тренировки |
|
|
|
Я делал эту модель для экспериментов с русской поэзией в рамках проекта ["Литературная студия"](https://github.com/Koziev/verslibre). |
|
Поэтому корпус претрейна содержал значительное количество текстов поэтического формата. |
|
Это может повлиять на ваши downstream задачи. |
|
|
|
Объем корпуса претрейна - около 30B токенов. |
|
|
|
### Использование |
|
|
|
С библиотекой transformerts модель можно использовать штатным способом как обычную GPT: |
|
|
|
``` |
|
import os |
|
import torch |
|
import transformers |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model_name_or_path = 'inkoziev/chargpt-96M' |
|
model = transformers.GPT2LMHeadModel.from_pretrained(model_name_or_path) |
|
model.to(device) |
|
model.eval() |
|
|
|
tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name_or_path) |
|
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'}) |
|
|
|
prompt = '<s>У Лукоморья дуб зеленый\n' |
|
encoded_prompt = tokenizer.encode(prompt, return_tensors='pt') |
|
|
|
output_sequences = model.generate( |
|
input_ids=encoded_prompt.to(device), |
|
max_length=400, |
|
temperature=1.0, |
|
top_k=0, |
|
top_p=0.8, |
|
repetition_penalty=1.0, |
|
do_sample=True, |
|
num_return_sequences=5, |
|
pad_token_id=0, |
|
) |
|
|
|
for o in output_sequences: |
|
text = tokenizer.decode(o) |
|
if text.startswith('<s>'): |
|
text = text.replace('<s>', '') |
|
text = text[:text.index('</s>')].strip() |
|
print(text) |
|
print('-'*80) |
|
``` |
|
|