''' Author: Qiguang Chen LastEditors: Qiguang Chen Date: 2023-02-12 22:23:58 LastEditTime: 2023-02-19 14:14:56 Description: ''' import json import os import queue import shutil import torch import dill from common import utils class Saver(): def __init__(self, config, start_time=None) -> None: self.config = config if self.config.get("save_dir"): self.model_save_dir = self.config["save_dir"] else: if not os.path.exists("save/"): os.mkdir("save/") self.model_save_dir = "save/" + start_time if not os.path.exists(self.model_save_dir): os.mkdir(self.model_save_dir) save_mode = config.get("save_mode") self.save_mode = save_mode if save_mode is not None else "save-by-eval" max_save_num = self.config.get("max_save_num") self.max_save_num = max_save_num if max_save_num is not None else 1 self.save_pool = queue.Queue(maxsize=max_save_num) def save_tokenizer(self, tokenizer): with open(os.path.join(self.model_save_dir, "tokenizer.pkl"), 'wb') as f: dill.dump(tokenizer, f) def save_label(self, intent_list, slot_list): utils.save_json(os.path.join(self.model_save_dir, "label.json"), {"intent": intent_list, "slot": slot_list}) def save_model(self, model, train_state, accelerator=None): step = train_state["step"] if self.max_save_num != 1: model_save_dir =os.path.join(self.model_save_dir, str(step)) if self.save_pool.full(): delete_dir = self.save_pool.get() shutil.rmtree(delete_dir) self.save_pool.put(model_save_dir) else: self.save_pool.put(model_save_dir) if not os.path.exists(model_save_dir): os.mkdir(model_save_dir) else: model_save_dir = self.model_save_dir if not os.path.exists(model_save_dir): os.mkdir(model_save_dir) if accelerator is None: torch.save(model, os.path.join(model_save_dir, "model.pkl")) torch.save(train_state, os.path.join(model_save_dir, "train_state.pkl"), pickle_module=dill) else: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) accelerator.save(unwrapped_model, os.path.join(model_save_dir, "model.pkl")) accelerator.save_state(output_dir=model_save_dir) def auto_save_step(self, model, train_state, accelerator=None): step = train_state["step"] if self.save_mode == "save-by-step" and step % self.config.get("save_step")==0 and step != 0: self.save_model(model, train_state, accelerator) return True else: return False def save_output(self, outputs, dataset): outputs.save(self.model_save_dir, dataset)