File size: 4,735 Bytes
672cd49 908ce46 672cd49 a901e5b 672cd49 7b5fd9a 672cd49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from src.model_run import RWKV_RNN
import numpy as np
import os, copy, types, gc, sys
import torch
from src.utils import TOKENIZER
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
np.set_printoptions(precision=4, suppress=True, linewidth=200)
WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"]
args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda"
args.FLOAT_MODE = "fp32"
args.vocab_size = 50277
args.MODEL_NAME = 'zrwkv-37fifth'
# args.MODEL_NAME = 'zrwkv-23fifth'
args.n_layer = 12
args.n_embd = 768
args.ctx_len = 1024
user = "User"
bot = "Daniel"
interface = ":"
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
model = RWKV_RNN(args)
model_tokens = []
current_state = None
def run_rnn(tokens, newline_adj = 0):
global model_tokens, current_state
for i in range(len(tokens)):
model_tokens += [int(tokens[i])]
if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state)
current_state = model.forward(model_tokens, current_state, preprocess_only = True)
out[0] = -999999999
out[187] += newline_adj
return out
all_state = {}
def save_all_stat(name, last_out):
all_state[name] = {}
all_state[name]['out'] = last_out
all_state[name]['rnn'] = copy.deepcopy(current_state)
all_state[name]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(name):
global model_tokens, current_state
current_state = copy.deepcopy(all_state[name]['rnn'])
model_tokens = copy.deepcopy(all_state[name]['token'])
return all_state[name]['out']
out = ""
save_all_stat('chat_init', out)
save_all_stat('chat', out) # ensure that 'chat' key is added to all_state
def reply_msg_generator():
while True:
msg = yield
print(f'{bot}{interface} {msg}\n')
def on_message_generator():
global model_tokens, current_state
message = yield # This yield allows us to receive the initial message
while True:
msg = message.replace('\\n','\n').strip()
if len(msg) > 10000:
message = yield 'your message is too long (max 1000 tokens)'
out = load_all_stat('chat')
new = f"{user}{interface} {msg}\n{bot}{interface}"
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
save_all_stat('chat_pre', out)
begin = len(model_tokens)
out_last = begin
yield f'{bot}{interface}' # Yield the bot's prompt immediately
for i in range(8000):
token = tokenizer.sample_logits(
out = run_rnn([token], newline_adj=1)
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx and 'user' not in str(xxx).lower() and '\n' not in xxx and str(xxx) != ':' and str(xxx) != '\n\n' and len(str(xxx)) > 0:
yield xxx # Yield each part of the response as soon as it's ready
out_last = begin + i + 1
out_last = begin + i + 1
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\ufffd' in send_msg or send_msg.endswith(f'{user}{interface}') or send_msg.endswith(f'{bot}{interface}') or '\n' in send_msg:
send_msg = send_msg.strip()
send_msg = send_msg.replace(f'{user}{interface}', '')
send_msg = send_msg.replace(f'{bot}{interface}', '')
send_msg = send_msg.replace('\n', '')
save_all_stat('chat', out)
yield '\n' # Yield a newline at the end of the response
message = yield # Get the next message
print('Start chatting with Daniel! Pretend to pick up the phone.')
on_message_gen = on_message_generator()
next_message = on_message_gen.__next__() # Start the generator
while True:
if next_message is None: # If the generator is ready for a new message
msg = input(f'{user}{interface} ')
if len(msg.strip()) > 0:
next_message = on_message_gen.send(msg) # Send the message to the generator and receive the next yield
print('Error: please say something')
else: # If the generator has yielded part of the response
print(next_message, end='', flush=True)
next_message = next(on_message_gen) # Get the next part of the response