scandinavian-tokenizer / train_tokenizer.py
versae's picture
Scandi+English tokenizer on OSCAR
e5b62ac
import argparse
import json
import os
import tempfile
from pathlib import Path
from tqdm import tqdm
from datasets import load_dataset
from tokenizers import SentencePieceBPETokenizer
from transformers import LlamaTokenizerFast, TrainingArguments, AutoTokenizer
def main(args):
# Load the dataset from the huggingface Hub and prepare it for training
if args.dataset_name is not None:
if args.dataset_type:
if os.path.isfile(args.dataset_name):
data_files = [args.dataset_name]
else:
data_files = os.listdir(args.dataset_name)
data_files = [Path(args.dataset_name) / f for f in data_files]
print(f"Training on {len(data_files)} files")
dataset = load_dataset(args.dataset_type,
data_files=data_files,
split=args.dataset_split,
token=args.hub_token if args.hub_token else None
)
else:
dataset = load_dataset(args.dataset_name,
split=args.dataset_split,
streaming=True,
token=args.hub_token if args.hub_token else None
)
print(dataset)
else:
raise ValueError("No dataset name provided or dataset is already tokenized")
# Remove non text columns
dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"])
# Randomize docs
dataset = dataset.shuffle(seed=args.seed)
# Select `num_samples` from the dataset
if args.num_samples:
dataset = dataset.select(range(args.num_samples))
# Create a SentencePieceBPETokenizer
tokenizer = SentencePieceBPETokenizer()
# Train the SentencePieceBPETokenizer on the dataset
tokenizer.train_from_iterator(
iterator=dataset['text'],
vocab_size=args.vocab_size,
show_progress=True,
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
)
# Save the tokenizer
new_tokenizer_file = tempfile.NamedTemporaryFile(prefix='tokenizer_', suffix='.json').name
tokenizer.save(new_tokenizer_file, pretty=True)
# Load reference tokenizer
if args.reference_tokenizer is not None and args.hub_token is not None:
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_tokenizer, token=args.hub_token if args.hub_token else None)
reference_tokenizer_path = tempfile.TemporaryDirectory().name
reference_tokenizer.save_pretrained(reference_tokenizer_path)
else:
raise ValueError("No tokenizer name provided or no hub token provided. Try using `--reference_tokenizer 'mistralai/Mistral-7B-Instruct-v0.2'")
# Read and dump the json file for the new tokenizer and the reference tokenizer
with open(new_tokenizer_file) as f:
new_tokenizer_json = json.load(f)
with open(Path(reference_tokenizer_path) / "tokenizer.json") as f:
reference_tokenizer_json = json.load(f)
# Add the reference tokenizer's config to the new tokenizer's config
new_tokenizer_json["normalizer"] = reference_tokenizer_json["normalizer"]
new_tokenizer_json["pre_tokenizer"] = reference_tokenizer_json["pre_tokenizer"]
new_tokenizer_json["post_processor"] = reference_tokenizer_json["post_processor"]
new_tokenizer_json["decoder"] = reference_tokenizer_json["decoder"]
new_tokenizer_json["model"]['fuse_unk'] = reference_tokenizer_json["model"]['fuse_unk']
new_tokenizer_json["model"]['byte_fallback'] = reference_tokenizer_json["model"]['byte_fallback']
# Dump the new tokenizer's config
with open(new_tokenizer_file, "w") as f:
json.dump(new_tokenizer_json, f, indent=2, ensure_ascii=False)
# Load the new tokenizer as a LlamaTokenizerFast
new_llama_tokenizer = LlamaTokenizerFast(
tokenizer_file=new_tokenizer_file,
name_or_path=args.reference_tokenizer + "-tokenizer",
unk_token="<unk>",
unk_token_id=0,
bos_token="<s>",
bos_token_id=1,
eos_token="</s>",
eos_token_id=2,
pad_token="<pad>",
pad_token_id=3,
padding_side="right",
)
# Save the new tokenizer
new_llama_tokenizer.save_pretrained(args.output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a new Llama tokenizer")
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to be tokenized",
)
parser.add_argument(
"--dataset_type",
type=str,
default=None,
help="The type, 'text', 'json', or 'csv'. Leave blank for regular HF datasets",
)
parser.add_argument(
"--dataset_split",
type=str,
default=None,
help="The split of the dataset to be tokenized",
)
parser.add_argument(
"--hub_token",
type=str,
default=None,
help="The token to access the dataset on the hub",
)
parser.add_argument(
"--reference_tokenizer",
type=str,
default=None,
help="The name of the reference tokenizer to use",
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="set random seed",
)
parser.add_argument(
"--num_samples",
type=int,
default=None,
help="Number of samples to use from the dataset",
)
parser.add_argument(
"--vocab_size",
type=int,
default=None,
help="Vocabulary size to use for the tokenizer",
)
parser.add_argument(
"--output",
type=str,
default="./",
help="Output path for the new tokenizer",
)
args = parser.parse_args()
main(args)
# How to run:
# python train_tokenizer.py --dataset_name texts/all.txt --dataset_type text --dataset_split train --reference_tokenizer mistralai/Mistral-7B-Instruct-v0.2 --vocab_size 32768 --hub_token True