language: en
datasets:
- wikitext
ByT5 base English fine tuned for OCR Correction
This model is a fine-tuned version of the byt5-base for OCR Correction. ByT5 was introduced in this paper and the idea and code for fine-tuning the model for OCR Correction was taken from here.
Model description
byt5-base-english-ocr-correction is a model that has taken the byt5-base model and fine-tuned it an OCR Correction dataset. The model has been fine-tuned to take an input sentence that has incorrectly transcribed from an OCR model and output a sentence that corrects the errors.
The model was trained by taking the wikitext dataset and adding synthetic OCR errors using nlpaug.
Intended uses & limitations
You can use the model for Text-to-Text Generation to remove errors caused by an OCR model.
How to use
from transformers import T5ForConditionalGeneration
import torch
import nlpaug.augmenter.char as nac
aug = nac.OcrAug(aug_char_p =0.4, aug_word_p = 0.6)
corrected_text = "Life is like a box of chocolates"
augmented_text = aug.augment(corrected_text)
model = T5ForConditionalGeneration.from_pretrained('yelpfeast/byt5-base-english-ocr-correction')
input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3 # add 3 for special tokens
loss = model(input_ids, labels=labels).loss # forward pass
from transformers import T5ForConditionalGeneration, AutoTokenizer
import nlpaug.augmenter.char as nac
aug = nac.OcrAug(aug_char_p =0.4, aug_word_p = 0.6)
corrected_text = "Life is like a box of chocolates"
augmented_text = aug.augment(corrected_text)
print(augmented_text)
model = T5ForConditionalGeneration.from_pretrained('yelpfeast/byt5-base-english-ocr-correction')
tokenizer = AutoTokenizer.from_pretrained("yelpfeast/byt5-base-english-ocr-correction")
inputs = tokenizer(augmented_text, return_tensors="pt", padding=True)
output_sequences = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
do_sample=False, # disable sampling to test if batching affects output
)
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
Limitations
The model has been trained on text that has been artificially corrupted to look like OCR errors. These errors may not be similar for all OCR models and hence the model may not do a good job at producing fully correct text.