ViMNer / Model /NER /VLSP2021 /Load_model.py
Linhz's picture
Update Model/NER/VLSP2021/Load_model.py
6e80c15 verified
raw
history blame
1.31 kB
from transformers import RobertaConfig, AutoConfig
from transformers import AutoTokenizer, AutoModelForTokenClassification
from Model.NER.VLSP2021.Ner_CRF import PhoBertCrf,PhoBertSoftmax,PhoBertLstmCrf
from Model.NER.VLSP2021.Predict_Ner import ViTagger
import torch
from spacy import displacy
import re
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_MAPPING = {
'vinai/phobert-base': {
'softmax': PhoBertSoftmax,
'crf': PhoBertCrf,
'lstm_crf': PhoBertLstmCrf
},
}
if device == 'cpu':
checkpoint_data = torch.load('/Model/NER/VLSP2021/best_model.pt', map_location='cpu')
else:
checkpoint_data = torch.load('/Model/NER/VLSP2021/best_model.pt')
configs = checkpoint_data['args']
print(configs.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(configs.model_name_or_path)
model_clss = MODEL_MAPPING[configs.model_name_or_path][configs.model_arch]
config = AutoConfig.from_pretrained(configs.model_name_or_path,
num_labels=len(checkpoint_data['classes']),
finetuning_task=configs.task)
model = model_clss(config=config)
model.resize_token_embeddings(len(tokenizer))
model.to(device)
model.load_state_dict(checkpoint_data['model'],strict=False)
print(model)