Spaces:
Runtime error
Runtime error
File size: 1,899 Bytes
402c662 492f975 402c662 aed5af0 402c662 492f975 402c662 84a97f6 402c662 aed5af0 402c662 204e8d8 402c662 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
import gradio as gr
import argparse
from utils import load_hyperparam, load_model
from models.tokenize import Tokenizer
from models.llama import *
from generate import LmGeneration
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
args = None
lm_generation = None
def init_args():
global args
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
args = parser.parse_args()
args.load_model_path = './model_file/chatllama_7b.bin'
args.config_path = './config/llama_7b.json'
args.spm_model_path = './model_file/tokenizer.model'
args.batch_size = 1
args.seq_length = 512
args.world_size = 1
args.use_int8 = False
args.top_p = 0
args.repetition_penalty_range = 1024
args.repetition_penalty_slope = 0
args.repetition_penalty = 1.15
args = load_hyperparam(args)
args.tokenizer = Tokenizer(model_path=args.spm_model_path)
args.vocab_size = args.tokenizer.sp_model.vocab_size()
def init_model():
global lm_generation
torch.set_default_tensor_type(torch.HalfTensor)
model = LLaMa(args)
torch.set_default_tensor_type(torch.FloatTensor)
model = load_model(model, args.load_model_path)
model.eval()
print(model)
print(torch.cuda.max_memory_allocated() / 1024 ** 3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
lm_generation = LmGeneration(model, args.tokenizer)
def chat(prompt, top_k, temperature):
args.top_k = int(top_k)
args.temperature = temperature
response = lm_generation.generate(args, [prompt])
return response[0]
if __name__ == '__main__':
init_args()
init_model()
demo = gr.Interface(
fn=chat,
inputs=["text", gr.Slider(1, 60, value=40, step=1), gr.Slider(0.1, 2.0, value=1.2, step=0.1)],
outputs="text",
)
demo.launch()
|