Upload 18 files
Browse files- 20B_tokenizer.json +0 -0
- chat_kivy.py +130 -0
- cuda/wkv_cuda.cu +133 -0
- cuda/wkv_cuda_bf16.cu +132 -0
- cuda/wkv_op.cpp +21 -0
- cuda/wkv_op_bf16.cpp +25 -0
- run.py +223 -0
- src/__init__.py +0 -0
- src/binidx.py +269 -0
- src/dataset.py +245 -0
- src/model.py +610 -0
- src/model_img.py +446 -0
- src/model_run.py +233 -0
- src/trainer.py +192 -0
- src/utils.py +130 -0
- train.py +350 -0
- verify.py +104 -0
- zrwkv-37fifth.pth +3 -0
20B_tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
chat_kivy.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
print('Loading...')
|
2 |
+
from src.model_run import RWKV_RNN
|
3 |
+
import numpy as np
|
4 |
+
import os, copy, types, gc, sys
|
5 |
+
import torch
|
6 |
+
from src.utils import TOKENIZER
|
7 |
+
|
8 |
+
torch.backends.cudnn.benchmark = False
|
9 |
+
torch.backends.cudnn.allow_tf32 = False
|
10 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
11 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
12 |
+
|
13 |
+
WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"]
|
14 |
+
UNKNOWN_CHAR = None
|
15 |
+
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
16 |
+
|
17 |
+
args = types.SimpleNamespace()
|
18 |
+
args.RUN_DEVICE = "cpu"
|
19 |
+
args.FLOAT_MODE = "fp32"
|
20 |
+
args.vocab_size = 50277
|
21 |
+
args.MODEL_NAME = 'zrwkv-37fifth'
|
22 |
+
args.n_layer = 12
|
23 |
+
args.n_embd = 768
|
24 |
+
args.ctx_len = 1024
|
25 |
+
|
26 |
+
user = "User"
|
27 |
+
bot = "Daniel"
|
28 |
+
interface = ":"
|
29 |
+
|
30 |
+
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
|
31 |
+
MODEL_NAME = args.MODEL_NAME
|
32 |
+
|
33 |
+
print(f'loading... {MODEL_NAME}')
|
34 |
+
model = RWKV_RNN(args)
|
35 |
+
|
36 |
+
model_tokens = []
|
37 |
+
current_state = None
|
38 |
+
|
39 |
+
def run_rnn(tokens, newline_adj = 0):
|
40 |
+
global model_tokens, current_state
|
41 |
+
for i in range(len(tokens)):
|
42 |
+
model_tokens += [int(tokens[i])]
|
43 |
+
if i == len(tokens) - 1:
|
44 |
+
out, current_state = model.forward(model_tokens, current_state)
|
45 |
+
else:
|
46 |
+
current_state = model.forward(model_tokens, current_state, preprocess_only = True)
|
47 |
+
|
48 |
+
out[0] = -999999999
|
49 |
+
out[187] += newline_adj
|
50 |
+
return out
|
51 |
+
|
52 |
+
all_state = {}
|
53 |
+
def save_all_stat(name, last_out):
|
54 |
+
all_state[name] = {}
|
55 |
+
all_state[name]['out'] = last_out
|
56 |
+
all_state[name]['rnn'] = copy.deepcopy(current_state)
|
57 |
+
all_state[name]['token'] = copy.deepcopy(model_tokens)
|
58 |
+
|
59 |
+
def load_all_stat(name):
|
60 |
+
global model_tokens, current_state
|
61 |
+
current_state = copy.deepcopy(all_state[name]['rnn'])
|
62 |
+
model_tokens = copy.deepcopy(all_state[name]['token'])
|
63 |
+
return all_state[name]['out']
|
64 |
+
|
65 |
+
print(f'\nRun prompt...')
|
66 |
+
|
67 |
+
out = ""
|
68 |
+
gc.collect()
|
69 |
+
|
70 |
+
save_all_stat('chat_init', out)
|
71 |
+
save_all_stat('chat', out) # ensure that 'chat' key is added to all_state
|
72 |
+
|
73 |
+
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
|
74 |
+
|
75 |
+
|
76 |
+
def reply_msg(msg):
|
77 |
+
print(f'{bot}{interface} {msg}\n')
|
78 |
+
|
79 |
+
def on_message(message):
|
80 |
+
global model_tokens, current_state
|
81 |
+
|
82 |
+
msg = message.replace('\\n','\n').strip()
|
83 |
+
if len(msg) > 10000:
|
84 |
+
reply_msg('your message is too long (max 1000 tokens)')
|
85 |
+
return
|
86 |
+
|
87 |
+
out = load_all_stat('chat')
|
88 |
+
new = f"{user}{interface} {msg}\n{bot}{interface}"
|
89 |
+
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
|
90 |
+
save_all_stat('chat_pre', out)
|
91 |
+
|
92 |
+
begin = len(model_tokens)
|
93 |
+
out_last = begin
|
94 |
+
print(f'{bot}{interface}', end='', flush=True)
|
95 |
+
for i in range(8000):
|
96 |
+
token = tokenizer.sample_logits(
|
97 |
+
out,
|
98 |
+
model_tokens,
|
99 |
+
args.ctx_len,
|
100 |
+
temperature=1.0,
|
101 |
+
top_p_usual=0.85,
|
102 |
+
top_p_newline=0.85,
|
103 |
+
)
|
104 |
+
out = run_rnn([token], newline_adj=1)
|
105 |
+
|
106 |
+
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
|
107 |
+
if '\ufffd' not in xxx and 'user' not in str(xxx).lower() and '\n' not in xxx and str(xxx) != ':' and str(xxx) != '\n\n' and len(str(xxx)) > 0:
|
108 |
+
print(xxx, end='', flush=True)
|
109 |
+
out_last = begin + i + 1
|
110 |
+
else:
|
111 |
+
print('\n', end='', flush=True)
|
112 |
+
out_last = begin + i + 1
|
113 |
+
|
114 |
+
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
|
115 |
+
if '\ufffd' in send_msg or send_msg.endswith(f'{user}{interface}') or send_msg.endswith(f'{bot}{interface}') or '\n' in send_msg:
|
116 |
+
send_msg = send_msg.strip()
|
117 |
+
send_msg = send_msg.replace(f'{user}{interface}', '')
|
118 |
+
send_msg = send_msg.replace(f'{bot}{interface}', '')
|
119 |
+
send_msg = send_msg.replace('\n', '')
|
120 |
+
break
|
121 |
+
save_all_stat('chat', out)
|
122 |
+
|
123 |
+
print('Start chatting with Daniel!')
|
124 |
+
|
125 |
+
while True:
|
126 |
+
msg = input(f'{user}{interface} ')
|
127 |
+
if len(msg.strip()) > 0:
|
128 |
+
on_message(msg)
|
129 |
+
else:
|
130 |
+
print('Error: please say something')
|
cuda/wkv_cuda.cu
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
|
4 |
+
#define MIN_VALUE (-1e38)
|
5 |
+
|
6 |
+
template <typename F>
|
7 |
+
__global__ void kernel_forward(const int B, const int T, const int C,
|
8 |
+
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
9 |
+
F *__restrict__ const _y) {
|
10 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
11 |
+
const int _b = idx / C;
|
12 |
+
const int _c = idx % C;
|
13 |
+
const int _offset = _b * T * C + _c;
|
14 |
+
|
15 |
+
F u = _u[_c];
|
16 |
+
F w = _w[_c];
|
17 |
+
const F *__restrict__ const k = _k + _offset;
|
18 |
+
const F *__restrict__ const v = _v + _offset;
|
19 |
+
F *__restrict__ const y = _y + _offset;
|
20 |
+
|
21 |
+
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
22 |
+
F aa = 0, bb = 0, pp = MIN_VALUE;
|
23 |
+
for (int i = 0; i < T; i++) {
|
24 |
+
const int ii = i * C;
|
25 |
+
const F kk = k[ii];
|
26 |
+
const F vv = v[ii];
|
27 |
+
|
28 |
+
F ww = u + kk;
|
29 |
+
F p = max(pp, ww);
|
30 |
+
F e1 = exp(pp - p);
|
31 |
+
F e2 = exp(ww - p);
|
32 |
+
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
33 |
+
|
34 |
+
ww = w + pp;
|
35 |
+
p = max(ww, kk);
|
36 |
+
e1 = exp(ww - p);
|
37 |
+
e2 = exp(kk - p);
|
38 |
+
aa = e1 * aa + e2 * vv;
|
39 |
+
bb = e1 * bb + e2;
|
40 |
+
pp = p;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
template <typename F>
|
45 |
+
__global__ void kernel_backward(const int B, const int T, const int C,
|
46 |
+
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
47 |
+
const F *__restrict__ const _y, const F *__restrict__ const _gy,
|
48 |
+
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
49 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
50 |
+
const int _b = idx / C;
|
51 |
+
const int _c = idx % C;
|
52 |
+
const int _offset = _b * T * C + _c;
|
53 |
+
|
54 |
+
F u = _u[_c];
|
55 |
+
F w = _w[_c];
|
56 |
+
const F *__restrict__ const k = _k + _offset;
|
57 |
+
const F *__restrict__ const v = _v + _offset;
|
58 |
+
const F *__restrict__ const y = _y + _offset;
|
59 |
+
const F *__restrict__ const gy = _gy + _offset;
|
60 |
+
F *__restrict__ const gk = _gk + _offset;
|
61 |
+
F *__restrict__ const gv = _gv + _offset;
|
62 |
+
|
63 |
+
F q[Tmax], r[Tmax];
|
64 |
+
|
65 |
+
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
66 |
+
for (int i = 0; i < T; i++) {
|
67 |
+
const int ii = i * C;
|
68 |
+
const F kk = k[ii];
|
69 |
+
const F vv = v[ii];
|
70 |
+
const F yy = y[ii];
|
71 |
+
|
72 |
+
F ww = u + kk;
|
73 |
+
F p = max(pp, ww);
|
74 |
+
F e1 = exp(pp - p);
|
75 |
+
F e2 = exp(ww - p);
|
76 |
+
const F qq = gy[ii] / (e1 * bb + e2);
|
77 |
+
gw += (ga - gb * yy) * e1 * qq;
|
78 |
+
gu += (vv - yy) * e2 * qq;
|
79 |
+
q[i] = qq;
|
80 |
+
r[i] = ww - p;
|
81 |
+
|
82 |
+
ww = w + pp;
|
83 |
+
p = max(ww, kk);
|
84 |
+
e1 = exp(ww - p);
|
85 |
+
e2 = exp(kk - p);
|
86 |
+
ga = e1 * (aa + ga);
|
87 |
+
gb = e1 * (bb + gb);
|
88 |
+
aa = e1 * aa + e2 * vv;
|
89 |
+
bb = e1 * bb + e2;
|
90 |
+
pp = p;
|
91 |
+
}
|
92 |
+
const int _offsetBC = _b * C + _c;
|
93 |
+
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
94 |
+
_gu[_offsetBC] = gu;
|
95 |
+
|
96 |
+
aa = 0, bb = 0, pp = MIN_VALUE;
|
97 |
+
for (int i = T - 1; i >= 0; i--) {
|
98 |
+
const int ii = i * C;
|
99 |
+
const F kk = k[ii];
|
100 |
+
const F vv = v[ii];
|
101 |
+
const F yy = y[ii];
|
102 |
+
const F qq = q[i];
|
103 |
+
const F rr = r[i];
|
104 |
+
|
105 |
+
F e1 = qq * exp(rr);
|
106 |
+
F e2 = exp(kk + pp);
|
107 |
+
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
108 |
+
gv[ii] = e1 + e2 * aa;
|
109 |
+
|
110 |
+
const F ww = w + pp;
|
111 |
+
const F www = rr - u - kk;
|
112 |
+
const F p = max(ww, www);
|
113 |
+
e1 = exp(ww - p);
|
114 |
+
e2 = qq * exp(www - p);
|
115 |
+
aa = e1 * aa + e2;
|
116 |
+
bb = e1 * bb - e2 * yy;
|
117 |
+
pp = p;
|
118 |
+
}
|
119 |
+
}
|
120 |
+
|
121 |
+
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
122 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
123 |
+
assert(B * C % threadsPerBlock.x == 0);
|
124 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
125 |
+
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
126 |
+
}
|
127 |
+
|
128 |
+
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
129 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
130 |
+
assert(B * C % threadsPerBlock.x == 0);
|
131 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
132 |
+
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
133 |
+
}
|
cuda/wkv_cuda_bf16.cu
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
#define MIN_VALUE (-1e38)
|
5 |
+
typedef at::BFloat16 bf16;
|
6 |
+
|
7 |
+
__global__ void kernel_forward(const int B, const int T, const int C,
|
8 |
+
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
9 |
+
bf16 *__restrict__ const _y) {
|
10 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
11 |
+
const int _b = idx / C;
|
12 |
+
const int _c = idx % C;
|
13 |
+
const int _offset = _b * T * C + _c;
|
14 |
+
|
15 |
+
float u = float(_u[_c]);
|
16 |
+
float w = _w[_c];
|
17 |
+
const bf16 *__restrict__ const k = _k + _offset;
|
18 |
+
const bf16 *__restrict__ const v = _v + _offset;
|
19 |
+
bf16 *__restrict__ const y = _y + _offset;
|
20 |
+
|
21 |
+
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
22 |
+
float aa = 0, bb = 0, pp = MIN_VALUE;
|
23 |
+
for (int i = 0; i < T; i++) {
|
24 |
+
const int ii = i * C;
|
25 |
+
const float kk = float(k[ii]);
|
26 |
+
const float vv = float(v[ii]);
|
27 |
+
|
28 |
+
float ww = u + kk;
|
29 |
+
float p = max(pp, ww);
|
30 |
+
float e1 = exp(pp - p);
|
31 |
+
float e2 = exp(ww - p);
|
32 |
+
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
33 |
+
|
34 |
+
ww = w + pp;
|
35 |
+
p = max(ww, kk);
|
36 |
+
e1 = exp(ww - p);
|
37 |
+
e2 = exp(kk - p);
|
38 |
+
aa = e1 * aa + e2 * vv;
|
39 |
+
bb = e1 * bb + e2;
|
40 |
+
pp = p;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
__global__ void kernel_backward(const int B, const int T, const int C,
|
45 |
+
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
46 |
+
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
|
47 |
+
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
|
48 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
49 |
+
const int _b = idx / C;
|
50 |
+
const int _c = idx % C;
|
51 |
+
const int _offset = _b * T * C + _c;
|
52 |
+
|
53 |
+
float u = float(_u[_c]);
|
54 |
+
float w = _w[_c];
|
55 |
+
const bf16 *__restrict__ const k = _k + _offset;
|
56 |
+
const bf16 *__restrict__ const v = _v + _offset;
|
57 |
+
const bf16 *__restrict__ const y = _y + _offset;
|
58 |
+
const bf16 *__restrict__ const gy = _gy + _offset;
|
59 |
+
bf16 *__restrict__ const gk = _gk + _offset;
|
60 |
+
bf16 *__restrict__ const gv = _gv + _offset;
|
61 |
+
|
62 |
+
float q[Tmax], r[Tmax];
|
63 |
+
|
64 |
+
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
65 |
+
for (int i = 0; i < T; i++) {
|
66 |
+
const int ii = i * C;
|
67 |
+
const float kk = float(k[ii]);
|
68 |
+
const float vv = float(v[ii]);
|
69 |
+
const float yy = float(y[ii]);
|
70 |
+
|
71 |
+
float ww = u + kk;
|
72 |
+
float p = max(pp, ww);
|
73 |
+
float e1 = exp(pp - p);
|
74 |
+
float e2 = exp(ww - p);
|
75 |
+
const float qq = float(gy[ii]) / (e1 * bb + e2);
|
76 |
+
gw += (ga - gb * yy) * e1 * qq;
|
77 |
+
gu += (vv - yy) * e2 * qq;
|
78 |
+
q[i] = qq;
|
79 |
+
r[i] = ww - p;
|
80 |
+
|
81 |
+
ww = w + pp;
|
82 |
+
p = max(ww, kk);
|
83 |
+
e1 = exp(ww - p);
|
84 |
+
e2 = exp(kk - p);
|
85 |
+
ga = e1 * (aa + ga);
|
86 |
+
gb = e1 * (bb + gb);
|
87 |
+
aa = e1 * aa + e2 * vv;
|
88 |
+
bb = e1 * bb + e2;
|
89 |
+
pp = p;
|
90 |
+
}
|
91 |
+
const int _offsetBC = _b * C + _c;
|
92 |
+
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
|
93 |
+
_gu[_offsetBC] = bf16(gu);
|
94 |
+
|
95 |
+
aa = 0, bb = 0, pp = MIN_VALUE;
|
96 |
+
for (int i = T - 1; i >= 0; i--) {
|
97 |
+
const int ii = i * C;
|
98 |
+
const float kk = float(k[ii]);
|
99 |
+
const float vv = float(v[ii]);
|
100 |
+
const float yy = float(y[ii]);
|
101 |
+
const float qq = q[i];
|
102 |
+
const float rr = r[i];
|
103 |
+
|
104 |
+
float e1 = qq * exp(rr);
|
105 |
+
float e2 = exp(kk + pp);
|
106 |
+
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
|
107 |
+
gv[ii] = bf16(e1 + e2 * aa);
|
108 |
+
|
109 |
+
const float ww = w + pp;
|
110 |
+
const float www = rr - u - kk;
|
111 |
+
const float p = max(ww, www);
|
112 |
+
e1 = exp(ww - p);
|
113 |
+
e2 = qq * exp(www - p);
|
114 |
+
aa = e1 * aa + e2;
|
115 |
+
bb = e1 * bb - e2 * yy;
|
116 |
+
pp = p;
|
117 |
+
}
|
118 |
+
}
|
119 |
+
|
120 |
+
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
|
121 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
122 |
+
assert(B * C % threadsPerBlock.x == 0);
|
123 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
124 |
+
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
125 |
+
}
|
126 |
+
|
127 |
+
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
|
128 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
129 |
+
assert(B * C % threadsPerBlock.x == 0);
|
130 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
131 |
+
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
132 |
+
}
|
cuda/wkv_op.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
4 |
+
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
5 |
+
|
6 |
+
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
7 |
+
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
8 |
+
}
|
9 |
+
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
10 |
+
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
11 |
+
}
|
12 |
+
|
13 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
14 |
+
m.def("forward", &forward, "wkv forward");
|
15 |
+
m.def("backward", &backward, "wkv backward");
|
16 |
+
}
|
17 |
+
|
18 |
+
TORCH_LIBRARY(wkv, m) {
|
19 |
+
m.def("forward", forward);
|
20 |
+
m.def("backward", backward);
|
21 |
+
}
|
cuda/wkv_op_bf16.cpp
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
typedef at::BFloat16 bf16;
|
4 |
+
|
5 |
+
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
|
6 |
+
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
|
7 |
+
|
8 |
+
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
9 |
+
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
|
10 |
+
}
|
11 |
+
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
12 |
+
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
13 |
+
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
|
14 |
+
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
|
15 |
+
}
|
16 |
+
|
17 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
18 |
+
m.def("forward", &forward, "wkv forward");
|
19 |
+
m.def("backward", &backward, "wkv backward");
|
20 |
+
}
|
21 |
+
|
22 |
+
TORCH_LIBRARY(wkv, m) {
|
23 |
+
m.def("forward", forward);
|
24 |
+
m.def("backward", backward);
|
25 |
+
}
|
run.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import math, os, sys, types, time, gc
|
7 |
+
import torch
|
8 |
+
from src.utils import TOKENIZER
|
9 |
+
try:
|
10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
|
11 |
+
except:
|
12 |
+
pass
|
13 |
+
torch.backends.cudnn.benchmark = True
|
14 |
+
torch.backends.cudnn.allow_tf32 = True
|
15 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
16 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
17 |
+
args = types.SimpleNamespace()
|
18 |
+
|
19 |
+
########################################################################################################
|
20 |
+
# Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible)
|
21 |
+
########################################################################################################
|
22 |
+
|
23 |
+
args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast)
|
24 |
+
args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU)
|
25 |
+
|
26 |
+
# if args.RUN_DEVICE == "cuda":
|
27 |
+
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
|
28 |
+
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!!
|
29 |
+
|
30 |
+
TOKEN_MODE = "pile"
|
31 |
+
WORD_NAME = [
|
32 |
+
"20B_tokenizer.json",
|
33 |
+
"20B_tokenizer.json",
|
34 |
+
] # [vocab, vocab] for Pile model
|
35 |
+
UNKNOWN_CHAR = None
|
36 |
+
vocab_size = 50277
|
37 |
+
|
38 |
+
# Download Pile models: https://huggingface.co/BlinkDL
|
39 |
+
# or, set MODEL_NAME to your fine-tuned model
|
40 |
+
|
41 |
+
# MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
|
42 |
+
# n_layer = 12
|
43 |
+
# n_embd = 768
|
44 |
+
# ctx_len = 1024
|
45 |
+
|
46 |
+
# MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
|
47 |
+
# n_layer = 24
|
48 |
+
# n_embd = 1024
|
49 |
+
# ctx_len = 1024
|
50 |
+
|
51 |
+
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
|
52 |
+
# n_layer = 24
|
53 |
+
# n_embd = 2048
|
54 |
+
# ctx_len = 1024
|
55 |
+
|
56 |
+
# MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
|
57 |
+
# n_layer = 32
|
58 |
+
# n_embd = 2560
|
59 |
+
# ctx_len = 1024
|
60 |
+
|
61 |
+
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
|
62 |
+
n_layer = 32
|
63 |
+
n_embd = 4096
|
64 |
+
ctx_len = 1024
|
65 |
+
|
66 |
+
args.MODEL_NAME = MODEL_NAME
|
67 |
+
args.n_layer = n_layer
|
68 |
+
args.n_embd = n_embd
|
69 |
+
args.ctx_len = ctx_len
|
70 |
+
args.vocab_size = vocab_size
|
71 |
+
args.head_qk = 0
|
72 |
+
args.pre_ffn = 0
|
73 |
+
args.grad_cp = 0
|
74 |
+
args.my_pos_emb = 0
|
75 |
+
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
|
76 |
+
|
77 |
+
########################################################################################################
|
78 |
+
# Step 2: set prompt & sampling stuffs
|
79 |
+
########################################################################################################
|
80 |
+
|
81 |
+
# context = 'A'
|
82 |
+
# context = "\nIn the"
|
83 |
+
# context = '\nSugar:'
|
84 |
+
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
|
85 |
+
|
86 |
+
# context = "\n深圳是" # test Chinese
|
87 |
+
# context = "\n東京は" # test Japanese
|
88 |
+
|
89 |
+
# ###### A good prompt for Q&A ######
|
90 |
+
# context = '''
|
91 |
+
# Questions & Helpful Answers
|
92 |
+
# Ask Research Experts
|
93 |
+
# Question:
|
94 |
+
# Can penguins fly?
|
95 |
+
|
96 |
+
# Full Answer:
|
97 |
+
# '''
|
98 |
+
|
99 |
+
# ###### A good prompt for chatbot ######
|
100 |
+
# context = '''
|
101 |
+
# The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins.
|
102 |
+
|
103 |
+
# User: who is president of usa?
|
104 |
+
|
105 |
+
# Bot: It’s Joe Biden; he was sworn in earlier this year.
|
106 |
+
|
107 |
+
# User: french revolution what year
|
108 |
+
|
109 |
+
# Bot: It started in 1789, but it lasted 10 years until 1799.
|
110 |
+
|
111 |
+
# User: guess i marry who ?
|
112 |
+
|
113 |
+
# Bot: Only if you tell me more about yourself - what are your interests?
|
114 |
+
|
115 |
+
# User: wat is lhc
|
116 |
+
|
117 |
+
# Bot: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
118 |
+
|
119 |
+
# User:''' # type your question here
|
120 |
+
|
121 |
+
NUM_TRIALS = 999
|
122 |
+
LENGTH_PER_TRIAL = 333
|
123 |
+
|
124 |
+
TEMPERATURE = 1.0
|
125 |
+
top_p = 0.8
|
126 |
+
top_p_newline = 0.9 # only used in TOKEN_MODE = char
|
127 |
+
|
128 |
+
DEBUG_DEBUG = False # True False --> show softmax output
|
129 |
+
|
130 |
+
########################################################################################################
|
131 |
+
|
132 |
+
from src.model_run import RWKV_RNN
|
133 |
+
|
134 |
+
model = RWKV_RNN(args)
|
135 |
+
|
136 |
+
out, _ = model.forward([187], None)
|
137 |
+
# print(out)
|
138 |
+
gc.collect()
|
139 |
+
torch.cuda.empty_cache()
|
140 |
+
|
141 |
+
# input(0)
|
142 |
+
|
143 |
+
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
144 |
+
if TOKEN_MODE == "pile":
|
145 |
+
assert tokenizer.tokenizer.decode([187]) == '\n'
|
146 |
+
|
147 |
+
########################################################################################################
|
148 |
+
|
149 |
+
if tokenizer.charMode:
|
150 |
+
context = tokenizer.refine_context(context)
|
151 |
+
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
152 |
+
else:
|
153 |
+
ctx = tokenizer.tokenizer.encode(context)
|
154 |
+
src_len = len(ctx)
|
155 |
+
src_ctx = ctx.copy()
|
156 |
+
|
157 |
+
|
158 |
+
time_slot = {}
|
159 |
+
time_ref = time.time_ns()
|
160 |
+
|
161 |
+
def record_time(name):
|
162 |
+
if name not in time_slot:
|
163 |
+
time_slot[name] = 1e20
|
164 |
+
tt = (time.time_ns() - time_ref) / 1e9
|
165 |
+
if tt < time_slot[name]:
|
166 |
+
time_slot[name] = tt
|
167 |
+
|
168 |
+
init_state = None
|
169 |
+
init_out = None
|
170 |
+
state = None
|
171 |
+
out = None
|
172 |
+
|
173 |
+
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
174 |
+
|
175 |
+
time_ref = time.time_ns()
|
176 |
+
ctx = src_ctx.copy()
|
177 |
+
|
178 |
+
if TRIAL == 0:
|
179 |
+
for i in range(src_len):
|
180 |
+
x = ctx[: i + 1]
|
181 |
+
if i == src_len - 1:
|
182 |
+
init_out, init_state = model.forward(x, init_state)
|
183 |
+
else:
|
184 |
+
init_state = model.forward(x, init_state, preprocess_only=True)
|
185 |
+
gc.collect()
|
186 |
+
torch.cuda.empty_cache()
|
187 |
+
|
188 |
+
record_time('preprocess')
|
189 |
+
out_last = src_len
|
190 |
+
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
191 |
+
x = ctx[: i + 1]
|
192 |
+
x = x[-ctx_len:]
|
193 |
+
|
194 |
+
if i == src_len:
|
195 |
+
out = init_out.clone()
|
196 |
+
state = init_state.clone()
|
197 |
+
else:
|
198 |
+
out, state = model.forward(x, state)
|
199 |
+
if DEBUG_DEBUG:
|
200 |
+
if TOKEN_MODE == "pile":
|
201 |
+
out[0] = -999999999 # disable <|endoftext|>
|
202 |
+
|
203 |
+
ttt = tokenizer.sample_logits(
|
204 |
+
out,
|
205 |
+
x,
|
206 |
+
ctx_len,
|
207 |
+
temperature=TEMPERATURE,
|
208 |
+
top_p_usual=top_p,
|
209 |
+
top_p_newline=top_p_newline,
|
210 |
+
)
|
211 |
+
ctx += [ttt]
|
212 |
+
|
213 |
+
if tokenizer.charMode:
|
214 |
+
char = tokenizer.itos[ttt]
|
215 |
+
else:
|
216 |
+
char = tokenizer.tokenizer.decode(ctx[out_last:])
|
217 |
+
if '\ufffd' not in char: # is valid utf8 string?
|
218 |
+
out_last = i+1
|
219 |
+
|
220 |
+
record_time('total')
|
221 |
+
# print(f'\n\n{time_slot}\n\n')
|
222 |
+
|
223 |
+
|
src/__init__.py
ADDED
File without changes
|
src/binidx.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib2to3.pgen2 import token
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import shutil
|
6 |
+
import struct
|
7 |
+
from functools import lru_cache
|
8 |
+
from itertools import accumulate
|
9 |
+
|
10 |
+
def print_rank_0(*message):
|
11 |
+
pass
|
12 |
+
# """If distributed is initialized print only on rank 0."""
|
13 |
+
# if torch.distributed.is_initialized():
|
14 |
+
# if torch.distributed.get_rank() == 0:
|
15 |
+
# print(*message, flush=True)
|
16 |
+
# else:
|
17 |
+
# print(*message, flush=True)
|
18 |
+
|
19 |
+
def _warmup_mmap_file(path):
|
20 |
+
pass
|
21 |
+
# with open(path, "rb") as stream:
|
22 |
+
# while stream.read(100 * 1024 * 1024):
|
23 |
+
# pass
|
24 |
+
|
25 |
+
dtypes = {
|
26 |
+
1: np.uint8,
|
27 |
+
2: np.int8,
|
28 |
+
3: np.int16,
|
29 |
+
4: np.int32,
|
30 |
+
5: np.int64,
|
31 |
+
6: float,
|
32 |
+
7: np.double,
|
33 |
+
8: np.uint16,
|
34 |
+
}
|
35 |
+
|
36 |
+
def code(dtype):
|
37 |
+
for k in dtypes.keys():
|
38 |
+
if dtypes[k] == dtype:
|
39 |
+
return k
|
40 |
+
raise ValueError(dtype)
|
41 |
+
|
42 |
+
def index_file_path(prefix_path):
|
43 |
+
return prefix_path + ".idx"
|
44 |
+
|
45 |
+
def data_file_path(prefix_path):
|
46 |
+
return prefix_path + ".bin"
|
47 |
+
|
48 |
+
class MMapIndexedDataset(torch.utils.data.Dataset):
|
49 |
+
class Index(object):
|
50 |
+
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def writer(cls, path, dtype):
|
54 |
+
class _Writer(object):
|
55 |
+
def __enter__(self):
|
56 |
+
self._file = open(path, "wb")
|
57 |
+
|
58 |
+
# Write Magic string so we can check the file format then opening it again.
|
59 |
+
self._file.write(cls._HDR_MAGIC)
|
60 |
+
# Write version number
|
61 |
+
# Little endian unsigned 64 Bit integer
|
62 |
+
self._file.write(struct.pack("<Q", 1))
|
63 |
+
# Little endian unsigned 8 Bit integer
|
64 |
+
self._file.write(struct.pack("<B", code(dtype)))
|
65 |
+
|
66 |
+
return self
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def _get_pointers(sizes):
|
70 |
+
dtype_size = dtype().itemsize
|
71 |
+
address = 0
|
72 |
+
pointers = []
|
73 |
+
|
74 |
+
for size in sizes:
|
75 |
+
pointers.append(address)
|
76 |
+
address += size * dtype_size
|
77 |
+
|
78 |
+
return pointers
|
79 |
+
|
80 |
+
def write(self, sizes, doc_idx):
|
81 |
+
pointers = self._get_pointers(sizes)
|
82 |
+
|
83 |
+
# Little endian unsigned 64 Bit integer
|
84 |
+
self._file.write(struct.pack("<Q", len(sizes)))
|
85 |
+
# Little endian unsigned 64 Bit integer
|
86 |
+
self._file.write(struct.pack("<Q", len(doc_idx)))
|
87 |
+
|
88 |
+
sizes = np.array(sizes, dtype=np.int32)
|
89 |
+
self._file.write(sizes.tobytes(order="C"))
|
90 |
+
del sizes
|
91 |
+
|
92 |
+
pointers = np.array(pointers, dtype=np.int64)
|
93 |
+
self._file.write(pointers.tobytes(order="C"))
|
94 |
+
del pointers
|
95 |
+
|
96 |
+
doc_idx = np.array(doc_idx, dtype=np.int64)
|
97 |
+
self._file.write(doc_idx.tobytes(order="C"))
|
98 |
+
|
99 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
100 |
+
self._file.close()
|
101 |
+
|
102 |
+
return _Writer()
|
103 |
+
|
104 |
+
def __init__(self, path, skip_warmup=False):
|
105 |
+
with open(path, "rb") as stream:
|
106 |
+
magic_test = stream.read(9)
|
107 |
+
assert self._HDR_MAGIC == magic_test, (
|
108 |
+
"Index file doesn't match expected format. "
|
109 |
+
"Make sure that --dataset-impl is configured properly."
|
110 |
+
)
|
111 |
+
# Little endian unsigned 64 Bit integer
|
112 |
+
version = struct.unpack("<Q", stream.read(8))
|
113 |
+
assert (1,) == version
|
114 |
+
|
115 |
+
# Little endian unsigned 8 Bit integer
|
116 |
+
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
117 |
+
self._dtype = dtypes[dtype_code]
|
118 |
+
self._dtype_size = self._dtype().itemsize
|
119 |
+
|
120 |
+
self._len = struct.unpack("<Q", stream.read(8))[0]
|
121 |
+
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
122 |
+
offset = stream.tell()
|
123 |
+
|
124 |
+
if not skip_warmup:
|
125 |
+
print_rank_0(" warming up index mmap file...")
|
126 |
+
_warmup_mmap_file(path)
|
127 |
+
|
128 |
+
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
129 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
130 |
+
print_rank_0(" reading sizes...")
|
131 |
+
self._sizes = np.frombuffer(
|
132 |
+
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
133 |
+
)
|
134 |
+
print_rank_0(" reading pointers...")
|
135 |
+
self._pointers = np.frombuffer(
|
136 |
+
self._bin_buffer,
|
137 |
+
dtype=np.int64,
|
138 |
+
count=self._len,
|
139 |
+
offset=offset + self._sizes.nbytes,
|
140 |
+
)
|
141 |
+
print_rank_0(" reading document index...")
|
142 |
+
self._doc_idx = np.frombuffer(
|
143 |
+
self._bin_buffer,
|
144 |
+
dtype=np.int64,
|
145 |
+
count=self._doc_count,
|
146 |
+
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
147 |
+
)
|
148 |
+
|
149 |
+
def __del__(self):
|
150 |
+
self._bin_buffer_mmap._mmap.close()
|
151 |
+
del self._bin_buffer_mmap
|
152 |
+
|
153 |
+
@property
|
154 |
+
def dtype(self):
|
155 |
+
return self._dtype
|
156 |
+
|
157 |
+
@property
|
158 |
+
def sizes(self):
|
159 |
+
return self._sizes
|
160 |
+
|
161 |
+
@property
|
162 |
+
def doc_idx(self):
|
163 |
+
return self._doc_idx
|
164 |
+
|
165 |
+
@lru_cache(maxsize=8)
|
166 |
+
def __getitem__(self, i):
|
167 |
+
return self._pointers[i], self._sizes[i]
|
168 |
+
|
169 |
+
def __len__(self):
|
170 |
+
return self._len
|
171 |
+
|
172 |
+
def __init__(self, path, skip_warmup=False):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self._path = None
|
176 |
+
self._index = None
|
177 |
+
self._bin_buffer = None
|
178 |
+
|
179 |
+
self._do_init(path, skip_warmup)
|
180 |
+
|
181 |
+
def __getstate__(self):
|
182 |
+
return self._path
|
183 |
+
|
184 |
+
def __setstate__(self, state):
|
185 |
+
self._do_init(state)
|
186 |
+
|
187 |
+
def _do_init(self, path, skip_warmup):
|
188 |
+
self._path = path
|
189 |
+
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
190 |
+
|
191 |
+
if not skip_warmup:
|
192 |
+
print_rank_0(" warming up data mmap file...")
|
193 |
+
_warmup_mmap_file(data_file_path(self._path))
|
194 |
+
print_rank_0(" creating numpy buffer of mmap...")
|
195 |
+
self._bin_buffer_mmap = np.memmap(
|
196 |
+
data_file_path(self._path), mode="r", order="C"
|
197 |
+
)
|
198 |
+
print_rank_0(" creating memory view of numpy buffer...")
|
199 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
200 |
+
|
201 |
+
def __del__(self):
|
202 |
+
self._bin_buffer_mmap._mmap.close()
|
203 |
+
del self._bin_buffer_mmap
|
204 |
+
del self._index
|
205 |
+
|
206 |
+
def __len__(self):
|
207 |
+
return len(self._index)
|
208 |
+
|
209 |
+
# @lru_cache(maxsize=8)
|
210 |
+
def __getitem__(self, idx):
|
211 |
+
if isinstance(idx, int):
|
212 |
+
ptr, size = self._index[idx]
|
213 |
+
np_array = np.frombuffer(
|
214 |
+
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
215 |
+
)
|
216 |
+
return np_array
|
217 |
+
elif isinstance(idx, slice):
|
218 |
+
start, stop, step = idx.indices(len(self))
|
219 |
+
if step != 1:
|
220 |
+
raise ValueError(
|
221 |
+
"Slices into indexed_dataset must be contiguous")
|
222 |
+
ptr = self._index._pointers[start]
|
223 |
+
sizes = self._index._sizes[idx]
|
224 |
+
offsets = list(accumulate(sizes))
|
225 |
+
total_size = sum(sizes)
|
226 |
+
np_array = np.frombuffer(
|
227 |
+
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
228 |
+
)
|
229 |
+
sents = np.split(np_array, offsets[:-1])
|
230 |
+
return sents
|
231 |
+
|
232 |
+
def get(self, idx, offset=0, length=None):
|
233 |
+
"""Retrieves a single item from the dataset with the option to only
|
234 |
+
return a portion of the item.
|
235 |
+
|
236 |
+
get(idx) is the same as [idx] but get() does not support slicing.
|
237 |
+
"""
|
238 |
+
ptr, size = self._index[idx]
|
239 |
+
if length is None:
|
240 |
+
length = size - offset
|
241 |
+
ptr += offset * np.dtype(self._index.dtype).itemsize
|
242 |
+
np_array = np.frombuffer(
|
243 |
+
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
244 |
+
)
|
245 |
+
return np_array
|
246 |
+
|
247 |
+
@property
|
248 |
+
def sizes(self):
|
249 |
+
return self._index.sizes
|
250 |
+
|
251 |
+
@property
|
252 |
+
def doc_idx(self):
|
253 |
+
return self._index.doc_idx
|
254 |
+
|
255 |
+
def get_doc_idx(self):
|
256 |
+
return self._index._doc_idx
|
257 |
+
|
258 |
+
def set_doc_idx(self, doc_idx_):
|
259 |
+
self._index._doc_idx = doc_idx_
|
260 |
+
|
261 |
+
@property
|
262 |
+
def supports_prefetch(self):
|
263 |
+
return False
|
264 |
+
|
265 |
+
@staticmethod
|
266 |
+
def exists(path):
|
267 |
+
return os.path.exists(index_file_path(path)) and os.path.exists(
|
268 |
+
data_file_path(path)
|
269 |
+
)
|
src/dataset.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import json, math, random, os, sys
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from pytorch_lightning.utilities import rank_zero_info
|
10 |
+
from .binidx import MMapIndexedDataset
|
11 |
+
from .utils import MaybeIsPrime
|
12 |
+
|
13 |
+
|
14 |
+
class MyDataset(Dataset):
|
15 |
+
def __init__(self, args):
|
16 |
+
self.args = args
|
17 |
+
|
18 |
+
if args.data_type == "binidx":
|
19 |
+
self.vocab_size = args.vocab_size
|
20 |
+
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
|
21 |
+
|
22 |
+
if args.my_pile_version == 1:
|
23 |
+
self.data = MMapIndexedDataset(args.data_file)
|
24 |
+
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
|
25 |
+
rank_zero_info(f"Data has {self.data_size} tokens.")
|
26 |
+
else:
|
27 |
+
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
|
28 |
+
data_list = [i.strip().split(' ') for i in data_list]
|
29 |
+
self.data = []
|
30 |
+
self.data_size = int(data_list[-1][-1])
|
31 |
+
rank_zero_info(f"Data has {self.data_size} chunks.")
|
32 |
+
for d in data_list:
|
33 |
+
data = MMapIndexedDataset(d[0])
|
34 |
+
data_size = len(data._bin_buffer) // data._index._dtype_size
|
35 |
+
assert (data_size - args.ctx_len) == int(d[1])
|
36 |
+
self.data += [[int(d[-1]), int(d[1]), data]]
|
37 |
+
# rank_zero_info(self.data)
|
38 |
+
|
39 |
+
if args.my_qa_mask > 0:
|
40 |
+
# self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
|
41 |
+
self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document')
|
42 |
+
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
43 |
+
else:
|
44 |
+
self.data_pile = None
|
45 |
+
self.data_pile_size = 0
|
46 |
+
|
47 |
+
if args.my_pile_stage > 0:
|
48 |
+
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
49 |
+
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
50 |
+
assert self.samples_per_epoch == 40320
|
51 |
+
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
|
52 |
+
dataset_slot = self.data_size // args.ctx_len
|
53 |
+
if args.my_pile_stage != 4:
|
54 |
+
assert MaybeIsPrime(args.magic_prime)
|
55 |
+
assert args.magic_prime % 3 == 2
|
56 |
+
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
|
57 |
+
elif args.data_type == "numpy":
|
58 |
+
self.data = np.load(args.data_file).astype("int")
|
59 |
+
self.vocab_size = args.vocab_size
|
60 |
+
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
61 |
+
self.data_size = len(self.data)
|
62 |
+
rank_zero_info(f"Data has {self.data_size} tokens.")
|
63 |
+
elif args.data_type == "uint16":
|
64 |
+
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
|
65 |
+
self.vocab_size = args.vocab_size
|
66 |
+
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
67 |
+
self.data_size = self.data.shape[0]
|
68 |
+
rank_zero_info(f"Data has {self.data_size} samples.")
|
69 |
+
elif args.data_type == "wds_img":
|
70 |
+
self.vocab_size = -1
|
71 |
+
self.data_size = -1
|
72 |
+
self.data = None
|
73 |
+
self.error_count = 0
|
74 |
+
else:
|
75 |
+
if args.data_type == "dummy":
|
76 |
+
rank_zero_info("Building dummy data...")
|
77 |
+
self.data = ""
|
78 |
+
for i in range(100000):
|
79 |
+
aa = (i) % 10000
|
80 |
+
bb = (i * i) % 10000
|
81 |
+
cc = aa + bb
|
82 |
+
self.data += f".{aa}+{bb}={cc}."
|
83 |
+
else:
|
84 |
+
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
85 |
+
rank_zero_info("Building token list...")
|
86 |
+
unique = sorted(list(set(self.data)))
|
87 |
+
self.vocab_size = len(unique)
|
88 |
+
# rank_zero_info()
|
89 |
+
# for u in unique:
|
90 |
+
# print(u, end=' ')
|
91 |
+
# rank_zero_info('\n\n')
|
92 |
+
xx = 0
|
93 |
+
xxObj = {}
|
94 |
+
for u in unique:
|
95 |
+
xxObj[xx] = u
|
96 |
+
xx += 1
|
97 |
+
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
|
98 |
+
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
99 |
+
self.data_size = len(self.data)
|
100 |
+
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
|
101 |
+
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
102 |
+
self.itos = {i: ch for i, ch in enumerate(unique)}
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return self.args.epoch_steps * self.args.micro_bsz
|
106 |
+
|
107 |
+
def __getitem__(self, idx):
|
108 |
+
args = self.args
|
109 |
+
rank = self.global_rank
|
110 |
+
epoch = self.real_epoch
|
111 |
+
world_size = self.world_size
|
112 |
+
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
113 |
+
|
114 |
+
if args.data_type == "wds_img":
|
115 |
+
def init_wds(self, bias=0):
|
116 |
+
def identity(x):
|
117 |
+
return x
|
118 |
+
import webdataset as wds
|
119 |
+
import torchvision.transforms as transforms
|
120 |
+
# img_transform = transforms.Compose(
|
121 |
+
# [transforms.CenterCrop(256)]
|
122 |
+
# )
|
123 |
+
img_transform = transforms.Compose([
|
124 |
+
transforms.CenterCrop(512),
|
125 |
+
transforms.Resize((args.my_img_size))
|
126 |
+
])
|
127 |
+
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
|
128 |
+
for pp in self.data_raw.pipeline:
|
129 |
+
if 'Resampled' in str(pp):
|
130 |
+
pp.deterministic = True
|
131 |
+
def worker_seed():
|
132 |
+
return rank*100000+epoch+bias*1e9
|
133 |
+
pp.worker_seed = worker_seed
|
134 |
+
self.data = iter(self.data_raw)
|
135 |
+
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
|
136 |
+
if self.data == None:
|
137 |
+
init_wds(self)
|
138 |
+
trial = 0
|
139 |
+
while trial < 10:
|
140 |
+
try:
|
141 |
+
dd = next(self.data) # jpg, json, txt
|
142 |
+
break
|
143 |
+
except:
|
144 |
+
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
|
145 |
+
self.error_count += 1
|
146 |
+
init_wds(self, self.error_count)
|
147 |
+
trial += 1
|
148 |
+
pass
|
149 |
+
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
|
150 |
+
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
|
151 |
+
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
|
152 |
+
return dd[0], dd[2]
|
153 |
+
else:
|
154 |
+
if args.data_type == "uint16":
|
155 |
+
i = np.random.randint(0, self.data_size-1)
|
156 |
+
dix = self.data[i]
|
157 |
+
x = torch.tensor(dix[:-1], dtype=torch.long)
|
158 |
+
y = torch.tensor(dix[1:], dtype=torch.long)
|
159 |
+
else:
|
160 |
+
ctx_len = args.ctx_len
|
161 |
+
req_len = ctx_len + 1
|
162 |
+
magic_prime = args.magic_prime
|
163 |
+
data = self.data
|
164 |
+
|
165 |
+
if args.my_pile_stage > 0:
|
166 |
+
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
|
167 |
+
|
168 |
+
if args.my_qa_mask > 0:
|
169 |
+
ii_orig = ii
|
170 |
+
if ii % 2 == 0:
|
171 |
+
ii = -1
|
172 |
+
data = self.data_pile
|
173 |
+
else:
|
174 |
+
ii = ii // 2
|
175 |
+
if data == self.data_pile:
|
176 |
+
i = np.random.randint(0, self.data_pile_size - req_len)
|
177 |
+
else:
|
178 |
+
if args.my_pile_stage == 4 or ii < args.my_random_steps:
|
179 |
+
# cheat: pick a random spot in dataset
|
180 |
+
if args.my_pile_version == 1:
|
181 |
+
i = np.random.randint(0, self.data_size - req_len)
|
182 |
+
else:
|
183 |
+
i = np.random.randint(0, self.data_size)
|
184 |
+
else:
|
185 |
+
ii = ii - args.my_random_steps
|
186 |
+
factor = (math.sqrt(5) - 1) / 2
|
187 |
+
factor = int(magic_prime * factor)
|
188 |
+
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
|
189 |
+
i = i + args.my_pile_shift
|
190 |
+
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
191 |
+
else:
|
192 |
+
# cheat: pick a random spot in dataset
|
193 |
+
i = np.random.randint(0, self.data_size - req_len)
|
194 |
+
|
195 |
+
if args.data_type == "binidx":
|
196 |
+
if args.my_pile_version == 1:
|
197 |
+
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
198 |
+
else:
|
199 |
+
# self.data : cutoff, chunk_count, data
|
200 |
+
for j in range(len(data)):
|
201 |
+
if i < data[j][0]:
|
202 |
+
ii = i
|
203 |
+
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
|
204 |
+
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
|
205 |
+
# print(ii, j, i)
|
206 |
+
break
|
207 |
+
elif args.data_type == "numpy":
|
208 |
+
dix = data[i : i + req_len]
|
209 |
+
else:
|
210 |
+
dix = [self.stoi[s] for s in data[i : i + req_len]]
|
211 |
+
|
212 |
+
if args.my_qa_mask == 1:
|
213 |
+
if data == self.data_pile:
|
214 |
+
z = [1] * ctx_len
|
215 |
+
else:
|
216 |
+
z = [0] * ctx_len
|
217 |
+
z_sum = 0
|
218 |
+
isGood = False
|
219 |
+
for i in range(3, ctx_len):
|
220 |
+
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
|
221 |
+
isGood = True
|
222 |
+
if dix[i] == 0:
|
223 |
+
isGood = False
|
224 |
+
if isGood:
|
225 |
+
z[i] = 1
|
226 |
+
z_sum += 1
|
227 |
+
if z_sum == 0:
|
228 |
+
z = [1] * ctx_len
|
229 |
+
i = np.random.randint(0, self.data_pile_size - req_len)
|
230 |
+
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
|
231 |
+
z = torch.tensor(z, dtype=torch.bfloat16)
|
232 |
+
|
233 |
+
x = torch.tensor(dix[:-1], dtype=torch.long)
|
234 |
+
y = torch.tensor(dix[1:], dtype=torch.long)
|
235 |
+
|
236 |
+
# if ii_orig < 50:
|
237 |
+
# # if rank == 1:
|
238 |
+
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
|
239 |
+
# else:
|
240 |
+
# exit(0)
|
241 |
+
|
242 |
+
if args.my_qa_mask == 1:
|
243 |
+
return x, y, z
|
244 |
+
|
245 |
+
return x, y
|
src/model.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import os, math, gc, importlib
|
6 |
+
import torch
|
7 |
+
# torch._C._jit_set_profiling_executor(True)
|
8 |
+
# torch._C._jit_set_profiling_mode(True)
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
import pytorch_lightning as pl
|
12 |
+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
13 |
+
from pytorch_lightning.strategies import DeepSpeedStrategy
|
14 |
+
if importlib.util.find_spec('deepspeed'):
|
15 |
+
import deepspeed
|
16 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
17 |
+
|
18 |
+
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
|
19 |
+
|
20 |
+
try:
|
21 |
+
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
|
22 |
+
except:
|
23 |
+
os.environ["RWKV_MY_TESTING"] = ''
|
24 |
+
|
25 |
+
def __nop(ob):
|
26 |
+
return ob
|
27 |
+
|
28 |
+
|
29 |
+
MyModule = nn.Module
|
30 |
+
MyFunction = __nop
|
31 |
+
if os.environ["RWKV_JIT_ON"] == "1":
|
32 |
+
MyModule = torch.jit.ScriptModule
|
33 |
+
MyFunction = torch.jit.script_method
|
34 |
+
|
35 |
+
|
36 |
+
########################################################################################################
|
37 |
+
# CUDA Kernel
|
38 |
+
########################################################################################################
|
39 |
+
|
40 |
+
T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
|
41 |
+
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
42 |
+
|
43 |
+
from torch.utils.cpp_extension import load
|
44 |
+
|
45 |
+
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
46 |
+
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
47 |
+
class WKV(torch.autograd.Function):
|
48 |
+
@staticmethod
|
49 |
+
def forward(ctx, B, T, C, w, u, k, v):
|
50 |
+
ctx.B = B
|
51 |
+
ctx.T = T
|
52 |
+
ctx.C = C
|
53 |
+
assert T <= T_MAX
|
54 |
+
assert B * C % min(C, 32) == 0
|
55 |
+
w = -torch.exp(w.float().contiguous())
|
56 |
+
u = u.contiguous()
|
57 |
+
k = k.contiguous()
|
58 |
+
v = v.contiguous()
|
59 |
+
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
60 |
+
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
61 |
+
ctx.save_for_backward(w, u, k, v, y)
|
62 |
+
return y
|
63 |
+
@staticmethod
|
64 |
+
def backward(ctx, gy):
|
65 |
+
B = ctx.B
|
66 |
+
T = ctx.T
|
67 |
+
C = ctx.C
|
68 |
+
assert T <= T_MAX
|
69 |
+
assert B * C % min(C, 32) == 0
|
70 |
+
w, u, k, v, y = ctx.saved_tensors
|
71 |
+
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
72 |
+
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
73 |
+
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
74 |
+
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
75 |
+
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
76 |
+
gw = torch.sum(gw, dim=0)
|
77 |
+
gu = torch.sum(gu, dim=0)
|
78 |
+
return (None, None, None, gw, gu, gk, gv)
|
79 |
+
else:
|
80 |
+
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
81 |
+
class WKV(torch.autograd.Function):
|
82 |
+
@staticmethod
|
83 |
+
def forward(ctx, B, T, C, w, u, k, v):
|
84 |
+
ctx.B = B
|
85 |
+
ctx.T = T
|
86 |
+
ctx.C = C
|
87 |
+
assert T <= T_MAX
|
88 |
+
assert B * C % min(C, 32) == 0
|
89 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
90 |
+
w = -torch.exp(w.contiguous())
|
91 |
+
u = u.contiguous()
|
92 |
+
k = k.contiguous()
|
93 |
+
v = v.contiguous()
|
94 |
+
else:
|
95 |
+
w = -torch.exp(w.float().contiguous())
|
96 |
+
u = u.float().contiguous()
|
97 |
+
k = k.float().contiguous()
|
98 |
+
v = v.float().contiguous()
|
99 |
+
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
|
100 |
+
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
101 |
+
ctx.save_for_backward(w, u, k, v, y)
|
102 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
103 |
+
return y
|
104 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
105 |
+
return y.half()
|
106 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
107 |
+
return y.bfloat16()
|
108 |
+
@staticmethod
|
109 |
+
def backward(ctx, gy):
|
110 |
+
B = ctx.B
|
111 |
+
T = ctx.T
|
112 |
+
C = ctx.C
|
113 |
+
assert T <= T_MAX
|
114 |
+
assert B * C % min(C, 32) == 0
|
115 |
+
w, u, k, v, y = ctx.saved_tensors
|
116 |
+
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
117 |
+
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
118 |
+
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
119 |
+
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
120 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
121 |
+
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
122 |
+
else:
|
123 |
+
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
|
124 |
+
gw = torch.sum(gw, dim=0)
|
125 |
+
gu = torch.sum(gu, dim=0)
|
126 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
127 |
+
return (None, None, None, gw, gu, gk, gv)
|
128 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
129 |
+
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
130 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
131 |
+
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
132 |
+
|
133 |
+
|
134 |
+
def RUN_CUDA(B, T, C, w, u, k, v):
|
135 |
+
return WKV.apply(B, T, C, w, u, k, v)
|
136 |
+
|
137 |
+
|
138 |
+
########################################################################################################
|
139 |
+
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
140 |
+
########################################################################################################
|
141 |
+
|
142 |
+
|
143 |
+
class RWKV_TimeMix(MyModule):
|
144 |
+
def __init__(self, args, layer_id):
|
145 |
+
super().__init__()
|
146 |
+
self.args = args
|
147 |
+
self.layer_id = layer_id
|
148 |
+
self.ctx_len = args.ctx_len
|
149 |
+
self.n_embd = args.n_embd
|
150 |
+
|
151 |
+
with torch.no_grad(): # fancy init
|
152 |
+
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
153 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
154 |
+
ddd = torch.ones(1, 1, args.n_embd)
|
155 |
+
for i in range(args.n_embd):
|
156 |
+
ddd[0, 0, i] = i / args.n_embd
|
157 |
+
|
158 |
+
# fancy time_decay
|
159 |
+
decay_speed = torch.ones(args.dim_att)
|
160 |
+
for h in range(args.dim_att):
|
161 |
+
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
162 |
+
self.time_decay = nn.Parameter(decay_speed)
|
163 |
+
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
164 |
+
|
165 |
+
# fancy time_first
|
166 |
+
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
|
167 |
+
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
|
168 |
+
|
169 |
+
# fancy time_mix
|
170 |
+
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
171 |
+
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
172 |
+
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
173 |
+
|
174 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
175 |
+
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
176 |
+
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
177 |
+
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
178 |
+
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
179 |
+
|
180 |
+
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
181 |
+
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
182 |
+
d_qkv = args.n_embd // 16
|
183 |
+
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
|
184 |
+
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
|
185 |
+
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
|
186 |
+
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
|
187 |
+
with torch.no_grad():
|
188 |
+
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
189 |
+
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
190 |
+
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
191 |
+
|
192 |
+
if 'a' not in os.environ["RWKV_MY_TESTING"]:
|
193 |
+
@MyFunction
|
194 |
+
def jit_func(self, x):
|
195 |
+
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
196 |
+
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
197 |
+
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
198 |
+
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
199 |
+
k = self.key(xk)
|
200 |
+
v = self.value(xv)
|
201 |
+
r = self.receptance(xr)
|
202 |
+
sr = torch.sigmoid(r)
|
203 |
+
return sr, k, v
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
B, T, C = x.size() # x = (Batch,Time,Channel)
|
207 |
+
sr, k, v = self.jit_func(x)
|
208 |
+
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
209 |
+
return self.output(rwkv)
|
210 |
+
|
211 |
+
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
212 |
+
@MyFunction
|
213 |
+
def QKV(self, q, k, v):
|
214 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
215 |
+
att = att.masked_fill(self.att_mask == 0, float('-inf'))
|
216 |
+
att = F.softmax(att, dim = -1)
|
217 |
+
x = att @ v
|
218 |
+
return x
|
219 |
+
|
220 |
+
@MyFunction
|
221 |
+
def jit_funcQKV(self, x):
|
222 |
+
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
223 |
+
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
224 |
+
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
225 |
+
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
226 |
+
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
|
227 |
+
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
|
228 |
+
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
|
229 |
+
k = self.key(xk)
|
230 |
+
v = self.value(xv)
|
231 |
+
r = self.receptance(xr)
|
232 |
+
sr = torch.sigmoid(r)
|
233 |
+
qq = self.qq(xqq)
|
234 |
+
kk = self.kk(xkk)
|
235 |
+
vv = self.vv(xvv)
|
236 |
+
return sr, k, v, qq, kk, vv
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
B, T, C = x.size() # x = (Batch,Time,Channel)
|
240 |
+
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
241 |
+
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
242 |
+
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
243 |
+
return rwkv
|
244 |
+
|
245 |
+
########################################################################################################
|
246 |
+
|
247 |
+
class RWKV_ChannelMix(MyModule):
|
248 |
+
def __init__(self, args, layer_id):
|
249 |
+
super().__init__()
|
250 |
+
self.args = args
|
251 |
+
self.layer_id = layer_id
|
252 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
253 |
+
|
254 |
+
with torch.no_grad(): # fancy init of time_mix
|
255 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
256 |
+
ddd = torch.ones(1, 1, args.n_embd)
|
257 |
+
for i in range(args.n_embd):
|
258 |
+
ddd[0, 0, i] = i / args.n_embd
|
259 |
+
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
260 |
+
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
261 |
+
|
262 |
+
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
263 |
+
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
264 |
+
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
265 |
+
|
266 |
+
@MyFunction
|
267 |
+
def forward(self, x):
|
268 |
+
xx = self.time_shift(x)
|
269 |
+
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
270 |
+
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
271 |
+
k = self.key(xk)
|
272 |
+
k = torch.square(torch.relu(k))
|
273 |
+
kv = self.value(k)
|
274 |
+
return torch.sigmoid(self.receptance(xr)) * kv
|
275 |
+
|
276 |
+
class MishGLU(MyModule):
|
277 |
+
def __init__(self, args, layer_id):
|
278 |
+
super().__init__()
|
279 |
+
self.args = args
|
280 |
+
self.layer_id = layer_id
|
281 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
282 |
+
|
283 |
+
with torch.no_grad():
|
284 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
285 |
+
|
286 |
+
x = torch.ones(1, 1, args.n_embd)
|
287 |
+
for i in range(args.n_embd):
|
288 |
+
x[0, 0, i] = i / args.n_embd
|
289 |
+
|
290 |
+
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
291 |
+
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
292 |
+
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
293 |
+
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
294 |
+
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
295 |
+
|
296 |
+
@MyFunction
|
297 |
+
def forward(self, x):
|
298 |
+
xx = self.time_shift(x)
|
299 |
+
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
300 |
+
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
301 |
+
a = self.aa(xa)
|
302 |
+
b = self.bb(xb)
|
303 |
+
return self.value(a * F.mish(b))
|
304 |
+
|
305 |
+
########################################################################################################
|
306 |
+
# The RWKV Model with our blocks
|
307 |
+
########################################################################################################
|
308 |
+
|
309 |
+
|
310 |
+
class Block(nn.Module):
|
311 |
+
def __init__(self, args, layer_id):
|
312 |
+
super().__init__()
|
313 |
+
self.args = args
|
314 |
+
self.layer_id = layer_id
|
315 |
+
|
316 |
+
self.ln1 = nn.LayerNorm(args.n_embd)
|
317 |
+
self.ln2 = nn.LayerNorm(args.n_embd)
|
318 |
+
|
319 |
+
if self.layer_id == 0:
|
320 |
+
self.ln0 = nn.LayerNorm(args.n_embd)
|
321 |
+
if args.my_pos_emb > 0:
|
322 |
+
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
323 |
+
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
324 |
+
|
325 |
+
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
326 |
+
self.ffnPre = RWKV_ChannelMix(args, 0)
|
327 |
+
else:
|
328 |
+
self.att = RWKV_TimeMix(args, layer_id)
|
329 |
+
|
330 |
+
if 'g' in os.environ["RWKV_MY_TESTING"]:
|
331 |
+
self.ffn = MishGLU(args, layer_id)
|
332 |
+
else:
|
333 |
+
self.ffn = RWKV_ChannelMix(args, layer_id)
|
334 |
+
|
335 |
+
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
336 |
+
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
337 |
+
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
338 |
+
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
339 |
+
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
340 |
+
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
341 |
+
|
342 |
+
def forward(self, x, x_emb=None):
|
343 |
+
args = self.args
|
344 |
+
B, T, C = x.size()
|
345 |
+
if self.layer_id == 0:
|
346 |
+
x = self.ln0(x)
|
347 |
+
if args.my_pos_emb > 0:
|
348 |
+
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
349 |
+
x = x + pos_emb
|
350 |
+
|
351 |
+
if self.layer_id == 0 and args.pre_ffn > 0:
|
352 |
+
x = x + self.ffnPre(self.ln1(x))
|
353 |
+
else:
|
354 |
+
x = x + self.att(self.ln1(x))
|
355 |
+
x = x + self.ffn(self.ln2(x))
|
356 |
+
|
357 |
+
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
358 |
+
xx = self.tiny_ln(x)
|
359 |
+
q = self.tiny_q(xx)[:, :T, :]
|
360 |
+
k = self.tiny_k(xx)[:, :T, :]
|
361 |
+
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
362 |
+
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
363 |
+
x = x + c @ self.tiny_v(x_emb)
|
364 |
+
return x
|
365 |
+
|
366 |
+
|
367 |
+
class L2Wrap(torch.autograd.Function):
|
368 |
+
@staticmethod
|
369 |
+
def forward(ctx, loss, y):
|
370 |
+
ctx.save_for_backward(y)
|
371 |
+
return loss
|
372 |
+
|
373 |
+
@staticmethod
|
374 |
+
def backward(ctx, grad_output):
|
375 |
+
y = ctx.saved_tensors[0]
|
376 |
+
# to encourage the logits to be close to 0
|
377 |
+
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
378 |
+
maxx, ids = torch.max(y, -1, keepdim=True)
|
379 |
+
gy = torch.zeros_like(y)
|
380 |
+
gy.scatter_(-1, ids, maxx * factor)
|
381 |
+
return (grad_output, gy)
|
382 |
+
|
383 |
+
|
384 |
+
class RWKV(pl.LightningModule):
|
385 |
+
def __init__(self, args):
|
386 |
+
super().__init__()
|
387 |
+
self.args = args
|
388 |
+
if not hasattr(args, 'dim_att'):
|
389 |
+
args.dim_att = args.n_embd
|
390 |
+
if not hasattr(args, 'dim_ffn'):
|
391 |
+
args.dim_ffn = args.n_embd * 4
|
392 |
+
if not hasattr(args, 'tiny_att_layer'):
|
393 |
+
args.tiny_att_layer = -1
|
394 |
+
if not hasattr(args, 'tiny_att_dim'):
|
395 |
+
args.tiny_att_dim = -1
|
396 |
+
|
397 |
+
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
398 |
+
|
399 |
+
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
400 |
+
|
401 |
+
self.ln_out = nn.LayerNorm(args.n_embd)
|
402 |
+
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
403 |
+
|
404 |
+
if args.head_qk > 0:
|
405 |
+
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
406 |
+
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
407 |
+
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
408 |
+
|
409 |
+
def configure_optimizers(self):
|
410 |
+
args = self.args
|
411 |
+
if args.layerwise_lr > 0:
|
412 |
+
lr_1x = set()
|
413 |
+
lr_2x = set()
|
414 |
+
lr_3x = set()
|
415 |
+
for n, p in self.named_parameters():
|
416 |
+
if "time_mix" in n:
|
417 |
+
if args.my_pile_stage == 2:
|
418 |
+
lr_2x.add(n)
|
419 |
+
else:
|
420 |
+
lr_1x.add(n)
|
421 |
+
elif "time_decay" in n:
|
422 |
+
if args.my_pile_stage == 2:
|
423 |
+
lr_3x.add(n)
|
424 |
+
else:
|
425 |
+
lr_2x.add(n)
|
426 |
+
elif "time_first" in n:
|
427 |
+
lr_3x.add(n)
|
428 |
+
else:
|
429 |
+
lr_1x.add(n)
|
430 |
+
lr_1x = sorted(list(lr_1x))
|
431 |
+
lr_2x = sorted(list(lr_2x))
|
432 |
+
lr_3x = sorted(list(lr_3x))
|
433 |
+
# print('1x', lr_1x)
|
434 |
+
# print('2x', lr_2x)
|
435 |
+
# print('3x', lr_3x)
|
436 |
+
param_dict = {n: p for n, p in self.named_parameters()}
|
437 |
+
if args.my_pile_stage == 2:
|
438 |
+
optim_groups = [
|
439 |
+
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
440 |
+
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
441 |
+
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
442 |
+
]
|
443 |
+
else:
|
444 |
+
optim_groups = [
|
445 |
+
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
446 |
+
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
447 |
+
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
448 |
+
]
|
449 |
+
else:
|
450 |
+
optim_groups = [
|
451 |
+
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
452 |
+
]
|
453 |
+
|
454 |
+
if self.deepspeed_offload:
|
455 |
+
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
456 |
+
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
457 |
+
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
458 |
+
|
459 |
+
@property
|
460 |
+
def deepspeed_offload(self) -> bool:
|
461 |
+
strategy = self.trainer.strategy
|
462 |
+
if isinstance(strategy, DeepSpeedStrategy):
|
463 |
+
cfg = strategy.config["zero_optimization"]
|
464 |
+
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
465 |
+
return False
|
466 |
+
|
467 |
+
def forward(self, idx):
|
468 |
+
args = self.args
|
469 |
+
B, T = idx.size()
|
470 |
+
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
471 |
+
|
472 |
+
x = self.emb(idx)
|
473 |
+
x_emb = x
|
474 |
+
|
475 |
+
if args.tiny_att_dim > 0:
|
476 |
+
for block in self.blocks:
|
477 |
+
if args.grad_cp == 1:
|
478 |
+
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
479 |
+
else:
|
480 |
+
x = block(x, x_emb)
|
481 |
+
else:
|
482 |
+
for block in self.blocks:
|
483 |
+
if args.grad_cp == 1:
|
484 |
+
x = deepspeed.checkpointing.checkpoint(block, x)
|
485 |
+
else:
|
486 |
+
x = block(x)
|
487 |
+
|
488 |
+
x = self.ln_out(x)
|
489 |
+
|
490 |
+
if args.head_qk > 0:
|
491 |
+
q = self.head_q(x)[:, :T, :]
|
492 |
+
k = self.head_k(x)[:, :T, :]
|
493 |
+
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
494 |
+
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
495 |
+
|
496 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
497 |
+
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
498 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
499 |
+
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
500 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
501 |
+
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
502 |
+
|
503 |
+
x = self.head(x) + c
|
504 |
+
else:
|
505 |
+
x = self.head(x)
|
506 |
+
|
507 |
+
return x
|
508 |
+
|
509 |
+
def training_step(self, batch, batch_idx):
|
510 |
+
args = self.args
|
511 |
+
if args.my_qa_mask != 1:
|
512 |
+
idx, targets = batch
|
513 |
+
logits = self(idx)
|
514 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
515 |
+
else:
|
516 |
+
idx, targets, mask = batch
|
517 |
+
mask = mask.view(-1)
|
518 |
+
sum_mask = torch.sum(mask).item()
|
519 |
+
# if sum_mask == 0:
|
520 |
+
# return torch.tensor([0.0], requires_grad=True)
|
521 |
+
|
522 |
+
logits = self(idx)
|
523 |
+
if sum_mask == mask.shape[0]:
|
524 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
525 |
+
# print('rank', self.global_rank, 'loss', loss.item())
|
526 |
+
else:
|
527 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
528 |
+
# loss_raw = loss
|
529 |
+
loss = torch.sum(loss * mask) / sum_mask
|
530 |
+
|
531 |
+
# torch.set_printoptions(threshold=10000)
|
532 |
+
# if True: #self.global_rank == 1:
|
533 |
+
# tmp = ''
|
534 |
+
# sss = 0
|
535 |
+
# ccc = 0
|
536 |
+
# for i in range(mask.shape[0]):
|
537 |
+
# if mask[i] > 0:
|
538 |
+
# tmp += str(idx.view(-1)[i].item()) + ','
|
539 |
+
# sss += loss_raw.view(-1)[i].float().item()
|
540 |
+
# ccc += 1
|
541 |
+
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
542 |
+
|
543 |
+
return L2Wrap.apply(loss, logits)
|
544 |
+
|
545 |
+
def training_step_end(self, batch_parts):
|
546 |
+
all = self.all_gather(batch_parts)
|
547 |
+
if self.trainer.is_global_zero:
|
548 |
+
self.trainer.my_loss_all = all
|
549 |
+
|
550 |
+
def generate_init_weight(self):
|
551 |
+
print(
|
552 |
+
f"""
|
553 |
+
############################################################################
|
554 |
+
#
|
555 |
+
# Init model weight (slow for large models)...
|
556 |
+
#
|
557 |
+
############################################################################
|
558 |
+
"""
|
559 |
+
)
|
560 |
+
m = {}
|
561 |
+
for n in self.state_dict():
|
562 |
+
p = self.state_dict()[n]
|
563 |
+
shape = p.shape
|
564 |
+
|
565 |
+
gain = 1.0
|
566 |
+
scale = 1.0
|
567 |
+
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
|
568 |
+
m[n] = p
|
569 |
+
else:
|
570 |
+
if n == "emb.weight":
|
571 |
+
scale = -1 * self.args.lr_init
|
572 |
+
else:
|
573 |
+
if shape[0] > shape[1]:
|
574 |
+
gain = math.sqrt(shape[0] / shape[1])
|
575 |
+
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
|
576 |
+
if kk in n:
|
577 |
+
scale = 0
|
578 |
+
if n == "head.weight":
|
579 |
+
scale = 0.5
|
580 |
+
if "head_k." in n:
|
581 |
+
scale = 0.1
|
582 |
+
if "head_q." in n:
|
583 |
+
scale = 0
|
584 |
+
|
585 |
+
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
586 |
+
|
587 |
+
if self.args.accelerator.upper() == "GPU":
|
588 |
+
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
589 |
+
else:
|
590 |
+
m[n] = torch.empty((shape[0], shape[1]))
|
591 |
+
|
592 |
+
if scale == 0:
|
593 |
+
nn.init.zeros_(m[n])
|
594 |
+
elif scale < 0:
|
595 |
+
nn.init.uniform_(m[n], a=scale, b=-scale)
|
596 |
+
else:
|
597 |
+
nn.init.orthogonal_(m[n], gain=gain * scale)
|
598 |
+
|
599 |
+
m[n] = m[n].cpu()
|
600 |
+
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
601 |
+
m[n] = m[n].half()
|
602 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
603 |
+
m[n] = m[n].bfloat16()
|
604 |
+
|
605 |
+
# if n == "emb.weight":
|
606 |
+
# print(m[n])
|
607 |
+
|
608 |
+
gc.collect()
|
609 |
+
torch.cuda.empty_cache()
|
610 |
+
return m
|
src/model_img.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import os, math, gc
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision as vision
|
11 |
+
import pytorch_lightning as pl
|
12 |
+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
13 |
+
from pytorch_lightning.strategies import DeepSpeedStrategy
|
14 |
+
import deepspeed
|
15 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
16 |
+
# from pytorch_msssim import MS_SSIM
|
17 |
+
|
18 |
+
def __nop(ob):
|
19 |
+
return ob
|
20 |
+
MyModule = torch.jit.ScriptModule
|
21 |
+
# MyFunction = __nop
|
22 |
+
MyFunction = torch.jit.script_method
|
23 |
+
|
24 |
+
import clip
|
25 |
+
from transformers import CLIPModel
|
26 |
+
|
27 |
+
class L2pooling(nn.Module):
|
28 |
+
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
|
29 |
+
super(L2pooling, self).__init__()
|
30 |
+
self.padding = (filter_size - 2) // 2
|
31 |
+
self.stride = stride
|
32 |
+
self.channels = channels
|
33 |
+
a = np.hanning(filter_size)[1:-1]
|
34 |
+
g = torch.Tensor(a[:, None] * a[None, :])
|
35 |
+
g = g / torch.sum(g)
|
36 |
+
self.register_buffer(
|
37 |
+
"filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, input):
|
41 |
+
input = input**2
|
42 |
+
out = F.conv2d(
|
43 |
+
input,
|
44 |
+
self.filter,
|
45 |
+
stride=self.stride,
|
46 |
+
padding=self.padding,
|
47 |
+
groups=input.shape[1],
|
48 |
+
)
|
49 |
+
return (out + 1e-12).sqrt()
|
50 |
+
|
51 |
+
|
52 |
+
class DISTS(torch.nn.Module):
|
53 |
+
def __init__(self, load_weights=True):
|
54 |
+
super(DISTS, self).__init__()
|
55 |
+
vgg_pretrained_features = vision.models.vgg16(
|
56 |
+
weights="VGG16_Weights.IMAGENET1K_V1"
|
57 |
+
).features
|
58 |
+
self.stage1 = torch.nn.Sequential()
|
59 |
+
self.stage2 = torch.nn.Sequential()
|
60 |
+
self.stage3 = torch.nn.Sequential()
|
61 |
+
self.stage4 = torch.nn.Sequential()
|
62 |
+
self.stage5 = torch.nn.Sequential()
|
63 |
+
for x in range(0, 4):
|
64 |
+
self.stage1.add_module(str(x), vgg_pretrained_features[x])
|
65 |
+
self.stage2.add_module(str(4), L2pooling(channels=64))
|
66 |
+
for x in range(5, 9):
|
67 |
+
self.stage2.add_module(str(x), vgg_pretrained_features[x])
|
68 |
+
self.stage3.add_module(str(9), L2pooling(channels=128))
|
69 |
+
for x in range(10, 16):
|
70 |
+
self.stage3.add_module(str(x), vgg_pretrained_features[x])
|
71 |
+
self.stage4.add_module(str(16), L2pooling(channels=256))
|
72 |
+
for x in range(17, 23):
|
73 |
+
self.stage4.add_module(str(x), vgg_pretrained_features[x])
|
74 |
+
self.stage5.add_module(str(23), L2pooling(channels=512))
|
75 |
+
for x in range(24, 30):
|
76 |
+
self.stage5.add_module(str(x), vgg_pretrained_features[x])
|
77 |
+
|
78 |
+
self.register_buffer(
|
79 |
+
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)
|
80 |
+
)
|
81 |
+
self.register_buffer(
|
82 |
+
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)
|
83 |
+
)
|
84 |
+
|
85 |
+
self.chns = [3, 64, 128, 256, 512, 512]
|
86 |
+
self.register_buffer(
|
87 |
+
"alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))
|
88 |
+
)
|
89 |
+
self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
|
90 |
+
self.alpha.data.normal_(0.1, 0.01)
|
91 |
+
self.beta.data.normal_(0.1, 0.01)
|
92 |
+
weights = torch.load("test/DISTS_weights.pt")
|
93 |
+
self.alpha.data = weights["alpha"]
|
94 |
+
self.beta.data = weights["beta"]
|
95 |
+
|
96 |
+
for param in self.parameters():
|
97 |
+
param.requires_grad = False
|
98 |
+
|
99 |
+
def forward_once(self, x):
|
100 |
+
h = (x - self.mean) / self.std
|
101 |
+
h = self.stage1(h)
|
102 |
+
h_relu1_2 = h
|
103 |
+
h = self.stage2(h)
|
104 |
+
h_relu2_2 = h
|
105 |
+
h = self.stage3(h)
|
106 |
+
h_relu3_3 = h
|
107 |
+
h = self.stage4(h)
|
108 |
+
h_relu4_3 = h
|
109 |
+
h = self.stage5(h)
|
110 |
+
h_relu5_3 = h
|
111 |
+
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
|
112 |
+
|
113 |
+
def forward(self, x, y, require_grad=False, batch_average=False):
|
114 |
+
if require_grad:
|
115 |
+
feats0 = self.forward_once(x)
|
116 |
+
feats1 = self.forward_once(y)
|
117 |
+
else:
|
118 |
+
with torch.no_grad():
|
119 |
+
feats0 = self.forward_once(x)
|
120 |
+
feats1 = self.forward_once(y)
|
121 |
+
dist1 = 0
|
122 |
+
dist2 = 0
|
123 |
+
c1 = 1e-6
|
124 |
+
c2 = 1e-6
|
125 |
+
w_sum = self.alpha.sum() + self.beta.sum()
|
126 |
+
alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
|
127 |
+
beta = torch.split(self.beta / w_sum, self.chns, dim=1)
|
128 |
+
|
129 |
+
for k in range(len(self.chns)):
|
130 |
+
x_mean = feats0[k].mean([2, 3], keepdim=True)
|
131 |
+
y_mean = feats1[k].mean([2, 3], keepdim=True)
|
132 |
+
S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
|
133 |
+
dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)
|
134 |
+
|
135 |
+
x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
|
136 |
+
y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
|
137 |
+
xy_cov = (feats0[k] * feats1[k]).mean(
|
138 |
+
[2, 3], keepdim=True
|
139 |
+
) - x_mean * y_mean
|
140 |
+
S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
|
141 |
+
dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
|
142 |
+
|
143 |
+
score = 1 - (dist1 + dist2).squeeze()
|
144 |
+
|
145 |
+
if batch_average:
|
146 |
+
return score.mean()
|
147 |
+
else:
|
148 |
+
return score
|
149 |
+
|
150 |
+
class ToBinary(torch.autograd.Function):
|
151 |
+
@staticmethod
|
152 |
+
def forward(ctx, x):#, noise_scale):
|
153 |
+
# if noise_scale > 0:
|
154 |
+
# noise_min = 0.5 - noise_scale / 2
|
155 |
+
# noise_max = 0.5 + noise_scale / 2
|
156 |
+
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
|
157 |
+
# else:
|
158 |
+
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def backward(ctx, grad_output):
|
162 |
+
return grad_output.clone()#, None
|
163 |
+
|
164 |
+
########################################################################################################
|
165 |
+
|
166 |
+
class R_ENCODER(MyModule):
|
167 |
+
def __init__(self, args):
|
168 |
+
super().__init__()
|
169 |
+
self.args = args
|
170 |
+
dd = 8
|
171 |
+
self.Bxx = nn.BatchNorm2d(dd*64)
|
172 |
+
|
173 |
+
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
|
174 |
+
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
|
175 |
+
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
|
176 |
+
|
177 |
+
self.B00 = nn.BatchNorm2d(dd*4)
|
178 |
+
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
179 |
+
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
180 |
+
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
181 |
+
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
182 |
+
|
183 |
+
self.B10 = nn.BatchNorm2d(dd*16)
|
184 |
+
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
185 |
+
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
186 |
+
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
187 |
+
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
188 |
+
|
189 |
+
self.B20 = nn.BatchNorm2d(dd*64)
|
190 |
+
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
191 |
+
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
192 |
+
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
193 |
+
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
194 |
+
# self.B21 = nn.BatchNorm2d(dd*64)
|
195 |
+
# self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
196 |
+
# self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
197 |
+
# self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
198 |
+
# self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
199 |
+
|
200 |
+
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
|
201 |
+
|
202 |
+
@MyFunction
|
203 |
+
def forward(self, img):
|
204 |
+
ACT = F.mish
|
205 |
+
|
206 |
+
x = self.CIN(img)
|
207 |
+
xx = self.Bxx(F.pixel_unshuffle(x, 8))
|
208 |
+
x = x + self.Cx1(ACT(self.Cx0(x)))
|
209 |
+
|
210 |
+
x = F.pixel_unshuffle(x, 2)
|
211 |
+
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
|
212 |
+
x = x + self.C03(ACT(self.C02(x)))
|
213 |
+
|
214 |
+
x = F.pixel_unshuffle(x, 2)
|
215 |
+
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
|
216 |
+
x = x + self.C13(ACT(self.C12(x)))
|
217 |
+
|
218 |
+
x = F.pixel_unshuffle(x, 2)
|
219 |
+
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
|
220 |
+
x = x + self.C23(ACT(self.C22(x)))
|
221 |
+
# x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
|
222 |
+
# x = x + self.C27(ACT(self.C26(x)))
|
223 |
+
|
224 |
+
x = self.COUT(x + xx)
|
225 |
+
return torch.sigmoid(x)
|
226 |
+
|
227 |
+
########################################################################################################
|
228 |
+
|
229 |
+
class R_DECODER(MyModule):
|
230 |
+
def __init__(self, args):
|
231 |
+
super().__init__()
|
232 |
+
self.args = args
|
233 |
+
dd = 8
|
234 |
+
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
|
235 |
+
|
236 |
+
self.B00 = nn.BatchNorm2d(dd*64)
|
237 |
+
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
238 |
+
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
239 |
+
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
240 |
+
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
241 |
+
# self.B01 = nn.BatchNorm2d(dd*64)
|
242 |
+
# self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
243 |
+
# self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
244 |
+
# self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
245 |
+
# self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
246 |
+
|
247 |
+
self.B10 = nn.BatchNorm2d(dd*16)
|
248 |
+
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
249 |
+
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
250 |
+
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
251 |
+
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
252 |
+
|
253 |
+
self.B20 = nn.BatchNorm2d(dd*4)
|
254 |
+
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
255 |
+
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
256 |
+
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
257 |
+
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
258 |
+
|
259 |
+
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
|
260 |
+
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
|
261 |
+
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
|
262 |
+
|
263 |
+
@MyFunction
|
264 |
+
def forward(self, code):
|
265 |
+
ACT = F.mish
|
266 |
+
x = self.CIN(code)
|
267 |
+
|
268 |
+
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
|
269 |
+
x = x + self.C03(ACT(self.C02(x)))
|
270 |
+
# x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
|
271 |
+
# x = x + self.C07(ACT(self.C06(x)))
|
272 |
+
x = F.pixel_shuffle(x, 2)
|
273 |
+
|
274 |
+
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
|
275 |
+
x = x + self.C13(ACT(self.C12(x)))
|
276 |
+
x = F.pixel_shuffle(x, 2)
|
277 |
+
|
278 |
+
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
|
279 |
+
x = x + self.C23(ACT(self.C22(x)))
|
280 |
+
x = F.pixel_shuffle(x, 2)
|
281 |
+
|
282 |
+
x = x + self.Cx1(ACT(self.Cx0(x)))
|
283 |
+
x = self.COUT(x)
|
284 |
+
|
285 |
+
return torch.sigmoid(x)
|
286 |
+
|
287 |
+
########################################################################################################`
|
288 |
+
|
289 |
+
def cosine_loss(x, y):
|
290 |
+
x = F.normalize(x, dim=-1)
|
291 |
+
y = F.normalize(y, dim=-1)
|
292 |
+
return 1 - torch.einsum('ij,ij->i',[x,y])
|
293 |
+
|
294 |
+
class RWKV_IMG(pl.LightningModule):
|
295 |
+
def __init__(self, args):
|
296 |
+
super().__init__()
|
297 |
+
self.args = args
|
298 |
+
|
299 |
+
self.encoder = R_ENCODER(args)
|
300 |
+
self.decoder = R_DECODER(args)
|
301 |
+
|
302 |
+
self.clip_model = None
|
303 |
+
clip_name = args.my_img_clip
|
304 |
+
if clip_name == 'B32':
|
305 |
+
clip_name = 'ViT-B/32'
|
306 |
+
elif clip_name == 'B16':
|
307 |
+
clip_name = 'ViT-B/16'
|
308 |
+
elif clip_name == 'L14':
|
309 |
+
clip_name = 'ViT-L/14'
|
310 |
+
elif clip_name == 'OB32':
|
311 |
+
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
312 |
+
self.clip_model = CLIPModel.from_pretrained(clip_name)
|
313 |
+
self.clip_model.encode_image = self.clip_model.get_image_features
|
314 |
+
if self.clip_model == None:
|
315 |
+
self.clip_model, _ = clip.load(clip_name, jit = True)
|
316 |
+
self.register_buffer(
|
317 |
+
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
|
318 |
+
)
|
319 |
+
self.register_buffer(
|
320 |
+
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
|
321 |
+
)
|
322 |
+
|
323 |
+
for n, p in self.named_parameters():
|
324 |
+
if 'clip_model' in n:
|
325 |
+
p.requires_grad = False
|
326 |
+
|
327 |
+
self.loss_dists = DISTS()
|
328 |
+
# self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
|
329 |
+
|
330 |
+
def configure_optimizers(self):
|
331 |
+
args = self.args
|
332 |
+
optim_groups = [
|
333 |
+
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
334 |
+
]
|
335 |
+
if self.deepspeed_offload:
|
336 |
+
return DeepSpeedCPUAdam(
|
337 |
+
optim_groups,
|
338 |
+
lr=self.args.lr_init,
|
339 |
+
betas=self.args.betas,
|
340 |
+
eps=self.args.adam_eps,
|
341 |
+
bias_correction=True,
|
342 |
+
adamw_mode=False,
|
343 |
+
weight_decay=0,
|
344 |
+
amsgrad=False,
|
345 |
+
)
|
346 |
+
return FusedAdam(
|
347 |
+
optim_groups,
|
348 |
+
lr=self.args.lr_init,
|
349 |
+
betas=self.args.betas,
|
350 |
+
eps=self.args.adam_eps,
|
351 |
+
bias_correction=True,
|
352 |
+
adam_w_mode=False,
|
353 |
+
weight_decay=0,
|
354 |
+
amsgrad=False,
|
355 |
+
)
|
356 |
+
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
357 |
+
|
358 |
+
@property
|
359 |
+
def deepspeed_offload(self) -> bool:
|
360 |
+
strategy = self.trainer.strategy
|
361 |
+
if isinstance(strategy, DeepSpeedStrategy):
|
362 |
+
config = strategy.config["zero_optimization"]
|
363 |
+
return config.get("offload_optimizer") or config.get("offload_param")
|
364 |
+
return False
|
365 |
+
|
366 |
+
def forward(self, img):
|
367 |
+
z = self.encoder(img)
|
368 |
+
z = ToBinary.apply(z)#, self.args.my_img_noise_scale)
|
369 |
+
out = self.decoder(z)
|
370 |
+
return out
|
371 |
+
|
372 |
+
def training_step(self, batch, batch_idx):
|
373 |
+
args = self.args
|
374 |
+
img, txt = batch
|
375 |
+
out = self(img)
|
376 |
+
if self.trainer.is_global_zero:
|
377 |
+
if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0:
|
378 |
+
img_dir = f"test/image_model/{args.run_name}"
|
379 |
+
if not os.path.exists(img_dir):
|
380 |
+
os.makedirs(img_dir)
|
381 |
+
vision.utils.save_image(
|
382 |
+
img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0
|
383 |
+
)
|
384 |
+
vision.utils.save_image(
|
385 |
+
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
|
386 |
+
)
|
387 |
+
|
388 |
+
# loss_ssim = 1 - self.loss_ssim(out, img)
|
389 |
+
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
|
390 |
+
|
391 |
+
iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std)
|
392 |
+
ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std)
|
393 |
+
loss_clip = torch.mean(cosine_loss(iii, ooo))
|
394 |
+
|
395 |
+
if args.my_img_l1_scale > 0:
|
396 |
+
loss_l1 = F.l1_loss(out, img)
|
397 |
+
return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale
|
398 |
+
else:
|
399 |
+
return loss_dists + loss_clip * args.my_img_clip_scale
|
400 |
+
|
401 |
+
def training_step_end(self, batch_parts):
|
402 |
+
all = self.all_gather(batch_parts)
|
403 |
+
if self.trainer.is_global_zero:
|
404 |
+
self.trainer.my_loss_all = all
|
405 |
+
|
406 |
+
def generate_init_weight(self):
|
407 |
+
print(
|
408 |
+
f"""
|
409 |
+
############################################################################
|
410 |
+
#
|
411 |
+
# Init model weight (slow for large models)...
|
412 |
+
#
|
413 |
+
############################################################################
|
414 |
+
"""
|
415 |
+
)
|
416 |
+
m = {}
|
417 |
+
for n in self.state_dict():
|
418 |
+
scale = 1
|
419 |
+
p = self.state_dict()[n]
|
420 |
+
shape = p.shape
|
421 |
+
ss = n.split('.')
|
422 |
+
|
423 |
+
# if ss[0] in ['encoder', 'decoder']:
|
424 |
+
# if ss[2] == 'bias':
|
425 |
+
# scale = 0
|
426 |
+
# # elif n == 'encoder.CIN.weight':
|
427 |
+
# # nn.init.dirac_(p)
|
428 |
+
# else:
|
429 |
+
# try:
|
430 |
+
# if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
|
431 |
+
# scale = 0
|
432 |
+
# except:
|
433 |
+
# pass
|
434 |
+
# m[n] = p * scale
|
435 |
+
|
436 |
+
m[n] = p
|
437 |
+
|
438 |
+
m[n] = m[n].cpu()
|
439 |
+
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
440 |
+
m[n] = m[n].half()
|
441 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
442 |
+
m[n] = m[n].bfloat16()
|
443 |
+
|
444 |
+
gc.collect()
|
445 |
+
torch.cuda.empty_cache()
|
446 |
+
return m
|
src/model_run.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import types
|
6 |
+
import torch
|
7 |
+
import math, os, gc
|
8 |
+
from torch.nn import functional as F
|
9 |
+
import torch.nn as nn
|
10 |
+
from typing import List, Dict
|
11 |
+
|
12 |
+
MyModule = nn.Module
|
13 |
+
def __nop(ob):
|
14 |
+
return ob
|
15 |
+
MyFunction = __nop
|
16 |
+
|
17 |
+
# # try torchdynamo
|
18 |
+
# import torchdynamo
|
19 |
+
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
|
20 |
+
|
21 |
+
# try torch jit --> faster for fp32, slower for fp16 (why?)
|
22 |
+
if os.environ["RWKV_JIT_ON"] == "1":
|
23 |
+
MyModule = torch.jit.ScriptModule
|
24 |
+
MyFunction = torch.jit.script_method
|
25 |
+
|
26 |
+
RWKV_HEAD_QK_DIM = 0
|
27 |
+
|
28 |
+
DEBUG_TIME = False # True False - show trained time-coeffs
|
29 |
+
|
30 |
+
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
|
31 |
+
|
32 |
+
############################################################################################################
|
33 |
+
|
34 |
+
class RWKV_RNN(MyModule):
|
35 |
+
def __init__(self, args):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.args = args
|
39 |
+
self.FLOAT_MODE = args.FLOAT_MODE
|
40 |
+
self.RUN_DEVICE = args.RUN_DEVICE
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
|
44 |
+
# refine weights and send to correct device
|
45 |
+
keys = list(w.keys())
|
46 |
+
if 'pos_emb_x' in keys:
|
47 |
+
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
|
48 |
+
keys = list(w.keys())
|
49 |
+
print_need_newline = False
|
50 |
+
for x in keys:
|
51 |
+
block_id = 0
|
52 |
+
if 'blocks.' in x:
|
53 |
+
block_id = int(x.split('.')[1])
|
54 |
+
if 'att.output.weight' in x:
|
55 |
+
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
|
56 |
+
if 'ffn.value.weight' in x:
|
57 |
+
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
|
58 |
+
|
59 |
+
if '.time_' in x:
|
60 |
+
w[x] = w[x].squeeze()
|
61 |
+
if DEBUG_TIME:
|
62 |
+
print(x, w[x].numpy())
|
63 |
+
if '.time_decay' in x:
|
64 |
+
w[x] = w[x].float()
|
65 |
+
w[x] = -torch.exp(w[x])
|
66 |
+
elif '.time_first' in x:
|
67 |
+
w[x] = w[x].float()
|
68 |
+
else:
|
69 |
+
if self.FLOAT_MODE == "fp32":
|
70 |
+
w[x] = w[x].float()
|
71 |
+
elif self.FLOAT_MODE == "bf16":
|
72 |
+
w[x] = w[x].bfloat16()
|
73 |
+
elif self.FLOAT_MODE == "fp16":
|
74 |
+
w[x] = w[x].half()
|
75 |
+
|
76 |
+
w[x].requires_grad = False
|
77 |
+
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
|
78 |
+
w[x] = w[x].cuda()
|
79 |
+
|
80 |
+
if ('blocks.' not in x) or ('blocks.0.' in x):
|
81 |
+
if print_need_newline:
|
82 |
+
print_need_newline = False
|
83 |
+
else:
|
84 |
+
print_need_newline = True
|
85 |
+
|
86 |
+
# store weights in self.w
|
87 |
+
keys = list(w.keys())
|
88 |
+
self.w = types.SimpleNamespace()
|
89 |
+
for x in keys:
|
90 |
+
xx = x.split('.')
|
91 |
+
here = self.w
|
92 |
+
for i in range(len(xx)):
|
93 |
+
if xx[i].isdigit():
|
94 |
+
ii = int(xx[i])
|
95 |
+
if ii not in here:
|
96 |
+
here[ii] = types.SimpleNamespace()
|
97 |
+
here = here[ii]
|
98 |
+
else:
|
99 |
+
if i == len(xx) - 1:
|
100 |
+
setattr(here, xx[i], w[x])
|
101 |
+
elif not hasattr(here, xx[i]):
|
102 |
+
if xx[i+1].isdigit():
|
103 |
+
setattr(here, xx[i], {})
|
104 |
+
else:
|
105 |
+
setattr(here, xx[i], types.SimpleNamespace())
|
106 |
+
here = getattr(here, xx[i])
|
107 |
+
|
108 |
+
self.eval()
|
109 |
+
gc.collect()
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
|
112 |
+
def LN(self, x, w):
|
113 |
+
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
|
114 |
+
|
115 |
+
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
|
116 |
+
|
117 |
+
@MyFunction
|
118 |
+
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
|
119 |
+
if self.FLOAT_MODE == "bf16":
|
120 |
+
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
|
121 |
+
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
|
122 |
+
state[5*i+0] = x.float()
|
123 |
+
elif self.FLOAT_MODE == "fp16":
|
124 |
+
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
|
125 |
+
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
|
126 |
+
state[5*i+0] = x.float()
|
127 |
+
else:
|
128 |
+
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
|
129 |
+
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
|
130 |
+
state[5*i+0] = x
|
131 |
+
|
132 |
+
r = torch.sigmoid(rw @ xr)
|
133 |
+
k = torch.square(torch.relu(kw @ xk))
|
134 |
+
kv = vw @ k
|
135 |
+
|
136 |
+
return r * kv
|
137 |
+
|
138 |
+
@MyFunction
|
139 |
+
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
|
140 |
+
if self.FLOAT_MODE == "bf16":
|
141 |
+
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
|
142 |
+
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
|
143 |
+
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
|
144 |
+
state[5*i+1] = x.float()
|
145 |
+
elif self.FLOAT_MODE == "fp16":
|
146 |
+
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
|
147 |
+
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
|
148 |
+
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
|
149 |
+
state[5*i+1] = x.float()
|
150 |
+
else:
|
151 |
+
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
|
152 |
+
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
|
153 |
+
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
|
154 |
+
state[5*i+1] = x
|
155 |
+
|
156 |
+
r = torch.sigmoid(rw @ xr)
|
157 |
+
k = kw @ xk
|
158 |
+
v = vw @ xv
|
159 |
+
|
160 |
+
if '16' in self.FLOAT_MODE:
|
161 |
+
kk = k.float()
|
162 |
+
vv = v.float()
|
163 |
+
else:
|
164 |
+
kk = k
|
165 |
+
vv = v
|
166 |
+
aa = state[5*i+2]
|
167 |
+
bb = state[5*i+3]
|
168 |
+
pp = state[5*i+4]
|
169 |
+
ww = time_first + kk
|
170 |
+
p = torch.maximum(pp, ww)
|
171 |
+
e1 = torch.exp(pp - p)
|
172 |
+
e2 = torch.exp(ww - p)
|
173 |
+
a = e1 * aa + e2 * vv
|
174 |
+
b = e1 * bb + e2
|
175 |
+
ww = pp + time_decay
|
176 |
+
p = torch.maximum(ww, kk)
|
177 |
+
e1 = torch.exp(ww - p)
|
178 |
+
e2 = torch.exp(kk - p)
|
179 |
+
state[5*i+2] = e1 * aa + e2 * vv
|
180 |
+
state[5*i+3] = e1 * bb + e2
|
181 |
+
state[5*i+4] = p
|
182 |
+
if self.FLOAT_MODE == "bf16":
|
183 |
+
wkv = (a / b).type(torch.bfloat16)
|
184 |
+
elif self.FLOAT_MODE == "fp16":
|
185 |
+
wkv = (a / b).half()
|
186 |
+
else:
|
187 |
+
wkv = a / b
|
188 |
+
|
189 |
+
return ow @ (r * wkv)
|
190 |
+
|
191 |
+
def forward(self, ctx, state, preprocess_only = False):
|
192 |
+
with torch.no_grad():
|
193 |
+
w = self.w
|
194 |
+
args = self.args
|
195 |
+
|
196 |
+
x = w.emb.weight[ctx[-1]]
|
197 |
+
if self.RUN_DEVICE == 'cuda':
|
198 |
+
x = x.cuda()
|
199 |
+
try:
|
200 |
+
pos_emb = w.pos_emb[len(ctx)-1]
|
201 |
+
x = x + pos_emb
|
202 |
+
except:
|
203 |
+
pass
|
204 |
+
|
205 |
+
if state == None:
|
206 |
+
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
|
207 |
+
for i in range(args.n_layer):
|
208 |
+
state[5*i+4] -= 1e30
|
209 |
+
|
210 |
+
for i in range(args.n_layer):
|
211 |
+
if i == 0:
|
212 |
+
x = self.LN(x, w.blocks[i].ln0)
|
213 |
+
|
214 |
+
ww = w.blocks[i].att
|
215 |
+
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
|
216 |
+
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
|
217 |
+
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
|
218 |
+
|
219 |
+
ww = w.blocks[i].ffn
|
220 |
+
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
|
221 |
+
ww.time_mix_k, ww.time_mix_r,
|
222 |
+
ww.key.weight, ww.value.weight, ww.receptance.weight)
|
223 |
+
|
224 |
+
if (i+1) % RWKV_RESCALE_LAYER == 0:
|
225 |
+
x = x / 2
|
226 |
+
|
227 |
+
if preprocess_only:
|
228 |
+
return state
|
229 |
+
|
230 |
+
x = self.LN(x, w.ln_out)
|
231 |
+
x = w.head.weight @ x
|
232 |
+
|
233 |
+
return x.float(), state
|
src/trainer.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, math, time, datetime, subprocess
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
6 |
+
|
7 |
+
def my_save(dd, ff):
|
8 |
+
if '14b-run1' not in ff:
|
9 |
+
torch.save(dd, ff)
|
10 |
+
else:
|
11 |
+
fn = ff.split('/')[-1]
|
12 |
+
fff = '/dev/shm/' + fn
|
13 |
+
torch.save(dd, fff)
|
14 |
+
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
15 |
+
|
16 |
+
class train_callback(pl.Callback):
|
17 |
+
def __init__(self, args):
|
18 |
+
super().__init__()
|
19 |
+
self.args = args
|
20 |
+
|
21 |
+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
22 |
+
args = self.args
|
23 |
+
# if args.cuda_cleanup > 0:
|
24 |
+
# torch.cuda.empty_cache()
|
25 |
+
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
26 |
+
|
27 |
+
# LR schedule
|
28 |
+
w_step = args.warmup_steps
|
29 |
+
if args.lr_final == args.lr_init or args.epoch_count == 0:
|
30 |
+
lr = args.lr_init
|
31 |
+
if trainer.global_step < w_step:
|
32 |
+
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
33 |
+
else:
|
34 |
+
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
|
35 |
+
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
|
36 |
+
progress = (decay_step - w_step + 1) / (decay_total - w_step)
|
37 |
+
progress = min(1, max(0, progress))
|
38 |
+
|
39 |
+
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
40 |
+
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
41 |
+
else: # exp decay
|
42 |
+
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
43 |
+
|
44 |
+
if trainer.global_step < w_step:
|
45 |
+
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
46 |
+
# if trainer.is_global_zero:
|
47 |
+
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
|
48 |
+
|
49 |
+
for param_group in trainer.optimizers[0].param_groups:
|
50 |
+
if args.layerwise_lr > 0:
|
51 |
+
param_group["lr"] = lr * param_group["my_lr_scale"]
|
52 |
+
# print(param_group["lr"], param_group["my_lr_scale"])
|
53 |
+
else:
|
54 |
+
param_group["lr"] = lr
|
55 |
+
|
56 |
+
trainer.my_lr = lr
|
57 |
+
# rank_zero_info(f"{real_step} {lr}")
|
58 |
+
|
59 |
+
if trainer.global_step == 0:
|
60 |
+
if trainer.is_global_zero: # logging
|
61 |
+
trainer.my_loss_sum = 0
|
62 |
+
trainer.my_loss_count = 0
|
63 |
+
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
64 |
+
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
|
65 |
+
try:
|
66 |
+
print(f"\n{trainer.strategy.config}\n")
|
67 |
+
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
68 |
+
except:
|
69 |
+
pass
|
70 |
+
trainer.my_log.flush()
|
71 |
+
if len(args.wandb) > 0:
|
72 |
+
print("Login to wandb...")
|
73 |
+
import wandb
|
74 |
+
wandb.init(
|
75 |
+
project=args.wandb,
|
76 |
+
name=args.run_name + " " + args.my_timestamp,
|
77 |
+
config=args,
|
78 |
+
save_code=False,
|
79 |
+
)
|
80 |
+
trainer.my_wandb = wandb
|
81 |
+
|
82 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
83 |
+
args = self.args
|
84 |
+
if trainer.is_global_zero: # logging
|
85 |
+
t_now = time.time_ns()
|
86 |
+
token_per_step = args.ctx_len * args.real_bsz
|
87 |
+
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
88 |
+
kt_s = 0
|
89 |
+
try:
|
90 |
+
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
91 |
+
kt_s = token_per_step / t_cost / 1000
|
92 |
+
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
93 |
+
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
|
94 |
+
except:
|
95 |
+
pass
|
96 |
+
trainer.my_time_ns = t_now
|
97 |
+
trainer.my_loss = trainer.my_loss_all.float().mean().item()
|
98 |
+
trainer.my_loss_sum += trainer.my_loss
|
99 |
+
trainer.my_loss_count += 1
|
100 |
+
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
101 |
+
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
102 |
+
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
103 |
+
# self.log("s", real_step, prog_bar=True, on_step=True)
|
104 |
+
|
105 |
+
if len(args.wandb) > 0:
|
106 |
+
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
|
107 |
+
if kt_s > 0:
|
108 |
+
lll["kt/s"] = kt_s
|
109 |
+
trainer.my_wandb.log(lll, step=int(real_step))
|
110 |
+
if args.magic_prime > 0:
|
111 |
+
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
112 |
+
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1 + int(args.my_random_steps):
|
113 |
+
to_save_dict = pl_module.state_dict()
|
114 |
+
my_save(
|
115 |
+
to_save_dict,
|
116 |
+
f"{args.proj_dir}/rwkv-final.pth",
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
def on_train_epoch_start(self, trainer, pl_module):
|
121 |
+
args = self.args
|
122 |
+
dataset = trainer.train_dataloader.dataset.datasets
|
123 |
+
assert "MyDataset" in str(dataset)
|
124 |
+
dataset.global_rank = trainer.global_rank
|
125 |
+
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
|
126 |
+
dataset.world_size = trainer.world_size
|
127 |
+
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
|
128 |
+
|
129 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
130 |
+
args = self.args
|
131 |
+
if trainer.is_global_zero: # logging & save state_dict
|
132 |
+
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
|
133 |
+
if args.data_type == 'wds_img':
|
134 |
+
raw_dict = pl_module.state_dict()
|
135 |
+
to_save_dict = {}
|
136 |
+
for k in raw_dict:
|
137 |
+
if k.startswith('encoder.') or k.startswith('decoder.'):
|
138 |
+
to_save_dict[k] = raw_dict[k]
|
139 |
+
else:
|
140 |
+
to_save_dict = pl_module.state_dict()
|
141 |
+
try:
|
142 |
+
my_save(
|
143 |
+
to_save_dict,
|
144 |
+
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
145 |
+
)
|
146 |
+
except Exception as e:
|
147 |
+
print('Error\n\n', e, '\n\n')
|
148 |
+
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
149 |
+
trainer.my_log.flush()
|
150 |
+
|
151 |
+
trainer.my_loss_sum = 0
|
152 |
+
trainer.my_loss_count = 0
|
153 |
+
|
154 |
+
|
155 |
+
@rank_zero_only
|
156 |
+
def generate_init_weight(model, init_weight_name):
|
157 |
+
mm = model.generate_init_weight()
|
158 |
+
|
159 |
+
if model.args.my_pile_stage == 1:
|
160 |
+
if len(model.args.load_model) > 0:
|
161 |
+
print(f"Combine weights from {model.args.load_model}...")
|
162 |
+
load_dict = torch.load(model.args.load_model, map_location="cpu")
|
163 |
+
for k in load_dict:
|
164 |
+
assert k in mm
|
165 |
+
src = load_dict[k]
|
166 |
+
try:
|
167 |
+
mm[k] = src.reshape(mm[k].shape)
|
168 |
+
except:
|
169 |
+
tmp = mm[k].squeeze().clone()
|
170 |
+
print(k, src.shape, '-->', mm[k].shape)
|
171 |
+
ss = src.shape[0]
|
172 |
+
dd = tmp.shape[0]
|
173 |
+
for i in range(dd):
|
174 |
+
pos = i / dd * ss
|
175 |
+
if pos >= ss - 1:
|
176 |
+
tmp[i] = src[ss-1]
|
177 |
+
else:
|
178 |
+
p0 = int(math.floor(pos))
|
179 |
+
ii = pos - p0
|
180 |
+
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
|
181 |
+
mm[k] = tmp.reshape(mm[k].shape)
|
182 |
+
sss = src.squeeze().float().cpu().numpy()
|
183 |
+
print(sss[:10], '...', sss[-10:])
|
184 |
+
mmm = mm[k].squeeze().float().cpu().numpy()
|
185 |
+
print(mmm[:10], '...', mmm[-10:])
|
186 |
+
|
187 |
+
print(f"Save to {init_weight_name}...")
|
188 |
+
torch.save(mm, init_weight_name)
|
189 |
+
|
190 |
+
if model.args.my_pile_stage == 1:
|
191 |
+
print("Done. Now go for stage 2.")
|
192 |
+
exit(0)
|
src/utils.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json, time, random, os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
time_slot = {}
|
7 |
+
time_ref = time.time_ns()
|
8 |
+
|
9 |
+
def record_time(name):
|
10 |
+
if name not in time_slot:
|
11 |
+
time_slot[name] = 1e20
|
12 |
+
tt = (time.time_ns() - time_ref) / 1e9
|
13 |
+
if tt < time_slot[name]:
|
14 |
+
time_slot[name] = tt
|
15 |
+
|
16 |
+
class TOKENIZER():
|
17 |
+
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
18 |
+
if 'list' in str(type(WORD_NAME)):
|
19 |
+
self.charMode = False
|
20 |
+
if WORD_NAME[0] == WORD_NAME[1]:
|
21 |
+
from transformers import PreTrainedTokenizerFast
|
22 |
+
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
23 |
+
else:
|
24 |
+
from transformers import GPT2TokenizerFast
|
25 |
+
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
26 |
+
self.vocab_size = len(self.tokenizer)
|
27 |
+
else:
|
28 |
+
self.charMode = True
|
29 |
+
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
30 |
+
self.word_table = json.load(result_file)
|
31 |
+
|
32 |
+
self.vocab_size = len(self.word_table)
|
33 |
+
|
34 |
+
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
35 |
+
self.itos = {int(k): v for k, v in self.word_table.items()}
|
36 |
+
|
37 |
+
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
38 |
+
|
39 |
+
def refine_context(self, context):
|
40 |
+
context = context.strip().split('\n')
|
41 |
+
for c in range(len(context)):
|
42 |
+
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
43 |
+
context = list(filter(lambda c: c != '', context))
|
44 |
+
context = '\n' + ('\n'.join(context)).strip()
|
45 |
+
if context == '':
|
46 |
+
context = '\n'
|
47 |
+
return context
|
48 |
+
|
49 |
+
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
50 |
+
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
51 |
+
lastChar = int(x[-1])
|
52 |
+
|
53 |
+
probs = F.softmax(out, dim=-1)
|
54 |
+
|
55 |
+
if self.charMode:
|
56 |
+
if self.itos[lastChar] == '\n':
|
57 |
+
top_p = top_p_newline
|
58 |
+
else:
|
59 |
+
top_p = top_p_usual
|
60 |
+
else:
|
61 |
+
top_p = top_p_usual
|
62 |
+
|
63 |
+
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
|
64 |
+
probs = probs.numpy()
|
65 |
+
sorted_probs = np.sort(probs)[::-1]
|
66 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
67 |
+
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
68 |
+
probs[probs < cutoff] = 0
|
69 |
+
if temperature != 1.0:
|
70 |
+
probs = probs.pow(1.0 / temperature)
|
71 |
+
probs = probs / np.sum(probs)
|
72 |
+
out = np.random.choice(a=len(probs), p=probs)
|
73 |
+
return out
|
74 |
+
else:
|
75 |
+
sorted_probs = torch.sort(probs, descending=True)[0]
|
76 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
77 |
+
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
78 |
+
probs[probs < cutoff] = 0
|
79 |
+
if temperature != 1.0:
|
80 |
+
probs = probs.pow(1.0 / temperature)
|
81 |
+
out = torch.multinomial(probs, num_samples=1)[0]
|
82 |
+
return out
|
83 |
+
|
84 |
+
def MaybeIsPrime(number):
|
85 |
+
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
86 |
+
return True
|
87 |
+
else:
|
88 |
+
return False
|
89 |
+
|
90 |
+
|
91 |
+
def FermatPrimalityTest(number):
|
92 |
+
if number > 1:
|
93 |
+
for time in range(3):
|
94 |
+
randomNumber = random.randint(2, number) - 1
|
95 |
+
if pow(randomNumber, number - 1, number) != 1:
|
96 |
+
return False
|
97 |
+
return True
|
98 |
+
else:
|
99 |
+
return False
|
100 |
+
|
101 |
+
|
102 |
+
def MillerRabinPrimalityTest(number):
|
103 |
+
if number == 2:
|
104 |
+
return True
|
105 |
+
elif number == 1 or number % 2 == 0:
|
106 |
+
return False
|
107 |
+
oddPartOfNumber = number - 1
|
108 |
+
timesTwoDividNumber = 0
|
109 |
+
while oddPartOfNumber % 2 == 0:
|
110 |
+
oddPartOfNumber = oddPartOfNumber // 2
|
111 |
+
timesTwoDividNumber = timesTwoDividNumber + 1
|
112 |
+
|
113 |
+
for time in range(3):
|
114 |
+
while True:
|
115 |
+
randomNumber = random.randint(2, number) - 1
|
116 |
+
if randomNumber != 0 and randomNumber != 1:
|
117 |
+
break
|
118 |
+
|
119 |
+
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
|
120 |
+
|
121 |
+
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
122 |
+
iterationNumber = 1
|
123 |
+
|
124 |
+
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
|
125 |
+
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
126 |
+
iterationNumber = iterationNumber + 1
|
127 |
+
if randomNumberWithPower != (number - 1):
|
128 |
+
return False
|
129 |
+
|
130 |
+
return True
|
train.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from pytorch_lightning import Trainer
|
8 |
+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
9 |
+
|
10 |
+
rank_zero_info("########## work in progress ##########")
|
11 |
+
|
12 |
+
########################################################################################################
|
13 |
+
#
|
14 |
+
# example: train a simple L12-D768 RWKV on dummy data
|
15 |
+
#
|
16 |
+
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
17 |
+
# --data_file "" --data_type "dummy" --vocab_size 0 \
|
18 |
+
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
|
19 |
+
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
|
20 |
+
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
21 |
+
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
22 |
+
|
23 |
+
# example: train a simple L6-D512 RWKV from scratch on enwik8
|
24 |
+
#
|
25 |
+
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
26 |
+
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
|
27 |
+
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
|
28 |
+
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
|
29 |
+
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
30 |
+
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
31 |
+
|
32 |
+
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
|
33 |
+
#
|
34 |
+
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
35 |
+
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
36 |
+
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
|
37 |
+
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
38 |
+
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
39 |
+
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
|
40 |
+
|
41 |
+
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
|
42 |
+
#
|
43 |
+
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
44 |
+
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
45 |
+
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
|
46 |
+
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
47 |
+
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
48 |
+
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
|
49 |
+
|
50 |
+
parser = ArgumentParser()
|
51 |
+
|
52 |
+
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
53 |
+
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
|
54 |
+
parser.add_argument("--proj_dir", default="out", type=str)
|
55 |
+
parser.add_argument("--random_seed", default="-1", type=int)
|
56 |
+
|
57 |
+
parser.add_argument("--data_file", default="", type=str)
|
58 |
+
parser.add_argument("--data_type", default="utf-8", type=str)
|
59 |
+
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
60 |
+
|
61 |
+
parser.add_argument("--ctx_len", default=1024, type=int)
|
62 |
+
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
|
63 |
+
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
|
64 |
+
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
|
65 |
+
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
|
66 |
+
|
67 |
+
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
68 |
+
parser.add_argument("--n_layer", default=6, type=int)
|
69 |
+
parser.add_argument("--n_embd", default=512, type=int)
|
70 |
+
parser.add_argument("--dim_att", default=0, type=int)
|
71 |
+
parser.add_argument("--dim_ffn", default=0, type=int)
|
72 |
+
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
|
73 |
+
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
74 |
+
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
75 |
+
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
|
76 |
+
|
77 |
+
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
78 |
+
parser.add_argument("--lr_final", default=1e-5, type=float)
|
79 |
+
parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
|
80 |
+
parser.add_argument("--beta1", default=0.9, type=float)
|
81 |
+
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
82 |
+
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
83 |
+
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
84 |
+
|
85 |
+
parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
|
86 |
+
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
87 |
+
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
|
88 |
+
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
89 |
+
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
90 |
+
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
91 |
+
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
92 |
+
|
93 |
+
parser.add_argument("--my_img_version", default=0, type=str)
|
94 |
+
parser.add_argument("--my_img_size", default=0, type=int)
|
95 |
+
parser.add_argument("--my_img_bit", default=0, type=int)
|
96 |
+
parser.add_argument("--my_img_clip", default='x', type=str)
|
97 |
+
parser.add_argument("--my_img_clip_scale", default=1, type=float)
|
98 |
+
parser.add_argument("--my_img_l1_scale", default=0, type=float)
|
99 |
+
parser.add_argument("--my_img_encoder", default='x', type=str)
|
100 |
+
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
|
101 |
+
parser.add_argument("--my_sample_len", default=0, type=int)
|
102 |
+
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
103 |
+
parser.add_argument("--my_att_shift", default=1, type=int)
|
104 |
+
parser.add_argument("--my_pos_emb", default=0, type=int)
|
105 |
+
parser.add_argument("--load_partial", default=0, type=int)
|
106 |
+
parser.add_argument("--magic_prime", default=0, type=int)
|
107 |
+
parser.add_argument("--my_qa_mask", default=0, type=int)
|
108 |
+
parser.add_argument("--my_random_steps", default=0, type=int)
|
109 |
+
parser.add_argument("--my_testing", default='', type=str)
|
110 |
+
|
111 |
+
parser = Trainer.add_argparse_args(parser)
|
112 |
+
args = parser.parse_args()
|
113 |
+
|
114 |
+
########################################################################################################
|
115 |
+
|
116 |
+
import os, warnings, math, datetime, sys, time, importlib
|
117 |
+
import numpy as np
|
118 |
+
import torch
|
119 |
+
from torch.utils.data import DataLoader
|
120 |
+
if "deepspeed" in args.strategy:
|
121 |
+
import deepspeed
|
122 |
+
import pytorch_lightning as pl
|
123 |
+
from pytorch_lightning import seed_everything
|
124 |
+
|
125 |
+
if args.random_seed >= 0:
|
126 |
+
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
|
127 |
+
seed_everything(args.random_seed)
|
128 |
+
|
129 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
130 |
+
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
131 |
+
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
132 |
+
# os.environ["WDS_SHOW_SEED"] = "1"
|
133 |
+
|
134 |
+
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
135 |
+
args.enable_checkpointing = False
|
136 |
+
args.replace_sampler_ddp = False
|
137 |
+
args.logger = False
|
138 |
+
args.gradient_clip_val = 1.0
|
139 |
+
args.num_sanity_val_steps = 0
|
140 |
+
args.check_val_every_n_epoch = int(1e20)
|
141 |
+
args.log_every_n_steps = int(1e20)
|
142 |
+
args.max_epochs = -1 # continue forever
|
143 |
+
args.betas = (args.beta1, args.beta2)
|
144 |
+
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
145 |
+
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
146 |
+
os.environ["RWKV_MY_TESTING"] = args.my_testing
|
147 |
+
if args.dim_att <= 0:
|
148 |
+
args.dim_att = args.n_embd
|
149 |
+
if args.dim_ffn <= 0:
|
150 |
+
args.dim_ffn = args.n_embd * 4
|
151 |
+
|
152 |
+
if args.data_type == "wds_img":
|
153 |
+
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
154 |
+
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
155 |
+
else:
|
156 |
+
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
157 |
+
if not os.path.exists(args.proj_dir):
|
158 |
+
os.makedirs(args.proj_dir)
|
159 |
+
|
160 |
+
if args.my_pile_stage > 0:
|
161 |
+
magic_prime_bak = args.magic_prime
|
162 |
+
|
163 |
+
if args.my_pile_version == 1:
|
164 |
+
if args.ctx_len == 1024:
|
165 |
+
args.magic_prime = 324331313
|
166 |
+
args.epoch_count = 8043
|
167 |
+
elif args.ctx_len == 2048:
|
168 |
+
args.magic_prime = 162165671
|
169 |
+
args.epoch_count = 4021
|
170 |
+
elif args.ctx_len == 4096:
|
171 |
+
args.magic_prime = 81082817
|
172 |
+
args.epoch_count = 2010
|
173 |
+
elif args.ctx_len == 8192:
|
174 |
+
args.magic_prime = 40541399
|
175 |
+
args.epoch_count = 1005
|
176 |
+
else:
|
177 |
+
if args.ctx_len == 1024:
|
178 |
+
args.magic_prime = 1670239709
|
179 |
+
args.epoch_count = 41423
|
180 |
+
elif args.ctx_len == 2048:
|
181 |
+
args.magic_prime = 835119767
|
182 |
+
args.epoch_count = 20711
|
183 |
+
elif args.ctx_len == 4096:
|
184 |
+
args.magic_prime = 417559889
|
185 |
+
args.epoch_count = 10355
|
186 |
+
elif args.ctx_len == 6144:
|
187 |
+
args.magic_prime = 278373239
|
188 |
+
args.epoch_count = 6903
|
189 |
+
elif args.ctx_len == 8192:
|
190 |
+
args.magic_prime = 208779911
|
191 |
+
args.epoch_count = 5177
|
192 |
+
if args.my_pile_shift < 0:
|
193 |
+
args.my_pile_shift = 0
|
194 |
+
|
195 |
+
if magic_prime_bak > 0:
|
196 |
+
args.magic_prime = magic_prime_bak
|
197 |
+
|
198 |
+
args.epoch_steps = 40320 // args.real_bsz
|
199 |
+
assert args.epoch_steps * args.real_bsz == 40320
|
200 |
+
if args.my_pile_stage == 2:
|
201 |
+
assert args.lr_final == args.lr_init
|
202 |
+
if args.my_pile_stage >= 2: # find latest saved model
|
203 |
+
list_p = []
|
204 |
+
for p in os.listdir(args.proj_dir):
|
205 |
+
if p.startswith("rwkv") and p.endswith(".pth"):
|
206 |
+
p = ((p.split("-"))[1].split("."))[0]
|
207 |
+
if p == "init":
|
208 |
+
p = -1
|
209 |
+
else:
|
210 |
+
p = int(p)
|
211 |
+
list_p += [p]
|
212 |
+
list_p.sort()
|
213 |
+
max_p = list_p[-1]
|
214 |
+
if len(list_p) > 1:
|
215 |
+
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
|
216 |
+
if max_p == -1:
|
217 |
+
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
218 |
+
else:
|
219 |
+
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
220 |
+
if args.warmup_steps < 0:
|
221 |
+
if args.my_pile_stage == 2:
|
222 |
+
args.warmup_steps = 10
|
223 |
+
else:
|
224 |
+
args.warmup_steps = 30
|
225 |
+
args.epoch_begin = max_p + 1
|
226 |
+
|
227 |
+
samples_per_epoch = args.epoch_steps * args.real_bsz
|
228 |
+
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
229 |
+
rank_zero_info(
|
230 |
+
f"""
|
231 |
+
############################################################################
|
232 |
+
#
|
233 |
+
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
|
234 |
+
#
|
235 |
+
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
236 |
+
#
|
237 |
+
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
|
238 |
+
#
|
239 |
+
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
240 |
+
#
|
241 |
+
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
242 |
+
#
|
243 |
+
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
244 |
+
#
|
245 |
+
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
|
246 |
+
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
|
247 |
+
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
|
248 |
+
#
|
249 |
+
############################################################################
|
250 |
+
"""
|
251 |
+
)
|
252 |
+
rank_zero_info(str(vars(args)) + "\n")
|
253 |
+
|
254 |
+
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
|
255 |
+
|
256 |
+
if args.lr_final == 0 or args.lr_init == 0:
|
257 |
+
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
|
258 |
+
|
259 |
+
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
260 |
+
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
261 |
+
if args.precision == "fp32":
|
262 |
+
for i in range(10):
|
263 |
+
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
264 |
+
if args.precision == "fp16":
|
265 |
+
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
266 |
+
|
267 |
+
os.environ["RWKV_JIT_ON"] = "1"
|
268 |
+
if "deepspeed_stage_3" in args.strategy:
|
269 |
+
os.environ["RWKV_JIT_ON"] = "0"
|
270 |
+
|
271 |
+
torch.backends.cudnn.benchmark = True
|
272 |
+
torch.backends.cudnn.enabled = True
|
273 |
+
if args.precision == "fp32":
|
274 |
+
torch.backends.cudnn.allow_tf32 = False
|
275 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
276 |
+
else:
|
277 |
+
torch.backends.cudnn.allow_tf32 = True
|
278 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
279 |
+
|
280 |
+
if "32" in args.precision:
|
281 |
+
args.precision = 32
|
282 |
+
elif args.precision == "fp16":
|
283 |
+
args.precision = 16
|
284 |
+
else:
|
285 |
+
args.precision = "bf16"
|
286 |
+
|
287 |
+
########################################################################################################
|
288 |
+
|
289 |
+
from src.trainer import train_callback, generate_init_weight
|
290 |
+
from src.dataset import MyDataset
|
291 |
+
|
292 |
+
train_data = MyDataset(args)
|
293 |
+
args.vocab_size = train_data.vocab_size
|
294 |
+
|
295 |
+
if args.data_type == 'wds_img':
|
296 |
+
from src.model_img import RWKV_IMG
|
297 |
+
model = RWKV_IMG(args)
|
298 |
+
else:
|
299 |
+
from src.model import RWKV
|
300 |
+
model = RWKV(args)
|
301 |
+
|
302 |
+
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
|
303 |
+
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
304 |
+
generate_init_weight(model, init_weight_name) # save initial weights
|
305 |
+
args.load_model = init_weight_name
|
306 |
+
|
307 |
+
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
308 |
+
try:
|
309 |
+
load_dict = torch.load(args.load_model, map_location="cpu")
|
310 |
+
except:
|
311 |
+
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
312 |
+
if args.my_pile_stage >= 2: # try again using another checkpoint
|
313 |
+
max_p = args.my_pile_prev_p
|
314 |
+
if max_p == -1:
|
315 |
+
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
316 |
+
else:
|
317 |
+
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
318 |
+
args.epoch_begin = max_p + 1
|
319 |
+
rank_zero_info(f"Trying {args.load_model}")
|
320 |
+
load_dict = torch.load(args.load_model, map_location="cpu")
|
321 |
+
|
322 |
+
if args.load_partial == 1:
|
323 |
+
load_keys = load_dict.keys()
|
324 |
+
for k in model.state_dict():
|
325 |
+
if k not in load_keys:
|
326 |
+
load_dict[k] = model.state_dict()[k]
|
327 |
+
model.load_state_dict(load_dict)
|
328 |
+
|
329 |
+
trainer = Trainer.from_argparse_args(
|
330 |
+
args,
|
331 |
+
callbacks=[train_callback(args)],
|
332 |
+
)
|
333 |
+
|
334 |
+
if trainer.global_rank == 0:
|
335 |
+
for n in model.state_dict():
|
336 |
+
shape = model.state_dict()[n].shape
|
337 |
+
shape = [i for i in shape if i != 1]
|
338 |
+
if len(shape) > 1:
|
339 |
+
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
|
340 |
+
else:
|
341 |
+
print(f"{str(shape[0]).ljust(5)} {n}")
|
342 |
+
|
343 |
+
if "deepspeed" in args.strategy:
|
344 |
+
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
345 |
+
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
346 |
+
|
347 |
+
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
348 |
+
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
349 |
+
|
350 |
+
trainer.fit(model, data_loader)
|
verify.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
# this is for verifying the results of different models and make sure they agree with each other
|
6 |
+
|
7 |
+
import os, sys, types
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
11 |
+
try:
|
12 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
|
13 |
+
except:
|
14 |
+
pass
|
15 |
+
torch.backends.cudnn.benchmark = True
|
16 |
+
torch.backends.cudnn.allow_tf32 = False
|
17 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
18 |
+
|
19 |
+
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32
|
20 |
+
os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA
|
21 |
+
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
|
22 |
+
|
23 |
+
TOKEN_MODE = 'pile'
|
24 |
+
|
25 |
+
if TOKEN_MODE == 'pile':
|
26 |
+
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
|
27 |
+
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
|
28 |
+
n_layer = 32
|
29 |
+
n_embd = 2560
|
30 |
+
ctx_len = 1024
|
31 |
+
UNKNOWN_CHAR = None
|
32 |
+
|
33 |
+
from src.utils import TOKENIZER
|
34 |
+
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
35 |
+
if TOKEN_MODE == 'pile':
|
36 |
+
tokenizer.vocab_size = 50277
|
37 |
+
|
38 |
+
########################################################################################################
|
39 |
+
|
40 |
+
os.environ["RWKV_JIT_ON"] = "1"
|
41 |
+
os.environ["RWKV_T_MAX"] = str(ctx_len)
|
42 |
+
|
43 |
+
from src.model_run import RWKV_RNN
|
44 |
+
from src.model import RWKV
|
45 |
+
|
46 |
+
args = types.SimpleNamespace()
|
47 |
+
args.vocab_size = tokenizer.vocab_size
|
48 |
+
args.ctx_len = ctx_len
|
49 |
+
args.n_embd = n_embd
|
50 |
+
args.n_layer = n_layer
|
51 |
+
args.head_qk = 0
|
52 |
+
args.pre_ffn = 0
|
53 |
+
args.grad_cp = 0
|
54 |
+
args.my_pos_emb = 0
|
55 |
+
model_train = RWKV(args).to(RUN_DEVICE)
|
56 |
+
|
57 |
+
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
58 |
+
model_train = model_train.half()
|
59 |
+
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
60 |
+
model_train = model_train.bfloat16()
|
61 |
+
|
62 |
+
print('loading ' + MODEL_NAME)
|
63 |
+
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu')
|
64 |
+
model_train.load_state_dict(m2)
|
65 |
+
|
66 |
+
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
67 |
+
model_train = model_train.half()
|
68 |
+
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
69 |
+
model_train = model_train.bfloat16()
|
70 |
+
|
71 |
+
args.MODEL_NAME = MODEL_NAME
|
72 |
+
args.RUN_DEVICE = RUN_DEVICE
|
73 |
+
args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE']
|
74 |
+
model_rnn = RWKV_RNN(args)
|
75 |
+
|
76 |
+
########################################################################################################
|
77 |
+
|
78 |
+
print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}")
|
79 |
+
|
80 |
+
# context = '\nIn a'
|
81 |
+
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
|
82 |
+
|
83 |
+
if TOKEN_MODE == 'pile':
|
84 |
+
ctx = tokenizer.tokenizer.encode(context)
|
85 |
+
print(f'input len {len(ctx)} data {ctx}')
|
86 |
+
|
87 |
+
########################################################################################################
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
print('\nRWKV-train output')
|
91 |
+
out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy()
|
92 |
+
print(out, '\n')
|
93 |
+
|
94 |
+
print('\nRWKV-RNN output')
|
95 |
+
state = None
|
96 |
+
out = None
|
97 |
+
src_len = len(ctx)
|
98 |
+
for i in range(src_len):
|
99 |
+
x = ctx[:i+1]
|
100 |
+
out, state = model_rnn.forward(x, state)
|
101 |
+
if i < 3 or i >= src_len - 3:
|
102 |
+
print(out.detach().cpu().numpy())
|
103 |
+
if i == 2:
|
104 |
+
print('...')
|
zrwkv-37fifth.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:426991ea8333bdc4a16fa27551b1e8e7ebe9090e2a5ff346d95290f4ffc55a3e
|
3 |
+
size 338718755
|