LightChen2333 commited on
Commit
bab7439
·
1 Parent(s): 89c300b

Upload model_manager.py

Browse files
Files changed (1) hide show
  1. model_manager.py +324 -0
model_manager.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-02-08 00:42:56
6
+ Description: manage all process of model training and prediction.
7
+
8
+ '''
9
+ import os
10
+ import random
11
+
12
+ import numpy as np
13
+ import torch
14
+ from tqdm import tqdm
15
+
16
+
17
+ from common import utils
18
+ from common.loader import DataFactory
19
+ from common.logger import Logger
20
+ from common.metric import Evaluator
21
+ from common.tokenizer import get_tokenizer, get_tokenizer_class, load_embedding
22
+ from common.utils import InputData, instantiate
23
+ from common.utils import OutputData
24
+ from common.config import Config
25
+ import dill
26
+
27
+
28
+ class ModelManager(object):
29
+ def __init__(self, config: Config):
30
+ """create model manager by config
31
+
32
+ Args:
33
+ config (Config): configuration to manage all process in OpenSLU
34
+ """
35
+ # init config
36
+ self.config = config
37
+ self.__set_seed(self.config.base.get("seed"))
38
+ self.device = self.config.base.get("device")
39
+
40
+ # enable accelerator
41
+ if "accelerator" in self.config and self.config["accelerator"].get("use_accelerator"):
42
+ from accelerate import Accelerator
43
+ self.accelerator = Accelerator(log_with="wandb")
44
+ else:
45
+ self.accelerator = None
46
+ if self.config.base.get("train"):
47
+ self.tokenizer = get_tokenizer(
48
+ self.config.tokenizer.get("_tokenizer_name_"))
49
+ self.logger = Logger(
50
+ "wandb", self.config.base["name"], start_time=config.start_time, accelerator=self.accelerator)
51
+
52
+ # init dataloader & load data
53
+ if self.config.base.get("save_dir"):
54
+ self.model_save_dir = self.config.base["save_dir"]
55
+ else:
56
+ if not os.path.exists("save/"):
57
+ os.mkdir("save/")
58
+ self.model_save_dir = "save/" + config.start_time
59
+ if not os.path.exists(self.model_save_dir):
60
+ os.mkdir(self.model_save_dir)
61
+ batch_size = self.config.base["batch_size"]
62
+ df = DataFactory(tokenizer=self.tokenizer,
63
+ use_multi_intent=self.config.base.get("multi_intent"),
64
+ to_lower_case=self.config.base.get("_to_lower_case_"))
65
+ train_dataset = df.load_dataset(self.config.dataset, split="train")
66
+
67
+ # update label and vocabulary
68
+ df.update_label_names(train_dataset)
69
+ df.update_vocabulary(train_dataset)
70
+
71
+ # init tokenizer config and dataloaders
72
+ tokenizer_config = {key: self.config.tokenizer[key]
73
+ for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"}
74
+ self.train_dataloader = df.get_data_loader(train_dataset,
75
+ batch_size,
76
+ shuffle=True,
77
+ device=self.device,
78
+ enable_label=True,
79
+ align_mode=self.config.tokenizer.get(
80
+ "_align_mode_"),
81
+ label2tensor=True,
82
+ **tokenizer_config)
83
+ dev_dataset = df.load_dataset(
84
+ self.config.dataset, split="validation")
85
+ self.dev_dataloader = df.get_data_loader(dev_dataset,
86
+ batch_size,
87
+ shuffle=False,
88
+ device=self.device,
89
+ enable_label=True,
90
+ align_mode=self.config.tokenizer.get(
91
+ "_align_mode_"),
92
+ label2tensor=False,
93
+ **tokenizer_config)
94
+ df.update_vocabulary(dev_dataset)
95
+ # add intent label num and slot label num to config
96
+ if int(self.config.get_intent_label_num()) == 0 or int(self.config.get_slot_label_num()) == 0:
97
+ self.intent_list = df.intent_label_list
98
+ self.intent_dict = df.intent_label_dict
99
+ self.config.set_intent_label_num(len(self.intent_list))
100
+ self.slot_list = df.slot_label_list
101
+ self.slot_dict = df.slot_label_dict
102
+ self.config.set_slot_label_num(len(self.slot_list))
103
+ self.config.set_vocab_size(self.tokenizer.vocab_size)
104
+
105
+ # autoload embedding for non-pretrained encoder
106
+ if self.config["model"]["encoder"].get("embedding") and self.config["model"]["encoder"]["embedding"].get(
107
+ "load_embedding_name"):
108
+ self.config["model"]["encoder"]["embedding"]["embedding_matrix"] = load_embedding(self.tokenizer,
109
+ self.config["model"][
110
+ "encoder"][
111
+ "embedding"].get(
112
+ "load_embedding_name"))
113
+ # fill template in config
114
+ self.config.autoload_template()
115
+ # save config
116
+ self.logger.set_config(self.config)
117
+
118
+ self.model = None
119
+ self.optimizer = None
120
+ self.total_step = None
121
+ self.lr_scheduler = None
122
+ if self.config.tokenizer.get("_tokenizer_name_") == "word_tokenizer":
123
+ self.tokenizer.save(os.path.join(self.model_save_dir, "tokenizer.json"))
124
+ utils.save_json(os.path.join(
125
+ self.model_save_dir, "label.json"), {"intent": self.intent_list,"slot": self.slot_list})
126
+ if self.config.base.get("test"):
127
+ self.test_dataset = df.load_dataset(
128
+ self.config.dataset, split="test")
129
+ self.test_dataloader = df.get_data_loader(self.test_dataset,
130
+ batch_size,
131
+ shuffle=False,
132
+ device=self.device,
133
+ enable_label=True,
134
+ align_mode=self.config.tokenizer.get(
135
+ "_align_mode_"),
136
+ label2tensor=False,
137
+ **tokenizer_config)
138
+
139
+ def init_model(self, model):
140
+ """init model, optimizer, lr_scheduler
141
+
142
+ Args:
143
+ model (Any): pytorch model
144
+ """
145
+ self.model = model
146
+ self.model.to(self.device)
147
+ if self.config.base.get("train"):
148
+ self.optimizer = instantiate(
149
+ self.config["optimizer"])(self.model.parameters())
150
+ self.total_step = int(self.config.base.get(
151
+ "epoch_num")) * len(self.train_dataloader)
152
+ self.lr_scheduler = instantiate(self.config["scheduler"])(
153
+ optimizer=self.optimizer,
154
+ num_training_steps=self.total_step
155
+ )
156
+ if self.accelerator is not None:
157
+ self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
158
+ self.model, self.optimizer, self.train_dataloader, self.lr_scheduler)
159
+ if self.config.base.get("load_dir_path"):
160
+ self.accelerator.load_state(self.config.base.get("load_dir_path"))
161
+ # self.dev_dataloader = self.accelerator.prepare(self.dev_dataloader)
162
+
163
+ def eval(self, step: int, best_metric: float) -> float:
164
+ """ evaluation models.
165
+
166
+ Args:
167
+ step (int): which step the model has trained in
168
+ best_metric (float): last best metric value to judge whether to test or save model
169
+
170
+ Returns:
171
+ float: updated best metric value
172
+ """
173
+ # TODO: save dev
174
+ _, res = self.__evaluate(self.model, self.dev_dataloader)
175
+ self.logger.log_metric(res, metric_split="dev", step=step)
176
+ if res[self.config.base.get("best_key")] > best_metric:
177
+ best_metric = res[self.config.base.get("best_key")]
178
+ outputs, test_res = self.__evaluate(
179
+ self.model, self.test_dataloader)
180
+ if not os.path.exists(self.model_save_dir):
181
+ os.mkdir(self.model_save_dir)
182
+ if self.accelerator is None:
183
+ torch.save(self.model, os.path.join(
184
+ self.model_save_dir, "model.pkl"))
185
+ torch.save(self.optimizer, os.path.join(
186
+ self.model_save_dir, "optimizer.pkl"))
187
+ torch.save(self.lr_scheduler, os.path.join(
188
+ self.model_save_dir, "lr_scheduler.pkl"), pickle_module=dill)
189
+ torch.save(step, os.path.join(
190
+ self.model_save_dir, "step.pkl"))
191
+ else:
192
+ self.accelerator.wait_for_everyone()
193
+ unwrapped_model = self.accelerator.unwrap_model(self.model)
194
+ self.accelerator.save(unwrapped_model.state_dict(
195
+ ), os.path.join(self.model_save_dir, "model.pkl"))
196
+ self.accelerator.save_state(output_dir=self.model_save_dir)
197
+ outputs.save(self.model_save_dir, self.test_dataset)
198
+ self.logger.log_metric(test_res, metric_split="test", step=step)
199
+ return best_metric
200
+
201
+ def train(self) -> float:
202
+ """ train models.
203
+
204
+ Returns:
205
+ float: updated best metric value
206
+ """
207
+ step = 0
208
+ best_metric = 0
209
+ progress_bar = tqdm(range(self.total_step))
210
+ for _ in range(int(self.config.base.get("epoch_num"))):
211
+ for data in self.train_dataloader:
212
+ if step == 0:
213
+ self.logger.info(data.get_item(
214
+ 0, tokenizer=self.tokenizer, intent_map=self.intent_list, slot_map=self.slot_list))
215
+ output = self.model(data)
216
+ if self.accelerator is not None and hasattr(self.model, "module"):
217
+ loss, intent_loss, slot_loss = self.model.module.compute_loss(
218
+ pred=output, target=data)
219
+ else:
220
+ loss, intent_loss, slot_loss = self.model.compute_loss(
221
+ pred=output, target=data)
222
+ self.logger.log_loss(loss, "Loss", step=step)
223
+ self.logger.log_loss(intent_loss, "Intent Loss", step=step)
224
+ self.logger.log_loss(slot_loss, "Slot Loss", step=step)
225
+ self.optimizer.zero_grad()
226
+
227
+ if self.accelerator is not None:
228
+ self.accelerator.backward(loss)
229
+ else:
230
+ loss.backward()
231
+ self.optimizer.step()
232
+ self.lr_scheduler.step()
233
+ if not self.config.base.get("eval_by_epoch") and step % self.config.base.get(
234
+ "eval_step") == 0 and step != 0:
235
+ best_metric = self.eval(step, best_metric)
236
+ step += 1
237
+ progress_bar.update(1)
238
+ if self.config.base.get("eval_by_epoch"):
239
+ best_metric = self.eval(step, best_metric)
240
+ self.logger.finish()
241
+ return best_metric
242
+
243
+ def __set_seed(self, seed_value: int):
244
+ """Manually set random seeds.
245
+
246
+ Args:
247
+ seed_value (int): random seed
248
+ """
249
+ random.seed(seed_value)
250
+ np.random.seed(seed_value)
251
+ torch.manual_seed(seed_value)
252
+ torch.random.manual_seed(seed_value)
253
+ os.environ['PYTHONHASHSEED'] = str(seed_value)
254
+ if torch.cuda.is_available():
255
+ torch.cuda.manual_seed(seed_value)
256
+ torch.cuda.manual_seed_all(seed_value)
257
+ torch.backends.cudnn.deterministic = True
258
+ torch.backends.cudnn.benchmark = True
259
+ return
260
+
261
+ def __evaluate(self, model, dataloader):
262
+ model.eval()
263
+ inps = InputData()
264
+ outputs = OutputData()
265
+ for data in dataloader:
266
+ torch.cuda.empty_cache()
267
+ output = model(data)
268
+ if self.accelerator is not None and hasattr(self.model, "module"):
269
+ decode_output = model.module.decode(output, data)
270
+ else:
271
+ decode_output = model.decode(output, data)
272
+
273
+ decode_output.map_output(slot_map=self.slot_list,
274
+ intent_map=self.intent_list)
275
+ data, decode_output = utils.remove_slot_ignore_index(
276
+ data, decode_output, ignore_index="#")
277
+
278
+ inps.merge_input_data(data)
279
+ outputs.merge_output_data(decode_output)
280
+ if "metric" in self.config:
281
+ res = Evaluator.compute_all_metric(
282
+ inps, outputs, intent_label_map=self.intent_dict, metric_list=self.config.metric)
283
+ else:
284
+ res = Evaluator.compute_all_metric(
285
+ inps, outputs, intent_label_map=self.intent_dict)
286
+ model.train()
287
+ return outputs, res
288
+
289
+ def load(self):
290
+
291
+ self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"),map_location=self.config.base["device"])
292
+ if self.config.tokenizer["_tokenizer_name_"] == "word_tokenizer":
293
+ self.tokenizer = get_tokenizer_class(self.config.tokenizer["_tokenizer_name_"]).from_file(
294
+ os.path.join(self.config.base["model_dir"], "tokenizer.json"))
295
+ else:
296
+ self.tokenizer = get_tokenizer(self.config.tokenizer["_tokenizer_name_"])
297
+ self.model.to(self.device)
298
+ label = utils.load_json(os.path.join(self.config.base["model_dir"], "label.json"))
299
+ self.intent_list = label["intent"]
300
+ self.slot_list = label["slot"]
301
+ self.data_factory=DataFactory(tokenizer=self.tokenizer,
302
+ use_multi_intent=self.config.base.get("multi_intent"),
303
+ to_lower_case=self.config.tokenizer.get("_to_lower_case_"))
304
+
305
+ def predict(self, text_data):
306
+ self.model.eval()
307
+ tokenizer_config = {key: self.config.tokenizer[key]
308
+ for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"}
309
+ align_mode = self.config.tokenizer.get("_align_mode_")
310
+ inputs = self.data_factory.batch_fn(batch=[{"text": text_data.split(" ")}],
311
+ device=self.device,
312
+ config=tokenizer_config,
313
+ enable_label=False,
314
+ align_mode= align_mode if align_mode is not None else "general",
315
+ label2tensor=False)
316
+ output = self.model(inputs)
317
+ decode_output = self.model.decode(output, inputs)
318
+ decode_output.map_output(slot_map=self.slot_list,
319
+ intent_map=self.intent_list)
320
+ if self.config.base.get("multi_intent"):
321
+ intent = decode_output.intent_ids[0]
322
+ else:
323
+ intent = [decode_output.intent_ids[0]]
324
+ return {"intent": intent, "slot": decode_output.slot_ids[0], "text": self.tokenizer.decode(inputs.input_ids[0])}