Spaces:
Runtime error
Runtime error
''' | |
Author: Qiguang Chen | |
Date: 2023-01-11 10:39:26 | |
LastEditors: Qiguang Chen | |
LastEditTime: 2023-02-08 00:42:56 | |
Description: manage all process of model training and prediction. | |
''' | |
import os | |
import random | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from common import utils | |
from common.loader import DataFactory | |
from common.logger import Logger | |
from common.metric import Evaluator | |
from common.tokenizer import get_tokenizer, get_tokenizer_class, load_embedding | |
from common.utils import InputData, instantiate | |
from common.utils import OutputData | |
from common.config import Config | |
import dill | |
class ModelManager(object): | |
def __init__(self, config: Config): | |
"""create model manager by config | |
Args: | |
config (Config): configuration to manage all process in OpenSLU | |
""" | |
# init config | |
self.config = config | |
self.__set_seed(self.config.base.get("seed")) | |
self.device = self.config.base.get("device") | |
# enable accelerator | |
if "accelerator" in self.config and self.config["accelerator"].get("use_accelerator"): | |
from accelerate import Accelerator | |
self.accelerator = Accelerator(log_with="wandb") | |
else: | |
self.accelerator = None | |
if self.config.base.get("train"): | |
self.tokenizer = get_tokenizer( | |
self.config.tokenizer.get("_tokenizer_name_")) | |
self.logger = Logger( | |
"wandb", self.config.base["name"], start_time=config.start_time, accelerator=self.accelerator) | |
# init dataloader & load data | |
if self.config.base.get("save_dir"): | |
self.model_save_dir = self.config.base["save_dir"] | |
else: | |
if not os.path.exists("save/"): | |
os.mkdir("save/") | |
self.model_save_dir = "save/" + config.start_time | |
if not os.path.exists(self.model_save_dir): | |
os.mkdir(self.model_save_dir) | |
batch_size = self.config.base["batch_size"] | |
df = DataFactory(tokenizer=self.tokenizer, | |
use_multi_intent=self.config.base.get("multi_intent"), | |
to_lower_case=self.config.base.get("_to_lower_case_")) | |
train_dataset = df.load_dataset(self.config.dataset, split="train") | |
# update label and vocabulary | |
df.update_label_names(train_dataset) | |
df.update_vocabulary(train_dataset) | |
# init tokenizer config and dataloaders | |
tokenizer_config = {key: self.config.tokenizer[key] | |
for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"} | |
self.train_dataloader = df.get_data_loader(train_dataset, | |
batch_size, | |
shuffle=True, | |
device=self.device, | |
enable_label=True, | |
align_mode=self.config.tokenizer.get( | |
"_align_mode_"), | |
label2tensor=True, | |
**tokenizer_config) | |
dev_dataset = df.load_dataset( | |
self.config.dataset, split="validation") | |
self.dev_dataloader = df.get_data_loader(dev_dataset, | |
batch_size, | |
shuffle=False, | |
device=self.device, | |
enable_label=True, | |
align_mode=self.config.tokenizer.get( | |
"_align_mode_"), | |
label2tensor=False, | |
**tokenizer_config) | |
df.update_vocabulary(dev_dataset) | |
# add intent label num and slot label num to config | |
if int(self.config.get_intent_label_num()) == 0 or int(self.config.get_slot_label_num()) == 0: | |
self.intent_list = df.intent_label_list | |
self.intent_dict = df.intent_label_dict | |
self.config.set_intent_label_num(len(self.intent_list)) | |
self.slot_list = df.slot_label_list | |
self.slot_dict = df.slot_label_dict | |
self.config.set_slot_label_num(len(self.slot_list)) | |
self.config.set_vocab_size(self.tokenizer.vocab_size) | |
# autoload embedding for non-pretrained encoder | |
if self.config["model"]["encoder"].get("embedding") and self.config["model"]["encoder"]["embedding"].get( | |
"load_embedding_name"): | |
self.config["model"]["encoder"]["embedding"]["embedding_matrix"] = load_embedding(self.tokenizer, | |
self.config["model"][ | |
"encoder"][ | |
"embedding"].get( | |
"load_embedding_name")) | |
# fill template in config | |
self.config.autoload_template() | |
# save config | |
self.logger.set_config(self.config) | |
self.model = None | |
self.optimizer = None | |
self.total_step = None | |
self.lr_scheduler = None | |
if self.config.tokenizer.get("_tokenizer_name_") == "word_tokenizer": | |
self.tokenizer.save(os.path.join(self.model_save_dir, "tokenizer.json")) | |
utils.save_json(os.path.join( | |
self.model_save_dir, "label.json"), {"intent": self.intent_list,"slot": self.slot_list}) | |
if self.config.base.get("test"): | |
self.test_dataset = df.load_dataset( | |
self.config.dataset, split="test") | |
self.test_dataloader = df.get_data_loader(self.test_dataset, | |
batch_size, | |
shuffle=False, | |
device=self.device, | |
enable_label=True, | |
align_mode=self.config.tokenizer.get( | |
"_align_mode_"), | |
label2tensor=False, | |
**tokenizer_config) | |
def init_model(self, model): | |
"""init model, optimizer, lr_scheduler | |
Args: | |
model (Any): pytorch model | |
""" | |
self.model = model | |
self.model.to(self.device) | |
if self.config.base.get("train"): | |
self.optimizer = instantiate( | |
self.config["optimizer"])(self.model.parameters()) | |
self.total_step = int(self.config.base.get( | |
"epoch_num")) * len(self.train_dataloader) | |
self.lr_scheduler = instantiate(self.config["scheduler"])( | |
optimizer=self.optimizer, | |
num_training_steps=self.total_step | |
) | |
if self.accelerator is not None: | |
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare( | |
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler) | |
if self.config.base.get("load_dir_path"): | |
self.accelerator.load_state(self.config.base.get("load_dir_path")) | |
# self.dev_dataloader = self.accelerator.prepare(self.dev_dataloader) | |
def eval(self, step: int, best_metric: float) -> float: | |
""" evaluation models. | |
Args: | |
step (int): which step the model has trained in | |
best_metric (float): last best metric value to judge whether to test or save model | |
Returns: | |
float: updated best metric value | |
""" | |
# TODO: save dev | |
_, res = self.__evaluate(self.model, self.dev_dataloader) | |
self.logger.log_metric(res, metric_split="dev", step=step) | |
if res[self.config.base.get("best_key")] > best_metric: | |
best_metric = res[self.config.base.get("best_key")] | |
outputs, test_res = self.__evaluate( | |
self.model, self.test_dataloader) | |
if not os.path.exists(self.model_save_dir): | |
os.mkdir(self.model_save_dir) | |
if self.accelerator is None: | |
torch.save(self.model, os.path.join( | |
self.model_save_dir, "model.pkl")) | |
torch.save(self.optimizer, os.path.join( | |
self.model_save_dir, "optimizer.pkl")) | |
torch.save(self.lr_scheduler, os.path.join( | |
self.model_save_dir, "lr_scheduler.pkl"), pickle_module=dill) | |
torch.save(step, os.path.join( | |
self.model_save_dir, "step.pkl")) | |
else: | |
self.accelerator.wait_for_everyone() | |
unwrapped_model = self.accelerator.unwrap_model(self.model) | |
self.accelerator.save(unwrapped_model.state_dict( | |
), os.path.join(self.model_save_dir, "model.pkl")) | |
self.accelerator.save_state(output_dir=self.model_save_dir) | |
outputs.save(self.model_save_dir, self.test_dataset) | |
self.logger.log_metric(test_res, metric_split="test", step=step) | |
return best_metric | |
def train(self) -> float: | |
""" train models. | |
Returns: | |
float: updated best metric value | |
""" | |
step = 0 | |
best_metric = 0 | |
progress_bar = tqdm(range(self.total_step)) | |
for _ in range(int(self.config.base.get("epoch_num"))): | |
for data in self.train_dataloader: | |
if step == 0: | |
self.logger.info(data.get_item( | |
0, tokenizer=self.tokenizer, intent_map=self.intent_list, slot_map=self.slot_list)) | |
output = self.model(data) | |
if self.accelerator is not None and hasattr(self.model, "module"): | |
loss, intent_loss, slot_loss = self.model.module.compute_loss( | |
pred=output, target=data) | |
else: | |
loss, intent_loss, slot_loss = self.model.compute_loss( | |
pred=output, target=data) | |
self.logger.log_loss(loss, "Loss", step=step) | |
self.logger.log_loss(intent_loss, "Intent Loss", step=step) | |
self.logger.log_loss(slot_loss, "Slot Loss", step=step) | |
self.optimizer.zero_grad() | |
if self.accelerator is not None: | |
self.accelerator.backward(loss) | |
else: | |
loss.backward() | |
self.optimizer.step() | |
self.lr_scheduler.step() | |
if not self.config.base.get("eval_by_epoch") and step % self.config.base.get( | |
"eval_step") == 0 and step != 0: | |
best_metric = self.eval(step, best_metric) | |
step += 1 | |
progress_bar.update(1) | |
if self.config.base.get("eval_by_epoch"): | |
best_metric = self.eval(step, best_metric) | |
self.logger.finish() | |
return best_metric | |
def __set_seed(self, seed_value: int): | |
"""Manually set random seeds. | |
Args: | |
seed_value (int): random seed | |
""" | |
random.seed(seed_value) | |
np.random.seed(seed_value) | |
torch.manual_seed(seed_value) | |
torch.random.manual_seed(seed_value) | |
os.environ['PYTHONHASHSEED'] = str(seed_value) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed_value) | |
torch.cuda.manual_seed_all(seed_value) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = True | |
return | |
def __evaluate(self, model, dataloader): | |
model.eval() | |
inps = InputData() | |
outputs = OutputData() | |
for data in dataloader: | |
torch.cuda.empty_cache() | |
output = model(data) | |
if self.accelerator is not None and hasattr(self.model, "module"): | |
decode_output = model.module.decode(output, data) | |
else: | |
decode_output = model.decode(output, data) | |
decode_output.map_output(slot_map=self.slot_list, | |
intent_map=self.intent_list) | |
data, decode_output = utils.remove_slot_ignore_index( | |
data, decode_output, ignore_index="#") | |
inps.merge_input_data(data) | |
outputs.merge_output_data(decode_output) | |
if "metric" in self.config: | |
res = Evaluator.compute_all_metric( | |
inps, outputs, intent_label_map=self.intent_dict, metric_list=self.config.metric) | |
else: | |
res = Evaluator.compute_all_metric( | |
inps, outputs, intent_label_map=self.intent_dict) | |
model.train() | |
return outputs, res | |
def load(self): | |
self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"),map_location=self.config.base["device"]) | |
if self.config.tokenizer["_tokenizer_name_"] == "word_tokenizer": | |
self.tokenizer = get_tokenizer_class(self.config.tokenizer["_tokenizer_name_"]).from_file( | |
os.path.join(self.config.base["model_dir"], "tokenizer.json")) | |
else: | |
self.tokenizer = get_tokenizer(self.config.tokenizer["_tokenizer_name_"]) | |
self.model.to(self.device) | |
label = utils.load_json(os.path.join(self.config.base["model_dir"], "label.json")) | |
self.intent_list = label["intent"] | |
self.slot_list = label["slot"] | |
self.data_factory=DataFactory(tokenizer=self.tokenizer, | |
use_multi_intent=self.config.base.get("multi_intent"), | |
to_lower_case=self.config.tokenizer.get("_to_lower_case_")) | |
def predict(self, text_data): | |
self.model.eval() | |
tokenizer_config = {key: self.config.tokenizer[key] | |
for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"} | |
align_mode = self.config.tokenizer.get("_align_mode_") | |
inputs = self.data_factory.batch_fn(batch=[{"text": text_data.split(" ")}], | |
device=self.device, | |
config=tokenizer_config, | |
enable_label=False, | |
align_mode= align_mode if align_mode is not None else "general", | |
label2tensor=False) | |
output = self.model(inputs) | |
decode_output = self.model.decode(output, inputs) | |
decode_output.map_output(slot_map=self.slot_list, | |
intent_map=self.intent_list) | |
if self.config.base.get("multi_intent"): | |
intent = decode_output.intent_ids[0] | |
else: | |
intent = [decode_output.intent_ids[0]] | |
return {"intent": intent, "slot": decode_output.slot_ids[0], "text": self.tokenizer.decode(inputs.input_ids[0])} |