|
|
|
import torch |
|
from transformers import ByT5Tokenizer, T5Tokenizer, T5Config |
|
import datasets |
|
|
|
tokenizer = ByT5Tokenizer.from_pretrained('google/byt5-small',use_fast = False) |
|
|
|
|
|
|
|
def add_eos_to_examples(example): |
|
|
|
example['input_text'] = 'question: %s context: %s </s>' % (example['asr_question_parsed'], example['asr_content_parsed']) |
|
example['target_text'] = '%s </s>' % example['asr_answer_parsed'] |
|
|
|
return example |
|
|
|
|
|
def convert_to_features(example_batch): |
|
input_encodings = tokenizer.batch_encode_plus(example_batch['input_text'], padding='max_length', truncation=True, max_length=1024) |
|
target_encodings = tokenizer.batch_encode_plus(example_batch['target_text'], padding='max_length', truncation=True, max_length=128) |
|
|
|
|
|
|
|
encodings = { |
|
'input_ids': input_encodings['input_ids'], |
|
'attention_mask': input_encodings['attention_mask'], |
|
'target_ids': target_encodings['input_ids'], |
|
'target_attention_mask': target_encodings['attention_mask'] |
|
} |
|
print(encodings.keys()) |
|
return encodings |
|
|
|
|
|
|
|
import pandas as pd |
|
from datasets import Dataset |
|
|
|
train_dataset = Dataset.from_pandas(pd.read_parquet("./hf/NMSQA-wav2vecu2/train_for_squad_reduce_silIsSpace_dedup.parquet")) |
|
valid_dataset = Dataset.from_pandas(pd.read_parquet("./hf/NMSQA-wav2vecu2/dev_for_squad_reduce_silIsSpace_dedup.parquet")) |
|
|
|
|
|
train_dataset = train_dataset.map(add_eos_to_examples, load_from_cache_file=False) |
|
|
|
train_dataset = train_dataset.map(convert_to_features, batched=True, load_from_cache_file=False) |
|
|
|
valid_dataset = valid_dataset.map(add_eos_to_examples, load_from_cache_file=False) |
|
valid_dataset = valid_dataset.map(convert_to_features, batched=True, load_from_cache_file=False) |
|
|
|
|
|
|
|
columns = ['input_ids', 'target_ids', 'attention_mask', 'target_attention_mask'] |
|
train_dataset.set_format(type="torch", columns=columns) |
|
valid_dataset.set_format(type="torch", columns=columns) |
|
from torch.utils.data import DataLoader |
|
from typing import Dict, List, Optional |
|
def collate_batch(batch: List) -> Dict[str, torch.Tensor]: |
|
""" |
|
Take a list of samples from a Dataset and collate them into a batch. |
|
Returns: |
|
A dictionary of tensors |
|
""" |
|
|
|
|
|
input_ids = torch.stack([example['input_ids'] for example in batch]) |
|
lm_labels = torch.stack([example['target_ids'] for example in batch]) |
|
lm_labels[lm_labels[:, :] == 0] = -100 |
|
attention_mask = torch.stack([example['attention_mask'] for example in batch]) |
|
decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in batch]) |
|
|
|
|
|
return { |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask, |
|
'labels': lm_labels, |
|
'decoder_attention_mask': decoder_attention_mask |
|
} |
|
|
|
|
|
|
|
|
|
print("training samples",len(train_dataset), "validation samples",len(valid_dataset)) |
|
|
|
import dataclasses |
|
import logging |
|
import os |
|
import sys |
|
from dataclasses import dataclass, field |
|
|
|
import json |
|
import numpy as np |
|
import torch |
|
|
|
from transformers import T5ForConditionalGeneration, ByT5Tokenizer, EvalPrediction, T5Tokenizer |
|
from transformers import ( |
|
HfArgumentParser, |
|
DataCollator, |
|
Trainer, |
|
TrainingArguments, |
|
set_seed, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. |
|
""" |
|
|
|
model_name_or_path: str = field( |
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} |
|
) |
|
tokenizer_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} |
|
) |
|
|
|
@dataclass |
|
class DataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
""" |
|
train_file_path: Optional[str] = field( |
|
default='train_data.pt', |
|
metadata={"help": "Path for cached train dataset"}, |
|
) |
|
valid_file_path: Optional[str] = field( |
|
default='valid_data.pt', |
|
metadata={"help": "Path for cached valid dataset"}, |
|
) |
|
max_len: Optional[int] = field( |
|
default=512, |
|
metadata={"help": "Max input length for the source text"}, |
|
) |
|
target_max_len: Optional[int] = field( |
|
default=32, |
|
metadata={"help": "Max input length for the target text"}, |
|
) |
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
|
|
|
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath('args.json')) |
|
|
|
if ( |
|
os.path.exists(training_args.output_dir) |
|
and os.listdir(training_args.output_dir) |
|
and training_args.do_train |
|
and not training_args.overwrite_output_dir |
|
): |
|
raise ValueError( |
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
set_seed(training_args.seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = ByT5Tokenizer.from_pretrained( |
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
) |
|
config = T5Config.from_pretrained(model_args.model_name_or_path) |
|
model = T5ForConditionalGeneration( |
|
config = config |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=valid_dataset, |
|
data_collator=collate_batch, |
|
) |
|
|
|
|
|
if training_args.do_train: |
|
trainer.train( |
|
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None |
|
) |
|
trainer.save_model() |
|
|
|
|
|
if trainer.is_world_master(): |
|
tokenizer.save_pretrained(training_args.output_dir) |
|
|
|
|
|
results = {} |
|
if training_args.do_eval and training_args.local_rank in [-1, 0]: |
|
logger.info("*** Evaluate ***") |
|
|
|
eval_output = trainer.evaluate() |
|
|
|
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") |
|
with open(output_eval_file, "w") as writer: |
|
logger.info("***** Eval results *****") |
|
for key in sorted(eval_output.keys()): |
|
logger.info(" %s = %s", key, str(eval_output[key])) |
|
writer.write("%s = %s\n" % (key, str(eval_output[key]))) |
|
|
|
results.update(eval_output) |
|
|
|
return results |
|
|
|
def _mp_fn(index): |
|
|
|
main() |
|
|
|
args_dict = { |
|
"tpu_num_cores": 1, |
|
'training_script': 'train_t5_squad.py', |
|
"model_name_or_path": 'voidful/phoneme_byt5_g2p_v1', |
|
"tokenizer_name": 'google/byt5-small', |
|
"max_len": 1024 , |
|
"target_max_len": 128, |
|
"output_dir": 'wav2vecu2-byt5small-randinit-squad', |
|
"overwrite_output_dir": True, |
|
"per_device_train_batch_size": 4, |
|
"per_device_eval_batch_size": 1, |
|
"gradient_accumulation_steps": 4, |
|
"learning_rate": 3e-4, |
|
"num_train_epochs": 3, |
|
"do_train": True, |
|
"do_eval":False, |
|
"save_strategy": "epoch", |
|
"save_total_limit" : 100, |
|
"push_to_hub" : False, |
|
"remove_unused_columns" : False |
|
|
|
} |
|
import os |
|
os.environ["WANDB_DISABLED"] = "true" |
|
with open('args.json', 'w') as f: |
|
json.dump(args_dict, f) |
|
import torch.multiprocessing as mp |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
main() |
|
|
|
|