|
--- |
|
language: |
|
- ru |
|
license: mit |
|
tags: |
|
- text2text-generation |
|
datasets: |
|
- Grpp/t5-russian-spell_I |
|
base_model: Grpp/rut5-base |
|
widget: |
|
- text: 'Исправь: нападавше иты кроме того при наадении на отдел уиполицииранение' |
|
--- |
|
|
|
# Model Card for Grpp/T5_spell-base |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
Модель для исправления опечаток на русском языке |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
|
|
|
|
|
- **Language(s) (NLP):** ['ru'] |
|
- **License:** mit |
|
- **Finetuned from model [optional]:** Grpp/rut5-base |
|
|
|
|
|
## Uses |
|
|
|
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. --> |
|
|
|
### Direct Use |
|
|
|
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. --> |
|
|
|
|
|
```python |
|
from transformers import pipeline |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer, Text2TextGenerationPipeline |
|
|
|
# Создание pipeline для генерации текста |
|
PIPELINE = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer, device=0) |
|
|
|
def answer_m(list_texts): |
|
texts = [] |
|
for txt in tqdm(list_texts): |
|
texts.append( |
|
PIPELINE( |
|
txt, |
|
max_length=256, |
|
repetition_penalty=1.5, |
|
temperature=0.7, |
|
top_k=50, |
|
num_return_sequences=1 |
|
)[0]['generated_text']) |
|
return texts |
|
|
|
text = 'нападавше иты кроме того при наадении на отдел уиполицииранение получилаи женщина из гражчданских сообщилон анронимныйистточни агентста тасс со ссылкой на источник пишет что у одногао из преступников быиевзрычатычгтка полицейские потребовали чтобнападавшие останвеились после чего те дотали ножи' |
|
prefix = 'Исправь: ' |
|
text_to_model = prefix + text |
|
|
|
answer_m([text_to_model]) |
|
|
|
# ['Нападавшие иты Кроме того, при нападении на отдел полиции ранение получила женщина из гражданских сообщил один аналогичный источник. Агентство ТАСС со ссылкой на источник пишет, что у одного из преступников были взрывчатка: полицейские потребовали, чтобы напавшие остановились после чего те достали ножы.'] |
|
|
|
``` |
|
|
|
|
|
|
|
## Training Details |
|
|
|
### Training Procedure |
|
|
|
|
|
```python |
|
import torch |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import AdamW |
|
from tqdm.auto import tqdm |
|
|
|
raw_model = 'Grpp/T5_spell-base' # предобученная модель |
|
|
|
DATASET = "Grpp/t5-russian-spell_I" # Введите наазвание название датасета |
|
|
|
model = T5ForConditionalGeneration.from_pretrained(raw_model).cuda(); |
|
tokenizer = T5Tokenizer.from_pretrained(raw_model) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) |
|
|
|
# Загрузка датасета |
|
new_dataset = load_dataset(DATASET) |
|
|
|
model.to('cuda') |
|
_ = model.config |
|
|
|
batch_size = 8 # сколько примеров показываем модели за один шаг |
|
report_steps = 1000 # раз в сколько шагов печатаем результат |
|
epochs = 1 # сколько раз мы покажем данные модели |
|
|
|
class TextDataset(Dataset): |
|
def __init__(self, tokenizer, pairs): |
|
self.tokenizer = tokenizer |
|
self.pairs = pairs |
|
|
|
def __len__(self): |
|
return len(self.pairs) |
|
|
|
def __getitem__(self, idx): |
|
question = self.pairs[idx]['input_text'].replace('Spell correct: ', 'Исправь: ') |
|
answer = self.pairs[idx]['label_text'] |
|
source = self.tokenizer(question, padding='max_length', truncation=True, max_length=256, return_tensors='pt') |
|
target = self.tokenizer(answer, padding='max_length', truncation=True, max_length=256, return_tensors='pt') |
|
target.input_ids[target.input_ids == 0] = -100 |
|
return source, target |
|
|
|
def train_epoch(model, dataloader, optimizer): |
|
model.train() |
|
losses = [] |
|
for i, (x, y) in enumerate(tqdm(dataloader)): |
|
optimizer.zero_grad() |
|
outputs = model( |
|
input_ids=x['input_ids'].squeeze().to(model.device), |
|
attention_mask=x['attention_mask'].squeeze().to(model.device), |
|
labels=y['input_ids'].squeeze().to(model.device), |
|
decoder_attention_mask=y['attention_mask'].squeeze().to(model.device), |
|
) |
|
loss = outputs.loss |
|
loss.backward() |
|
optimizer.step() |
|
|
|
losses.append(loss.item()) |
|
if i % report_steps == 0: |
|
print('step', i, 'loss', np.mean(losses[-report_steps:])) |
|
return np.mean(losses) |
|
|
|
|
|
# Создаем датасет и даталоадер |
|
dataset = TextDataset(tokenizer, new_dataset['train']) |
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
# Оптимизатор |
|
optimizer = AdamW(model.parameters(), lr=5e-5) |
|
|
|
|
|
model_name_t5 = 'T5_spell-base' |
|
# Обучение модели |
|
for epoch in range(epochs): |
|
print('EPOCH', epoch + 1) |
|
epoch_loss = train_epoch(model, dataloader, optimizer) |
|
print(f'Epoch {epoch + 1} Loss: {epoch_loss}') |
|
|
|
# Сохранение модели после каждой эпохи |
|
print('saving') |
|
model.save_pretrained(model_name_t5) |
|
tokenizer.save_pretrained(model_name_t5) |
|
``` |
|
|