''' Author: Qiguang Chen Date: 2023-01-11 10:39:26 LastEditors: Qiguang Chen LastEditTime: 2023-02-08 00:42:56 Description: manage all process of model training and prediction. ''' import os import random import numpy as np import torch from tqdm import tqdm from common import utils from common.loader import DataFactory from common.logger import Logger from common.metric import Evaluator from common.tokenizer import get_tokenizer, get_tokenizer_class, load_embedding from common.utils import InputData, instantiate from common.utils import OutputData from common.config import Config import dill class ModelManager(object): def __init__(self, config: Config): """create model manager by config Args: config (Config): configuration to manage all process in OpenSLU """ # init config self.config = config self.__set_seed(self.config.base.get("seed")) self.device = self.config.base.get("device") # enable accelerator if "accelerator" in self.config and self.config["accelerator"].get("use_accelerator"): from accelerate import Accelerator self.accelerator = Accelerator(log_with="wandb") else: self.accelerator = None if self.config.base.get("train"): self.tokenizer = get_tokenizer( self.config.tokenizer.get("_tokenizer_name_")) self.logger = Logger( "wandb", self.config.base["name"], start_time=config.start_time, accelerator=self.accelerator) # init dataloader & load data if self.config.base.get("save_dir"): self.model_save_dir = self.config.base["save_dir"] else: if not os.path.exists("save/"): os.mkdir("save/") self.model_save_dir = "save/" + config.start_time if not os.path.exists(self.model_save_dir): os.mkdir(self.model_save_dir) batch_size = self.config.base["batch_size"] df = DataFactory(tokenizer=self.tokenizer, use_multi_intent=self.config.base.get("multi_intent"), to_lower_case=self.config.base.get("_to_lower_case_")) train_dataset = df.load_dataset(self.config.dataset, split="train") # update label and vocabulary df.update_label_names(train_dataset) df.update_vocabulary(train_dataset) # init tokenizer config and dataloaders tokenizer_config = {key: self.config.tokenizer[key] for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"} self.train_dataloader = df.get_data_loader(train_dataset, batch_size, shuffle=True, device=self.device, enable_label=True, align_mode=self.config.tokenizer.get( "_align_mode_"), label2tensor=True, **tokenizer_config) dev_dataset = df.load_dataset( self.config.dataset, split="validation") self.dev_dataloader = df.get_data_loader(dev_dataset, batch_size, shuffle=False, device=self.device, enable_label=True, align_mode=self.config.tokenizer.get( "_align_mode_"), label2tensor=False, **tokenizer_config) df.update_vocabulary(dev_dataset) # add intent label num and slot label num to config if int(self.config.get_intent_label_num()) == 0 or int(self.config.get_slot_label_num()) == 0: self.intent_list = df.intent_label_list self.intent_dict = df.intent_label_dict self.config.set_intent_label_num(len(self.intent_list)) self.slot_list = df.slot_label_list self.slot_dict = df.slot_label_dict self.config.set_slot_label_num(len(self.slot_list)) self.config.set_vocab_size(self.tokenizer.vocab_size) # autoload embedding for non-pretrained encoder if self.config["model"]["encoder"].get("embedding") and self.config["model"]["encoder"]["embedding"].get( "load_embedding_name"): self.config["model"]["encoder"]["embedding"]["embedding_matrix"] = load_embedding(self.tokenizer, self.config["model"][ "encoder"][ "embedding"].get( "load_embedding_name")) # fill template in config self.config.autoload_template() # save config self.logger.set_config(self.config) self.model = None self.optimizer = None self.total_step = None self.lr_scheduler = None if self.config.tokenizer.get("_tokenizer_name_") == "word_tokenizer": self.tokenizer.save(os.path.join(self.model_save_dir, "tokenizer.json")) utils.save_json(os.path.join( self.model_save_dir, "label.json"), {"intent": self.intent_list,"slot": self.slot_list}) if self.config.base.get("test"): self.test_dataset = df.load_dataset( self.config.dataset, split="test") self.test_dataloader = df.get_data_loader(self.test_dataset, batch_size, shuffle=False, device=self.device, enable_label=True, align_mode=self.config.tokenizer.get( "_align_mode_"), label2tensor=False, **tokenizer_config) def init_model(self, model): """init model, optimizer, lr_scheduler Args: model (Any): pytorch model """ self.model = model self.model.to(self.device) if self.config.base.get("train"): self.optimizer = instantiate( self.config["optimizer"])(self.model.parameters()) self.total_step = int(self.config.base.get( "epoch_num")) * len(self.train_dataloader) self.lr_scheduler = instantiate(self.config["scheduler"])( optimizer=self.optimizer, num_training_steps=self.total_step ) if self.accelerator is not None: self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.train_dataloader, self.lr_scheduler) if self.config.base.get("load_dir_path"): self.accelerator.load_state(self.config.base.get("load_dir_path")) # self.dev_dataloader = self.accelerator.prepare(self.dev_dataloader) def eval(self, step: int, best_metric: float) -> float: """ evaluation models. Args: step (int): which step the model has trained in best_metric (float): last best metric value to judge whether to test or save model Returns: float: updated best metric value """ # TODO: save dev _, res = self.__evaluate(self.model, self.dev_dataloader) self.logger.log_metric(res, metric_split="dev", step=step) if res[self.config.base.get("best_key")] > best_metric: best_metric = res[self.config.base.get("best_key")] outputs, test_res = self.__evaluate( self.model, self.test_dataloader) if not os.path.exists(self.model_save_dir): os.mkdir(self.model_save_dir) if self.accelerator is None: torch.save(self.model, os.path.join( self.model_save_dir, "model.pkl")) torch.save(self.optimizer, os.path.join( self.model_save_dir, "optimizer.pkl")) torch.save(self.lr_scheduler, os.path.join( self.model_save_dir, "lr_scheduler.pkl"), pickle_module=dill) torch.save(step, os.path.join( self.model_save_dir, "step.pkl")) else: self.accelerator.wait_for_everyone() unwrapped_model = self.accelerator.unwrap_model(self.model) self.accelerator.save(unwrapped_model.state_dict( ), os.path.join(self.model_save_dir, "model.pkl")) self.accelerator.save_state(output_dir=self.model_save_dir) outputs.save(self.model_save_dir, self.test_dataset) self.logger.log_metric(test_res, metric_split="test", step=step) return best_metric def train(self) -> float: """ train models. Returns: float: updated best metric value """ step = 0 best_metric = 0 progress_bar = tqdm(range(self.total_step)) for _ in range(int(self.config.base.get("epoch_num"))): for data in self.train_dataloader: if step == 0: self.logger.info(data.get_item( 0, tokenizer=self.tokenizer, intent_map=self.intent_list, slot_map=self.slot_list)) output = self.model(data) if self.accelerator is not None and hasattr(self.model, "module"): loss, intent_loss, slot_loss = self.model.module.compute_loss( pred=output, target=data) else: loss, intent_loss, slot_loss = self.model.compute_loss( pred=output, target=data) self.logger.log_loss(loss, "Loss", step=step) self.logger.log_loss(intent_loss, "Intent Loss", step=step) self.logger.log_loss(slot_loss, "Slot Loss", step=step) self.optimizer.zero_grad() if self.accelerator is not None: self.accelerator.backward(loss) else: loss.backward() self.optimizer.step() self.lr_scheduler.step() if not self.config.base.get("eval_by_epoch") and step % self.config.base.get( "eval_step") == 0 and step != 0: best_metric = self.eval(step, best_metric) step += 1 progress_bar.update(1) if self.config.base.get("eval_by_epoch"): best_metric = self.eval(step, best_metric) self.logger.finish() return best_metric def __set_seed(self, seed_value: int): """Manually set random seeds. Args: seed_value (int): random seed """ random.seed(seed_value) np.random.seed(seed_value) torch.manual_seed(seed_value) torch.random.manual_seed(seed_value) os.environ['PYTHONHASHSEED'] = str(seed_value) if torch.cuda.is_available(): torch.cuda.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True return def __evaluate(self, model, dataloader): model.eval() inps = InputData() outputs = OutputData() for data in dataloader: torch.cuda.empty_cache() output = model(data) if self.accelerator is not None and hasattr(self.model, "module"): decode_output = model.module.decode(output, data) else: decode_output = model.decode(output, data) decode_output.map_output(slot_map=self.slot_list, intent_map=self.intent_list) data, decode_output = utils.remove_slot_ignore_index( data, decode_output, ignore_index="#") inps.merge_input_data(data) outputs.merge_output_data(decode_output) if "metric" in self.config: res = Evaluator.compute_all_metric( inps, outputs, intent_label_map=self.intent_dict, metric_list=self.config.metric) else: res = Evaluator.compute_all_metric( inps, outputs, intent_label_map=self.intent_dict) model.train() return outputs, res def load(self): self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"),map_location=self.config.base["device"]) if self.config.tokenizer["_tokenizer_name_"] == "word_tokenizer": self.tokenizer = get_tokenizer_class(self.config.tokenizer["_tokenizer_name_"]).from_file( os.path.join(self.config.base["model_dir"], "tokenizer.json")) else: self.tokenizer = get_tokenizer(self.config.tokenizer["_tokenizer_name_"]) self.model.to(self.device) label = utils.load_json(os.path.join(self.config.base["model_dir"], "label.json")) self.intent_list = label["intent"] self.slot_list = label["slot"] self.data_factory=DataFactory(tokenizer=self.tokenizer, use_multi_intent=self.config.base.get("multi_intent"), to_lower_case=self.config.tokenizer.get("_to_lower_case_")) def predict(self, text_data): self.model.eval() tokenizer_config = {key: self.config.tokenizer[key] for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"} align_mode = self.config.tokenizer.get("_align_mode_") inputs = self.data_factory.batch_fn(batch=[{"text": text_data.split(" ")}], device=self.device, config=tokenizer_config, enable_label=False, align_mode= align_mode if align_mode is not None else "general", label2tensor=False) output = self.model(inputs) decode_output = self.model.decode(output, inputs) decode_output.map_output(slot_map=self.slot_list, intent_map=self.intent_list) if self.config.base.get("multi_intent"): intent = decode_output.intent_ids[0] else: intent = [decode_output.intent_ids[0]] return {"intent": intent, "slot": decode_output.slot_ids[0], "text": self.tokenizer.decode(inputs.input_ids[0])}