RWKlogo.png

Model card for RWKV-4 | 3B parameters trained on Pile dataset

RWKV is a project led by Bo Peng. Learn more about the model architecture in the blogposts from Johan Wind here and here. Learn more about the project by joining the RWKV discord server.

Table of contents

  1. TL;DR
  2. Model Details
  3. Usage
  4. Citation

TL;DR

Below is the description from the original repository

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). It's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

Model Details

The details of the architecture can be found on the blogpost mentioned above and the Hugging Face blogpost of the integration.

Usage

Convert the raw weights to the HF format

You can use the convert_rwkv_checkpoint_to_hf.py script by specifying the repo_id of the original weights, the filename and the output directory. You can also optionally directly push the converted model on the Hub by passing --push_to_hub flag and --model_name argument to specify where to push the converted weights.

python convert_rwkv_checkpoint_to_hf.py --repo_id RAW_HUB_REPO --checkpoint_file RAW_FILE --output_dir OUTPUT_DIR --push_to_hub --model_name dummy_user/converted-rwkv

Generate text

You can use the AutoModelForCausalLM and AutoTokenizer classes to generate texts from the model. Expand the sections below to understand how to run the model in different scenarios:

Running the model on a CPU

Click to expand
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-3b-pile")
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-3b-pile")

prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))

Running the model on a single GPU

Click to expand
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-3b-pile").to(0)
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-3b-pile")

prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))

Running the model in half-precision, on GPU

Click to expand
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-3b-pile", torch_dtype=torch.float16).to(0)
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-3b-pile")

prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))

Running the model multiple GPUs

Click to expand
# pip install accelerate
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-3b-pile", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-3b-pile")

prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))

Citation

If you use this model, please consider citing the original work, from the original repo here

Downloads last month
663
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train RWKV/rwkv-4-3b-pile

Spaces using RWKV/rwkv-4-3b-pile 21