yangfan commited on
Commit
abc7a8b
·
1 Parent(s): a9e3801

feat(*): change models for hg

Browse files
Files changed (1) hide show
  1. main.py +4 -3
main.py CHANGED
@@ -14,11 +14,13 @@ from transformers import BertConfig, BertTokenizer, BertForSequenceClassificatio
14
  parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
15
  parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
16
  parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
17
- parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
18
  parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
19
  parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
20
  args = parser.parse_args()
21
 
 
 
22
  def set_seed(seed):
23
  np.random.seed(seed)
24
  torch.manual_seed(seed)
@@ -31,8 +33,7 @@ def main():
31
 
32
  tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
33
  bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
34
- model = BertForSequenceClassification.from_pretrained(
35
- os.path.join(args.pretrained_bert_dir, "pytorch_model.bin"),
36
  config=bert_config
37
  )
38
  model.to(config.device)
 
14
  parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
15
  parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
16
  parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
17
+ #parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
18
  parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
19
  parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
20
  args = parser.parse_args()
21
 
22
+ args.pretrained_bert_dir = "bert-base-chinese"
23
+
24
  def set_seed(seed):
25
  np.random.seed(seed)
26
  torch.manual_seed(seed)
 
33
 
34
  tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
35
  bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
36
+ model = BertForSequenceClassification.from_pretrained(args.pretrained_bert_dir,
 
37
  config=bert_config
38
  )
39
  model.to(config.device)