//===----------------------------------------------------------------------===// // // Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved. // // TPU-MLIR is licensed under the 2-Clause BSD License except for the // third-party components. // //===----------------------------------------------------------------------===// #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "bmruntime_interface.h" #include "memory.h" static const int K = 4; static const int GUESS_LEN = K + 1; static const uint16_t ATTENTION_MASK = 0xF0E2; class Qwen { public: void init(const std::vector &devid, std::string draft_model_path, std::string target_model_path); void deinit(); int draft_forward_first(std::vector &tokens); int draft_forward_next(int index); std::pair, std::vector> target_forward_first(std::vector &tokens); std::pair, std::vector> target_forward_next(); std::vector generate(std::vector &history_tokens, int EOS); std::mt19937 sgen; Qwen() : sgen(42){}; private: void net_launch(void *p_bmrt, const bm_net_info_t *net, int stage_idx = 0); inline void d2d(bm_device_mem_t &dst, bm_device_mem_t &src); void head_launch(void *p_bmrt, const bm_net_info_t *net, bm_device_mem_t &logits_mem); int greedy_search(void *p_bmrt, const bm_net_info_t *net, bm_device_mem_t &logits_mem); std::pair, std::vector> penalty_sample(void *p_bmrt, const bm_net_info_t *net, bm_device_mem_t &logits_mem, const std::vector &visited_tokens, int token_length); std::pair, std::vector> batch_penalty_sample( void *p_bmrt, const bm_net_info_t *net, bm_device_mem_t &logits_mem, const std::vector &visited_tokens, int token_length); void roll_back(std::vector &probs, std::vector &tokens, std::vector &prob_history, int index); int sample_from_probs(std::vector &probs, std::vector &tokens); int verify(std::vector &guess_tokens, std::uniform_real_distribution &udist); int resample(std::vector &probs, std::vector &tokens, int accepted); public: int SEQLEN; // read from bmodel int DRAFT_NUM_LAYERS; // read from bmodel int TARGET_NUM_LAYERS; // read from bmodel int candidate_num; // read from bmodel bool draft_io_alone; bool target_io_alone; int VOCAB_SIZE; std::vector draft_visited_tokens; std::vector target_visited_tokens; int draft_token_length; int target_token_length; std::vector draft_prob_history; std::vector target_prob_history; // generation float temperature; float top_p; float repeat_penalty; int repeat_last_n; int max_new_tokens; std::string generation_mode; std::string prompt_mode; private: std::vector handles; bm_handle_t bm_handle; void *d_bmrt; std::vector draft_net_blocks; std::vector draft_net_blocks_cache; const bm_net_info_t *draft_net_embed; const bm_net_info_t *draft_net_embed_cache; const bm_net_info_t *draft_net_lm, *draft_net_greedy_head, *draft_net_penalty_sample_head; std::vector draft_past_key; std::vector draft_past_value; void *t_bmrt; std::vector target_net_blocks; std::vector target_net_blocks_cache; const bm_net_info_t *target_net_embed; const bm_net_info_t *target_net_embed_cache; const bm_net_info_t *target_net_lm, *target_net_greedy_head, *target_net_penalty_sample_head; std::vector target_past_key; std::vector target_past_value; }; void Qwen::net_launch(void *p_bmrt, const bm_net_info_t *net, int stage_idx) { std::vector in_tensors(net->input_num); std::vector out_tensors(net->output_num); for (int i = 0; i < net->input_num; i++) { bmrt_tensor_with_device( &in_tensors[i], net->stages[stage_idx].input_mems[i], net->input_dtypes[i], net->stages[stage_idx].input_shapes[i]); } for (int i = 0; i < net->output_num; i++) { bmrt_tensor_with_device( &out_tensors[i], net->stages[stage_idx].output_mems[i], net->output_dtypes[i], net->stages[stage_idx].output_shapes[i]); } auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(), net->input_num, out_tensors.data(), net->output_num, true, false); assert(ret); bm_thread_sync(bm_handle); } void Qwen::d2d(bm_device_mem_t &dst, bm_device_mem_t &src) { bm_memcpy_d2d_byte(bm_handle, dst, 0, src, 0, bm_mem_get_device_size(src)); } void Qwen::init(const std::vector &devices, std::string draft_model_path, std::string target_model_path) { // request bm_handle std::cout << "Device [ "; for (auto d : devices) { std::cout << d << " "; } std::cout << "] loading ....\n"; for (auto d : devices) { bm_handle_t h; bm_status_t status = bm_dev_request(&h, d); assert(BM_SUCCESS == status); handles.push_back(h); } bm_handle = handles[0]; // create bmruntime #ifdef SOC_TARGET d_bmrt = bmrt_create(handles[0]); t_bmrt = bmrt_create(handles[0]); #else d_bmrt = bmrt_create_ex(handles.data(), handles.size()); t_bmrt = bmrt_create_ex(handles.data(), handles.size()); #endif assert(NULL != d_bmrt); assert(NULL != t_bmrt); // load bmodel by file printf("Model[%s] loading ....\n", draft_model_path.c_str()); assert(true == bmrt_load_bmodel(d_bmrt, draft_model_path.c_str())); printf("Model[%s] loading ....\n", target_model_path.c_str()); assert(true == bmrt_load_bmodel(t_bmrt, target_model_path.c_str())); printf("Done!\n"); // draft net embed and lm_head draft_net_embed = bmrt_get_network_info(d_bmrt, "embedding"); draft_net_embed_cache = bmrt_get_network_info(d_bmrt, "embedding_cache"); draft_net_lm = bmrt_get_network_info(d_bmrt, "lm_head"); draft_net_greedy_head = bmrt_get_network_info(d_bmrt, "greedy_head"); draft_net_penalty_sample_head = bmrt_get_network_info(d_bmrt, "penalty_sample_head"); auto draft_num_nets = bmrt_get_network_number(d_bmrt); DRAFT_NUM_LAYERS = (draft_num_nets - 5) / 2; // draft net blocks for (int i = 0; i < DRAFT_NUM_LAYERS; i++) { auto block_name = "block_" + std::to_string(i); auto cache_name = "block_cache_" + std::to_string(i); draft_net_blocks.emplace_back( bmrt_get_network_info(d_bmrt, block_name.c_str())); draft_net_blocks_cache.emplace_back( bmrt_get_network_info(d_bmrt, cache_name.c_str())); } // draft kv cache draft_past_key.resize(DRAFT_NUM_LAYERS); draft_past_value.resize(DRAFT_NUM_LAYERS); auto draft_addr_mode = draft_net_blocks_cache[0]->addr_mode; draft_io_alone = draft_addr_mode == 1; for (int i = 0; i < DRAFT_NUM_LAYERS; i++) { assert(draft_addr_mode == draft_net_blocks_cache[i]->addr_mode); if (draft_io_alone) { draft_past_key[i] = draft_net_blocks_cache[i]->stages[0].input_mems[3]; draft_past_value[i] = draft_net_blocks_cache[i]->stages[0].input_mems[4]; } else { auto ret = bm_malloc_device_byte(bm_handle, &draft_past_key[i], draft_net_blocks_cache[i]->max_input_bytes[3]); assert(BM_SUCCESS == ret); ret = bm_malloc_device_byte(bm_handle, &draft_past_value[i], draft_net_blocks_cache[i]->max_input_bytes[4]); assert(BM_SUCCESS == ret); } } // target net embed and lm_head target_net_embed = bmrt_get_network_info(t_bmrt, "embedding"); target_net_embed_cache = bmrt_get_network_info(t_bmrt, "embedding_cache"); target_net_lm = bmrt_get_network_info(t_bmrt, "lm_head"); auto target_num_nets = bmrt_get_network_number(t_bmrt); TARGET_NUM_LAYERS = (target_num_nets - 3) / 2; // target net blocks for (int i = 0; i < TARGET_NUM_LAYERS; i++) { auto block_name = "block_" + std::to_string(i); auto cache_name = "block_cache_" + std::to_string(i); target_net_blocks.emplace_back( bmrt_get_network_info(t_bmrt, block_name.c_str())); target_net_blocks_cache.emplace_back( bmrt_get_network_info(t_bmrt, cache_name.c_str())); } // target kv cache target_past_key.resize(TARGET_NUM_LAYERS); target_past_value.resize(TARGET_NUM_LAYERS); auto target_addr_mode = target_net_blocks_cache[0]->addr_mode; target_io_alone = target_addr_mode == 1; for (int i = 0; i < TARGET_NUM_LAYERS; i++) { assert(target_addr_mode == target_net_blocks_cache[i]->addr_mode); if (target_io_alone) { target_past_key[i] = target_net_blocks_cache[i]->stages[0].input_mems[3]; target_past_value[i] = target_net_blocks_cache[i]->stages[0].input_mems[4]; } else { auto ret = bm_malloc_device_byte(bm_handle, &target_past_key[i], target_net_blocks_cache[i]->max_input_bytes[3]); assert(BM_SUCCESS == ret); ret = bm_malloc_device_byte(bm_handle, &target_past_value[i], target_net_blocks_cache[i]->max_input_bytes[4]); assert(BM_SUCCESS == ret); } } // resize assert(draft_net_embed->stages[0].input_shapes[0].dims[1] == target_net_embed->stages[0].input_shapes[0].dims[1]); SEQLEN = draft_net_embed->stages[0].input_shapes[0].dims[1]; VOCAB_SIZE = draft_net_lm->stages[0].output_shapes[0].dims[1]; candidate_num = draft_net_penalty_sample_head->stages[0].output_shapes[0].dims[1]; draft_visited_tokens.resize(SEQLEN); target_visited_tokens.resize(SEQLEN); draft_prob_history.resize(K * VOCAB_SIZE); target_prob_history.resize(K * VOCAB_SIZE); } void Qwen::deinit() { if (false == draft_io_alone) { for (int i = 0; i < DRAFT_NUM_LAYERS; i++) { bm_free_device(bm_handle, draft_past_key[i]); bm_free_device(bm_handle, draft_past_value[i]); } } if (false == target_io_alone) { for (int i = 0; i < TARGET_NUM_LAYERS; i++) { bm_free_device(bm_handle, target_past_key[i]); bm_free_device(bm_handle, target_past_value[i]); } } bmrt_destroy(d_bmrt); bmrt_destroy(t_bmrt); for (auto h : handles) { bm_dev_free(h); } } void Qwen::head_launch(void *p_bmrt, const bm_net_info_t *net, bm_device_mem_t &logits_mem) { std::vector in_tensors(net->input_num); std::vector out_tensors(net->output_num); bmrt_tensor_with_device(&in_tensors[0], logits_mem, net->input_dtypes[0], net->stages[0].input_shapes[0]); for (int i = 1; i < net->input_num; i++) { bmrt_tensor_with_device(&in_tensors[i], net->stages[0].input_mems[i], net->input_dtypes[i], net->stages[0].input_shapes[i]); } for (int i = 0; i < net->output_num; i++) { bmrt_tensor_with_device(&out_tensors[i], net->stages[0].output_mems[i], net->output_dtypes[i], net->stages[0].output_shapes[i]); } auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(), net->input_num, out_tensors.data(), net->output_num, true, false); assert(ret); bm_thread_sync(bm_handle); } int Qwen::greedy_search(void *p_bmrt, const bm_net_info_t *net, bm_device_mem_t &logits_mem) { auto &out_mem = net->stages[0].output_mems[0]; head_launch(p_bmrt, net, logits_mem); int token = 0; bm_memcpy_d2s(bm_handle, (void *)&token, out_mem); return token; } std::pair, std::vector> Qwen::penalty_sample(void *p_bmrt, const bm_net_info_t *net, bm_device_mem_t &logits_mem, const std::vector &visited_tokens, int token_length) { auto &in1_mem = net->stages[0].input_mems[1]; auto &in2_mem = net->stages[0].input_mems[2]; auto &in3_mem = net->stages[0].input_mems[3]; auto &in4_mem = net->stages[0].input_mems[4]; auto &out0_mem = net->stages[0].output_mems[0]; auto &out1_mem = net->stages[0].output_mems[1]; // repeat_penalty + top_p + top_k + temperature std::vector generated_tokens(SEQLEN, visited_tokens[token_length - 1]); repeat_last_n = std::min(repeat_last_n, token_length); std::copy(visited_tokens.begin() + token_length - repeat_last_n, visited_tokens.begin() + token_length, generated_tokens.begin()); bm_memcpy_s2d(bm_handle, in1_mem, (void *)generated_tokens.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)&top_p); bm_memcpy_s2d(bm_handle, in3_mem, (void *)&temperature); bm_memcpy_s2d(bm_handle, in4_mem, (void *)&repeat_penalty); // inference head_launch(p_bmrt, net, logits_mem); // get logit & token int candidate_num = net->stages[0].output_shapes[0].dims[1]; std::vector probs(candidate_num); bm_memcpy_d2s(bm_handle, probs.data(), out0_mem); std::vector tokens(candidate_num); bm_memcpy_d2s(bm_handle, tokens.data(), out1_mem); return std::make_pair(probs, tokens); } std::pair, std::vector> Qwen::batch_penalty_sample( void *p_bmrt, const bm_net_info_t *net, bm_device_mem_t &logits_mem, const std::vector &visited_tokens, int token_length) { auto &in1_mem = net->stages[0].input_mems[1]; auto &in2_mem = net->stages[0].input_mems[2]; auto &in3_mem = net->stages[0].input_mems[3]; auto &in4_mem = net->stages[0].input_mems[4]; auto &out0_mem = net->stages[0].output_mems[0]; auto &out1_mem = net->stages[0].output_mems[1]; // repeat_penalty + top_p + top_k + temperature std::vector generated_tokens(SEQLEN, visited_tokens[token_length - 1]); repeat_last_n = std::min(repeat_last_n, token_length); std::copy(visited_tokens.begin() + token_length - repeat_last_n, visited_tokens.begin() + token_length, generated_tokens.begin()); bm_memcpy_s2d(bm_handle, in1_mem, (void *)generated_tokens.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)&top_p); bm_memcpy_s2d(bm_handle, in3_mem, (void *)&temperature); bm_memcpy_s2d(bm_handle, in4_mem, (void *)&repeat_penalty); // inference head_launch(p_bmrt, net, logits_mem); // get logit & token std::vector probs(candidate_num * GUESS_LEN); bm_memcpy_d2s(bm_handle, probs.data(), out0_mem); std::vector tokens(candidate_num * GUESS_LEN); bm_memcpy_d2s(bm_handle, tokens.data(), out1_mem); return std::make_pair(probs, tokens); } void Qwen::roll_back(std::vector &probs, std::vector &tokens, std::vector &prob_history, int index) { for (size_t i = 0; i < tokens.size(); i++) { prob_history[tokens[i] + index * VOCAB_SIZE] = probs[i]; } } int Qwen::sample_from_probs(std::vector &probs, std::vector &tokens) { std::discrete_distribution<> dist(probs.begin(), probs.end()); return tokens[dist(sgen)]; } //===------------------------------------------------------------===// // Draft Model Forward //===------------------------------------------------------------===// int Qwen::draft_forward_first(std::vector &tokens) { std::vector position_id(SEQLEN, 0); std::vector attention_mask(SEQLEN * SEQLEN, ATTENTION_MASK); std::copy(tokens.begin(), tokens.end(), draft_visited_tokens.data()); draft_token_length = tokens.size(); for (int i = 0; i < draft_token_length; i++) { position_id[i] = i; } for (int i = 0; i < draft_token_length; i++) { for (int j = 0; j < SEQLEN; j++) { if (j <= i) { attention_mask[i * SEQLEN + j] = 0; } } } // forward embeding auto &in_mem = draft_net_embed->stages[0].input_mems[0]; auto &out_mem = draft_net_embed->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)draft_visited_tokens.data()); net_launch(d_bmrt, draft_net_embed); // prefil embedding // forward blocks for (int idx = 0; idx < DRAFT_NUM_LAYERS; idx++) { auto &in0_mem = draft_net_blocks[idx]->stages[0].input_mems[0]; auto &in1_mem = draft_net_blocks[idx]->stages[0].input_mems[1]; auto &in2_mem = draft_net_blocks[idx]->stages[0].input_mems[2]; d2d(in0_mem, out_mem); if (idx == 0) { // only first time need copy bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_id.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } net_launch(d_bmrt, draft_net_blocks[idx]); out_mem = draft_net_blocks[idx]->stages[0].output_mems[0]; d2d(draft_past_key[idx], draft_net_blocks[idx]->stages[0].output_mems[1]); d2d(draft_past_value[idx], draft_net_blocks[idx]->stages[0].output_mems[2]); } // forward lmhead int bytes = out_mem.size / SEQLEN; auto &lm_in_mem = draft_net_lm->stages[0].input_mems[0]; auto &lm_out_mem = draft_net_lm->stages[0].output_mems[0]; bm_memcpy_d2d_byte(bm_handle, lm_in_mem, 0, out_mem, (draft_token_length - 1) * bytes, bytes); net_launch(d_bmrt, draft_net_lm); auto pair = penalty_sample(d_bmrt, draft_net_penalty_sample_head, lm_out_mem, draft_visited_tokens, draft_token_length); auto &candidate_probs = pair.first; auto &candidate_tokens = pair.second; // roll back roll_back(candidate_probs, candidate_tokens, draft_prob_history, 0); auto token = sample_from_probs(candidate_probs, candidate_tokens); draft_visited_tokens[draft_token_length] = token; draft_token_length += 1; return token; } int Qwen::draft_forward_next(int index) { int cur_token = draft_visited_tokens[draft_token_length - 1]; std::vector attention_mask(SEQLEN + 1, 0); for (int i = draft_token_length - 1; i < SEQLEN; i++) { attention_mask[i] = ATTENTION_MASK; } int32_t position_id = draft_token_length - 1; // embedding auto &in_mem = draft_net_embed_cache->stages[0].input_mems[0]; auto &out_mem = draft_net_embed_cache->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)&cur_token); net_launch(d_bmrt, draft_net_embed_cache); // blocks int bytes = bm_mem_get_device_size( draft_net_blocks_cache[0]->stages[0].output_mems[1]); int token_offset = (draft_token_length - 1) * bytes; for (int idx = 0; idx < DRAFT_NUM_LAYERS; idx++) { auto &in0_mem = draft_net_blocks_cache[idx]->stages[0].input_mems[0]; auto &in1_mem = draft_net_blocks_cache[idx]->stages[0].input_mems[1]; auto &in2_mem = draft_net_blocks_cache[idx]->stages[0].input_mems[2]; auto &in3_mem = draft_net_blocks_cache[idx]->stages[0].input_mems[3]; auto &in4_mem = draft_net_blocks_cache[idx]->stages[0].input_mems[4]; auto &out0_mem = draft_net_blocks_cache[idx]->stages[0].output_mems[0]; auto &out1_mem = draft_net_blocks_cache[idx]->stages[0].output_mems[1]; auto &out2_mem = draft_net_blocks_cache[idx]->stages[0].output_mems[2]; d2d(in0_mem, out_mem); if (draft_io_alone) { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } else { d2d(in1_mem, draft_net_blocks_cache[0]->stages[0].input_mems[1]); d2d(in2_mem, draft_net_blocks_cache[0]->stages[0].input_mems[2]); } } else { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } d2d(in3_mem, draft_past_key[idx]); d2d(in4_mem, draft_past_value[idx]); } net_launch(d_bmrt, draft_net_blocks_cache[idx]); out_mem = out0_mem; bm_memcpy_d2d_byte(bm_handle, draft_past_key[idx], token_offset, out1_mem, 0, bytes); bm_memcpy_d2d_byte(bm_handle, draft_past_value[idx], token_offset, out2_mem, 0, bytes); } // forward lmhead auto &lm_in_mem = draft_net_lm->stages[0].input_mems[0]; auto &lm_out_mem = draft_net_lm->stages[0].output_mems[0]; d2d(lm_in_mem, out_mem); net_launch(d_bmrt, draft_net_lm); auto pair = penalty_sample(d_bmrt, draft_net_penalty_sample_head, lm_out_mem, draft_visited_tokens, draft_token_length); auto &candidate_probs = pair.first; auto &candidate_tokens = pair.second; // roll back roll_back(candidate_probs, candidate_tokens, draft_prob_history, index); auto token = sample_from_probs(candidate_probs, candidate_tokens); draft_visited_tokens[draft_token_length] = token; draft_token_length += 1; return token; } //===------------------------------------------------------------===// // Target Model Forward //===------------------------------------------------------------===// std::pair, std::vector> Qwen::target_forward_first(std::vector &tokens) { std::vector position_id(SEQLEN, 0); std::vector attention_mask(SEQLEN * SEQLEN, ATTENTION_MASK); std::copy(tokens.begin(), tokens.end(), target_visited_tokens.data()); target_token_length = tokens.size(); for (int i = 0; i < target_token_length; i++) { position_id[i] = i; } for (int i = 0; i < target_token_length; i++) { for (int j = 0; j < SEQLEN; j++) { if (j <= i) { attention_mask[i * SEQLEN + j] = 0; } } } // forward embeding auto &in_mem = target_net_embed->stages[0].input_mems[0]; auto &out_mem = target_net_embed->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)target_visited_tokens.data()); net_launch(t_bmrt, target_net_embed); // prefil embedding // forward blocks for (int idx = 0; idx < TARGET_NUM_LAYERS; idx++) { auto &in0_mem = target_net_blocks[idx]->stages[0].input_mems[0]; auto &in1_mem = target_net_blocks[idx]->stages[0].input_mems[1]; auto &in2_mem = target_net_blocks[idx]->stages[0].input_mems[2]; d2d(in0_mem, out_mem); if (idx == 0) { // only first time need copy bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_id.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } net_launch(t_bmrt, target_net_blocks[idx]); out_mem = target_net_blocks[idx]->stages[0].output_mems[0]; d2d(target_past_key[idx], target_net_blocks[idx]->stages[0].output_mems[1]); d2d(target_past_value[idx], target_net_blocks[idx]->stages[0].output_mems[2]); } // forward lmhead int bytes = out_mem.size / SEQLEN; auto &lm_in0_mem = target_net_lm->stages[0].input_mems[0]; auto &lm_in1_mem = target_net_lm->stages[0].input_mems[1]; auto &lm_in2_mem = target_net_lm->stages[0].input_mems[2]; auto &lm_in3_mem = target_net_lm->stages[0].input_mems[3]; auto &lm_in4_mem = target_net_lm->stages[0].input_mems[4]; auto &lm_out0_mem = target_net_lm->stages[0].output_mems[0]; auto &lm_out1_mem = target_net_lm->stages[0].output_mems[1]; // repeat_penalty + top_p + top_k + temperature bm_memcpy_d2d_byte(bm_handle, lm_in0_mem, 0, out_mem, (target_token_length - GUESS_LEN) * bytes, GUESS_LEN * bytes); std::vector generated_tokens(SEQLEN, target_visited_tokens[target_token_length - 1]); repeat_last_n = std::min(repeat_last_n, target_token_length); std::copy(target_visited_tokens.begin() + target_token_length - repeat_last_n, target_visited_tokens.begin() + target_token_length, generated_tokens.begin()); bm_memcpy_s2d(bm_handle, lm_in1_mem, (void *)generated_tokens.data()); bm_memcpy_s2d(bm_handle, lm_in2_mem, (void *)&top_p); bm_memcpy_s2d(bm_handle, lm_in3_mem, (void *)&temperature); bm_memcpy_s2d(bm_handle, lm_in4_mem, (void *)&repeat_penalty); // inference net_launch(t_bmrt, target_net_lm); // get logit & token std::vector batch_probs(candidate_num * GUESS_LEN); bm_memcpy_d2s(bm_handle, batch_probs.data(), lm_out0_mem); std::vector batch_tokens(candidate_num * GUESS_LEN); bm_memcpy_d2s(bm_handle, batch_tokens.data(), lm_out1_mem); for (int i = 0; i < K; i++) { std::vector candidate_probs(batch_probs.begin() + i * candidate_num, batch_probs.begin() + (i + 1) * candidate_num); std::vector candidate_tokens(batch_tokens.begin() + i * candidate_num, batch_tokens.begin() + (i + 1) * candidate_num); roll_back(candidate_probs, candidate_tokens, target_prob_history, i); } target_token_length += 1; return std::make_pair(batch_probs, batch_tokens); } std::pair, std::vector> Qwen::target_forward_next() { std::vector cur_tokens( target_visited_tokens.begin() + target_token_length - GUESS_LEN, target_visited_tokens.begin() + target_token_length); std::vector attention_mask((GUESS_LEN) * (SEQLEN + GUESS_LEN), ATTENTION_MASK); std::vector position_ids(GUESS_LEN, 0); for (int i = 0; i < GUESS_LEN; i++) { for (int j = 0; j < target_token_length - GUESS_LEN; j++) { attention_mask[i * (SEQLEN + GUESS_LEN) + j] = 0; } for (int j = SEQLEN; j < SEQLEN + i + 1; j++) { attention_mask[i * (SEQLEN + GUESS_LEN) + j] = 0; } position_ids[i] = target_token_length + i - GUESS_LEN; } // embedding auto &in_mem = target_net_embed_cache->stages[0].input_mems[0]; auto &out_mem = target_net_embed_cache->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)cur_tokens.data()); net_launch(t_bmrt, target_net_embed_cache); // blocks int bytes = bm_mem_get_device_size( target_net_blocks_cache[0]->stages[0].output_mems[1]) / GUESS_LEN; int token_offset = (target_token_length - GUESS_LEN) * bytes; for (int idx = 0; idx < TARGET_NUM_LAYERS; idx++) { auto &in0_mem = target_net_blocks_cache[idx]->stages[0].input_mems[0]; auto &in1_mem = target_net_blocks_cache[idx]->stages[0].input_mems[1]; auto &in2_mem = target_net_blocks_cache[idx]->stages[0].input_mems[2]; auto &in3_mem = target_net_blocks_cache[idx]->stages[0].input_mems[3]; auto &in4_mem = target_net_blocks_cache[idx]->stages[0].input_mems[4]; auto &out0_mem = target_net_blocks_cache[idx]->stages[0].output_mems[0]; auto &out1_mem = target_net_blocks_cache[idx]->stages[0].output_mems[1]; auto &out2_mem = target_net_blocks_cache[idx]->stages[0].output_mems[2]; d2d(in0_mem, out_mem); if (target_io_alone) { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_ids.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } else { d2d(in1_mem, target_net_blocks_cache[0]->stages[0].input_mems[1]); d2d(in2_mem, target_net_blocks_cache[0]->stages[0].input_mems[2]); } } else { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_ids.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } d2d(in3_mem, target_past_key[idx]); d2d(in4_mem, target_past_value[idx]); } net_launch(t_bmrt, target_net_blocks_cache[idx]); out_mem = out0_mem; bm_memcpy_d2d_byte(bm_handle, target_past_key[idx], token_offset, out1_mem, 0, GUESS_LEN * bytes); bm_memcpy_d2d_byte(bm_handle, target_past_value[idx], token_offset, out2_mem, 0, GUESS_LEN * bytes); } // forward lmhead auto &lm_in0_mem = target_net_lm->stages[0].input_mems[0]; auto &lm_in1_mem = target_net_lm->stages[0].input_mems[1]; auto &lm_in2_mem = target_net_lm->stages[0].input_mems[2]; auto &lm_in3_mem = target_net_lm->stages[0].input_mems[3]; auto &lm_in4_mem = target_net_lm->stages[0].input_mems[4]; auto &lm_out0_mem = target_net_lm->stages[0].output_mems[0]; auto &lm_out1_mem = target_net_lm->stages[0].output_mems[1]; // repeat_penalty + top_p + top_k + temperature d2d(lm_in0_mem, out_mem); std::vector generated_tokens(SEQLEN, target_visited_tokens[target_token_length - 1]); repeat_last_n = std::min(repeat_last_n, target_token_length); std::copy(target_visited_tokens.begin() + target_token_length - repeat_last_n, target_visited_tokens.begin() + target_token_length, generated_tokens.begin()); bm_memcpy_s2d(bm_handle, lm_in1_mem, (void *)generated_tokens.data()); bm_memcpy_s2d(bm_handle, lm_in2_mem, (void *)&top_p); bm_memcpy_s2d(bm_handle, lm_in3_mem, (void *)&temperature); bm_memcpy_s2d(bm_handle, lm_in4_mem, (void *)&repeat_penalty); // inference net_launch(t_bmrt, target_net_lm); // get logit & token std::vector batch_probs(candidate_num * GUESS_LEN); bm_memcpy_d2s(bm_handle, batch_probs.data(), lm_out0_mem); std::vector batch_tokens(candidate_num * GUESS_LEN); bm_memcpy_d2s(bm_handle, batch_tokens.data(), lm_out1_mem); for (int i = 0; i < K; i++) { std::vector candidate_probs(batch_probs.begin() + i * candidate_num, batch_probs.begin() + (i + 1) * candidate_num); std::vector candidate_tokens(batch_tokens.begin() + i * candidate_num, batch_tokens.begin() + (i + 1) * candidate_num); roll_back(candidate_probs, candidate_tokens, target_prob_history, i); } target_token_length += 1; return std::make_pair(batch_probs, batch_tokens); } int Qwen::verify(std::vector &guess_tokens, std::uniform_real_distribution &udist) { int accepted = 0; for (size_t i = 0; i < K; i++) { float randomValue = udist(sgen); if (randomValue > target_prob_history[guess_tokens[i] + VOCAB_SIZE * i] / draft_prob_history[guess_tokens[i] + VOCAB_SIZE * i]) { break; } accepted += 1; } return accepted; } int Qwen::resample(std::vector &probs, std::vector &tokens, int accepted) { std::vector modified_probs(candidate_num, 0); std::vector modified_tokens(tokens.begin() + accepted * candidate_num, tokens.begin() + (accepted + 1) * candidate_num); if (accepted == K) { for (int i = 0; i < candidate_num; i++) { modified_probs[i] = probs[accepted * candidate_num + i] - draft_prob_history[tokens[accepted * candidate_num + i] + accepted * VOCAB_SIZE]; } draft_forward_next(0); // important !!! } else { std::copy(probs.begin() + accepted * candidate_num, probs.begin() + (accepted + 1) * candidate_num, modified_probs.begin()); } return sample_from_probs(modified_probs, modified_tokens); } std::vector Qwen::generate(std::vector &history_tokens, int EOS) { if (history_tokens.empty()) { printf("Sorry: your question is empty!!\n"); history_tokens.clear(); return {}; } // make sure token not too large if ((int)history_tokens.size() > SEQLEN - 10) { history_tokens.clear(); printf("Error: your question is too large!\n"); return {}; } int accepted = 0; std::vector guess_tokens; std::vector result_tokens; std::uniform_real_distribution udist(0.0f, 1.0f); // 1. Prefill // draft_model forward K guess_tokens.emplace_back(draft_forward_first(history_tokens)); for (int i = 1; i < K; i++) { guess_tokens.emplace_back(draft_forward_next(i)); } // target_model forward std::vector target_tokens(history_tokens); target_tokens.insert(target_tokens.end(), guess_tokens.begin(), guess_tokens.end()); auto pair = target_forward_first(target_tokens); // Verify accepted = verify(guess_tokens, udist); for (int i = 0; i < accepted; i++) { result_tokens.emplace_back(guess_tokens[i]); } // Resample int last_token = resample(pair.first, pair.second, accepted); result_tokens.emplace_back(last_token); // 2. Decode while (std::find(result_tokens.end() - GUESS_LEN, result_tokens.end(), EOS) == result_tokens.end() && result_tokens.size() < SEQLEN - history_tokens.size() - 10) { guess_tokens.clear(); draft_prob_history.clear(); target_prob_history.clear(); // draft model forward draft_token_length = history_tokens.size() + result_tokens.size(); draft_visited_tokens[draft_token_length - 1] = last_token; for (int i = 0; i < K; i++) { guess_tokens.emplace_back(draft_forward_next(i)); } // target model forward target_token_length = draft_token_length; target_visited_tokens = draft_visited_tokens; pair = target_forward_next(); // verfiy accepted = verify(guess_tokens, udist); for (int i = 0; i < accepted; i++) { result_tokens.emplace_back(guess_tokens[i]); } // resample last_token = resample(pair.first, pair.second, accepted); result_tokens.emplace_back(last_token); } return result_tokens; } PYBIND11_MODULE(chat_speculative, m) { pybind11::class_(m, "Qwen") .def(pybind11::init<>()) .def("init", &Qwen::init) // .def("forward_first", &Qwen::forward_first) // .def("forward_next", &Qwen::forward_next) .def("generate", &Qwen::generate) .def("deinit", &Qwen::deinit) .def_readwrite("SEQLEN", &Qwen::SEQLEN) // read SEQLEN in pipeline.py // .def_readwrite("token_length", &Qwen::token_length) .def_readwrite("temperature", &Qwen::temperature) .def_readwrite("top_p", &Qwen::top_p) .def_readwrite("repeat_penalty", &Qwen::repeat_penalty) .def_readwrite("repeat_last_n", &Qwen::repeat_last_n) .def_readwrite("max_new_tokens", &Qwen::max_new_tokens) .def_readwrite("generation_mode", &Qwen::generation_mode) .def_readwrite("prompt_mode", &Qwen::prompt_mode); }