#!/usr/bin/env python3 # ============================================================================== # # Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved. # # TPU-MLIR is licensed under the 2-Clause BSD License except for the # third-party components. # # ============================================================================== import os import torch import argparse from tqdm import tqdm from transformers import AutoModel, AutoTokenizer parser = argparse.ArgumentParser(description='export onnx') parser.add_argument('-m', '--model_path', type=str, help='path to the torch model') parser.add_argument('-s', '--seq_length', type=int, default=512, help="sequence length") parser.add_argument('-d', '--device', type=str, choices=["cpu", "cuda"], default="cpu") args = parser.parse_args() model_path = args.model_path folder = f"./tmp/onnx" device = torch.device(args.device) origin_model = AutoModel.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.float, device_map='auto').eval() tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) for param in origin_model.parameters(): param.requires_grad = False config = origin_model.config transformer = origin_model.transformer layers = transformer.encoder.layers SEQ_LENGTH = transformer.seq_length NUM_LAYERS = config.num_layers HIDDEN_SIZE = config.hidden_size NUM_ATTENTION_HEADS = config.num_attention_heads HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS VOCAB_SIZE = config.vocab_size print(f'Layers: {NUM_LAYERS}\nHidden size: {HIDDEN_SIZE}\n') if transformer.seq_length is not None: assert transformer.seq_length == args.seq_length if config.seq_length is not None: assert config.seq_length == args.seq_length class Embedding(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_ids): return transformer.embedding.word_embeddings(input_ids) class Block(torch.nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id self.layer = layers[layer_id] def forward(self, hidden_states, position_ids, attention_mask): rotary_pos_emb = transformer.rotary_pos_emb(SEQ_LENGTH)[position_ids] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() hidden_states, past_kv = self.layer(hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb) return hidden_states, past_kv class BlockCache(torch.nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id self.layer = layers[layer_id] def forward(self, hidden_states, position_ids, attention_mask, past_k, past_v): rotary_pos_emb = transformer.rotary_pos_emb(SEQ_LENGTH)[position_ids] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() hidden_states, past_kv = self.layer(hidden_states, attention_mask, kv_cache=(past_k, past_v), rotary_pos_emb=rotary_pos_emb) present_k, present_v = past_kv return hidden_states, present_k[1:], present_v[1:] class LmHead(torch.nn.Module): def __init__(self): super().__init__() def forward(self, hidden_states): hidden_states = transformer.encoder.final_layernorm(hidden_states) m_logits = transformer.output_layer(hidden_states) return m_logits class GreedyHead(torch.nn.Module): def __init__(self): super().__init__() def forward(self, m_logits): _, token = torch.topk(m_logits.float(), 1) return token # refs:https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py class PenaltySampleHead(torch.nn.Module): def __init__(self, top_k = 50, min_tokens_to_keep = 5): super().__init__() self.top_k = top_k self.min_tokens_to_keep = min_tokens_to_keep self.keep_matrix = torch.zeros((1, self.top_k), dtype=torch.bool) self.keep_matrix[0, :self.min_tokens_to_keep] = True def forward(self, m_logits, input_ids, top_p, temperature, penalty): # repeat penalty logits = torch.gather(m_logits, 1, input_ids) logits = torch.where(logits < 0, logits * penalty, logits / penalty) m_logits.scatter_(1, input_ids, logits) # top_k logits, token = torch.topk(m_logits.float(), self.top_k) # temperature logits = logits / temperature # top_p cumulative_probs = logits.softmax(dim=1).cumsum(dim=1) mask = cumulative_probs < top_p mask = mask + self.keep_matrix filtered_logits = torch.where(mask, logits, torch.FloatTensor([-1000.])) probs = filtered_logits.softmax(dim=1) return probs, token def convert_block(layer_id): model = Block(layer_id) hidden_states = torch.randn((SEQ_LENGTH, 1, HIDDEN_SIZE), dtype = torch.float).to(device) position_ids = torch.tensor([range(SEQ_LENGTH)], dtype=torch.long).to(device) attention_mask = -1000 * torch.ones((1, 1, SEQ_LENGTH, SEQ_LENGTH), dtype = torch.float).triu(diagonal=1).to(device) torch.onnx.export( model, (hidden_states, position_ids, attention_mask), f'{folder}/block_{layer_id}.onnx', verbose=False, input_names=['input_states', 'position_ids', 'attention_mask'], output_names=['hidden_states', 'past_k', 'past_v'], do_constant_folding=True, opset_version=15) def convert_block_cache(layer_id): model = BlockCache(layer_id) hidden_states = torch.randn((1, 1, HIDDEN_SIZE), dtype = torch.float).to(device) position_ids = torch.tensor([range(1)], dtype=torch.long).to(device) attention_mask = -1000 * torch.ones((1, 1, 1, SEQ_LENGTH + 1), dtype = torch.float).triu(diagonal=1).to(device) past_k = torch.randn((SEQ_LENGTH, 1, 2, HEAD_DIM), dtype = torch.float).to(device) past_v = torch.randn((SEQ_LENGTH, 1, 2, HEAD_DIM), dtype = torch.float).to(device) torch.onnx.export( model, (hidden_states, position_ids, attention_mask, past_k, past_v), f'{folder}/block_cache_{layer_id}.onnx', verbose=False, input_names=[ 'input_states', 'position_ids', 'attention_mask', 'history_k', 'history_v' ], output_names=['hidden_states', 'past_k', 'past_v'], do_constant_folding=True, opset_version=15) def convert_embedding(): model = Embedding() input_ids = torch.tensor([range(SEQ_LENGTH)]).to(device) torch.onnx.export(model, (input_ids), f'{folder}/embedding.onnx', verbose=False, input_names=['input_ids'], output_names=['input_embed'], do_constant_folding=True, opset_version=15) def convert_lm_head(): model = LmHead() input = torch.randn(1, HIDDEN_SIZE).to(device) torch.onnx.export(model, (input), f'{folder}/lm_head.onnx', verbose=False, input_names=['hidden_states'], output_names=['token'], do_constant_folding=True, opset_version=15) def convert_greedy_head(): model = GreedyHead() m_logits = torch.randn(1, VOCAB_SIZE) torch.onnx.export( model, (m_logits), f'{folder}/greedy_head.onnx', verbose=False, input_names=['m_logits'], output_names=['token'], do_constant_folding=True, opset_version=15) def convert_penalty_sample_head(): model = PenaltySampleHead() m_logits = torch.randn(1, VOCAB_SIZE) input_ids = torch.tensor([range(SEQ_LENGTH)]) top_p = torch.tensor([0.8]) temperature = torch.tensor([0.98]) penalty = torch.tensor([0.98]) torch.onnx.export( model, (m_logits, input_ids, top_p, temperature, penalty), f'{folder}/penalty_sample_head.onnx', verbose=False, input_names=[ 'm_logits', 'input_ids', 'top_p', 'temperature', 'penalty' ], output_names=['probs', 'token'], do_constant_folding=True, opset_version=15) def test_net_with_mask(): embed = Embedding() blocks = [Block(i) for i in range(NUM_LAYERS)] block_kvs = [BlockCache(i) for i in range(NUM_LAYERS)] # ids = tokenizer.encode('你好') # system_prompt = "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown." # history = [{"role": "system", "content": system_prompt}] history = [] visited_token = [] query = '你好' print(query) ids = tokenizer.build_chat_input(query, history=history, role="user") ids = ids["input_ids"][0].tolist() # import pdb; pdb.set_trace() visited_token = visited_token + ids print("input ids:{}".format(ids)) token_len = len(ids) ids = ids + (SEQ_LENGTH - token_len) * [0] input_ids = torch.tensor(ids).view(SEQ_LENGTH) out = embed(input_ids).view(SEQ_LENGTH, 1, 4096) position_ids = list(range(token_len)) + (SEQ_LENGTH - token_len) * [0] position_ids = torch.tensor([position_ids]) attention_mask = torch.full((SEQ_LENGTH, SEQ_LENGTH), -1000.0) for i in range(token_len): for j in range(token_len): if j <= i: attention_mask[i][j] = 0 attention_mask = attention_mask.view(1, 1, SEQ_LENGTH, SEQ_LENGTH) # attention_mask = None k_cache = [] v_cache = [] for i in range(NUM_LAYERS): # import numpy as np # np.savez("block0_input.npz", input_states=out, position_ids=position_ids, attention_mask=attention_mask) # if i == 20: # import pdb; pdb.set_trace() out, kv_cache = blocks[i](out, position_ids, attention_mask) # import pdb; pdb.set_trace() k, v = kv_cache k[SEQ_LENGTH - token_len:] = k[:token_len] v[SEQ_LENGTH - token_len:] = v[:token_len] k[:SEQ_LENGTH - token_len] = 0 v[:SEQ_LENGTH - token_len] = 0 k_cache.append(k) v_cache.append(v) # import pdb; pdb.set_trace() out = out[token_len - 1:token_len].view(1, 4096) lm = LmHead() greedyhead = GreedyHead() lm_out = lm(out) token = greedyhead(lm_out) visited_token.append(int(token)) out_ids = [int(token)] word = tokenizer._convert_id_to_token(int(token[0])) print(word, end="") while token > 2 and token_len < 640: token_len += 1 # import pdb;pdb.set_trace() input_ids = torch.tensor([token]) out = embed(input_ids).view(1, 1, 4096) position_ids = torch.tensor([[token_len - 1]]) attention_mask = torch.ones((1, 1, 1, SEQ_LENGTH + 1))*-1000 attention_mask[:, :, :, SEQ_LENGTH + 1 - token_len:] = 0 for i in range(NUM_LAYERS): if i == 27: import pdb;pdb.set_trace() out, k_cache[i], v_cache[i] = block_kvs[i](out, position_ids, attention_mask, k_cache[i], v_cache[i]) k_cache[i][:SEQ_LENGTH - token_len] = 0 v_cache[i][:SEQ_LENGTH - token_len] = 0 # import pdb;pdb.set_trace() lm_out = lm(out) token = greedyhead(lm_out) visited_token.append(int(token)) out_ids.append(int(token)) word = tokenizer._convert_id_to_token(int(token[0])) import pdb; pdb.set_trace() print(word, end="") print("\noutput_ids:{}".format(out_ids)) # test_net_with_mask() # create folder to store onnx if not os.path.exists(folder): os.makedirs(folder) # export models print(f'Convert block & block_cache') for i in tqdm(range(NUM_LAYERS)): convert_block(i) convert_block_cache(i) print(f'Convert embedding') convert_embedding() print(f'Convert lm_head') convert_lm_head() convert_greedy_head() convert_penalty_sample_head()