File size: 2,973 Bytes
223340a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
'''
Author: Qiguang Chen
LastEditors: Qiguang Chen
Date: 2023-02-12 22:23:58
LastEditTime: 2023-02-19 14:14:56
Description: 

'''
import json
import os
import queue
import shutil
import torch
import dill
from common import utils


class Saver():
    def __init__(self, config, start_time=None) -> None:
        self.config = config
        if self.config.get("save_dir"):
            self.model_save_dir = self.config["save_dir"]
        else:
            if not os.path.exists("save/"):
                os.mkdir("save/")
            self.model_save_dir = "save/" + start_time
        if not os.path.exists(self.model_save_dir):
            os.mkdir(self.model_save_dir)
        save_mode = config.get("save_mode")
        self.save_mode = save_mode if save_mode is not None else "save-by-eval"
        
        max_save_num = self.config.get("max_save_num")
        self.max_save_num = max_save_num if max_save_num is not None else 1
        self.save_pool = queue.Queue(maxsize=max_save_num)
    
    def save_tokenizer(self, tokenizer):
        with open(os.path.join(self.model_save_dir, "tokenizer.pkl"), 'wb') as f:
            dill.dump(tokenizer, f)
    
    def save_label(self, intent_list, slot_list):
        utils.save_json(os.path.join(self.model_save_dir, "label.json"), {"intent": intent_list, "slot": slot_list})
    
    
    def save_model(self, model, train_state, accelerator=None):
        step = train_state["step"]
        if self.max_save_num != 1:
            
            model_save_dir =os.path.join(self.model_save_dir, str(step))
            if self.save_pool.full():
                delete_dir = self.save_pool.get()
                shutil.rmtree(delete_dir)
                self.save_pool.put(model_save_dir)
            else:
                self.save_pool.put(model_save_dir)
            if not os.path.exists(model_save_dir):
                os.mkdir(model_save_dir)
        else:
            model_save_dir = self.model_save_dir
        if not os.path.exists(model_save_dir):
            os.mkdir(model_save_dir)
        if accelerator is None:
            torch.save(model, os.path.join(model_save_dir, "model.pkl"))
            torch.save(train_state, os.path.join(model_save_dir, "train_state.pkl"), pickle_module=dill)
        else:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            accelerator.save(unwrapped_model, os.path.join(model_save_dir, "model.pkl"))
            accelerator.save_state(output_dir=model_save_dir)
    
    def auto_save_step(self, model, train_state, accelerator=None):
        step = train_state["step"]
        if self.save_mode == "save-by-step" and step % self.config.get("save_step")==0 and step != 0:
            self.save_model(model, train_state, accelerator)
            return True
        else:
            return False
        
    
    def save_output(self, outputs, dataset):
        outputs.save(self.model_save_dir, dataset)