# Copyright (c) Guangsheng Bao. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time import os def from_pretrained(cls, model_name, kwargs, cache_dir): # use local model if it exists local_path = os.path.join(cache_dir, 'local.' + model_name.replace("/", "_")) if os.path.exists(local_path): return cls.from_pretrained(local_path, **kwargs) return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir) # predefined models model_fullnames = { 'gpt2': 'gpt2', 'gpt2-xl': 'gpt2-xl', 'opt-2.7b': 'facebook/opt-2.7b', 'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B', 'gpt-j-6B': 'EleutherAI/gpt-j-6B', 'gpt-neox-20b': 'EleutherAI/gpt-neox-20b', 'mgpt': 'sberbank-ai/mGPT', 'pubmedgpt': 'stanford-crfm/pubmedgpt', 'mt5-xl': 'google/mt5-xl', 'llama-13b': 'huggyllama/llama-13b', 'llama2-13b': 'TheBloke/Llama-2-13B-fp16', 'bloom-7b1': 'bigscience/bloom-7b1', 'opt-13b': 'facebook/opt-13b', } float16_models = ['gpt-j-6B', 'gpt-neox-20b', 'llama-13b', 'llama2-13b', 'bloom-7b1', 'opt-13b'] def get_model_fullname(model_name): return model_fullnames[model_name] if model_name in model_fullnames else model_name def load_model(model_name, device, cache_dir): model_fullname = get_model_fullname(model_name) print(f'Loading model {model_fullname}...') model_kwargs = {} if model_name in float16_models: model_kwargs.update(dict(torch_dtype=torch.float16)) if 'gpt-j' in model_name: model_kwargs.update(dict(revision='float16')) model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir) print('Moving model to GPU...', end='', flush=True) start = time.time() model.to(device) print(f'DONE ({time.time() - start:.2f}s)') return model def load_tokenizer(model_name, for_dataset, cache_dir): model_fullname = get_model_fullname(model_name) optional_tok_kwargs = {} if "facebook/opt-" in model_fullname: print("Using non-fast tokenizer for OPT") optional_tok_kwargs['fast'] = False if for_dataset in ['pubmed']: optional_tok_kwargs['padding_side'] = 'left' else: optional_tok_kwargs['padding_side'] = 'right' base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir) if base_tokenizer.pad_token_id is None: base_tokenizer.pad_token_id = base_tokenizer.eos_token_id if '13b' in model_fullname: base_tokenizer.pad_token_id = 0 return base_tokenizer if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, default="bloom-7b1") parser.add_argument('--cache_dir', type=str, default="../cache") args = parser.parse_args() load_tokenizer(args.model_name, 'xsum', args.cache_dir) load_model(args.model_name, 'cpu', args.cache_dir)