koCSN_SAPR / utils /load_model.py
yuneun92's picture
Upload 13 files
bcb1848 verified
raw
history blame
No virus
611 Bytes
"""
모델들 불러오는 모듈
"""
import torch
# from .load_model import KCSN
# from .arguments import get_train_args
# args = get_train_args()
def load_ner(path ='model/NER.pth'):
"""
NER 모델
"""
checkpoint = torch.load(path)
model = checkpoint['model']
model.load_state_dict(checkpoint['model_state_dict'])
return model, checkpoint
# def load_fs(path = 'model/FS.pth'):
# """
# Find Speaker 모델
# """
# model = KCSN(args)
# checkpoint = torch.load(path)
# model.load_state_dict(checkpoint['model_state_dict'])
# return model, checkpoint