File size: 22,457 Bytes
7c071a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 |
//===----------------------------------------------------------------------===//
//
// Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
//
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
// third-party components.
//
//===----------------------------------------------------------------------===//
#include "bmruntime_interface.h"
#include <iostream>
#include <cstdlib>
#include <vector>
#include <assert.h>
#include <chrono>
#include <algorithm>
#include "memory.h"
#include "sentencepiece/sentencepiece_processor.h"
#include <getopt.h>
#include <fstream>
#include <map>
#include <random>
#include <vector>
static const uint16_t ATTENTION_MASK = 0xF0E2;
class Llama2Chat {
public:
void init(const std::vector<int> &devices,
const std::string &model_path,
const std::string &tokenizer_path);
void chat();
void deinit();
private:
void answer(const std::string &input_str);
int forward_first(std::vector<int> &tokens);
int forward_next(int cur_token);
void net_launch(const std::string &net_name,
std::vector<bm_tensor_t> &inputs,
std::vector<bm_tensor_t> &outputs,
int stage_idx = 0);
void load_sentencepiece(std::string tokenizer_path);
std::string build_prompt(std::string query, std::vector<std::pair<std::string, std::string>> history);
private:
std::vector<bm_handle_t> handles;
bm_handle_t bm_handle;
void *p_bmrt;
std::string name_embed;
std::string name_embed_cache;
std::string name_lm;
std::vector<std::string> name_blocks;
std::vector<std::string> name_blocks_cache;
const bm_net_info_t *net_embed;
const bm_net_info_t *net_embed_cache;
const bm_net_info_t *net_lm;
std::vector<const bm_net_info_t *> net_blocks;
std::vector<const bm_net_info_t *> net_blocks_cache;
std::vector<bm_tensor_t> inputs_embed, inputs_embed_cache;
std::vector<bm_tensor_t> hidden_states, hidden_states_cache;
std::vector<bm_tensor_t> inputs_pid, next_pid;
std::vector<bm_tensor_t> inputs_attention, next_attention;
std::vector<std::vector<bm_tensor_t>> past_keys, past_values;
std::vector<bm_tensor_t> present_key_cache, present_value_cache;
std::vector<bm_tensor_t> inputs_lm, outputs_lm;
int device_num;
int token_length;
int SEQLEN; // read from bmodel
int NUM_LAYERS; // read from bmodel
bool io_alone;
sentencepiece::SentencePieceProcessor sentencepiece;
std::vector<std::pair<std::string, std::string>> history_vector;
std::string sys_config = R"(<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n)";
int EOS;
};
void Llama2Chat::load_sentencepiece(std::string tokenizer_path) {
printf("Load %s ... ", tokenizer_path.c_str());
auto status = sentencepiece.Load(tokenizer_path);
if (!status.ok()) {
std::cout << status.ToString() << std::endl;
exit(-1);
}
EOS = sentencepiece.eos_id();
printf("Done!\n");
}
void Llama2Chat::net_launch(const std::string &net_name,
std::vector<bm_tensor_t> &inputs,
std::vector<bm_tensor_t> &outputs,
int stage_idx) {
bool ret = bmrt_launch_tensor_ex(
p_bmrt, net_name.c_str(), inputs.data(), inputs.size(), outputs.data(),
outputs.size(), true, false);
assert(ret);
bm_thread_sync(bm_handle);
}
void Llama2Chat::init(const std::vector<int> &devices,
const std::string &model_path,
const std::string &tokenizer_path) {
// load tokenizer
load_sentencepiece(tokenizer_path);
// request bm_handle
device_num = devices.size();
std::cout << "Device [ ";
for (auto d : devices) {
std::cout << d << " ";
}
std::cout << "] loading ....\n";
for (auto d : devices) {
bm_handle_t h;
bm_status_t status = bm_dev_request(&h, d);
assert(BM_SUCCESS == status);
handles.push_back(h);
}
bm_handle = handles[0];
// create bmruntime
p_bmrt = bmrt_create_ex(handles.data(), handles.size());
assert(NULL != p_bmrt);
// load bmodel by file
printf("Model[%s] loading ....\n", model_path.c_str());
bool ret = bmrt_load_bmodel(p_bmrt, model_path.c_str());
assert(true == ret);
printf("Done!\n");
// embed, lm_head
name_embed = "embedding";
name_embed_cache = "embedding_cache";
name_lm = "lm_head";
net_embed = bmrt_get_network_info(p_bmrt, name_embed.c_str());
net_embed_cache = bmrt_get_network_info(p_bmrt, name_embed_cache.c_str());
net_lm = bmrt_get_network_info(p_bmrt, name_lm.c_str());
int num_dims = net_embed->stages[0].input_shapes[0].num_dims;
SEQLEN = net_embed->stages[0].input_shapes[0].dims[num_dims - 1]; // real seqlen
auto num_nets = bmrt_get_network_number(p_bmrt);
NUM_LAYERS = (num_nets - 3) / 2;
// blocks
name_blocks.resize(NUM_LAYERS);
name_blocks_cache.resize(NUM_LAYERS);
net_blocks.resize(NUM_LAYERS);
net_blocks_cache.resize(NUM_LAYERS);
for (int i = 0; i < NUM_LAYERS; i++) {
name_blocks[i] = "block_" + std::to_string(i);
name_blocks_cache[i] = "block_cache_" + std::to_string(i);
net_blocks[i] = bmrt_get_network_info(p_bmrt, name_blocks[i].c_str());
net_blocks_cache[i] =
bmrt_get_network_info(p_bmrt, name_blocks_cache[i].c_str());
}
// net device mem
hidden_states.resize(device_num);
hidden_states_cache.resize(device_num);
inputs_embed.resize(device_num);
inputs_embed_cache.resize(device_num);
int out_num = net_blocks[0]->output_num / device_num;
int out_num_cache = net_blocks_cache[0]->output_num / device_num;
for (int i = 0; i < device_num; ++i) {
bmrt_tensor_with_device(
&hidden_states[i],
net_blocks[0]->stages[0].output_mems[0 + i * out_num],
net_blocks[0]->output_dtypes[0 + i * out_num],
net_blocks[0]->stages[0].output_shapes[0 + out_num]);
bmrt_tensor_with_device(
&hidden_states_cache[i],
net_blocks_cache[0]->stages[0].output_mems[0 + i * out_num_cache],
net_blocks_cache[0]->output_dtypes[0 + i * out_num_cache],
net_blocks_cache[0]->stages[0].output_shapes[0 + out_num_cache]);
bmrt_tensor_with_device(
&inputs_embed[i],
net_embed->stages[0].input_mems[i],
net_embed->input_dtypes[i],
net_embed->stages[0].input_shapes[i]);
bmrt_tensor_with_device(
&inputs_embed_cache[i],
net_embed_cache->stages[0].input_mems[i],
net_embed_cache->input_dtypes[i],
net_embed_cache->stages[0].input_shapes[i]);
}
inputs_pid.resize(device_num);
inputs_attention.resize(device_num);
next_pid.resize(device_num);
next_attention.resize(device_num);
int in_num = net_blocks[0]->input_num / device_num;
int in_num_cache = net_blocks_cache[0]->input_num / device_num;
for (int i = 0; i < device_num; ++i) {
ret = bmrt_tensor_ex(
&inputs_pid[i], p_bmrt,
net_blocks[0]->input_loc_devices[1 + i * in_num],
net_blocks[0]->input_dtypes[1 + i * in_num],
net_blocks[0]->stages[0].input_shapes[1 + i * in_num]);
assert(true == ret);
ret = bmrt_tensor_ex(
&inputs_attention[i], p_bmrt,
net_blocks[0]->input_loc_devices[2 + i * in_num],
net_blocks[0]->input_dtypes[2 + i * in_num],
net_blocks[0]->stages[0].input_shapes[2 + i * in_num]);
assert(true == ret);
ret = bmrt_tensor_ex(
&next_pid[i], p_bmrt,
net_blocks_cache[0]->input_loc_devices[1 + i * in_num_cache],
net_blocks_cache[0]->input_dtypes[1 + i * in_num_cache],
net_blocks_cache[0]->stages[0].input_shapes[1 + i * in_num_cache]);
assert(true == ret);
ret = bmrt_tensor_ex(
&next_attention[i], p_bmrt,
net_blocks_cache[0]->input_loc_devices[2 + i * in_num_cache],
net_blocks_cache[0]->input_dtypes[2 + i * in_num_cache],
net_blocks_cache[0]->stages[0].input_shapes[2 + i * in_num_cache]);
assert(true == ret);
}
auto addr_mode = net_blocks_cache[0]->addr_mode;
io_alone = (addr_mode == 1);
past_keys.resize(NUM_LAYERS);
past_values.resize(NUM_LAYERS);
if (io_alone) {
for (int i = 0; i < NUM_LAYERS; i++) {
past_keys[i].resize(device_num);
past_values[i].resize(device_num);
auto &net = net_blocks_cache[i];
for (int j = 0; j < device_num; j++) {
bmrt_tensor_with_device(
&past_keys[i][j],
net->stages[0].input_mems[3 + j * in_num_cache],
net->input_dtypes[3 + j * in_num_cache],
net->stages[0].input_shapes[3 + j * in_num_cache]);
bmrt_tensor_with_device(
&past_values[i][j],
net->stages[0].input_mems[4 + j * in_num_cache],
net->input_dtypes[4 + j * in_num_cache],
net->stages[0].input_shapes[4 + j * in_num_cache]);
}
}
} else {
for (int i = 0; i < NUM_LAYERS; i++) {
past_keys[i].resize(device_num);
past_values[i].resize(device_num);
auto &net = net_blocks_cache[i];
for (int j = 0; j < device_num; j++) {
ret = bmrt_tensor_ex(
&past_keys[i][j], p_bmrt,
net->input_loc_devices[3 + j * in_num_cache],
net->input_dtypes[3 + j * in_num_cache],
net->stages[0].input_shapes[3 + j * in_num_cache]);
assert(true == ret);
ret = bmrt_tensor_ex(
&past_values[i][j], p_bmrt,
net->input_loc_devices[4 + j * in_num_cache],
net->input_dtypes[4 + j * in_num_cache],
net->stages[0].input_shapes[4 + j *in_num_cache]);
assert(true == ret);
}
}
}
present_key_cache.resize(device_num);
present_value_cache.resize(device_num);
inputs_lm.resize(device_num);
outputs_lm.resize(device_num);
for (int i = 0; i < device_num; ++i) {
present_key_cache[i] = past_keys[0][i];
present_value_cache[i] = past_values[0][i];
present_key_cache[i].shape.dims[1] = 1;
present_value_cache[i].shape.dims[1] = 1;
ret = bmrt_tensor_ex(&inputs_lm[i], p_bmrt, i, net_lm->input_dtypes[0],
net_lm->stages[0].input_shapes[0]);
assert(true == ret);
ret = bmrt_tensor_ex(&outputs_lm[i], p_bmrt, i, net_lm->output_dtypes[0],
net_lm->stages[0].output_shapes[0]);
assert(true == ret);
}
}
void Llama2Chat::deinit() {
for (int i = 0; i < device_num; ++i) {
bm_free_device(handles[i], inputs_pid[i].device_mem);
bm_free_device(handles[i], next_pid[i].device_mem);
bm_free_device(handles[i], inputs_attention[i].device_mem);
bm_free_device(handles[i], next_attention[i].device_mem);
bm_free_device(handles[i], inputs_lm[i].device_mem);
bm_free_device(handles[i], outputs_lm[i].device_mem);
}
if (!io_alone) {
for (int i = 0; i < NUM_LAYERS; i++) {
for (int j = 0; j < device_num; j++) {
bm_free_device(handles[j], past_keys[i][j].device_mem);
bm_free_device(handles[j], past_values[i][j].device_mem);
}
}
}
bmrt_destroy(p_bmrt);
for (auto h : handles) {
bm_dev_free(h);
}
}
std::string Llama2Chat::build_prompt(std::string query, std::vector<std::pair<std::string, std::string>> history_vector) {
std::string prompt = sys_config;
for (const auto& item : history_vector) {
prompt += item.first + " [/INST] " + item.second + "</s><s>[INST]] ";
}
prompt += query + " [/INST] ";
return prompt;
}
int Llama2Chat::forward_first(std::vector<int> &tokens) {
std::vector<int> input_ids(SEQLEN, 0);
std::vector<int> position_id(SEQLEN, 0);
std::vector<uint16_t> attention_mask(SEQLEN * SEQLEN, ATTENTION_MASK);
std::copy(tokens.begin(), tokens.end(), input_ids.data());
for (int i = 0; i < token_length; i++) {
position_id[i] = i;
}
for (int i = 0; i < token_length; i++) {
for (int j = 0; j < SEQLEN; j++) {
if (j <= i) {
attention_mask[i * SEQLEN + j] = 0;
}
}
}
// forward embeding
std::vector<int> input_nums(device_num, 1);
std::vector<void*> datas(device_num, (void*)input_ids.data());
bmrt_memcpy_s2d_parallel(p_bmrt, inputs_embed.data(), datas.data(),
input_nums.data(), device_num);
auto output_embeds = hidden_states;
for (int i = 0; i < device_num; ++i) {
output_embeds[i].shape = net_embed[0].stages[0].output_shapes[0];
}
auto ret = bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(),
inputs_embed.data(), inputs_embed.size(),
output_embeds.data(), output_embeds.size(),
true, false);
assert(ret);
bm_thread_sync(bm_handle);
// forward blocks
std::vector<void*> pos_id_datas(device_num, (void*)position_id.data());
std::vector<void*> in_attn_datas(device_num, (void*)attention_mask.data());
bmrt_memcpy_s2d_parallel(p_bmrt, inputs_pid.data(), pos_id_datas.data(),
input_nums.data(), device_num);
bmrt_memcpy_s2d_parallel(p_bmrt, inputs_attention.data(),in_attn_datas.data(),
input_nums.data(), device_num);
auto tmp_hidden_states = hidden_states;
std::vector<bm_tensor_t> inputs_block;
std::vector<bm_tensor_t> outputs_block;
for (int i = 0; i < device_num; ++i) {
tmp_hidden_states[i].shape = net_blocks[0]->stages[0].input_shapes[0];
inputs_block.push_back(tmp_hidden_states[i]);
inputs_block.push_back(inputs_pid[i]);
inputs_block.push_back(inputs_attention[i]);
outputs_block.push_back(tmp_hidden_states[i]);
outputs_block.push_back(past_keys[0][i]);
outputs_block.push_back(past_values[0][i]);
}
for (int i = 0; i < NUM_LAYERS; i++) {
for (int j = 0; j < device_num; ++j) {
outputs_block[1 + j * 3] = past_keys[i][j];
outputs_block[2 + j * 3] = past_values[i][j];
}
net_launch(name_blocks[i], inputs_block, outputs_block);
}
int bytes = hidden_states[0].device_mem.size / SEQLEN;
bm_memcpy_d2d_byte(bm_handle, inputs_lm[0].device_mem, 0,
hidden_states[0].device_mem, (token_length - 1) * bytes,
bytes);
ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm[0], 1,
&outputs_lm[0], 1,
true, false);
assert(ret);
bm_thread_sync(bm_handle);
int token = 0;
bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm[0].device_mem);
return token;
}
int Llama2Chat::forward_next(int cur_token) {
std::vector<uint16_t> attention_mask(SEQLEN + 1, 0);
for (int i = token_length - 1; i < SEQLEN; i++) {
attention_mask[i] = ATTENTION_MASK;
}
int32_t position_id = token_length - 1;
// embedding
// std::vector<bm_tensor_t> inputs_embed;
std::vector<void*> input_datas;
std::vector<int> input_nums(device_num, 1);
for (int i = 0; i < device_num; ++i) {
// inputs_embed_cache.push_back(outputs_lm[i]); // token_id
// inputs_embed_cache[i].shape = net_embed_cache->stages[0].input_shapes[0];
input_datas.push_back((void*)(&cur_token));
}
bmrt_memcpy_s2d_parallel(p_bmrt, inputs_embed_cache.data(), input_datas.data(),
input_nums.data(), device_num);
auto outputs_embed_cache = hidden_states_cache;
for (int i = 0; i < device_num; ++i) {
outputs_embed_cache[i].shape = net_embed_cache[0].stages[0].output_shapes[0];
}
auto ret = bmrt_launch_tensor_ex(
p_bmrt, name_embed_cache.c_str(), inputs_embed_cache.data(),
inputs_embed_cache.size(), outputs_embed_cache.data(),
outputs_embed_cache.size(), true, false);
assert(ret);
bm_thread_sync(bm_handle);
// blocks
std::vector<void*> attn_datas(device_num, attention_mask.data());
std::vector<void*> pid_datas(device_num, &position_id);
bmrt_memcpy_s2d_parallel(p_bmrt, next_attention.data(), attn_datas.data(),
input_nums.data(), device_num);
bmrt_memcpy_s2d_parallel(p_bmrt, next_pid.data(), pid_datas.data(),
input_nums.data(), device_num);
// WARNING: make inputs_lm device_num
std::vector<bm_tensor_t> embed_1 = hidden_states_cache;
for (int i = 0; i < device_num; ++i) {
embed_1[i].shape = net_blocks_cache[0]->stages[0].input_shapes[0];
}
std::vector<bm_tensor_t> inputs_block;
std::vector<bm_tensor_t> outputs_block;
for (int i = 0; i < device_num; ++i) {
inputs_block.push_back(embed_1[i]);
inputs_block.push_back(next_pid[i]);
inputs_block.push_back(next_attention[i]);
inputs_block.push_back(past_keys[0][i]);
inputs_block.push_back(past_values[0][i]);
outputs_block.push_back(embed_1[i]);
outputs_block.push_back(present_key_cache[i]);
outputs_block.push_back(present_value_cache[i]);
}
for (int i = 0; i < NUM_LAYERS; i++) {
for (int j = 0; j < device_num; ++j) {
inputs_block[3 + j * 5] = past_keys[i][j];
inputs_block[4 + j * 5] = past_values[i][j];
int bytes = bm_mem_get_device_size(past_keys[0][j].device_mem) / SEQLEN;
int token_offset = (token_length - 1) * bytes;
bm_set_device_mem(&outputs_block[1 + j * 3].device_mem, bytes,
bm_mem_get_device_addr(past_keys[i][j].device_mem) + token_offset);
bm_set_device_mem(&outputs_block[2 + j * 3].device_mem, bytes,
bm_mem_get_device_addr(past_values[i][j].device_mem) + token_offset);
}
net_launch(name_blocks_cache[i], inputs_block, outputs_block);
}
ret = bmrt_launch_tensor_ex(
p_bmrt, name_lm.c_str(), &hidden_states_cache[0], 1, &outputs_lm[0], 1,
true, false);
assert(ret);
bm_thread_sync(bm_handle);
int token = 0;
bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm[0].device_mem);
return token;
}
void Llama2Chat::chat() {
while (true) {
std::cout << "\nQuestion: ";
std::string input_str;
std::getline(std::cin, input_str);
if (input_str == "exit") {
break;
}
std::cout << "\nAnswer: " << std::flush;
answer(input_str);
std::cout << std::endl;
}
}
void Llama2Chat::answer(const std::string &input_str) {
std::string sentence_input = build_prompt(input_str, history_vector);
int tok_num = 1;
std::vector<int> tokens;
sentencepiece.Encode(sentence_input, &tokens);
int pre_token = 0;
auto t0 = std::chrono::system_clock::now();
token_length = tokens.size();
int token = forward_first(tokens);
auto t1 = std::chrono::system_clock::now();
std::string result;
while (token != EOS && token_length < SEQLEN) {
std::string pre_word;
std::string word;
std::vector<int> pre_ids = {pre_token};
std::vector<int> ids = {pre_token, token};
sentencepiece.Decode(pre_ids, &pre_word);
sentencepiece.Decode(ids, &word);
std::string diff = word.substr(pre_word.size());
result += diff;
std::cout << diff << std::flush;
tok_num++;
token_length++;
token = forward_next(token);
}
auto t2 = std::chrono::system_clock::now();
auto use0 = std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0);
auto use1 = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1);
printf("\n\nfirst token latency: %f s", (use0.count() * 1e-6));
printf("\nspeed: %f token/s\n", tok_num / (use1.count() * 1e-6));
if (token_length >= SEQLEN) {
history_vector.push_back({input_str, result});
result.clear();
size_t half_size = history_vector.size() / 2;
history_vector.erase(history_vector.begin(), history_vector.begin() + half_size);
} else {
history_vector.push_back({input_str, result});
result.clear();
}
}
static void split(const std::string &s, const std::string &delim,
std::vector<std::string> &ret) {
size_t last = 0;
size_t index = s.find_first_of(delim, last);
while (index != std::string::npos) {
ret.push_back(s.substr(last, index - last));
last = index + 1;
index = s.find_first_of(delim, last);
}
if (last < s.length()) {
ret.push_back(s.substr(last));
}
}
static std::vector<int> parseCascadeDevices(const std::string &str) {
std::vector<int> devices;
std::vector<std::string> sub_str;
split(str, ",", sub_str);
for (auto &s : sub_str) {
devices.push_back(std::atoi(s.c_str()));
}
return devices;
}
void Usage() {
printf("Usage:\n"
" --help : Show help info.\n"
" --model : Set model path \n"
" --tokenizer : Set tokenizer path \n"
" --devid : Set devices to run for model, e.g. 1,2. if not "
"set, use 0\n");
}
void processArguments(int argc, char *argv[],
std::string &model_path,
std::string &tokenizer_path,
std::vector<int> &devices) {
struct option longOptions[] = {{"model", required_argument, nullptr, 'm'},
{"tokenizer", required_argument, nullptr, 't'},
{"devid", required_argument, nullptr, 'd'},
{"help", no_argument, nullptr, 'h'},
{nullptr, 0, nullptr, 0}};
int optionIndex = 0;
int option;
while ((option = getopt_long(argc, argv, "m:t:d:h:", longOptions,
&optionIndex)) != -1) {
switch (option) {
case 'm':
model_path = optarg;
break;
case 't':
tokenizer_path = optarg;
break;
case 'd':
devices = parseCascadeDevices(optarg);
break;
case 'h':
Usage();
exit(EXIT_SUCCESS);
case '?':
Usage();
exit(EXIT_FAILURE);
default:
exit(EXIT_FAILURE);
}
}
}
int main(int argc, char **argv) {
// set your bmodel path here
printf("Demo for Llama2Chat in BM1684X\n");
std::string model_path;
std::string tokenizer_path;
std::vector<int> devices = {0};
processArguments(argc, argv, model_path, tokenizer_path, devices);
if (model_path.empty()) {
Usage();
exit(EXIT_FAILURE);
}
Llama2Chat llama2;
printf("Init Environment ...\n");
llama2.init(devices, model_path, tokenizer_path);
printf("==========================\n");
llama2.chat();
llama2.deinit();
return 0;
}
|