File size: 4,166 Bytes
6bce1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abc7a8b
6bce1f7
 
 
 
abc7a8b
 
6bce1f7
 
 
 
 
 
 
 
 
 
 
 
abc7a8b
6bce1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# coding: UTF-8

import os
import time
import torch
import argparse
import numpy as np
from tqdm import tqdm
from train import train
from config import Config
from preprocess import DataProcessor, get_time_dif
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification

parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
#parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
args = parser.parse_args()

args.pretrained_bert_dir = "bert-base-chinese"

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def main():
    set_seed(args.seed)
    config = Config(args.data_dir)

    tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
    bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
    model = BertForSequenceClassification.from_pretrained(args.pretrained_bert_dir,
        config=bert_config
    )
    model.to(config.device)

    if args.mode == "train":
        print("loading data...")
        start_time = time.time()
        train_iterator = DataProcessor(config.train_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
        dev_iterator = DataProcessor(config.dev_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
        time_dif = get_time_dif(start_time)
        print("time usage:", time_dif)

        # train
        train(model, config, train_iterator, dev_iterator)
    
    elif args.mode == "demo":
        model.load_state_dict(torch.load(config.saved_model))
        model.eval()
        while True:
            sentence = input("请输入文本:\n")
            inputs = tokenizer(
                sentence, 
                max_length=config.max_seq_len,
                truncation="longest_first",
                return_tensors="pt")
            inputs = inputs.to(config.device)
            with torch.no_grad():
                outputs = model(**inputs)
                logits = outputs[0]
                label = torch.max(logits.data, 1)[1].tolist()
                print("分类结果:" + config.label_list[label[0]])
            flag = str(input("continue? (y/n):"))
            if flag == "Y" or flag == "y":
                continue
            else:
                break
    else:
        model.load_state_dict(torch.load(config.saved_model))
        model.eval()

        text = []
        with open(args.input_file, mode="r", encoding="UTF-8") as f:
            for line in tqdm(f):
                sentence = line.strip()
                if not sentence:    continue
                text.append(sentence)

        num_samples = len(text)
        num_batches = (num_samples - 1) // config.batch_size + 1
        for i in range(num_batches):
            start = i * config.batch_size
            end = min(num_samples, (i + 1) * config.batch_size)
            inputs = tokenizer.batch_encode_plus(
                text[start: end],
                padding=True,
                max_length=config.max_seq_len,
                truncation="longest_first",
                return_tensors="pt")
            inputs = inputs.to(config.device)

            outputs = model(**inputs)
            logits = outputs[0]

            preds = torch.max(logits.data, 1)[1].tolist()
            labels = [config.label_list[_] for _ in preds]
            for j in range(start, end):
                print("%s\t%s" % (text[j], labels[j - start]))
                

if __name__ == "__main__":
    main()