You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

FlexPrefill

This repository provides the code for the paper "FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference".

TL;DR

FlexPrefill is a dynamic and context-aware sparse attention mechanism that optimizes computational efficiency during long-sequence inference for large language models (LLMs). It achieves this by dynamically adjusting sparse attention patterns and computational budgets in real-time based on input demands and attention head requirements.

Requirements

To use FlexPrefill, you will need the following packages:

  • torch==2.4.0
  • triton==3.0.0
  • transformers==4.44.0
  • flash_attn (optional)
  • vllm==0.5.4 (optional)

Quick Start

Example Test

To run tests using a specific model, you can use the test script located in tests/test_llm.py:

python tests/test_llm.py --model meta-llama/Llama-3.1-8B-Instruct --pattern flex_prefill --engine hf

FlexPrefill Sparse Attention Function

import torch
from flex_prefill import flex_prefill_attention

B, N, H, D = 1, 64000, 32, 64
gamma = 0.9
tau = 0.1

q = torch.randn(B, N, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, N, H // 4, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, N, H // 4, D, device="cuda", dtype=torch.bfloat16)

flex_prefill_output, computational_ratio = flex_prefill_attention(
    q,
    k,
    v,
    gamma,
    tau,
    min_budget=1024,
    max_budget=None,
    gqa_interleave=False,
    block_size=128,
    return_computational_ratio=True,
)

Faster Hugging Face Transformers Model Inference

from transformers import AutoModelForCausalLM, AutoTokenizer
from flex_prefill import patch_model


tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2",
).cuda()

flex_prefill_config =  {
    "block_size": 128,
    "flex_prefill_gamma": 0.9,
    "flex_prefill_tau": 0.1,
    "flex_prefill_min_budget": 1024,
    "flex_prefill_max_budget": None,
}

patch_model(model, "flex_prefill", flex_prefill_config)

input_ids = tokenizer(prompt, return_tensors="pt", return_attention_mask=False).input_ids.cuda()
output_ids = model.generate(input_ids, max_new_tokens=64)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

vLLM Model Inference

from vllm import LLM, SamplingParams
from flex_prefill import patch_model


model = LLM("meta-llama/Llama-3.1-8B-Instruct", enable_chunked_prefill=False, max_num_seqs=1)
sampling_params = SamplingParams(temperature=0, max_tokens=64)

flex_prefill_config =  {
    "block_size": 128,
    "flex_prefill_gamma": 0.9,
    "flex_prefill_tau": 0.1,
    "flex_prefill_min_budget": 1024,
    "flex_prefill_max_budget": None,
}

patch_model(model, "flex_prefill", flex_prefill_config)

model.generate(prompts=[prompt], sampling_params=sampling_params)
output = outputs[0].outputs[0].text

Experiments

Experiment scripts are provided in the experiments folder. First, you need to install dependencies, and download the necessary data and models:

bash install.sh
bash experiments/download_data.sh
bash experiments/download_model.sh

Then, you can run benchmark experiments:

bash experiments/scripts/flex_prefill/ruler.sh
bash experiments/scripts/flex_prefill/infinitebench.sh

The results will be saved in the experiments/result directory.

Supported Models

Currently, flex_prefill.patch_model only supports the following models:

flex_prefill can be used with both Hugging Face Transformers models and VLLM models, but note that the batch size must be equal to 1.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use this code in your research, please cite the following paper:

@article{FlexPrefill2024,
  title={FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference},
  author={Your Name and Collaborators},
  journal={ArXiv Preprint },
  year={2024}
}

Acknowledgments

We acknowledge the support from our collaborators and the community. Thank you for your contributions and feedback.

Contact

For any questions or comments about the paper or the code, please contact [email protected].

Enjoy using FlexPrefill, and feel free to contribute to the project by opening issues or submitting pull requests!

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .