|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pdb |
|
import torch |
|
import os |
|
import requests |
|
from .config import ModelConfig |
|
import pathlib |
|
import loguru |
|
logger = loguru.logger |
|
|
|
__all__ = ['pretrained_models', 'load_model_state', 'load_vocab'] |
|
|
|
class PretrainedModel: |
|
def __init__(self, name, vocab, vocab_type, model='pytorch_model.bin', config='config.json', **kwargs): |
|
self.__dict__.update(kwargs) |
|
host = f'https://huggingface.co/microsoft/{name}/resolve/main/' |
|
self.name = name |
|
self.model_url = host + model |
|
self.config_url = host + config |
|
self.vocab_url = host + vocab |
|
self.vocab_type = vocab_type |
|
|
|
pretrained_models= { |
|
'base': PretrainedModel('deberta-base', 'bpe_encoder.bin', 'gpt2'), |
|
'large': PretrainedModel('deberta-large', 'bpe_encoder.bin', 'gpt2'), |
|
'xlarge': PretrainedModel('deberta-xlarge', 'bpe_encoder.bin', 'gpt2'), |
|
'base-mnli': PretrainedModel('deberta-base-mnli', 'bpe_encoder.bin', 'gpt2'), |
|
'large-mnli': PretrainedModel('deberta-large-mnli', 'bpe_encoder.bin', 'gpt2'), |
|
'xlarge-mnli': PretrainedModel('deberta-xlarge-mnli', 'bpe_encoder.bin', 'gpt2'), |
|
'xlarge-v2': PretrainedModel('deberta-v2-xlarge', 'spm.model', 'spm'), |
|
'xxlarge-v2': PretrainedModel('deberta-v2-xxlarge', 'spm.model', 'spm'), |
|
'xlarge-v2-mnli': PretrainedModel('deberta-v2-xlarge-mnli', 'spm.model', 'spm'), |
|
'xxlarge-v2-mnli': PretrainedModel('deberta-v2-xxlarge-mnli', 'spm.model', 'spm'), |
|
'deberta-v3-small': PretrainedModel('deberta-v3-small', 'spm.model', 'spm'), |
|
'deberta-v3-base': PretrainedModel('deberta-v3-base', 'spm.model', 'spm'), |
|
'deberta-v3-large': PretrainedModel('deberta-v3-large', 'spm.model', 'spm'), |
|
'mdeberta-v3-base': PretrainedModel('mdeberta-v3-base', 'spm.model', 'spm'), |
|
'deberta-v3-xsmall': PretrainedModel('deberta-v3-xsmall', 'spm.model', 'spm'), |
|
} |
|
|
|
|
|
|
|
def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=None): |
|
model_path = path_or_pretrained_id |
|
if model_path and (not os.path.exists(model_path)) and (path_or_pretrained_id.lower() in pretrained_models): |
|
_tag = tag |
|
pretrained = pretrained_models[path_or_pretrained_id.lower()] |
|
if _tag is None: |
|
_tag = 'latest' |
|
if not cache_dir: |
|
cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}') |
|
os.makedirs(cache_dir, exist_ok=True) |
|
model_path = os.path.join(cache_dir, 'pytorch_model.bin') |
|
elif not model_path: |
|
return None,None |
|
|
|
config_path = os.path.join(os.path.dirname(model_path), 'model_config.json') |
|
model_state = torch.load(model_path, map_location='cpu') |
|
logger.info("Loaded pretrained model file {}".format(model_path)) |
|
if 'config' in model_state: |
|
model_config = ModelConfig.from_dict(model_state['config']) |
|
elif os.path.exists(config_path): |
|
model_config = ModelConfig.from_json_file(config_path) |
|
else: |
|
model_config = None |
|
return model_state, model_config |
|
|