# coding: UTF-8 import spaces import os import time import torch import argparse import numpy as np from tqdm import tqdm from train import train from config import Config from preprocess import DataProcessor, get_time_dif from transformers import BertConfig, BertTokenizer, BertForSequenceClassification from ebart import PegasusSummarizer parser = argparse.ArgumentParser(description="Bert Chinese Text Classification") parser.add_argument("--mode", type=str, required=True, help="train/demo/predict") parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path") #parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path") parser.add_argument("--seed", type=int, default=1, help="random seed for initialization") parser.add_argument("--input_file", type=str, default="./data_12345/test.txt", help="input file to be predicted") args = parser.parse_args() args.pretrained_bert_dir = "bert-base-chinese" def set_seed(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True @spaces.GPU def main(): set_seed(args.seed) config = Config(args.data_dir) tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir) bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels) model = BertForSequenceClassification.from_pretrained(args.pretrained_bert_dir, config=bert_config ) model.to(config.device) # summarizerModel = PegasusSummarizer() if args.mode == "train": print("loading data...") start_time = time.time() train_iterator = DataProcessor(config.train_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed) dev_iterator = DataProcessor(config.dev_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed) time_dif = get_time_dif(start_time) print("time usage:", time_dif) # train train(model, config, train_iterator, dev_iterator) elif args.mode == "demo": model.load_state_dict(torch.load(config.saved_model)) model.eval() while True: sentence = input("please input txt:\n") inputs = tokenizer( sentence, max_length=config.max_seq_len, truncation="longest_first", return_tensors="pt") inputs = inputs.to(config.device) with torch.no_grad(): outputs = model(**inputs) logits = outputs[0] label = torch.max(logits.data, 1)[1].tolist() print("Classification result:" + config.label_list[label[0]]) flag = str(input("continue? (y/n):")) if flag == "Y" or flag == "y": continue else: break else: model.load_state_dict(torch.load(config.saved_model)) model.eval() text = [] with open(args.input_file, mode="r", encoding="UTF-8") as f: for line in tqdm(f): sentence = line.strip() if not sentence: continue text.append(sentence) num_samples = len(text) num_batches = (num_samples - 1) // config.batch_size + 1 for i in range(num_batches): start = i * config.batch_size end = min(num_samples, (i + 1) * config.batch_size) inputs = tokenizer.batch_encode_plus( text[start: end], padding=True, max_length=config.max_seq_len, truncation="longest_first", return_tensors="pt") inputs = inputs.to(config.device) outputs = model(**inputs) logits = outputs[0] preds = torch.max(logits.data, 1)[1].tolist() labels = [config.label_list[_] for _ in preds] for j in range(start, end): print("%s\t%s" % (text[j], labels[j - start])) if __name__ == "__main__": main()