dusense / main.py
yangfan
feat(*): change models for hg
abc7a8b
raw
history blame
4.17 kB
# coding: UTF-8
import os
import time
import torch
import argparse
import numpy as np
from tqdm import tqdm
from train import train
from config import Config
from preprocess import DataProcessor, get_time_dif
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
#parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
args = parser.parse_args()
args.pretrained_bert_dir = "bert-base-chinese"
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def main():
set_seed(args.seed)
config = Config(args.data_dir)
tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
model = BertForSequenceClassification.from_pretrained(args.pretrained_bert_dir,
config=bert_config
)
model.to(config.device)
if args.mode == "train":
print("loading data...")
start_time = time.time()
train_iterator = DataProcessor(config.train_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
dev_iterator = DataProcessor(config.dev_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
time_dif = get_time_dif(start_time)
print("time usage:", time_dif)
# train
train(model, config, train_iterator, dev_iterator)
elif args.mode == "demo":
model.load_state_dict(torch.load(config.saved_model))
model.eval()
while True:
sentence = input("请输入文本:\n")
inputs = tokenizer(
sentence,
max_length=config.max_seq_len,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to(config.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs[0]
label = torch.max(logits.data, 1)[1].tolist()
print("分类结果:" + config.label_list[label[0]])
flag = str(input("continue? (y/n):"))
if flag == "Y" or flag == "y":
continue
else:
break
else:
model.load_state_dict(torch.load(config.saved_model))
model.eval()
text = []
with open(args.input_file, mode="r", encoding="UTF-8") as f:
for line in tqdm(f):
sentence = line.strip()
if not sentence: continue
text.append(sentence)
num_samples = len(text)
num_batches = (num_samples - 1) // config.batch_size + 1
for i in range(num_batches):
start = i * config.batch_size
end = min(num_samples, (i + 1) * config.batch_size)
inputs = tokenizer.batch_encode_plus(
text[start: end],
padding=True,
max_length=config.max_seq_len,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to(config.device)
outputs = model(**inputs)
logits = outputs[0]
preds = torch.max(logits.data, 1)[1].tolist()
labels = [config.label_list[_] for _ in preds]
for j in range(start, end):
print("%s\t%s" % (text[j], labels[j - start]))
if __name__ == "__main__":
main()