yangfan
commited on
Commit
·
abc7a8b
1
Parent(s):
a9e3801
feat(*): change models for hg
Browse files
main.py
CHANGED
@@ -14,11 +14,13 @@ from transformers import BertConfig, BertTokenizer, BertForSequenceClassificatio
|
|
14 |
parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
|
15 |
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
|
16 |
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
|
17 |
-
parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
|
18 |
parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
|
19 |
parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
|
20 |
args = parser.parse_args()
|
21 |
|
|
|
|
|
22 |
def set_seed(seed):
|
23 |
np.random.seed(seed)
|
24 |
torch.manual_seed(seed)
|
@@ -31,8 +33,7 @@ def main():
|
|
31 |
|
32 |
tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
|
33 |
bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
|
34 |
-
model = BertForSequenceClassification.from_pretrained(
|
35 |
-
os.path.join(args.pretrained_bert_dir, "pytorch_model.bin"),
|
36 |
config=bert_config
|
37 |
)
|
38 |
model.to(config.device)
|
|
|
14 |
parser = argparse.ArgumentParser(description="Bert Chinese Text Classification")
|
15 |
parser.add_argument("--mode", type=str, required=True, help="train/demo/predict")
|
16 |
parser.add_argument("--data_dir", type=str, default="./data", help="training data and saved model path")
|
17 |
+
#parser.add_argument("--pretrained_bert_dir", type=str, default="./pretrained_bert", help="pretrained bert model path")
|
18 |
parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
|
19 |
parser.add_argument("--input_file", type=str, default="./data/input.txt", help="input file to be predicted")
|
20 |
args = parser.parse_args()
|
21 |
|
22 |
+
args.pretrained_bert_dir = "bert-base-chinese"
|
23 |
+
|
24 |
def set_seed(seed):
|
25 |
np.random.seed(seed)
|
26 |
torch.manual_seed(seed)
|
|
|
33 |
|
34 |
tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_dir)
|
35 |
bert_config = BertConfig.from_pretrained(args.pretrained_bert_dir, num_labels=config.num_labels)
|
36 |
+
model = BertForSequenceClassification.from_pretrained(args.pretrained_bert_dir,
|
|
|
37 |
config=bert_config
|
38 |
)
|
39 |
model.to(config.device)
|