Uzbek Punctuation Restoration Model πŸ‡ΊπŸ‡Ώ

This repository contains a fine-tuned version of xlm-roberta-base for Uzbek Punctuation Restoration, developed by booba-uz.

The model is designed to automatically restore punctuation marks in raw, unpunctuated Uzbek texts β€” making it highly useful for text post-processing from ASR systems, OCR outputs, and casual user input normalization.


πŸ“Œ Model Overview

  • Repository: booba-uz/punctuation_xlm_roberta_base
  • Base Model: xlm-roberta-base
  • Task: Token-level Punctuation Restoration
  • Target Language: Uzbek
  • Special Features:
    • LSTM Layer: Disabled (lstm_dim = -1)
    • CRF Layer: Disabled (use_crf = False)
    • Bert weights: Fully trainable (freeze_bert = False)
    • Augmentation: 15% probability, type = all (augment_type = all)
    • Max Sequence Length: 256
    • Learning Rate: 2e-6
    • Batch Size: 16
    • Training Epochs: 20

πŸš€ Training Results

The model was trained on an Uzbek punctuation-annotated dataset for 20 epochs, and showed steady convergence.

Metric Final Value
Train Loss 0.0502
Validation Loss 0.0364
Validation Accuracy 97.43%

πŸ† Evaluation Scores

Class (Punctuation) Precision Recall F1-Score
O (No punctuation) 98.39% 98.96% 98.67%
. (Period) 85.38% 76.91% 80.93%
, (Comma) 94.84% 95.01% 94.92%
? (Question mark) 95.59% 97.87% 96.72%
! (Exclamation mark) 91.65% 88.64% 90.12%

πŸ’‘ Use Cases

  • πŸ—£οΈ Speech-to-Text Post-processing:
    Automatically restore punctuation in transcriptions for better readability.

  • πŸ“ OCR Post-processing:
    Useful for adding punctuation to scanned Uzbek documents.

  • πŸ’¬ Text Normalization:
    Enhance NLP tasks like machine translation, summarization, and sentiment analysis by pre-cleaning raw text.


πŸ’» Usage Example

import re
import torch

import argparse
from model import DeepPunctuation, DeepPunctuationCRF
from config import *

parser = argparse.ArgumentParser(description='Punctuation restoration inference on text file')
parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available')
parser.add_argument('--pretrained-model', default='xlm-roberta-large', type=str, help='pretrained language model')
parser.add_argument('--lstm-dim', default=-1, type=int,
                    help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model')
parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'),
                    help='whether to use CRF layer or not')
parser.add_argument('--language', default='en', type=str, help='language English (en) oe Bangla (bn)')
parser.add_argument('--in-file', default='data/test_en.txt', type=str, help='path to inference file')
parser.add_argument('--weight-path', default='xlm-roberta-large.pt', type=str, help='model weight path')
parser.add_argument('--sequence-length', default=256, type=int,
                    help='sequence length to use when preparing dataset (default 256)')
parser.add_argument('--out-file', default='data/test_en_out.txt', type=str, help='output file location')

args = parser.parse_args()

# tokenizer
tokenizer = MODELS[args.pretrained_model][1].from_pretrained(args.pretrained_model)
token_style = MODELS[args.pretrained_model][3]

# logs
model_save_path = args.weight_path

# Model
device = torch.device('cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu')
if args.use_crf:
    deep_punctuation = DeepPunctuationCRF(args.pretrained_model, freeze_bert=False, lstm_dim=args.lstm_dim)
else:
    deep_punctuation = DeepPunctuation(args.pretrained_model, freeze_bert=False, lstm_dim=args.lstm_dim)
deep_punctuation.to(device)


def inference():
    deep_punctuation.load_state_dict(torch.load(model_save_path))
    deep_punctuation.eval()

    with open(args.in_file, 'r', encoding='utf-8') as f:
        text = f.read()
    text = re.sub(r"[,:\-–.!;?]", '', text)
    words_original_case = text.split()
    words = text.lower().split()

    word_pos = 0
    sequence_len = args.sequence_length
    result = ""
    decode_idx = 0
    punctuation_map = {0: '', 1: ',', 2: '.', 3: '?'}
    if args.language != 'en':
        punctuation_map[2] = 'ΰ₯€'

    while word_pos < len(words):
        x = [TOKEN_IDX[token_style]['START_SEQ']]
        y_mask = [0]

        while len(x) < sequence_len and word_pos < len(words):
            tokens = tokenizer.tokenize(words[word_pos])
            if len(tokens) + len(x) >= sequence_len:
                break
            else:
                for i in range(len(tokens) - 1):
                    x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
                    y_mask.append(0)
                x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
                y_mask.append(1)
                word_pos += 1
        x.append(TOKEN_IDX[token_style]['END_SEQ'])
        y_mask.append(0)
        if len(x) < sequence_len:
            x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
            y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
        attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]

        x = torch.tensor(x).reshape(1,-1)
        y_mask = torch.tensor(y_mask)
        attn_mask = torch.tensor(attn_mask).reshape(1,-1)
        x, attn_mask, y_mask = x.to(device), attn_mask.to(device), y_mask.to(device)



        with torch.no_grad():
            if args.use_crf:
                y = torch.zeros(x.shape[0])
                y_predict = deep_punctuation(x, attn_mask, y)
                y_predict = y_predict.view(-1)
            else:
                y_predict = deep_punctuation(x, attn_mask)
                y_predict = y_predict.view(-1, y_predict.shape[2])
                y_predict = torch.argmax(y_predict, dim=1).view(-1)
        for i in range(y_mask.shape[0]):
            if y_mask[i] == 1:
                result += words_original_case[decode_idx] + punctuation_map[y_predict[i].item()] + ' '
                decode_idx += 1
    print('Punctuated text')
    print(result)
    with open(args.out_file, 'w', encoding='utf-8') as f:
        f.write(result)


if __name__ == '__main__':
    inference()


πŸ“„ License

MIT License β€” Free to use, modify, and distribute.


🀝 Contributing

Pull requests are welcome! For major changes, please open an issue first to discuss your ideas.


πŸ™Œ Acknowledgements

Thanks to Huggingface and the Uzbek NLP community for providing tools and resources to make this model possible!

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for booba-uz/punctuation_xlm_roberta_base

Finetuned
(3013)
this model