//===----------------------------------------------------------------------===// // // Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved. // // TPU-MLIR is licensed under the 2-Clause BSD License except for the // third-party components. // //===----------------------------------------------------------------------===// #include "bmruntime_interface.h" #include "memory.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include "utils.h" static const uint16_t ATTENTION_MASK = 0xF0E2; class ChatGLM { public: void init(const std::vector &devid, std::string model_path); void deinit(); int forward_first(std::vector &tokens); std::string predict_option(std::vector &tokens); int forward_next(); std::vector forward_next_without_topk(); std::mt19937 sgen; ChatGLM() : sgen(std::random_device()()){}; private: void move2end(const bm_device_mem_t &kv); void net_launch(const bm_net_info_t *net, int stage_idx = 0); inline void d2d(bm_device_mem_t &dst, bm_device_mem_t &src); void head_launch(const bm_net_info_t *net, bm_device_mem_t &logits_mem); int greedy_search(const bm_net_info_t *net, bm_device_mem_t &logits_mem); int penalty_sample(const bm_net_info_t *net, bm_device_mem_t &logits_mem); public: int token_length; int SEQLEN; // read from bmodel int NUM_LAYERS; // read from bmodel bool io_alone; std::vector visited_tokens; // generation float temperature; float top_p; float repeat_penalty; int repeat_last_n; int max_new_tokens; std::string generation_mode; std::string prompt_mode; private: std::vector handles; bm_handle_t bm_handle; void *p_bmrt; std::vector net_blocks; std::vector net_blocks_cache; const bm_net_info_t *net_embed; const bm_net_info_t *net_embed_cache; const bm_net_info_t *net_lm, *net_greedy_head, *net_penalty_sample_head; std::vector past_key; std::vector past_value; std::vector option_prompt{316, 347, 319, 367}; // A B C D std::map option_map { {0, "A"}, {1, "B"}, {2, "C"}, {3, "D"} }; }; void ChatGLM::net_launch(const bm_net_info_t *net, int stage_idx) { std::vector in_tensors(net->input_num); std::vector out_tensors(net->output_num); for (int i = 0; i < net->input_num; i++) { bmrt_tensor_with_device( &in_tensors[i], net->stages[stage_idx].input_mems[i], net->input_dtypes[i], net->stages[stage_idx].input_shapes[i]); } for (int i = 0; i < net->output_num; i++) { bmrt_tensor_with_device( &out_tensors[i], net->stages[stage_idx].output_mems[i], net->output_dtypes[i], net->stages[stage_idx].output_shapes[i]); } auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(), net->input_num, out_tensors.data(), net->output_num, true, false); assert(ret); bm_thread_sync(bm_handle); } void ChatGLM::d2d(bm_device_mem_t &dst, bm_device_mem_t &src) { bm_memcpy_d2d_byte(bm_handle, dst, 0, src, 0, bm_mem_get_device_size(src)); } void ChatGLM::init(const std::vector &devices, std::string model_path) { // request bm_handle std::cout << "Device [ "; for (auto d : devices) { std::cout << d << " "; } std::cout << "] loading ....\n"; for (auto d : devices) { bm_handle_t h; bm_status_t status = bm_dev_request(&h, d); assert(BM_SUCCESS == status); handles.push_back(h); } bm_handle = handles[0]; // create bmruntime #ifdef SOC_TARGET p_bmrt = bmrt_create(handles[0]); #else p_bmrt = bmrt_create_ex(handles.data(), handles.size()); #endif assert(NULL != p_bmrt); // load bmodel by file printf("Model[%s] loading ....\n", model_path.c_str()); bool ret = bmrt_load_bmodel(p_bmrt, model_path.c_str()); assert(true == ret); printf("Done!\n"); // net embed and lm_head net_embed = bmrt_get_network_info(p_bmrt, "embedding"); net_embed_cache = bmrt_get_network_info(p_bmrt, "embedding_cache"); net_lm = bmrt_get_network_info(p_bmrt, "lm_head"); net_greedy_head = bmrt_get_network_info(p_bmrt, "greedy_head"); net_penalty_sample_head = bmrt_get_network_info(p_bmrt, "penalty_sample_head"); SEQLEN = net_embed->stages[0].input_shapes[0].dims[1]; // real seqlen auto num_nets = bmrt_get_network_number(p_bmrt); NUM_LAYERS = (num_nets - 5) / 2; // resize visited_tokens.resize(SEQLEN); // net blocks for (int i = 0; i < NUM_LAYERS; i++) { auto block_name = "block_" + std::to_string(i); auto cache_name = "block_cache_" + std::to_string(i); net_blocks.emplace_back(bmrt_get_network_info(p_bmrt, block_name.c_str())); net_blocks_cache.emplace_back( bmrt_get_network_info(p_bmrt, cache_name.c_str())); } // kv cache past_key.resize(NUM_LAYERS); past_value.resize(NUM_LAYERS); auto addr_mode = net_blocks_cache[0]->addr_mode; io_alone = addr_mode == 1; for (int i = 0; i < NUM_LAYERS; i++) { assert(addr_mode == net_blocks_cache[i]->addr_mode); if (io_alone) { past_key[i] = net_blocks_cache[i]->stages[0].input_mems[3]; past_value[i] = net_blocks_cache[i]->stages[0].input_mems[4]; } else { auto ret = bm_malloc_device_byte(bm_handle, &past_key[i], net_blocks_cache[i]->max_input_bytes[3]); assert(BM_SUCCESS == ret); ret = bm_malloc_device_byte(bm_handle, &past_value[i], net_blocks_cache[i]->max_input_bytes[4]); assert(BM_SUCCESS == ret); } } } void ChatGLM::deinit() { if (false == io_alone) { for (int i = 0; i < NUM_LAYERS; i++) { bm_free_device(bm_handle, past_key[i]); bm_free_device(bm_handle, past_value[i]); } } bmrt_destroy(p_bmrt); for (auto h : handles) { bm_dev_free(h); } } // after first block, move real result to end of mem void ChatGLM::move2end(const bm_device_mem_t &kv) { if (token_length >= SEQLEN) { return; } auto total_size = bm_mem_get_device_size(kv); auto bytes = total_size / SEQLEN; auto real_size = token_length * bytes; auto mem = bm_mem_from_device(bm_mem_get_device_addr(kv), real_size); auto buffer = new uint8_t[real_size]; auto dst = new uint8_t[total_size]; bm_memcpy_d2s(bm_handle, (void *)buffer, mem); memset(dst, 0, total_size - real_size); memcpy(dst + total_size - real_size, buffer, real_size); bm_memcpy_s2d(bm_handle, kv, (void *)dst); delete[] buffer; delete[] dst; } void ChatGLM::head_launch(const bm_net_info_t *net, bm_device_mem_t &logits_mem) { std::vector in_tensors(net->input_num); std::vector out_tensors(net->output_num); bmrt_tensor_with_device( &in_tensors[0], logits_mem, net->input_dtypes[0], net->stages[0].input_shapes[0]); for (int i = 1; i < net->input_num; i++) { bmrt_tensor_with_device( &in_tensors[i], net->stages[0].input_mems[i], net->input_dtypes[i], net->stages[0].input_shapes[i]); } for (int i = 0; i < net->output_num; i++) { bmrt_tensor_with_device( &out_tensors[i], net->stages[0].output_mems[i], net->output_dtypes[i], net->stages[0].output_shapes[i]); } auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(), net->input_num, out_tensors.data(), net->output_num, true, false); assert(ret); bm_thread_sync(bm_handle); } int ChatGLM::greedy_search(const bm_net_info_t *net, bm_device_mem_t &logits_mem) { auto &out_mem = net->stages[0].output_mems[0]; head_launch(net, logits_mem); int token = 0; bm_memcpy_d2s(bm_handle, (void *)&token, out_mem); return token; } int ChatGLM::penalty_sample(const bm_net_info_t *net, bm_device_mem_t &logits_mem) { auto &in1_mem = net->stages[0].input_mems[1]; auto &in2_mem = net->stages[0].input_mems[2]; auto &in3_mem = net->stages[0].input_mems[3]; auto &in4_mem = net->stages[0].input_mems[4]; auto &out0_mem = net->stages[0].output_mems[0]; auto &out1_mem = net->stages[0].output_mems[1]; // repeat_penalty + top_p + top_k + temperature std::vector generated_tokens(SEQLEN, visited_tokens[token_length - 1]); repeat_last_n = std::min(repeat_last_n, token_length); std::copy(visited_tokens.begin() + token_length - repeat_last_n, visited_tokens.begin() + token_length, generated_tokens.begin()); bm_memcpy_s2d(bm_handle, in1_mem, (void *)generated_tokens.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)&top_p); bm_memcpy_s2d(bm_handle, in3_mem, (void *)&temperature); bm_memcpy_s2d(bm_handle, in4_mem, (void *)&repeat_penalty); // inference head_launch(net, logits_mem); // get logit & token int candidate_num = net->stages[0].output_shapes[0].dims[1]; std::vector probs(candidate_num); bm_memcpy_d2s(bm_handle, probs.data(), out0_mem); std::vector tokens(candidate_num); bm_memcpy_d2s(bm_handle, tokens.data(), out1_mem); // penalty_sample std::discrete_distribution<> dist(probs.begin(), probs.end()); return tokens[dist(sgen)]; } int ChatGLM::forward_first(std::vector &tokens) { std::vector position_id(SEQLEN, 0); std::vector attention_mask(SEQLEN * SEQLEN, 0); std::copy(tokens.begin(), tokens.end(), visited_tokens.data()); token_length = tokens.size(); for (int i = 0; i < token_length; i++) { position_id[i] = i; } for (int i = 0; i < SEQLEN; i++) { for (int j = 0; j < SEQLEN; j++) { if (j <= i && i < token_length) { } else { attention_mask[i * SEQLEN + j] = ATTENTION_MASK; } } } // forward embeding auto &in_mem = net_embed->stages[0].input_mems[0]; auto &out_mem = net_embed->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)visited_tokens.data()); net_launch(net_embed); // prefil embedding // forward blocks for (int idx = 0; idx < NUM_LAYERS; idx++) { auto &in0_mem = net_blocks[idx]->stages[0].input_mems[0]; auto &in1_mem = net_blocks[idx]->stages[0].input_mems[1]; auto &in2_mem = net_blocks[idx]->stages[0].input_mems[2]; d2d(in0_mem, out_mem); if (idx == 0) { // only first time need copy bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_id.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } net_launch(net_blocks[idx]); out_mem = net_blocks[idx]->stages[0].output_mems[0]; d2d(past_key[idx], net_blocks[idx]->stages[0].output_mems[1]); d2d(past_value[idx], net_blocks[idx]->stages[0].output_mems[2]); move2end(past_key[idx]); move2end(past_value[idx]); } // forward lmhead int bytes = out_mem.size / SEQLEN; auto &lm_in_mem = net_lm->stages[0].input_mems[0]; auto &lm_out_mem = net_lm->stages[0].output_mems[0]; bm_memcpy_d2d_byte(bm_handle, lm_in_mem, 0, out_mem, (token_length - 1) * bytes, bytes); net_launch(net_lm); int token = 0; if (generation_mode == "greedy") { token = greedy_search(net_greedy_head, lm_out_mem); } else if (generation_mode == "penalty_sample") { token = penalty_sample(net_penalty_sample_head, lm_out_mem); } visited_tokens[token_length] = token; token_length += 1; return token; } // for c-eval std::string ChatGLM::predict_option(std::vector &tokens) { int token = forward_first(tokens); token = forward_next(); std::vector logits = forward_next_without_topk(); std::vector option_logits; for(unsigned int i = 0; i < option_prompt.size(); i++){ option_logits.push_back(logits[option_prompt[i]]); } // get the index of maximum and map the index to option_map auto max_it = std::max_element(option_logits.begin(), option_logits.end()); int max_index = std::distance(option_logits.begin(), max_it); return option_map[max_index]; } int ChatGLM::forward_next() { int cur_token = visited_tokens[token_length - 1]; std::vector attention_mask(SEQLEN + 1, 0); for (int i = 0; i <= SEQLEN - token_length; i++) { attention_mask[i] = ATTENTION_MASK; } int32_t position_id = token_length - 1; // embedding auto &in_mem = net_embed_cache->stages[0].input_mems[0]; auto &out_mem = net_embed_cache->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)&cur_token); net_launch(net_embed_cache); // blocks for (int idx = 0; idx < NUM_LAYERS; idx++) { auto &in0_mem = net_blocks_cache[idx]->stages[0].input_mems[0]; auto &in1_mem = net_blocks_cache[idx]->stages[0].input_mems[1]; auto &in2_mem = net_blocks_cache[idx]->stages[0].input_mems[2]; auto &in3_mem = net_blocks_cache[idx]->stages[0].input_mems[3]; auto &in4_mem = net_blocks_cache[idx]->stages[0].input_mems[4]; auto &out0_mem = net_blocks_cache[idx]->stages[0].output_mems[0]; auto &out1_mem = net_blocks_cache[idx]->stages[0].output_mems[1]; auto &out2_mem = net_blocks_cache[idx]->stages[0].output_mems[2]; d2d(in0_mem, out_mem); if (io_alone) { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } else { d2d(in1_mem, net_blocks_cache[0]->stages[0].input_mems[1]); d2d(in2_mem, net_blocks_cache[0]->stages[0].input_mems[2]); } } else { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } in3_mem = past_key[idx]; in4_mem = past_value[idx]; } net_launch(net_blocks_cache[idx]); out_mem = out0_mem; d2d(past_key[idx], out1_mem); d2d(past_value[idx], out2_mem); } // forward lmhead auto &lm_in_mem = net_lm->stages[0].input_mems[0]; auto &lm_out_mem = net_lm->stages[0].output_mems[0]; d2d(lm_in_mem, out_mem); net_launch(net_lm); int token = 0; if (generation_mode == "greedy") { token = greedy_search(net_greedy_head, lm_out_mem); } else if (generation_mode == "penalty_sample") { token = penalty_sample(net_penalty_sample_head, lm_out_mem); } visited_tokens[token_length] = token; token_length += 1; return token; } // for c-eval std::vector ChatGLM::forward_next_without_topk(){ int cur_token = visited_tokens[token_length - 1]; std::vector attention_mask(SEQLEN + 1, 0); for (int i = 0; i <= SEQLEN - token_length; i++) { attention_mask[i] = ATTENTION_MASK; } int32_t position_id = token_length - 1; // embedding auto &in_mem = net_embed_cache->stages[0].input_mems[0]; auto &out_mem = net_embed_cache->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)&cur_token); net_launch(net_embed_cache); // blocks for (int idx = 0; idx < NUM_LAYERS; idx++) { auto &in0_mem = net_blocks_cache[idx]->stages[0].input_mems[0]; auto &in1_mem = net_blocks_cache[idx]->stages[0].input_mems[1]; auto &in2_mem = net_blocks_cache[idx]->stages[0].input_mems[2]; auto &in3_mem = net_blocks_cache[idx]->stages[0].input_mems[3]; auto &in4_mem = net_blocks_cache[idx]->stages[0].input_mems[4]; auto &out0_mem = net_blocks_cache[idx]->stages[0].output_mems[0]; auto &out1_mem = net_blocks_cache[idx]->stages[0].output_mems[1]; auto &out2_mem = net_blocks_cache[idx]->stages[0].output_mems[2]; d2d(in0_mem, out_mem); if (io_alone) { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } else { d2d(in1_mem, net_blocks_cache[0]->stages[0].input_mems[1]); d2d(in2_mem, net_blocks_cache[0]->stages[0].input_mems[2]); } } else { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } in3_mem = past_key[idx]; in4_mem = past_value[idx]; } net_launch(net_blocks_cache[idx]); out_mem = out0_mem; d2d(past_key[idx], out1_mem); d2d(past_value[idx], out2_mem); } // forward lmhead auto &lm_in_mem = net_lm->stages[0].input_mems[0]; auto &lm_out_mem = net_lm->stages[0].output_mems[0]; d2d(lm_in_mem, out_mem); net_launch(net_lm); int vocab_size = net_lm->stages[0].output_shapes[0].dims[1]; std::vector logits(vocab_size); bm_memcpy_d2s(bm_handle, logits.data(), lm_out_mem); return logits; } PYBIND11_MODULE(chat, m) { pybind11::class_(m, "ChatGLM") .def(pybind11::init<>()) .def("init", &ChatGLM::init) .def("deinit", &ChatGLM::deinit) .def("predict_option", &ChatGLM::predict_option) .def_readwrite("SEQLEN", &ChatGLM::SEQLEN) // read SEQLEN in pipeline.py .def_readwrite("token_length", &ChatGLM::token_length) .def_readwrite("temperature", &ChatGLM::temperature) .def_readwrite("top_p", &ChatGLM::top_p) .def_readwrite("repeat_penalty", &ChatGLM::repeat_penalty) .def_readwrite("repeat_last_n", &ChatGLM::repeat_last_n) .def_readwrite("max_new_tokens", &ChatGLM::max_new_tokens) .def_readwrite("generation_mode", &ChatGLM::generation_mode) .def_readwrite("prompt_mode", &ChatGLM::prompt_mode); }