change train py some
Browse files- config.py +1 -1
- main.py +8 -2
- preprocess.py +4 -1
config.py
CHANGED
@@ -24,7 +24,7 @@ class Config(object):
|
|
24 |
self.num_epochs = 3
|
25 |
self.log_batch = 100
|
26 |
self.batch_size = 128
|
27 |
-
self.max_seq_len =
|
28 |
self.require_improvement = 1000
|
29 |
|
30 |
self.warmup_steps = 0
|
|
|
24 |
self.num_epochs = 3
|
25 |
self.log_batch = 100
|
26 |
self.batch_size = 128
|
27 |
+
self.max_seq_len = 128
|
28 |
self.require_improvement = 1000
|
29 |
|
30 |
self.warmup_steps = 0
|
main.py
CHANGED
@@ -12,6 +12,9 @@ from config import Config
|
|
12 |
from preprocess import DataProcessor, get_time_dif
|
13 |
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
|
14 |
|
|
|
|
|
|
|
15 |
parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
|
16 |
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
|
17 |
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
|
@@ -40,11 +43,14 @@ def main():
|
|
40 |
)
|
41 |
model.to(config.device)
|
42 |
|
|
|
|
|
|
|
43 |
if args.mode == "train":
|
44 |
print("loading data...")
|
45 |
start_time = time.time()
|
46 |
-
train_iterator = DataProcessor(config.train_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
|
47 |
-
dev_iterator = DataProcessor(config.dev_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
|
48 |
time_dif = get_time_dif(start_time)
|
49 |
print("time usage:", time_dif)
|
50 |
|
|
|
12 |
from preprocess import DataProcessor, get_time_dif
|
13 |
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
|
14 |
|
15 |
+
from ebart import PegasusSummarizer
|
16 |
+
|
17 |
+
|
18 |
parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
|
19 |
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
|
20 |
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
|
|
|
43 |
)
|
44 |
model.to(config.device)
|
45 |
|
46 |
+
#
|
47 |
+
summarizerModel = PegasusSummarizer()
|
48 |
+
|
49 |
if args.mode == "train":
|
50 |
print("loading data...")
|
51 |
start_time = time.time()
|
52 |
+
train_iterator = DataProcessor(config.train_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed)
|
53 |
+
dev_iterator = DataProcessor(config.dev_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed)
|
54 |
time_dif = get_time_dif(start_time)
|
55 |
print("time usage:", time_dif)
|
56 |
|
preprocess.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
import random
|
7 |
from tqdm import tqdm
|
8 |
from datetime import timedelta
|
|
|
9 |
|
10 |
def get_time_dif(start_time):
|
11 |
end_time = time.time()
|
@@ -14,10 +15,11 @@ def get_time_dif(start_time):
|
|
14 |
|
15 |
## 获取张量数据
|
16 |
class DataProcessor(object):
|
17 |
-
def __init__(self, dataPath, device, tokenizer, batch_size, max_seq_len, seed):
|
18 |
self.seed = seed
|
19 |
self.device = device
|
20 |
self.tokenizer = tokenizer
|
|
|
21 |
self.batch_size = batch_size
|
22 |
self.max_seq_len = max_seq_len
|
23 |
|
@@ -51,6 +53,7 @@ class DataProcessor(object):
|
|
51 |
content, label = line.rsplit('\t', 1)
|
52 |
mapped_label = labels_map.get(label)
|
53 |
if mapped_label is not None and isinstance(mapped_label, int) and mapped_label >= 0:
|
|
|
54 |
contents.append(content)
|
55 |
labels.append(mapped_label)
|
56 |
else:
|
|
|
6 |
import random
|
7 |
from tqdm import tqdm
|
8 |
from datetime import timedelta
|
9 |
+
from ebart import PegasusSummarizer
|
10 |
|
11 |
def get_time_dif(start_time):
|
12 |
end_time = time.time()
|
|
|
15 |
|
16 |
## 获取张量数据
|
17 |
class DataProcessor(object):
|
18 |
+
def __init__(self, dataPath, device,summarizerModel, tokenizer, batch_size, max_seq_len, seed):
|
19 |
self.seed = seed
|
20 |
self.device = device
|
21 |
self.tokenizer = tokenizer
|
22 |
+
self.summarizerModel=summarizerModel
|
23 |
self.batch_size = batch_size
|
24 |
self.max_seq_len = max_seq_len
|
25 |
|
|
|
53 |
content, label = line.rsplit('\t', 1)
|
54 |
mapped_label = labels_map.get(label)
|
55 |
if mapped_label is not None and isinstance(mapped_label, int) and mapped_label >= 0:
|
56 |
+
content= self.summarizerModel.generate_summary(content, 128, 64)
|
57 |
contents.append(content)
|
58 |
labels.append(mapped_label)
|
59 |
else:
|