rwkv-7-world-ONNX-RKNN2 / inference.py
happyme531's picture
Upload 10 files
0053ecb verified
# import onnxruntime as ort # Uncomment this line to use onnxruntime
import ztu_somemodelruntime_rknnlite2 as ort # Uncomment this line to use rknnlite2
import numpy as np
from pathlib import Path
from rwkv_tokenizer import RWKV_TOKENIZER
import time
class RWKVModel:
def __init__(self, model_path: str, tokenizer_path: str = None, use_external_embedding: bool = False):
# 加载ONNX模型
session_options = ort.SessionOptions()
# session_options.core_mask = 7 # 00000111 使用0,1,2三个核心
self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'], session_options=session_options)
# 打印模型输入信息
print("\nModel inputs:")
for inp in self.session.get_inputs():
print(f"{inp.name}: shape={inp.shape}, type={inp.type}")
# 获取模型信息
self.n_layer = len([x for x in self.session.get_inputs() if 'state' in x.name]) // 3
self.n_embd = self.session.get_inputs()[0].shape[-1] if not use_external_embedding else None
# 从模型中获取状态向量的维度
self.state_shapes = {}
for inp in self.session.get_inputs():
if 'state' in inp.name:
self.state_shapes[inp.name] = inp.shape
print("\nNumber of layers:", self.n_layer)
# 加载tokenizer
if tokenizer_path:
self.tokenizer = RWKV_TOKENIZER(tokenizer_path)
else:
self.tokenizer = None
# 加载外部embedding(如果需要)
self.use_external_embedding = use_external_embedding
if use_external_embedding:
emb_path = Path(model_path).parent / (Path(model_path).stem + '.emb')
self.embedding = np.fromfile(emb_path, dtype=np.float32)
# 重新组织embedding数组的形状
vocab_size = len(self.embedding) // 768 # 假设embedding维度是768
self.embedding = self.embedding.reshape(vocab_size, 768)
self.n_embd = 768
print(f"\nEmbedding shape: {self.embedding.shape}")
# 初始化状态
self.reset_state()
def reset_state(self):
"""重置所有状态为0"""
self.states = []
for i in range(self.n_layer * 3):
state_name = f'state{i}_in'
state_shape = self.state_shapes[state_name]
self.states.append(np.zeros(state_shape, dtype=np.float32))
def _prepare_inputs(self, token_id):
"""准备模型输入"""
inputs = {}
# 准备主输入
if self.use_external_embedding:
# 使用外部embedding
embedding = self.embedding[token_id].reshape(1, 1, self.n_embd)
inputs['in'] = embedding.astype(np.float32)
else:
# 使用token id
inputs['in'] = np.array([[token_id]], dtype=np.int64)
# 添加状态
for i in range(len(self.states)):
inputs[f'state{i}_in'] = self.states[i]
# 打印输入shape
if token_id == 0: # 只打印第一个token的信息
print("\nPrepared input shapes:")
for k, v in inputs.items():
print(f"{k}: shape={v.shape}, type={v.dtype}")
return inputs
def forward(self, token_id):
"""单步推理"""
# 准备输入
inputs = self._prepare_inputs(token_id)
# 运行推理
outputs = self.session.run(None, inputs)
# 打印输出信息(仅第一次)
if token_id == 0:
print("\nModel outputs:")
for i, out in enumerate(outputs):
print(f"Output {i}: shape={out.shape}, type={out.dtype}")
# 更新状态
for i in range(len(self.states)):
new_state = outputs[i + 1] # 第一个输出是logits
# 确保维度匹配
if new_state.shape != self.states[i].shape:
if token_id == 0:
print(f"\nState shape mismatch for state{i}_in:")
print(f"Expected: {self.states[i].shape}")
print(f"Got: {new_state.shape}")
# 处理维度
if len(self.states[i].shape) == 2: # (1, 768)
new_state = new_state.squeeze(1) # (1, 1, 768) -> (1, 768)
elif len(self.states[i].shape) == 3: # (12, 64, 64)
new_state = new_state.squeeze(0) # (1, 12, 64, 64) -> (12, 64, 64)
self.states[i] = new_state
return outputs[0] # 返回logits
def generate(self, prompt: str, max_length: int = 100, temperature: float = 1.0, stop_tokens: set = None):
"""生成文本"""
if not self.tokenizer:
raise ValueError("需要提供tokenizer才能进行文本生成")
# 编码prompt
tokens = self.tokenizer.encode(prompt)
generated = list(tokens)
# 重置状态
self.reset_state()
# 处理prompt
print("\nProcessing prompt...", end='', flush=True)
t_start = time.time()
for token in tokens:
logits = self.forward(token)
t_prompt = time.time() - t_start
print(f" Done. ({len(tokens)} tokens, {t_prompt:.2f}s, {len(tokens)/t_prompt:.2f} tokens/s)")
# 生成新token
print("\nGenerating:", end='', flush=True)
t_start = time.time()
generated_tokens = 0
for i in range(max_length):
# 获取logits并应用temperature
t_token_start = time.time()
logits = self.forward(generated[-1])
# 打印第一次生成的logits信息
if i == 0:
print(f"\nLogits shape: {logits.shape}")
# 确保logits是1维的
logits = logits.reshape(-1) # 展平成1维
if temperature > 0:
# 应用temperature并计算概率
logits = logits / temperature
# 减去最大值以避免exp溢出
logits = logits - np.max(logits)
probs = np.exp(logits)
probs = probs / np.sum(probs)
next_token = np.random.choice(len(probs), p=probs)
else:
next_token = np.argmax(logits)
generated.append(next_token)
generated_tokens += 1
# 检查是否生成了停止标记
if stop_tokens and next_token in stop_tokens:
break
# 实时输出新生成的token
new_text = self.tokenizer.decode([next_token])
print(new_text, end='', flush=True)
t_generate = time.time() - t_start
print(f"\n\nGeneration finished: {generated_tokens} tokens generated in {t_generate:.2f}s ({generated_tokens/t_generate:.2f} tokens/s)")
return self.tokenizer.decode(generated)
def main():
import time
# 使用示例
print("Loading model...")
t_start = time.time()
model = RWKVModel(
model_path='RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx',
tokenizer_path='rwkv_vocab_v20230424.txt',
use_external_embedding=True
)
print(f"Model loaded in {time.time() - t_start:.2f}s")
prompt = "Here is a example of Quick Sort algorithm implemented in C++:\n```cpp"
print(f"\nPrompt: {prompt}")
generated_text = model.generate(
prompt=prompt,
max_length=1024,
temperature=0.7,
stop_tokens={0, 1, 2, 3} # 特殊token作为停止标记
)
print("\nFull text:")
print(generated_text)
if __name__ == '__main__':
main()