File size: 4,735 Bytes
672cd49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
908ce46
672cd49
 
a901e5b
 
672cd49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b5fd9a
672cd49
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from src.model_run import RWKV_RNN
import numpy as np
import os, copy, types, gc, sys
import torch
from src.utils import TOKENIZER

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
np.set_printoptions(precision=4, suppress=True, linewidth=200)

WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"]
UNKNOWN_CHAR = None
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)

args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda"  
args.FLOAT_MODE = "fp32"
args.vocab_size = 50277
args.MODEL_NAME = 'zrwkv-37fifth'
# args.MODEL_NAME = 'zrwkv-23fifth'

args.n_layer = 12
args.n_embd = 768
args.ctx_len = 1024

user = "User"
bot = "Daniel"
interface = ":"

os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
MODEL_NAME = args.MODEL_NAME

model = RWKV_RNN(args)

model_tokens = []
current_state = None

def run_rnn(tokens, newline_adj = 0):
    global model_tokens, current_state
    for i in range(len(tokens)):
        model_tokens += [int(tokens[i])]
        if i == len(tokens) - 1:
            out, current_state = model.forward(model_tokens, current_state)
        else:
            current_state = model.forward(model_tokens, current_state, preprocess_only = True)
    
    out[0] = -999999999  
    out[187] += newline_adj
    return out

all_state = {}
def save_all_stat(name, last_out):
    all_state[name] = {}
    all_state[name]['out'] = last_out
    all_state[name]['rnn'] = copy.deepcopy(current_state)
    all_state[name]['token'] = copy.deepcopy(model_tokens)

def load_all_stat(name):
    global model_tokens, current_state
    current_state = copy.deepcopy(all_state[name]['rnn'])
    model_tokens = copy.deepcopy(all_state[name]['token'])
    return all_state[name]['out']


out = ""
gc.collect()

save_all_stat('chat_init', out)
save_all_stat('chat', out)  # ensure that 'chat' key is added to all_state

def reply_msg_generator():
    while True:
        msg = yield
        print(f'{bot}{interface} {msg}\n')

def on_message_generator():
    global model_tokens, current_state
    message = yield  # This yield allows us to receive the initial message
    while True:
        msg = message.replace('\\n','\n').strip()
        if len(msg) > 10000:
            message = yield 'your message is too long (max 1000 tokens)'

        out = load_all_stat('chat')
        new = f"{user}{interface} {msg}\n{bot}{interface}"
        out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
        save_all_stat('chat_pre', out)

        begin = len(model_tokens)
        out_last = begin
        yield f'{bot}{interface}'  # Yield the bot's prompt immediately
        for i in range(8000):
            token = tokenizer.sample_logits(
                out,
                model_tokens,
                args.ctx_len,
                temperature=1.0,
                top_p_usual=0.85,
                top_p_newline=0.85,
            )
            out = run_rnn([token], newline_adj=1)

            xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
            if '\ufffd' not in xxx and 'user' not in str(xxx).lower() and '\n' not in xxx and str(xxx) != ':' and str(xxx) != '\n\n' and len(str(xxx)) > 0:
                yield xxx  # Yield each part of the response as soon as it's ready
                out_last = begin + i + 1
            else:
                out_last = begin + i + 1

            send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
            if '\ufffd' in send_msg or send_msg.endswith(f'{user}{interface}') or send_msg.endswith(f'{bot}{interface}') or '\n' in send_msg:
                send_msg = send_msg.strip()
                send_msg = send_msg.replace(f'{user}{interface}', '')
                send_msg = send_msg.replace(f'{bot}{interface}', '')
                send_msg = send_msg.replace('\n', '')
                break   
        save_all_stat('chat', out)
        yield '\n'  # Yield a newline at the end of the response
        message = yield  # Get the next message

print('Start chatting with Daniel! Pretend to pick up the phone.')

on_message_gen = on_message_generator()
next_message = on_message_gen.__next__()  # Start the generator
while True:
    if next_message is None:  # If the generator is ready for a new message
        msg = input(f'{user}{interface} ')
        if len(msg.strip()) > 0:
            next_message = on_message_gen.send(msg)  # Send the message to the generator and receive the next yield
        else:
            print('Error: please say something')
    else:  # If the generator has yielded part of the response
        print(next_message, end='', flush=True)
        next_message = next(on_message_gen)  # Get the next part of the response