sailor2-3b-chat / main.py
yusufs's picture
fix(remove-params): Removing max_model_len
0ef012d
raw
history blame
4.96 kB
import torch
from typing import Any
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from vllm import LLM, SamplingParams, RequestOutput
# Don't forget to set HF_TOKEN in the env during running
app = FastAPI()
# Initialize the LLM engine
# Replace 'your-model-path' with the actual path or name of your model
# example:
# https://huggingface.co/spaces/damienbenveniste/deploy_vLLM/blob/b210a934d4ff7b68254d42fa28736d74649e610d/app.py#L17-L20
engine_llama_3_2: LLM = LLM(
model='meta-llama/Llama-3.2-3B-Instruct',
revision="0cb88a4f764b7a12671c53f0838cd831a0843b95",
# https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L1062-L1065
max_num_batched_tokens=512, # Reduced for T4
max_num_seqs=16, # Reduced for T4
gpu_memory_utilization=0.85, # Slightly increased, adjust if needed
tensor_parallel_size=2,
# Llama-3.2-3B-Instruct max context length is 131072, but we reduce it to 32k.
# 32k tokens, 3/4 of 32k is 24k words, each page average is 500 or 0.5k words,
# so that's basically 24k / .5k = 24 x 2 =~48 pages.
# Because when we use maximum token length, it will be slower and the memory is not enough for T4.
# https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L85-L86
# https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L98-L102
# max_model_len=32768,
enforce_eager=True, # Disable CUDA graph
dtype='auto', # Use 'half' if you want half precision
)
# ValueError: max_num_batched_tokens (512) is smaller than max_model_len (32768).
# This effectively limits the maximum sequence length to max_num_batched_tokens and makes vLLM reject longer sequences.
# Please increase max_num_batched_tokens or decrease max_model_len.
engine_sailor_chat: LLM = LLM(
model='sail/Sailor-4B-Chat',
revision="89a866a7041e6ec023dd462adeca8e28dd53c83e",
max_num_batched_tokens=512, # Reduced for T4
max_num_seqs=16, # Reduced for T4
gpu_memory_utilization=0.85, # Slightly increased, adjust if needed
tensor_parallel_size=2,
# max_model_len=32768,
enforce_eager=True, # Disable CUDA graph
dtype='auto', # Use 'half' if you want half precision
)
@app.get("/")
def greet_json():
cuda_info: dict[str, Any] = {}
if torch.cuda.is_available():
cuda_current_device: int = torch.cuda.current_device()
cuda_info = {
"device_count": torch.cuda.device_count(),
"cuda_device": torch.cuda.get_device_name(cuda_current_device),
"cuda_capability": torch.cuda.get_device_capability(cuda_current_device),
"allocated": f"{round(torch.cuda.memory_allocated(cuda_current_device) / 1024 ** 3, 1)} GB",
"cached": f"{round(torch.cuda.memory_reserved(cuda_current_device) / 1024 ** 3, 1)} GB",
}
return {
"message": f"CUDA availability is {torch.cuda.is_available()}",
"cuda_info": cuda_info,
"model": [
{
"name": "meta-llama/Llama-3.2-3B-Instruct",
"revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
"max_model_len": engine_llama_3_2.llm_engine.model_config.max_model_len,
},
{
"name": "sail/Sailor-4B-Chat",
"revision": "89a866a7041e6ec023dd462adeca8e28dd53c83e",
"max_model_len": engine_sailor_chat.llm_engine.model_config.max_model_len,
},
]
}
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 100
temperature: float = 0.7
logit_bias: Optional[dict[int, float]] = None
class GenerationResponse(BaseModel):
text: Optional[str]
error: Optional[str]
@app.post("/generate-llama3-2")
def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
try:
sampling_params: SamplingParams = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
logit_bias=request.logit_bias,
)
# Generate text
return engine_llama_3_2.generate(
prompts=request.prompt,
sampling_params=sampling_params
)
except Exception as e:
return {
"error": str(e)
}
@app.post("/generate-sailor-chat")
def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
try:
sampling_params: SamplingParams = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
logit_bias=request.logit_bias,
)
# Generate text
return engine_sailor_chat.generate(
prompts=request.prompt,
sampling_params=sampling_params
)
except Exception as e:
return {
"error": str(e)
}