Uzbek Punctuation Restoration Model πΊπΏ
This repository contains a fine-tuned version of xlm-roberta-base
for Uzbek Punctuation Restoration, developed by booba-uz.
The model is designed to automatically restore punctuation marks in raw, unpunctuated Uzbek texts β making it highly useful for text post-processing from ASR systems, OCR outputs, and casual user input normalization.
π Model Overview
- Repository:
booba-uz/punctuation_xlm_roberta_base
- Base Model:
xlm-roberta-base
- Task: Token-level Punctuation Restoration
- Target Language: Uzbek
- Special Features:
- LSTM Layer: Disabled (
lstm_dim = -1
) - CRF Layer: Disabled (
use_crf = False
) - Bert weights: Fully trainable (
freeze_bert = False
) - Augmentation:
15%
probability, type =all
(augment_type = all
) - Max Sequence Length:
256
- Learning Rate:
2e-6
- Batch Size:
16
- Training Epochs:
20
- LSTM Layer: Disabled (
π Training Results
The model was trained on an Uzbek punctuation-annotated dataset for 20 epochs, and showed steady convergence.
Metric | Final Value |
---|---|
Train Loss | 0.0502 |
Validation Loss | 0.0364 |
Validation Accuracy | 97.43% |
π Evaluation Scores
Class (Punctuation) | Precision | Recall | F1-Score |
---|---|---|---|
O (No punctuation) |
98.39% | 98.96% | 98.67% |
. (Period) |
85.38% | 76.91% | 80.93% |
, (Comma) |
94.84% | 95.01% | 94.92% |
? (Question mark) |
95.59% | 97.87% | 96.72% |
! (Exclamation mark) |
91.65% | 88.64% | 90.12% |
π‘ Use Cases
π£οΈ Speech-to-Text Post-processing:
Automatically restore punctuation in transcriptions for better readability.π OCR Post-processing:
Useful for adding punctuation to scanned Uzbek documents.π¬ Text Normalization:
Enhance NLP tasks like machine translation, summarization, and sentiment analysis by pre-cleaning raw text.
π» Usage Example
import re
import torch
import argparse
from model import DeepPunctuation, DeepPunctuationCRF
from config import *
parser = argparse.ArgumentParser(description='Punctuation restoration inference on text file')
parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available')
parser.add_argument('--pretrained-model', default='xlm-roberta-large', type=str, help='pretrained language model')
parser.add_argument('--lstm-dim', default=-1, type=int,
help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model')
parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'),
help='whether to use CRF layer or not')
parser.add_argument('--language', default='en', type=str, help='language English (en) oe Bangla (bn)')
parser.add_argument('--in-file', default='data/test_en.txt', type=str, help='path to inference file')
parser.add_argument('--weight-path', default='xlm-roberta-large.pt', type=str, help='model weight path')
parser.add_argument('--sequence-length', default=256, type=int,
help='sequence length to use when preparing dataset (default 256)')
parser.add_argument('--out-file', default='data/test_en_out.txt', type=str, help='output file location')
args = parser.parse_args()
# tokenizer
tokenizer = MODELS[args.pretrained_model][1].from_pretrained(args.pretrained_model)
token_style = MODELS[args.pretrained_model][3]
# logs
model_save_path = args.weight_path
# Model
device = torch.device('cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu')
if args.use_crf:
deep_punctuation = DeepPunctuationCRF(args.pretrained_model, freeze_bert=False, lstm_dim=args.lstm_dim)
else:
deep_punctuation = DeepPunctuation(args.pretrained_model, freeze_bert=False, lstm_dim=args.lstm_dim)
deep_punctuation.to(device)
def inference():
deep_punctuation.load_state_dict(torch.load(model_save_path))
deep_punctuation.eval()
with open(args.in_file, 'r', encoding='utf-8') as f:
text = f.read()
text = re.sub(r"[,:\-β.!;?]", '', text)
words_original_case = text.split()
words = text.lower().split()
word_pos = 0
sequence_len = args.sequence_length
result = ""
decode_idx = 0
punctuation_map = {0: '', 1: ',', 2: '.', 3: '?'}
if args.language != 'en':
punctuation_map[2] = 'ΰ₯€'
while word_pos < len(words):
x = [TOKEN_IDX[token_style]['START_SEQ']]
y_mask = [0]
while len(x) < sequence_len and word_pos < len(words):
tokens = tokenizer.tokenize(words[word_pos])
if len(tokens) + len(x) >= sequence_len:
break
else:
for i in range(len(tokens) - 1):
x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
y_mask.append(0)
x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
y_mask.append(1)
word_pos += 1
x.append(TOKEN_IDX[token_style]['END_SEQ'])
y_mask.append(0)
if len(x) < sequence_len:
x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]
x = torch.tensor(x).reshape(1,-1)
y_mask = torch.tensor(y_mask)
attn_mask = torch.tensor(attn_mask).reshape(1,-1)
x, attn_mask, y_mask = x.to(device), attn_mask.to(device), y_mask.to(device)
with torch.no_grad():
if args.use_crf:
y = torch.zeros(x.shape[0])
y_predict = deep_punctuation(x, attn_mask, y)
y_predict = y_predict.view(-1)
else:
y_predict = deep_punctuation(x, attn_mask)
y_predict = y_predict.view(-1, y_predict.shape[2])
y_predict = torch.argmax(y_predict, dim=1).view(-1)
for i in range(y_mask.shape[0]):
if y_mask[i] == 1:
result += words_original_case[decode_idx] + punctuation_map[y_predict[i].item()] + ' '
decode_idx += 1
print('Punctuated text')
print(result)
with open(args.out_file, 'w', encoding='utf-8') as f:
f.write(result)
if __name__ == '__main__':
inference()
π License
MIT License β Free to use, modify, and distribute.
π€ Contributing
Pull requests are welcome! For major changes, please open an issue first to discuss your ideas.
π Acknowledgements
Thanks to Huggingface and the Uzbek NLP community for providing tools and resources to make this model possible!
Model tree for booba-uz/punctuation_xlm_roberta_base
Base model
FacebookAI/xlm-roberta-base