|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import copy |
|
import logging |
|
from dataclasses import dataclass, field |
|
from typing import Optional, Dict, Sequence |
|
import io |
|
import torch |
|
import transformers |
|
from torch.utils.data import Dataset |
|
from transformers import Trainer |
|
import argparse |
|
import json |
|
import random;random.seed(42) |
|
|
|
def _make_r_io_base(f, mode: str): |
|
if not isinstance(f, io.IOBase): |
|
f = open(f, mode=mode) |
|
return f |
|
|
|
def jload(f, mode="r"): |
|
"""Load a .json file into a dictionary.""" |
|
f = _make_r_io_base(f, mode) |
|
jdict = json.load(f) |
|
f.close() |
|
return jdict |
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_PAD_TOKEN = "[PAD]" |
|
DEFAULT_EOS_TOKEN = "</s>" |
|
DEFAULT_BOS_TOKEN = "<s>" |
|
DEFAULT_UNK_TOKEN = "<unk>" |
|
PROMPT_DICT = { |
|
"lean4": ( |
|
"Statement and proof in natural language:\n\n" |
|
"{statement_text}\n\n" |
|
"Translate the statement and proof in natural language to lean4:" |
|
), |
|
"plain": ( |
|
"{statement_text}" |
|
), |
|
"statement": ( |
|
"Statement in natural language:\n" |
|
"{problem}\n" |
|
"Translate the statement in natural language to Lean4:" |
|
), |
|
"prompt_no_input": ( |
|
"Below is an instruction that describes a task. " |
|
"Write a response that appropriately completes the request.\n\n" |
|
"### Instruction:\n{instruction}\n\n### Response:" |
|
), |
|
} |
|
|
|
@dataclass |
|
class ModelArguments: |
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
|
|
|
|
|
@dataclass |
|
class DataArguments: |
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
|
|
|
|
|
@dataclass |
|
class TrainingArguments(transformers.TrainingArguments): |
|
cache_dir: Optional[str] = field(default=None) |
|
optim: str = field(default="adamw_torch") |
|
model_max_length: int = field( |
|
default=2048, |
|
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, |
|
) |
|
overwrite_output_dir: bool = field(default=True) |
|
|
|
|
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): |
|
"""Collects the state dict and dump to disk.""" |
|
state_dict = trainer.model.state_dict() |
|
if trainer.args.should_save: |
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} |
|
del state_dict |
|
trainer._save(output_dir, state_dict=cpu_state_dict) |
|
|
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict: Dict, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
model: transformers.PreTrainedModel, |
|
): |
|
"""Resize tokenizer and embedding. |
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = model.get_input_embeddings().weight.data |
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
|
"""Tokenize a list of strings.""" |
|
tokenized_list = [ |
|
tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
) |
|
for text in strings |
|
] |
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] |
|
input_ids_lens = labels_lens = [ |
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list |
|
] |
|
return dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
input_ids_lens=input_ids_lens, |
|
labels_lens=labels_lens, |
|
) |
|
|
|
|
|
def preprocess( |
|
sources: Sequence[str], |
|
targets: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
) -> Dict: |
|
"""Preprocess the data by tokenizing.""" |
|
examples = [s + t for s, t in zip(sources, targets)] |
|
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] |
|
input_ids = examples_tokenized["input_ids"] |
|
labels = copy.deepcopy(input_ids) |
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): |
|
label[:source_len] = IGNORE_INDEX |
|
return dict(input_ids=input_ids, labels=labels) |
|
|
|
class SupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer): |
|
super(SupervisedDataset, self).__init__() |
|
logging.warning("Loading data...") |
|
data_path = data_args.data_path |
|
try: |
|
data_path = data_path_map[data_path] |
|
except: |
|
data_path = data_path |
|
list_data_dict = [] |
|
for item in data_path.split(','): |
|
try: |
|
list_data_dict += jload(item) |
|
|
|
except BaseException: |
|
with open(item, 'r') as f: |
|
lines = f.readlines() |
|
list_data_dict += [json.loads(line.strip()) for line in lines] |
|
|
|
list_data_dict = random.sample(list_data_dict, len(list_data_dict)) |
|
list_data_dict = list_data_dict[:data_args.data_length] |
|
|
|
logging.warning("Formatting inputs...") |
|
prompt_lean4 = PROMPT_DICT["statement"] |
|
|
|
|
|
|
|
list_data_dict = [{'instruction':prompt_lean4.format(problem= data['problem']), 'input':'', 'output':data['statement']} for data in list_data_dict] |
|
print(f"len of {len(list_data_dict)}") |
|
sources = [example['instruction'] for example in list_data_dict] |
|
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] |
|
|
|
|
|
self.sources = sources |
|
self.targets = targets |
|
|
|
def __len__(self): |
|
return len(self.sources) |
|
|
|
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
return dict(input_ids=self.input_ids[i], labels=self.labels[i]) |
|
|
|
def __getitem__(self, i): |
|
return dict(input_ids=self.sources[i], labels=self.targets[i]) |
|
|
|
@dataclass |
|
class DataCollatorForSupervisedDataset(object): |
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id |
|
) |
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
return dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
) |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
sources = [] |
|
targets = [] |
|
for instance in instances: |
|
source = instance['input_ids'] |
|
target = instance['labels'] |
|
sources.append(source) |
|
targets.append(target) |
|
|
|
data_dict = preprocess(sources, targets, self.tokenizer) |
|
input_ids, labels = data_dict['input_ids'], data_dict['labels'] |
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id |
|
) |
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
return dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
) |
|
|
|
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args) |
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
|
|
|
|
|
os.environ["WANDB_PROJECT"] = "auto_statement" |
|
|
|
def train(): |
|
|
|
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True) |
|
data_args.data_length = int(remaining_args[1]) |
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2", |
|
) |
|
|
|
model.config.use_cache = False |
|
model.gradient_checkpointing_enable() |
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
if tokenizer.pad_token is None: |
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), |
|
tokenizer=tokenizer, |
|
model=model, |
|
) |
|
if "llama" in model_args.model_name_or_path: |
|
tokenizer.add_special_tokens( |
|
{ |
|
"eos_token": DEFAULT_EOS_TOKEN, |
|
"bos_token": DEFAULT_BOS_TOKEN, |
|
"unk_token": DEFAULT_UNK_TOKEN, |
|
} |
|
) |
|
try: |
|
tokenizer.pad_token = tokenizer.unk_token |
|
except: |
|
pass |
|
|
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
|
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) |
|
trainer.train() |
|
model.config.use_cache = True |
|
|
|
|
|
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|