File size: 3,462 Bytes
7c071a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# coding=utf-8

import ctypes


class TokenWord(ctypes.Structure):
    _fields_ = [
        ("token", ctypes.c_int),
        ("word", ctypes.c_char * 2048)  # 假设最大长度为 100,你可以根据实际情况调整
    ]


class TPUChatglm:
    def __init__(self):
        self.lib = ctypes.cdll.LoadLibrary('./build/libtpuchat.so')
        device_id = 3
        bmodel_path = "../model/baichuan2-7b-test_int8.bmodel"
        token_path = "../model/tokenizer.model"
        self.device_id = device_id
        self.bmodel_path = bmodel_path
        self.token_path = token_path
        self.libset()
        self.init()

    def libset(self):
        self.lib.Baichuan2_with_devid_and_model.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
        self.lib.Baichuan2_with_devid_and_model.restype = ctypes.c_void_p

        self.lib.Baichuan2_delete.argtypes = [ctypes.c_void_p]

        # deinit
        self.lib.Baichuan2_deinit.argtypes = [ctypes.c_void_p]

        # Baichuan2_predict_first_token
        self.lib.Baichuan2_predict_first_token.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
        self.lib.Baichuan2_predict_first_token.restype = ctypes.c_char_p

        # Baichuan2_predict_next_token
        self.lib.Baichuan2_predict_next_token.argtypes = [ctypes.c_void_p]
        self.lib.Baichuan2_predict_next_token.restype = ctypes.c_char_p

        # get_eos
        self.lib.get_eos.argtypes = [ctypes.c_void_p]
        self.lib.get_eos.restype = ctypes.c_int
        # get_history
        self.lib.get_history.argtypes = [ctypes.c_void_p]
        self.lib.get_history.restype = ctypes.c_char_p
        # set history
        self.lib.set_history.argtypes = [ctypes.c_void_p, ctypes.c_char_p]

    def init(self):
        self.obj = self.lib.Baichuan2_with_devid_and_model(self.device_id, self.bmodel_path.encode('utf-8'),
                                                          self.token_path.encode('utf-8'))

    def predict_first_token(self, context):
        return self.lib.Baichuan2_predict_first_token(self.obj, context.encode('utf-8')).decode('utf-8')

    def predict_next_token(self):
        return self.lib.Baichuan2_predict_next_token(self.obj).decode('utf-8')

    def predict(self, context):

        first_token = self.predict_first_token(context)
        # print(first_token, end='')
        res = ''
        while True:
            next_token = self.predict_next_token()
            if next_token == '_GETMAX_' or next_token == '_GETEOS_':
                # print(next_token)
                break
            # print(next_token, end='')
            res += next_token
        return res

    def stream_predict(self, query, history):
        history.append((query, ''))

        prompt = ''
        # for i, (old_query, response) in enumerate(history):
        #     prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
        # prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
        prompt = "<reserved_106>" + query + "<reserved_107>"
        
        res = ''
        first_token = self.predict_first_token(prompt)
        res += first_token

        while True:
            next_token = self.predict_next_token()
            if next_token == '_GETMAX_' or next_token == '_GETEOS_':
                break
            res += next_token
            history[-1] = (query, res)
            yield res, history

    def get_config(self):
        pass