fast_detect_gpt / model.py
azra-kml's picture
Upload 30 files
aefc9ef verified
raw
history blame
3.31 kB
# Copyright (c) Guangsheng Bao.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import os
def from_pretrained(cls, model_name, kwargs, cache_dir):
# use local model if it exists
local_path = os.path.join(cache_dir, 'local.' + model_name.replace("/", "_"))
if os.path.exists(local_path):
return cls.from_pretrained(local_path, **kwargs)
return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir)
# predefined models
model_fullnames = { 'gpt2': 'gpt2',
'gpt2-xl': 'gpt2-xl',
'opt-2.7b': 'facebook/opt-2.7b',
'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B',
'gpt-j-6B': 'EleutherAI/gpt-j-6B',
'gpt-neox-20b': 'EleutherAI/gpt-neox-20b',
'mgpt': 'sberbank-ai/mGPT',
'pubmedgpt': 'stanford-crfm/pubmedgpt',
'mt5-xl': 'google/mt5-xl',
'llama-13b': 'huggyllama/llama-13b',
'llama2-13b': 'TheBloke/Llama-2-13B-fp16',
'bloom-7b1': 'bigscience/bloom-7b1',
'opt-13b': 'facebook/opt-13b',
}
float16_models = ['gpt-j-6B', 'gpt-neox-20b', 'llama-13b', 'llama2-13b', 'bloom-7b1', 'opt-13b']
def get_model_fullname(model_name):
return model_fullnames[model_name] if model_name in model_fullnames else model_name
def load_model(model_name, device, cache_dir):
model_fullname = get_model_fullname(model_name)
print(f'Loading model {model_fullname}...')
model_kwargs = {}
if model_name in float16_models:
model_kwargs.update(dict(torch_dtype=torch.float16))
if 'gpt-j' in model_name:
model_kwargs.update(dict(revision='float16'))
model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
print('Moving model to GPU...', end='', flush=True)
start = time.time()
model.to(device)
print(f'DONE ({time.time() - start:.2f}s)')
return model
def load_tokenizer(model_name, for_dataset, cache_dir):
model_fullname = get_model_fullname(model_name)
optional_tok_kwargs = {}
if "facebook/opt-" in model_fullname:
print("Using non-fast tokenizer for OPT")
optional_tok_kwargs['fast'] = False
if for_dataset in ['pubmed']:
optional_tok_kwargs['padding_side'] = 'left'
else:
optional_tok_kwargs['padding_side'] = 'right'
base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir)
if base_tokenizer.pad_token_id is None:
base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
if '13b' in model_fullname:
base_tokenizer.pad_token_id = 0
return base_tokenizer
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default="bloom-7b1")
parser.add_argument('--cache_dir', type=str, default="../cache")
args = parser.parse_args()
load_tokenizer(args.model_name, 'xsum', args.cache_dir)
load_model(args.model_name, 'cpu', args.cache_dir)