File size: 1,899 Bytes
402c662
 
 
 
 
 
 
 
 
492f975
 
402c662
 
 
 
 
 
 
 
 
 
 
 
 
aed5af0
402c662
492f975
402c662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a97f6
402c662
 
aed5af0
402c662
204e8d8
402c662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch

import gradio as gr
import argparse
from utils import load_hyperparam, load_model
from models.tokenize import Tokenizer
from models.llama import *
from generate import LmGeneration

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

args = None
lm_generation = None


def init_args():
    global args
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    args = parser.parse_args()
    args.load_model_path = './model_file/chatllama_7b.bin'
    args.config_path = './config/llama_7b.json'
    args.spm_model_path = './model_file/tokenizer.model'
    args.batch_size = 1
    args.seq_length = 512
    args.world_size = 1
    args.use_int8 = False
    args.top_p = 0
    args.repetition_penalty_range = 1024
    args.repetition_penalty_slope = 0
    args.repetition_penalty = 1.15

    args = load_hyperparam(args)

    args.tokenizer = Tokenizer(model_path=args.spm_model_path)
    args.vocab_size = args.tokenizer.sp_model.vocab_size()


def init_model():
    global lm_generation
    torch.set_default_tensor_type(torch.HalfTensor)
    model = LLaMa(args)
    torch.set_default_tensor_type(torch.FloatTensor)
    model = load_model(model, args.load_model_path)
    model.eval()
    print(model)

    print(torch.cuda.max_memory_allocated() / 1024 ** 3)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    lm_generation = LmGeneration(model, args.tokenizer)


def chat(prompt, top_k, temperature):
    args.top_k = int(top_k)
    args.temperature = temperature
    response = lm_generation.generate(args, [prompt])
    return response[0]


if __name__ == '__main__':
    init_args()
    init_model()
    demo = gr.Interface(
        fn=chat,
        inputs=["text", gr.Slider(1, 60, value=40, step=1), gr.Slider(0.1, 2.0, value=1.2, step=0.1)],
        outputs="text",
    )
    demo.launch()