Linhz commited on
Commit
6e80c15
1 Parent(s): 79e2bbf

Update Model/NER/VLSP2021/Load_model.py

Browse files
Files changed (1) hide show
  1. Model/NER/VLSP2021/Load_model.py +34 -34
Model/NER/VLSP2021/Load_model.py CHANGED
@@ -1,34 +1,34 @@
1
- from transformers import RobertaConfig, AutoConfig
2
- from transformers import AutoTokenizer, AutoModelForTokenClassification
3
- from Model.NER.VLSP2021.Ner_CRF import PhoBertCrf,PhoBertSoftmax,PhoBertLstmCrf
4
- from Model.NER.VLSP2021.Predict_Ner import ViTagger
5
- import torch
6
- from spacy import displacy
7
- import re
8
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
- MODEL_MAPPING = {
10
- 'vinai/phobert-base': {
11
- 'softmax': PhoBertSoftmax,
12
- 'crf': PhoBertCrf,
13
- 'lstm_crf': PhoBertLstmCrf
14
- },
15
- }
16
- if device == 'cpu':
17
- checkpoint_data = torch.load('E:/demo_datn/pythonProject1/Model/NER/VLSP2021/best_model.pt', map_location='cpu')
18
- else:
19
- checkpoint_data = torch.load('E:/demo_datn/pythonProject1/Model/NER/VLSP2021/best_model.pt')
20
-
21
- configs = checkpoint_data['args']
22
- print(configs.model_name_or_path)
23
- tokenizer = AutoTokenizer.from_pretrained(configs.model_name_or_path)
24
- model_clss = MODEL_MAPPING[configs.model_name_or_path][configs.model_arch]
25
- config = AutoConfig.from_pretrained(configs.model_name_or_path,
26
- num_labels=len(checkpoint_data['classes']),
27
- finetuning_task=configs.task)
28
- model = model_clss(config=config)
29
- model.resize_token_embeddings(len(tokenizer))
30
- model.to(device)
31
- model.load_state_dict(checkpoint_data['model'],strict=False)
32
- print(model)
33
-
34
-
 
1
+ from transformers import RobertaConfig, AutoConfig
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
3
+ from Model.NER.VLSP2021.Ner_CRF import PhoBertCrf,PhoBertSoftmax,PhoBertLstmCrf
4
+ from Model.NER.VLSP2021.Predict_Ner import ViTagger
5
+ import torch
6
+ from spacy import displacy
7
+ import re
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+ MODEL_MAPPING = {
10
+ 'vinai/phobert-base': {
11
+ 'softmax': PhoBertSoftmax,
12
+ 'crf': PhoBertCrf,
13
+ 'lstm_crf': PhoBertLstmCrf
14
+ },
15
+ }
16
+ if device == 'cpu':
17
+ checkpoint_data = torch.load('/Model/NER/VLSP2021/best_model.pt', map_location='cpu')
18
+ else:
19
+ checkpoint_data = torch.load('/Model/NER/VLSP2021/best_model.pt')
20
+
21
+ configs = checkpoint_data['args']
22
+ print(configs.model_name_or_path)
23
+ tokenizer = AutoTokenizer.from_pretrained(configs.model_name_or_path)
24
+ model_clss = MODEL_MAPPING[configs.model_name_or_path][configs.model_arch]
25
+ config = AutoConfig.from_pretrained(configs.model_name_or_path,
26
+ num_labels=len(checkpoint_data['classes']),
27
+ finetuning_task=configs.task)
28
+ model = model_clss(config=config)
29
+ model.resize_token_embeddings(len(tokenizer))
30
+ model.to(device)
31
+ model.load_state_dict(checkpoint_data['model'],strict=False)
32
+ print(model)
33
+
34
+