|
|
|
|
|
import os
|
|
import time
|
|
import torch
|
|
import argparse
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from train import train
|
|
from config import Config
|
|
from preprocess import DataProcessor, get_time_dif
|
|
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
|
|
|
|
parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
|
|
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
|
|
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
|
|
|
|
parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
|
|
parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
|
|
args = parser.parse_args()
|
|
|
|
args.pretrained_bert_dir = "bert-base-chinese"
|
|
|
|
def set_seed(seed):
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
def main():
|
|
set_seed(args.seed)
|
|
config = Config(args.data_dir)
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
|
|
bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
|
|
model = BertForSequenceClassification.from_pretrained(args.pretrained_bert_dir,
|
|
config=bert_config
|
|
)
|
|
model.to(config.device)
|
|
|
|
if args.mode == "train":
|
|
print("loading data...")
|
|
start_time = time.time()
|
|
train_iterator = DataProcessor(config.train_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
|
|
dev_iterator = DataProcessor(config.dev_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
|
|
time_dif = get_time_dif(start_time)
|
|
print("time usage:", time_dif)
|
|
|
|
|
|
train(model, config, train_iterator, dev_iterator)
|
|
|
|
elif args.mode == "demo":
|
|
model.load_state_dict(torch.load(config.saved_model))
|
|
model.eval()
|
|
while True:
|
|
sentence = input("请输入文本:\n")
|
|
inputs = tokenizer(
|
|
sentence,
|
|
max_length=config.max_seq_len,
|
|
truncation="longest_first",
|
|
return_tensors="pt")
|
|
inputs = inputs.to(config.device)
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
logits = outputs[0]
|
|
label = torch.max(logits.data, 1)[1].tolist()
|
|
print("分类结果:" + config.label_list[label[0]])
|
|
flag = str(input("continue? (y/n):"))
|
|
if flag == "Y" or flag == "y":
|
|
continue
|
|
else:
|
|
break
|
|
else:
|
|
model.load_state_dict(torch.load(config.saved_model))
|
|
model.eval()
|
|
|
|
text = []
|
|
with open(args.input_file, mode="r", encoding="UTF-8") as f:
|
|
for line in tqdm(f):
|
|
sentence = line.strip()
|
|
if not sentence: continue
|
|
text.append(sentence)
|
|
|
|
num_samples = len(text)
|
|
num_batches = (num_samples - 1) // config.batch_size + 1
|
|
for i in range(num_batches):
|
|
start = i * config.batch_size
|
|
end = min(num_samples, (i + 1) * config.batch_size)
|
|
inputs = tokenizer.batch_encode_plus(
|
|
text[start: end],
|
|
padding=True,
|
|
max_length=config.max_seq_len,
|
|
truncation="longest_first",
|
|
return_tensors="pt")
|
|
inputs = inputs.to(config.device)
|
|
|
|
outputs = model(**inputs)
|
|
logits = outputs[0]
|
|
|
|
preds = torch.max(logits.data, 1)[1].tolist()
|
|
labels = [config.label_list[_] for _ in preds]
|
|
for j in range(start, end):
|
|
print("%s\t%s" % (text[j], labels[j - start]))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|