saily commited on
Commit
2ffe0f4
1 Parent(s): b5e256c

change train py some

Browse files
Files changed (3) hide show
  1. config.py +1 -1
  2. main.py +8 -2
  3. preprocess.py +4 -1
config.py CHANGED
@@ -24,7 +24,7 @@ class Config(object):
24
  self.num_epochs = 3
25
  self.log_batch = 100
26
  self.batch_size = 128
27
- self.max_seq_len = 32
28
  self.require_improvement = 1000
29
 
30
  self.warmup_steps = 0
 
24
  self.num_epochs = 3
25
  self.log_batch = 100
26
  self.batch_size = 128
27
+ self.max_seq_len = 128
28
  self.require_improvement = 1000
29
 
30
  self.warmup_steps = 0
main.py CHANGED
@@ -12,6 +12,9 @@ from config import Config
12
  from preprocess import DataProcessor, get_time_dif
13
  from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
14
 
 
 
 
15
  parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
16
  parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
17
  parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
@@ -40,11 +43,14 @@ def main():
40
  )
41
  model.to(config.device)
42
 
 
 
 
43
  if args.mode == "train":
44
  print("loading data...")
45
  start_time = time.time()
46
- train_iterator = DataProcessor(config.train_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
47
- dev_iterator = DataProcessor(config.dev_file, config.device, tokenizer, config.batch_size, config.max_seq_len, args.seed)
48
  time_dif = get_time_dif(start_time)
49
  print("time usage:", time_dif)
50
 
 
12
  from preprocess import DataProcessor, get_time_dif
13
  from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
14
 
15
+ from ebart import PegasusSummarizer
16
+
17
+
18
  parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
19
  parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
20
  parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
 
43
  )
44
  model.to(config.device)
45
 
46
+ #
47
+ summarizerModel = PegasusSummarizer()
48
+
49
  if args.mode == "train":
50
  print("loading data...")
51
  start_time = time.time()
52
+ train_iterator = DataProcessor(config.train_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed)
53
+ dev_iterator = DataProcessor(config.dev_file, config.device, summarizerModel,tokenizer, config.batch_size, config.max_seq_len, args.seed)
54
  time_dif = get_time_dif(start_time)
55
  print("time usage:", time_dif)
56
 
preprocess.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import random
7
  from tqdm import tqdm
8
  from datetime import timedelta
 
9
 
10
  def get_time_dif(start_time):
11
  end_time = time.time()
@@ -14,10 +15,11 @@ def get_time_dif(start_time):
14
 
15
  ## 获取张量数据
16
  class DataProcessor(object):
17
- def __init__(self, dataPath, device, tokenizer, batch_size, max_seq_len, seed):
18
  self.seed = seed
19
  self.device = device
20
  self.tokenizer = tokenizer
 
21
  self.batch_size = batch_size
22
  self.max_seq_len = max_seq_len
23
 
@@ -51,6 +53,7 @@ class DataProcessor(object):
51
  content, label = line.rsplit('\t', 1)
52
  mapped_label = labels_map.get(label)
53
  if mapped_label is not None and isinstance(mapped_label, int) and mapped_label >= 0:
 
54
  contents.append(content)
55
  labels.append(mapped_label)
56
  else:
 
6
  import random
7
  from tqdm import tqdm
8
  from datetime import timedelta
9
+ from ebart import PegasusSummarizer
10
 
11
  def get_time_dif(start_time):
12
  end_time = time.time()
 
15
 
16
  ## 获取张量数据
17
  class DataProcessor(object):
18
+ def __init__(self, dataPath, device,summarizerModel, tokenizer, batch_size, max_seq_len, seed):
19
  self.seed = seed
20
  self.device = device
21
  self.tokenizer = tokenizer
22
+ self.summarizerModel=summarizerModel
23
  self.batch_size = batch_size
24
  self.max_seq_len = max_seq_len
25
 
 
53
  content, label = line.rsplit('\t', 1)
54
  mapped_label = labels_map.get(label)
55
  if mapped_label is not None and isinstance(mapped_label, int) and mapped_label >= 0:
56
+ content= self.summarizerModel.generate_summary(content, 128, 64)
57
  contents.append(content)
58
  labels.append(mapped_label)
59
  else: