ModernBERT Fine-Tuned on Persian Data

Persian ModernBERT is a Persian-language Masked Language Model (MLM) fine-tuned with a custom tokenizer on a massive corpus of 2.5 billion tokens, exceeding the 1.3 billion tokens ParsBERT is trained on. This model leverages state-of-the-art attention mechanisms.

Model Details

  • Base Model: answerdotai/ModernBERT-base
  • Tokenizer: Custom, optimized for Persian
  • Corpus: 2.5 billion Persian tokens from diverse sources
  • Objective: Masked Language Modeling (MLM)
  • Attention Mechanism: Flash Attention v2
  • Precision: torch.bfloat16 for efficient computation on modern hardware

Usage

You can use these models directly with the transformers library. Until the next transformers release, doing so requires installing transformers from main:

pip install git+https://github.com/huggingface/transformers.git

Since ModernBERT is a Masked Language Model (MLM), you can use the fill-mask pipeline or load it via AutoModelForMaskedLM. To use ModernBERT for downstream tasks like classification, retrieval, or QA, fine-tune it following standard BERT fine-tuning recipes.

⚠️ If your GPU supports it, we recommend using ModernBERT with Flash Attention 2 to reach the highest efficiency. To do so, install Flash Attention as follows, then use the model as normal:

pip install flash-attn

Inference on CPU

Load the Model and Tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load custom tokenizer and fine-tuned model
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base")
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base", attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="cpu")

Example: Masked Token Prediction

text = "حال و [MASK] مردم خوب است."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k:v.cpu() for k, v in inputs.items()}
token_logits = model(**inputs).logits

# Find the [MASK] token and decode top predictions
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

Inference on GPU

Load the Model and Tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load custom tokenizer and fine-tuned model
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base")
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="cuda")

Example: Masked Token Prediction

text = "حال و [MASK] مردم خوب است."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k:v.cuda() for k, v in inputs.items()}
token_logits = model(**inputs).logits

# Find the [MASK] token and decode top predictions
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

Training Details

Dataset

The model was fine-tuned on a custom dataset with 2.5 billion Persian tokens. The dataset was preprocessed and tokenized using a custom tokenizer designed to maximize efficiency and coverage for Persian.

Training Configuration

  • Optimizer: AdamW
  • Learning Rate: 6e-4
  • Batch Size: 32
  • Epochs: 2
  • Scheduler: Inverse square root
  • Precision: bfloat16 for faster computation and lower memory usage
  • Masking Strategy: Whole Word Masking (WWM) with a probability of 30%

Efficient Training with Flash Attention

The model uses the flash_attention_2 implementation, significantly reducing memory overhead while accelerating training on large datasets.

Downloads last month
32
Safetensors
Model size
150M params
Tensor type
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for myrkur/Persian-ModernBert-base

Finetuned
(152)
this model