Spaces:
Running
Running
# Copyright (c) Guangsheng Bao. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import time | |
import os | |
def from_pretrained(cls, model_name, kwargs, cache_dir): | |
# use local model if it exists | |
local_path = os.path.join(cache_dir, 'local.' + model_name.replace("/", "_")) | |
if os.path.exists(local_path): | |
return cls.from_pretrained(local_path, **kwargs) | |
return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir) | |
# predefined models | |
model_fullnames = { 'gpt2': 'gpt2', | |
'gpt2-xl': 'gpt2-xl', | |
'opt-2.7b': 'facebook/opt-2.7b', | |
'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B', | |
'gpt-j-6B': 'EleutherAI/gpt-j-6B', | |
'gpt-neox-20b': 'EleutherAI/gpt-neox-20b', | |
'mgpt': 'sberbank-ai/mGPT', | |
'pubmedgpt': 'stanford-crfm/pubmedgpt', | |
'mt5-xl': 'google/mt5-xl', | |
'llama-13b': 'huggyllama/llama-13b', | |
'llama2-13b': 'TheBloke/Llama-2-13B-fp16', | |
'bloom-7b1': 'bigscience/bloom-7b1', | |
'opt-13b': 'facebook/opt-13b', | |
} | |
float16_models = ['gpt-j-6B', 'gpt-neox-20b', 'llama-13b', 'llama2-13b', 'bloom-7b1', 'opt-13b'] | |
def get_model_fullname(model_name): | |
return model_fullnames[model_name] if model_name in model_fullnames else model_name | |
def load_model(model_name, device, cache_dir): | |
model_fullname = get_model_fullname(model_name) | |
print(f'Loading model {model_fullname}...') | |
model_kwargs = {} | |
if model_name in float16_models: | |
model_kwargs.update(dict(torch_dtype=torch.float16)) | |
if 'gpt-j' in model_name: | |
model_kwargs.update(dict(revision='float16')) | |
model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir) | |
print('Moving model to GPU...', end='', flush=True) | |
start = time.time() | |
model.to(device) | |
print(f'DONE ({time.time() - start:.2f}s)') | |
return model | |
def load_tokenizer(model_name, for_dataset, cache_dir): | |
model_fullname = get_model_fullname(model_name) | |
optional_tok_kwargs = {} | |
if "facebook/opt-" in model_fullname: | |
print("Using non-fast tokenizer for OPT") | |
optional_tok_kwargs['fast'] = False | |
if for_dataset in ['pubmed']: | |
optional_tok_kwargs['padding_side'] = 'left' | |
else: | |
optional_tok_kwargs['padding_side'] = 'right' | |
base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir) | |
if base_tokenizer.pad_token_id is None: | |
base_tokenizer.pad_token_id = base_tokenizer.eos_token_id | |
if '13b' in model_fullname: | |
base_tokenizer.pad_token_id = 0 | |
return base_tokenizer | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model_name', type=str, default="bloom-7b1") | |
parser.add_argument('--cache_dir', type=str, default="../cache") | |
args = parser.parse_args() | |
load_tokenizer(args.model_name, 'xsum', args.cache_dir) | |
load_model(args.model_name, 'cpu', args.cache_dir) | |