Upload run-chatgpt.py
Browse files- run-chatgpt.py +215 -0
run-chatgpt.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import math, os, sys, types, time, gc
|
7 |
+
import torch
|
8 |
+
from src.utils import TOKENIZER
|
9 |
+
try:
|
10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
|
11 |
+
except:
|
12 |
+
pass
|
13 |
+
torch.backends.cudnn.benchmark = True
|
14 |
+
torch.backends.cudnn.allow_tf32 = True
|
15 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
16 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
17 |
+
args = types.SimpleNamespace()
|
18 |
+
|
19 |
+
########################################################################################################
|
20 |
+
# Step 1: set model & config
|
21 |
+
# Do this first: pip install torchdynamo
|
22 |
+
########################################################################################################
|
23 |
+
|
24 |
+
args.RUN_DEVICE = "cpu" # 'cpu' (already fast) // 'cuda'
|
25 |
+
args.FLOAT_MODE = "fp32" # fp32 (good for CPU) // fp16 (good for GPU, does not work for CPU) // bf16 (less accurate, but works for CPU)
|
26 |
+
|
27 |
+
# if args.RUN_DEVICE == "cuda":
|
28 |
+
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
|
29 |
+
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!!
|
30 |
+
|
31 |
+
TOKEN_MODE = "pile"
|
32 |
+
WORD_NAME = [
|
33 |
+
"20B_tokenizer_openchatgpt.json",
|
34 |
+
"20B_tokenizer_openchatgpt.json",
|
35 |
+
] # [vocab, vocab] for Pile model
|
36 |
+
UNKNOWN_CHAR = None
|
37 |
+
vocab_size = 50277 + 1
|
38 |
+
|
39 |
+
# Download Pile models: https://huggingface.co/BlinkDL
|
40 |
+
# or, set MODEL_NAME to your fine-tuned model
|
41 |
+
|
42 |
+
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
|
43 |
+
# n_layer = 12
|
44 |
+
# n_embd = 768
|
45 |
+
# ctx_len = 1024
|
46 |
+
|
47 |
+
MODEL_NAME = './out2/rwkv-5'
|
48 |
+
n_layer = 24
|
49 |
+
n_embd = 1024
|
50 |
+
ctx_len = 1024
|
51 |
+
|
52 |
+
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
|
53 |
+
# n_layer = 24
|
54 |
+
# n_embd = 2048
|
55 |
+
# ctx_len = 1024
|
56 |
+
|
57 |
+
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
|
58 |
+
# n_layer = 32
|
59 |
+
# n_embd = 2560
|
60 |
+
# ctx_len = 1024
|
61 |
+
|
62 |
+
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
|
63 |
+
# n_layer = 32
|
64 |
+
# n_embd = 4096
|
65 |
+
# ctx_len = 1024
|
66 |
+
|
67 |
+
args.MODEL_NAME = MODEL_NAME
|
68 |
+
args.n_layer = n_layer
|
69 |
+
args.n_embd = n_embd
|
70 |
+
args.ctx_len = ctx_len
|
71 |
+
args.vocab_size = vocab_size
|
72 |
+
args.head_qk = 0
|
73 |
+
args.pre_ffn = 0
|
74 |
+
args.grad_cp = 0
|
75 |
+
args.my_pos_emb = 0
|
76 |
+
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
|
77 |
+
|
78 |
+
########################################################################################################
|
79 |
+
# Step 2: set prompt & sampling stuffs
|
80 |
+
########################################################################################################
|
81 |
+
|
82 |
+
context = """quality: high
|
83 |
+
|
84 |
+
[System]
|
85 |
+
Assistant is a distilled language model trained by the community.<|STK_SP|>
|
86 |
+
|
87 |
+
[System]
|
88 |
+
<|STK_SP|>
|
89 |
+
|
90 |
+
[User]
|
91 |
+
Hi!<|STK_SP|>
|
92 |
+
|
93 |
+
[Assistant]
|
94 |
+
"""
|
95 |
+
|
96 |
+
NUM_TRIALS = 999
|
97 |
+
LENGTH_PER_TRIAL = 333
|
98 |
+
|
99 |
+
TEMPERATURE = 1.0
|
100 |
+
top_p = 0.8
|
101 |
+
top_p_newline = 0.9 # only used in TOKEN_MODE = char
|
102 |
+
|
103 |
+
DEBUG_DEBUG = False # True False --> show softmax output
|
104 |
+
|
105 |
+
########################################################################################################
|
106 |
+
|
107 |
+
print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...')
|
108 |
+
from src.model_run import RWKV_RNN
|
109 |
+
|
110 |
+
model = RWKV_RNN(args)
|
111 |
+
|
112 |
+
print(f'\nOptimizing speed...')
|
113 |
+
out, _ = model.forward([187], None)
|
114 |
+
# print(out)
|
115 |
+
gc.collect()
|
116 |
+
torch.cuda.empty_cache()
|
117 |
+
|
118 |
+
# input(0)
|
119 |
+
|
120 |
+
print(f'\nLoading tokenizer {WORD_NAME}...')
|
121 |
+
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
122 |
+
if TOKEN_MODE == "pile":
|
123 |
+
assert tokenizer.tokenizer.decode([187]) == '\n'
|
124 |
+
|
125 |
+
########################################################################################################
|
126 |
+
|
127 |
+
if tokenizer.charMode:
|
128 |
+
context = tokenizer.refine_context(context)
|
129 |
+
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
130 |
+
else:
|
131 |
+
ctx = tokenizer.tokenizer.encode(context)
|
132 |
+
src_len = len(ctx)
|
133 |
+
src_ctx = ctx.copy()
|
134 |
+
|
135 |
+
print("\nYour prompt has " + str(src_len) + " tokens.")
|
136 |
+
print(
|
137 |
+
"Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n"
|
138 |
+
)
|
139 |
+
|
140 |
+
time_slot = {}
|
141 |
+
time_ref = time.time_ns()
|
142 |
+
|
143 |
+
def record_time(name):
|
144 |
+
if name not in time_slot:
|
145 |
+
time_slot[name] = 1e20
|
146 |
+
tt = (time.time_ns() - time_ref) / 1e9
|
147 |
+
if tt < time_slot[name]:
|
148 |
+
time_slot[name] = tt
|
149 |
+
|
150 |
+
init_state = None
|
151 |
+
init_out = None
|
152 |
+
state = None
|
153 |
+
out = None
|
154 |
+
|
155 |
+
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
156 |
+
print(("-" * 50) + '\n' + context, end="")
|
157 |
+
|
158 |
+
time_ref = time.time_ns()
|
159 |
+
ctx = src_ctx.copy()
|
160 |
+
|
161 |
+
if TRIAL == 0:
|
162 |
+
for i in range(src_len):
|
163 |
+
x = ctx[: i + 1]
|
164 |
+
if i == src_len - 1:
|
165 |
+
init_out, init_state = model.forward(x, init_state)
|
166 |
+
else:
|
167 |
+
init_state = model.forward(x, init_state, preprocess_only=True)
|
168 |
+
gc.collect()
|
169 |
+
torch.cuda.empty_cache()
|
170 |
+
|
171 |
+
record_time('preprocess')
|
172 |
+
out_last = src_len
|
173 |
+
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
174 |
+
x = ctx[: i + 1]
|
175 |
+
x = x[-ctx_len:]
|
176 |
+
|
177 |
+
if i == src_len:
|
178 |
+
out = init_out.clone()
|
179 |
+
state = init_state.clone()
|
180 |
+
else:
|
181 |
+
out, state = model.forward(x, state)
|
182 |
+
if DEBUG_DEBUG:
|
183 |
+
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy()))
|
184 |
+
if TOKEN_MODE == "pile":
|
185 |
+
out[0] = -999999999 # disable <|endoftext|>
|
186 |
+
|
187 |
+
ttt = tokenizer.sample_logits(
|
188 |
+
out,
|
189 |
+
x,
|
190 |
+
ctx_len,
|
191 |
+
temperature=TEMPERATURE,
|
192 |
+
top_p_usual=top_p,
|
193 |
+
top_p_newline=top_p_newline,
|
194 |
+
)
|
195 |
+
ctx += [ttt]
|
196 |
+
|
197 |
+
if ttt == vocab_size - 1:
|
198 |
+
break
|
199 |
+
|
200 |
+
if tokenizer.charMode:
|
201 |
+
char = tokenizer.itos[ttt]
|
202 |
+
print(char, end="", flush=True)
|
203 |
+
else:
|
204 |
+
char = tokenizer.tokenizer.decode(ctx[out_last:])
|
205 |
+
if '\ufffd' not in char:
|
206 |
+
print(char, end="", flush=True)
|
207 |
+
out_last = i+1
|
208 |
+
|
209 |
+
record_time('total')
|
210 |
+
# print(f'\n\n{time_slot}\n\n')
|
211 |
+
print(
|
212 |
+
f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = ''
|
213 |
+
)
|
214 |
+
|
215 |
+
print(("-" * 50) + '\n')
|