|
|
|
import ztu_somemodelruntime_rknnlite2 as ort |
|
import numpy as np |
|
from pathlib import Path |
|
from rwkv_tokenizer import RWKV_TOKENIZER |
|
import time |
|
|
|
class RWKVModel: |
|
def __init__(self, model_path: str, tokenizer_path: str = None, use_external_embedding: bool = False): |
|
|
|
session_options = ort.SessionOptions() |
|
|
|
self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'], session_options=session_options) |
|
|
|
|
|
print("\nModel inputs:") |
|
for inp in self.session.get_inputs(): |
|
print(f"{inp.name}: shape={inp.shape}, type={inp.type}") |
|
|
|
|
|
self.n_layer = len([x for x in self.session.get_inputs() if 'state' in x.name]) // 3 |
|
self.n_embd = self.session.get_inputs()[0].shape[-1] if not use_external_embedding else None |
|
|
|
|
|
self.state_shapes = {} |
|
for inp in self.session.get_inputs(): |
|
if 'state' in inp.name: |
|
self.state_shapes[inp.name] = inp.shape |
|
|
|
print("\nNumber of layers:", self.n_layer) |
|
|
|
|
|
if tokenizer_path: |
|
self.tokenizer = RWKV_TOKENIZER(tokenizer_path) |
|
else: |
|
self.tokenizer = None |
|
|
|
|
|
self.use_external_embedding = use_external_embedding |
|
if use_external_embedding: |
|
emb_path = Path(model_path).parent / (Path(model_path).stem + '.emb') |
|
self.embedding = np.fromfile(emb_path, dtype=np.float32) |
|
|
|
vocab_size = len(self.embedding) // 768 |
|
self.embedding = self.embedding.reshape(vocab_size, 768) |
|
self.n_embd = 768 |
|
print(f"\nEmbedding shape: {self.embedding.shape}") |
|
|
|
|
|
self.reset_state() |
|
|
|
def reset_state(self): |
|
"""重置所有状态为0""" |
|
self.states = [] |
|
for i in range(self.n_layer * 3): |
|
state_name = f'state{i}_in' |
|
state_shape = self.state_shapes[state_name] |
|
self.states.append(np.zeros(state_shape, dtype=np.float32)) |
|
|
|
def _prepare_inputs(self, token_id): |
|
"""准备模型输入""" |
|
inputs = {} |
|
|
|
|
|
if self.use_external_embedding: |
|
|
|
embedding = self.embedding[token_id].reshape(1, 1, self.n_embd) |
|
inputs['in'] = embedding.astype(np.float32) |
|
else: |
|
|
|
inputs['in'] = np.array([[token_id]], dtype=np.int64) |
|
|
|
|
|
for i in range(len(self.states)): |
|
inputs[f'state{i}_in'] = self.states[i] |
|
|
|
|
|
if token_id == 0: |
|
print("\nPrepared input shapes:") |
|
for k, v in inputs.items(): |
|
print(f"{k}: shape={v.shape}, type={v.dtype}") |
|
|
|
return inputs |
|
|
|
def forward(self, token_id): |
|
"""单步推理""" |
|
|
|
inputs = self._prepare_inputs(token_id) |
|
|
|
|
|
outputs = self.session.run(None, inputs) |
|
|
|
|
|
if token_id == 0: |
|
print("\nModel outputs:") |
|
for i, out in enumerate(outputs): |
|
print(f"Output {i}: shape={out.shape}, type={out.dtype}") |
|
|
|
|
|
for i in range(len(self.states)): |
|
new_state = outputs[i + 1] |
|
|
|
if new_state.shape != self.states[i].shape: |
|
if token_id == 0: |
|
print(f"\nState shape mismatch for state{i}_in:") |
|
print(f"Expected: {self.states[i].shape}") |
|
print(f"Got: {new_state.shape}") |
|
|
|
if len(self.states[i].shape) == 2: |
|
new_state = new_state.squeeze(1) |
|
elif len(self.states[i].shape) == 3: |
|
new_state = new_state.squeeze(0) |
|
self.states[i] = new_state |
|
|
|
return outputs[0] |
|
|
|
def generate(self, prompt: str, max_length: int = 100, temperature: float = 1.0, stop_tokens: set = None): |
|
"""生成文本""" |
|
if not self.tokenizer: |
|
raise ValueError("需要提供tokenizer才能进行文本生成") |
|
|
|
|
|
tokens = self.tokenizer.encode(prompt) |
|
generated = list(tokens) |
|
|
|
|
|
self.reset_state() |
|
|
|
|
|
print("\nProcessing prompt...", end='', flush=True) |
|
t_start = time.time() |
|
for token in tokens: |
|
logits = self.forward(token) |
|
t_prompt = time.time() - t_start |
|
print(f" Done. ({len(tokens)} tokens, {t_prompt:.2f}s, {len(tokens)/t_prompt:.2f} tokens/s)") |
|
|
|
|
|
print("\nGenerating:", end='', flush=True) |
|
t_start = time.time() |
|
generated_tokens = 0 |
|
|
|
for i in range(max_length): |
|
|
|
t_token_start = time.time() |
|
logits = self.forward(generated[-1]) |
|
|
|
|
|
if i == 0: |
|
print(f"\nLogits shape: {logits.shape}") |
|
|
|
|
|
logits = logits.reshape(-1) |
|
|
|
if temperature > 0: |
|
|
|
logits = logits / temperature |
|
|
|
logits = logits - np.max(logits) |
|
probs = np.exp(logits) |
|
probs = probs / np.sum(probs) |
|
next_token = np.random.choice(len(probs), p=probs) |
|
else: |
|
next_token = np.argmax(logits) |
|
|
|
generated.append(next_token) |
|
generated_tokens += 1 |
|
|
|
|
|
if stop_tokens and next_token in stop_tokens: |
|
break |
|
|
|
|
|
new_text = self.tokenizer.decode([next_token]) |
|
print(new_text, end='', flush=True) |
|
|
|
|
|
|
|
t_generate = time.time() - t_start |
|
print(f"\n\nGeneration finished: {generated_tokens} tokens generated in {t_generate:.2f}s ({generated_tokens/t_generate:.2f} tokens/s)") |
|
|
|
return self.tokenizer.decode(generated) |
|
|
|
def main(): |
|
import time |
|
|
|
|
|
print("Loading model...") |
|
t_start = time.time() |
|
model = RWKVModel( |
|
model_path='RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx', |
|
tokenizer_path='rwkv_vocab_v20230424.txt', |
|
use_external_embedding=True |
|
) |
|
print(f"Model loaded in {time.time() - t_start:.2f}s") |
|
|
|
prompt = "Here is a example of Quick Sort algorithm implemented in C++:\n```cpp" |
|
print(f"\nPrompt: {prompt}") |
|
|
|
generated_text = model.generate( |
|
prompt=prompt, |
|
max_length=1024, |
|
temperature=0.7, |
|
stop_tokens={0, 1, 2, 3} |
|
) |
|
|
|
print("\nFull text:") |
|
print(generated_text) |
|
|
|
if __name__ == '__main__': |
|
main() |