File size: 1,268 Bytes
6bce1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d73e9f0
6bce1f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding: UTF-8

import os
import torch

class Config(object):
    def __init__(self, data_dir):
        assert os.path.exists(data_dir)
        self.train_file = os.path.join(data_dir, "train.txt")
        self.dev_file = os.path.join(data_dir, "dev.txt")
        self.label_file = os.path.join(data_dir, "label.txt")
        assert os.path.isfile(self.train_file)
        assert os.path.isfile(self.dev_file)
        assert os.path.isfile(self.label_file)

        self.saved_model_dir = os.path.join(data_dir, "model")
        self.saved_model = os.path.join(self.saved_model_dir, "bert_model.pth")
        if not os.path.exists(self.saved_model_dir):
            os.mkdir(self.saved_model_dir)

        self.label_list = [label.strip() for label in open(self.label_file, "r", encoding="UTF-8").readlines()]
        self.num_labels = len(self.label_list)

        self.num_epochs = 3
        self.log_batch = 100
        self.batch_size = 128
        self.max_seq_len = 32
        self.require_improvement = 1000

        self.warmup_steps = 0
        self.weight_decay = 0.01
        self.max_grad_norm = 1.0
        self.learning_rate = 5e-5
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")