JoshuaChak commited on
Commit
893630b
1 Parent(s): b7553eb

Upload chat.cpp with huggingface_hub

Browse files
Files changed (1) hide show
  1. chat.cpp +428 -0
chat.cpp ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //===----------------------------------------------------------------------===//
2
+ //
3
+ // Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
4
+ //
5
+ // TPU-MLIR is licensed under the 2-Clause BSD License except for the
6
+ // third-party components.
7
+ //
8
+ //===----------------------------------------------------------------------===//
9
+
10
+ #include <iostream>
11
+ #include <cstdlib>
12
+ #include <vector>
13
+ #include <assert.h>
14
+ #include <chrono>
15
+ #include <algorithm>
16
+ #include <pybind11/pybind11.h>
17
+ #include <pybind11/stl.h>
18
+ #include "memory.h"
19
+ #include "bmruntime_interface.h"
20
+ #include <getopt.h>
21
+ #include <stdio.h>
22
+ #include <inttypes.h>
23
+ #include <random>
24
+ #include <numeric>
25
+
26
+ static const uint16_t ATTENTION_MASK = 0xF0E2;
27
+
28
+ class Llama3 {
29
+ public:
30
+ void init(const std::vector<int> &devid, std::string model_path);
31
+ void deinit();
32
+ int forward_first(std::vector<int> &tokens);
33
+ int forward_next();
34
+ std::vector<int> generate(std::vector<int> &history_tokens, int EOS);
35
+
36
+ std::mt19937 sgen;
37
+ Llama3() : sgen(std::random_device()()){};
38
+
39
+ private:
40
+ void net_launch(const bm_net_info_t *net, int stage_idx = 0);
41
+ inline void d2d(bm_device_mem_t &dst, bm_device_mem_t &src);
42
+
43
+ void head_launch(const bm_net_info_t *net, bm_device_mem_t &logits_mem);
44
+ int greedy_search(const bm_net_info_t *net, bm_device_mem_t &logits_mem);
45
+ int penalty_sample(const bm_net_info_t *net, bm_device_mem_t &logits_mem);
46
+
47
+ public:
48
+ int token_length;
49
+ int SEQLEN; // read from bmodel
50
+ int NUM_LAYERS; // read from bmodel
51
+ bool io_alone;
52
+ std::vector<int> visited_tokens;
53
+
54
+ // generation
55
+ float temperature;
56
+ float top_p;
57
+ float repeat_penalty;
58
+ int repeat_last_n;
59
+ int max_new_tokens;
60
+ std::string generation_mode;
61
+ std::string prompt_mode;
62
+
63
+ private:
64
+ std::vector<bm_handle_t> handles;
65
+ bm_handle_t bm_handle;
66
+ void *p_bmrt;
67
+ std::vector<const bm_net_info_t *> net_blocks;
68
+ std::vector<const bm_net_info_t *> net_blocks_cache;
69
+ const bm_net_info_t *net_embed;
70
+ const bm_net_info_t *net_embed_cache;
71
+ const bm_net_info_t *net_lm, *net_greedy_head, *net_penalty_sample_head;
72
+ std::vector<bm_device_mem_t> past_key;
73
+ std::vector<bm_device_mem_t> past_value;
74
+ };
75
+
76
+ void Llama3::net_launch(const bm_net_info_t *net, int stage_idx) {
77
+ std::vector<bm_tensor_t> in_tensors(net->input_num);
78
+ std::vector<bm_tensor_t> out_tensors(net->output_num);
79
+
80
+ for (int i = 0; i < net->input_num; i++) {
81
+ bmrt_tensor_with_device(
82
+ &in_tensors[i], net->stages[stage_idx].input_mems[i],
83
+ net->input_dtypes[i], net->stages[stage_idx].input_shapes[i]);
84
+ }
85
+ for (int i = 0; i < net->output_num; i++) {
86
+ bmrt_tensor_with_device(
87
+ &out_tensors[i], net->stages[stage_idx].output_mems[i],
88
+ net->output_dtypes[i], net->stages[stage_idx].output_shapes[i]);
89
+ }
90
+ auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(),
91
+ net->input_num, out_tensors.data(),
92
+ net->output_num, true, false);
93
+ assert(ret);
94
+ bm_thread_sync(bm_handle);
95
+ }
96
+
97
+ void Llama3::d2d(bm_device_mem_t &dst, bm_device_mem_t &src) {
98
+ bm_memcpy_d2d_byte(bm_handle, dst, 0, src, 0, bm_mem_get_device_size(src));
99
+ }
100
+
101
+ void Llama3::init(const std::vector<int> &devices, std::string model_path) {
102
+
103
+ // request bm_handle
104
+ std::cout << "Device [ ";
105
+ for (auto d : devices) {
106
+ std::cout << d << " ";
107
+ }
108
+ std::cout << "] loading ....\n";
109
+ for (auto d : devices) {
110
+ bm_handle_t h;
111
+ bm_status_t status = bm_dev_request(&h, d);
112
+ assert(BM_SUCCESS == status);
113
+ handles.push_back(h);
114
+ }
115
+ bm_handle = handles[0];
116
+
117
+ // create bmruntime
118
+ #ifdef SOC_TARGET
119
+ p_bmrt = bmrt_create(handles[0]);
120
+ #else
121
+ p_bmrt = bmrt_create_ex(handles.data(), handles.size());
122
+ #endif
123
+ assert(NULL != p_bmrt);
124
+
125
+ // load bmodel by file
126
+ printf("Model[%s] loading ....\n", model_path.c_str());
127
+ bool ret = bmrt_load_bmodel(p_bmrt, model_path.c_str());
128
+ assert(true == ret);
129
+ printf("Done!\n");
130
+
131
+ // net embed and lm_head
132
+ net_embed = bmrt_get_network_info(p_bmrt, "embedding");
133
+ net_embed_cache = bmrt_get_network_info(p_bmrt, "embedding_cache");
134
+ net_lm = bmrt_get_network_info(p_bmrt, "lm_head");
135
+ net_greedy_head = bmrt_get_network_info(p_bmrt, "greedy_head");
136
+ net_penalty_sample_head = bmrt_get_network_info(p_bmrt, "penalty_sample_head");
137
+ SEQLEN = net_embed->stages[0].input_shapes[0].dims[1]; // real seqlen
138
+ auto num_nets = bmrt_get_network_number(p_bmrt);
139
+ NUM_LAYERS = (num_nets - 5) / 2;
140
+
141
+ // resize
142
+ visited_tokens.resize(SEQLEN);
143
+
144
+ // net blocks
145
+ for (int i = 0; i < NUM_LAYERS; i++) {
146
+ auto block_name = "block_" + std::to_string(i);
147
+ auto cache_name = "block_cache_" + std::to_string(i);
148
+ net_blocks.emplace_back(bmrt_get_network_info(p_bmrt, block_name.c_str()));
149
+ net_blocks_cache.emplace_back(
150
+ bmrt_get_network_info(p_bmrt, cache_name.c_str()));
151
+ }
152
+
153
+ // kv cache
154
+ past_key.resize(NUM_LAYERS);
155
+ past_value.resize(NUM_LAYERS);
156
+ auto addr_mode = net_blocks_cache[0]->addr_mode;
157
+ io_alone = addr_mode == 1;
158
+ for (int i = 0; i < NUM_LAYERS; i++) {
159
+ assert(addr_mode == net_blocks_cache[i]->addr_mode);
160
+ if (io_alone) {
161
+ past_key[i] = net_blocks_cache[i]->stages[0].input_mems[3];
162
+ past_value[i] = net_blocks_cache[i]->stages[0].input_mems[4];
163
+ } else {
164
+ auto ret = bm_malloc_device_byte(bm_handle, &past_key[i],
165
+ net_blocks_cache[i]->max_input_bytes[3]);
166
+ assert(BM_SUCCESS == ret);
167
+ ret = bm_malloc_device_byte(bm_handle, &past_value[i],
168
+ net_blocks_cache[i]->max_input_bytes[4]);
169
+ assert(BM_SUCCESS == ret);
170
+ }
171
+ }
172
+ }
173
+
174
+ void Llama3::deinit() {
175
+ if (false == io_alone) {
176
+ for (int i = 0; i < NUM_LAYERS; i++) {
177
+ bm_free_device(bm_handle, past_key[i]);
178
+ bm_free_device(bm_handle, past_value[i]);
179
+ }
180
+ }
181
+ bmrt_destroy(p_bmrt);
182
+ for (auto h : handles) {
183
+ bm_dev_free(h);
184
+ }
185
+ }
186
+
187
+ void Llama3::head_launch(const bm_net_info_t *net, bm_device_mem_t &logits_mem) {
188
+ std::vector<bm_tensor_t> in_tensors(net->input_num);
189
+ std::vector<bm_tensor_t> out_tensors(net->output_num);
190
+
191
+ bmrt_tensor_with_device(
192
+ &in_tensors[0], logits_mem,
193
+ net->input_dtypes[0], net->stages[0].input_shapes[0]);
194
+
195
+ for (int i = 1; i < net->input_num; i++) {
196
+ bmrt_tensor_with_device(
197
+ &in_tensors[i], net->stages[0].input_mems[i],
198
+ net->input_dtypes[i], net->stages[0].input_shapes[i]);
199
+ }
200
+ for (int i = 0; i < net->output_num; i++) {
201
+ bmrt_tensor_with_device(
202
+ &out_tensors[i], net->stages[0].output_mems[i],
203
+ net->output_dtypes[i], net->stages[0].output_shapes[i]);
204
+ }
205
+ auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(),
206
+ net->input_num, out_tensors.data(),
207
+ net->output_num, true, false);
208
+ assert(ret);
209
+ bm_thread_sync(bm_handle);
210
+ }
211
+
212
+ int Llama3::greedy_search(const bm_net_info_t *net, bm_device_mem_t &logits_mem) {
213
+ auto &out_mem = net->stages[0].output_mems[0];
214
+ head_launch(net, logits_mem);
215
+ int token = 0;
216
+ bm_memcpy_d2s(bm_handle, (void *)&token, out_mem);
217
+ return token;
218
+ }
219
+
220
+ int Llama3::penalty_sample(const bm_net_info_t *net, bm_device_mem_t &logits_mem) {
221
+ auto &in1_mem = net->stages[0].input_mems[1];
222
+ auto &in2_mem = net->stages[0].input_mems[2];
223
+ auto &in3_mem = net->stages[0].input_mems[3];
224
+ auto &in4_mem = net->stages[0].input_mems[4];
225
+ auto &out0_mem = net->stages[0].output_mems[0];
226
+ auto &out1_mem = net->stages[0].output_mems[1];
227
+
228
+ // repeat_penalty + top_p + top_k + temperature
229
+ std::vector<int> generated_tokens(SEQLEN, visited_tokens[token_length - 1]);
230
+ repeat_last_n = std::min(repeat_last_n, token_length);
231
+ std::copy(visited_tokens.begin() + token_length - repeat_last_n,
232
+ visited_tokens.begin() + token_length,
233
+ generated_tokens.begin());
234
+ bm_memcpy_s2d(bm_handle, in1_mem, (void *)generated_tokens.data());
235
+ bm_memcpy_s2d(bm_handle, in2_mem, (void *)&top_p);
236
+ bm_memcpy_s2d(bm_handle, in3_mem, (void *)&temperature);
237
+ bm_memcpy_s2d(bm_handle, in4_mem, (void *)&repeat_penalty);
238
+
239
+ // inference
240
+ head_launch(net, logits_mem);
241
+
242
+ // get logit & token
243
+ int candidate_num = net->stages[0].output_shapes[0].dims[1];
244
+ std::vector<float> probs(candidate_num);
245
+ bm_memcpy_d2s(bm_handle, probs.data(), out0_mem);
246
+ std::vector<int> tokens(candidate_num);
247
+ bm_memcpy_d2s(bm_handle, tokens.data(), out1_mem);
248
+
249
+ // penalty_sample
250
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
251
+ return tokens[dist(sgen)];
252
+ }
253
+
254
+ int Llama3::forward_first(std::vector<int> &tokens) {
255
+ std::vector<int> position_id(SEQLEN, 0);
256
+ std::vector<uint16_t> attention_mask(SEQLEN * SEQLEN, ATTENTION_MASK);
257
+ std::copy(tokens.begin(), tokens.end(), visited_tokens.data());
258
+
259
+ token_length = tokens.size();
260
+
261
+ for (int i = 0; i < token_length; i++) {
262
+ position_id[i] = i;
263
+ }
264
+ for (int i = 0; i < token_length; i++) {
265
+ for (int j = 0; j < SEQLEN; j++) {
266
+ if (j <= i) {
267
+ attention_mask[i * SEQLEN + j] = 0;
268
+ }
269
+ }
270
+ }
271
+
272
+ // forward embeding
273
+ auto &in_mem = net_embed->stages[0].input_mems[0];
274
+ auto &out_mem = net_embed->stages[0].output_mems[0];
275
+ bm_memcpy_s2d(bm_handle, in_mem, (void *)visited_tokens.data());
276
+ net_launch(net_embed); // prefil embedding
277
+
278
+ // forward blocks
279
+ for (int idx = 0; idx < NUM_LAYERS; idx++) {
280
+ auto &in0_mem = net_blocks[idx]->stages[0].input_mems[0];
281
+ auto &in1_mem = net_blocks[idx]->stages[0].input_mems[1];
282
+ auto &in2_mem = net_blocks[idx]->stages[0].input_mems[2];
283
+ d2d(in0_mem, out_mem);
284
+ if (idx == 0) {
285
+ // only first time need copy
286
+ bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_id.data());
287
+ bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data());
288
+ }
289
+ net_launch(net_blocks[idx]);
290
+ out_mem = net_blocks[idx]->stages[0].output_mems[0];
291
+ d2d(past_key[idx], net_blocks[idx]->stages[0].output_mems[1]);
292
+ d2d(past_value[idx], net_blocks[idx]->stages[0].output_mems[2]);
293
+ }
294
+
295
+ // forward lmhead
296
+ int bytes = out_mem.size / SEQLEN;
297
+ auto &lm_in_mem = net_lm->stages[0].input_mems[0];
298
+ auto &lm_out_mem = net_lm->stages[0].output_mems[0];
299
+ bm_memcpy_d2d_byte(bm_handle, lm_in_mem, 0, out_mem,
300
+ (token_length - 1) * bytes, bytes);
301
+ net_launch(net_lm);
302
+
303
+ int token = 0;
304
+ if (generation_mode == "greedy") {
305
+ token = greedy_search(net_greedy_head, lm_out_mem);
306
+ } else if (generation_mode == "penalty_sample") {
307
+ token = penalty_sample(net_penalty_sample_head, lm_out_mem);
308
+ }
309
+
310
+ visited_tokens[token_length] = token;
311
+ token_length += 1;
312
+ return token;
313
+ }
314
+
315
+ int Llama3::forward_next() {
316
+ int cur_token = visited_tokens[token_length - 1];
317
+
318
+ std::vector<uint16_t> attention_mask(SEQLEN + 1, 0);
319
+ for (int i = token_length - 1; i < SEQLEN; i++) {
320
+ attention_mask[i] = ATTENTION_MASK;
321
+ }
322
+ int32_t position_id = token_length - 1;
323
+
324
+ // embedding
325
+ auto &in_mem = net_embed_cache->stages[0].input_mems[0];
326
+ auto &out_mem = net_embed_cache->stages[0].output_mems[0];
327
+ bm_memcpy_s2d(bm_handle, in_mem, (void *)&cur_token);
328
+ net_launch(net_embed_cache);
329
+
330
+ // blocks
331
+ int bytes =
332
+ bm_mem_get_device_size(net_blocks_cache[0]->stages[0].output_mems[1]);
333
+ int token_offset = (token_length - 1) * bytes;
334
+ for (int idx = 0; idx < NUM_LAYERS; idx++) {
335
+ auto &in0_mem = net_blocks_cache[idx]->stages[0].input_mems[0];
336
+ auto &in1_mem = net_blocks_cache[idx]->stages[0].input_mems[1];
337
+ auto &in2_mem = net_blocks_cache[idx]->stages[0].input_mems[2];
338
+ auto &in3_mem = net_blocks_cache[idx]->stages[0].input_mems[3];
339
+ auto &in4_mem = net_blocks_cache[idx]->stages[0].input_mems[4];
340
+ auto &out0_mem = net_blocks_cache[idx]->stages[0].output_mems[0];
341
+ auto &out1_mem = net_blocks_cache[idx]->stages[0].output_mems[1];
342
+ auto &out2_mem = net_blocks_cache[idx]->stages[0].output_mems[2];
343
+ d2d(in0_mem, out_mem);
344
+ if (io_alone) {
345
+ if (idx == 0) {
346
+ bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id);
347
+ bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data());
348
+ } else {
349
+ d2d(in1_mem, net_blocks_cache[0]->stages[0].input_mems[1]);
350
+ d2d(in2_mem, net_blocks_cache[0]->stages[0].input_mems[2]);
351
+ }
352
+ } else {
353
+ if (idx == 0) {
354
+ bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id);
355
+ bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data());
356
+ }
357
+ d2d(in3_mem, past_key[idx]);
358
+ d2d(in4_mem, past_value[idx]);
359
+ }
360
+ net_launch(net_blocks_cache[idx]);
361
+ out_mem = out0_mem;
362
+ bm_memcpy_d2d_byte(bm_handle, past_key[idx], token_offset, out1_mem, 0,
363
+ bytes);
364
+ bm_memcpy_d2d_byte(bm_handle, past_value[idx], token_offset, out2_mem, 0,
365
+ bytes);
366
+ }
367
+
368
+ // forward lmhead
369
+ auto &lm_in_mem = net_lm->stages[0].input_mems[0];
370
+ auto &lm_out_mem = net_lm->stages[0].output_mems[0];
371
+ d2d(lm_in_mem, out_mem);
372
+ net_launch(net_lm);
373
+
374
+ int token = 0;
375
+ if (generation_mode == "greedy") {
376
+ token = greedy_search(net_greedy_head, lm_out_mem);
377
+ } else if (generation_mode == "penalty_sample") {
378
+ token = penalty_sample(net_penalty_sample_head, lm_out_mem);
379
+ }
380
+
381
+ visited_tokens[token_length] = token;
382
+ token_length += 1;
383
+ return token;
384
+ }
385
+
386
+
387
+ std::vector<int> Llama3::generate(std::vector<int> &history_tokens, int EOS) {
388
+ if (history_tokens.empty()) {
389
+ printf("Sorry: your question is empty!!\n");
390
+ history_tokens.clear();
391
+ return {};
392
+ }
393
+
394
+ // make sure token not too large
395
+ if ((int)history_tokens.size() > SEQLEN - 10) {
396
+ history_tokens.clear();
397
+ printf("Error: your question is too large!\n");
398
+ return {};
399
+ }
400
+
401
+ std::vector<int> result_tokens;
402
+ int token = forward_first(history_tokens);
403
+ while (token != EOS && token_length < SEQLEN) {
404
+ result_tokens.emplace_back(token);
405
+ token = forward_next();
406
+ }
407
+
408
+ return result_tokens;
409
+ }
410
+
411
+ PYBIND11_MODULE(chat, m) {
412
+ pybind11::class_<Llama3>(m, "Llama3")
413
+ .def(pybind11::init<>())
414
+ .def("init", &Llama3::init)
415
+ .def("forward_first", &Llama3::forward_first)
416
+ .def("forward_next", &Llama3::forward_next)
417
+ .def("generate", &Llama3::generate)
418
+ .def("deinit", &Llama3::deinit)
419
+ .def_readwrite("SEQLEN", &Llama3::SEQLEN) // read SEQLEN in pipeline.py
420
+ .def_readwrite("token_length", &Llama3::token_length)
421
+ .def_readwrite("temperature", &Llama3::temperature)
422
+ .def_readwrite("top_p", &Llama3::top_p)
423
+ .def_readwrite("repeat_penalty", &Llama3::repeat_penalty)
424
+ .def_readwrite("repeat_last_n", &Llama3::repeat_last_n)
425
+ .def_readwrite("max_new_tokens", &Llama3::max_new_tokens)
426
+ .def_readwrite("generation_mode", &Llama3::generation_mode)
427
+ .def_readwrite("prompt_mode", &Llama3::prompt_mode);
428
+ }