Spaces:
Paused
Paused
File size: 4,961 Bytes
4998ce7 ae7cfbb 0ef012d ae7cfbb 4998ce7 ae7cfbb 0ef012d ae7cfbb 0ef012d 2425953 0ef012d ae7cfbb 2425953 ae7cfbb 0ef012d 586265c 0ef012d 586265c ae7cfbb 4998ce7 586265c 4998ce7 ae7cfbb 4998ce7 ae7cfbb 586265c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
from typing import Any
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from vllm import LLM, SamplingParams, RequestOutput
# Don't forget to set HF_TOKEN in the env during running
app = FastAPI()
# Initialize the LLM engine
# Replace 'your-model-path' with the actual path or name of your model
# example:
# https://huggingface.co/spaces/damienbenveniste/deploy_vLLM/blob/b210a934d4ff7b68254d42fa28736d74649e610d/app.py#L17-L20
engine_llama_3_2: LLM = LLM(
model='meta-llama/Llama-3.2-3B-Instruct',
revision="0cb88a4f764b7a12671c53f0838cd831a0843b95",
# https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L1062-L1065
max_num_batched_tokens=512, # Reduced for T4
max_num_seqs=16, # Reduced for T4
gpu_memory_utilization=0.85, # Slightly increased, adjust if needed
tensor_parallel_size=2,
# Llama-3.2-3B-Instruct max context length is 131072, but we reduce it to 32k.
# 32k tokens, 3/4 of 32k is 24k words, each page average is 500 or 0.5k words,
# so that's basically 24k / .5k = 24 x 2 =~48 pages.
# Because when we use maximum token length, it will be slower and the memory is not enough for T4.
# https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L85-L86
# https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L98-L102
# max_model_len=32768,
enforce_eager=True, # Disable CUDA graph
dtype='auto', # Use 'half' if you want half precision
)
# ValueError: max_num_batched_tokens (512) is smaller than max_model_len (32768).
# This effectively limits the maximum sequence length to max_num_batched_tokens and makes vLLM reject longer sequences.
# Please increase max_num_batched_tokens or decrease max_model_len.
engine_sailor_chat: LLM = LLM(
model='sail/Sailor-4B-Chat',
revision="89a866a7041e6ec023dd462adeca8e28dd53c83e",
max_num_batched_tokens=512, # Reduced for T4
max_num_seqs=16, # Reduced for T4
gpu_memory_utilization=0.85, # Slightly increased, adjust if needed
tensor_parallel_size=2,
# max_model_len=32768,
enforce_eager=True, # Disable CUDA graph
dtype='auto', # Use 'half' if you want half precision
)
@app.get("/")
def greet_json():
cuda_info: dict[str, Any] = {}
if torch.cuda.is_available():
cuda_current_device: int = torch.cuda.current_device()
cuda_info = {
"device_count": torch.cuda.device_count(),
"cuda_device": torch.cuda.get_device_name(cuda_current_device),
"cuda_capability": torch.cuda.get_device_capability(cuda_current_device),
"allocated": f"{round(torch.cuda.memory_allocated(cuda_current_device) / 1024 ** 3, 1)} GB",
"cached": f"{round(torch.cuda.memory_reserved(cuda_current_device) / 1024 ** 3, 1)} GB",
}
return {
"message": f"CUDA availability is {torch.cuda.is_available()}",
"cuda_info": cuda_info,
"model": [
{
"name": "meta-llama/Llama-3.2-3B-Instruct",
"revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
"max_model_len": engine_llama_3_2.llm_engine.model_config.max_model_len,
},
{
"name": "sail/Sailor-4B-Chat",
"revision": "89a866a7041e6ec023dd462adeca8e28dd53c83e",
"max_model_len": engine_sailor_chat.llm_engine.model_config.max_model_len,
},
]
}
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 100
temperature: float = 0.7
logit_bias: Optional[dict[int, float]] = None
class GenerationResponse(BaseModel):
text: Optional[str]
error: Optional[str]
@app.post("/generate-llama3-2")
def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
try:
sampling_params: SamplingParams = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
logit_bias=request.logit_bias,
)
# Generate text
return engine_llama_3_2.generate(
prompts=request.prompt,
sampling_params=sampling_params
)
except Exception as e:
return {
"error": str(e)
}
@app.post("/generate-sailor-chat")
def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
try:
sampling_params: SamplingParams = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
logit_bias=request.logit_bias,
)
# Generate text
return engine_sailor_chat.generate(
prompts=request.prompt,
sampling_params=sampling_params
)
except Exception as e:
return {
"error": str(e)
}
|