File size: 3,409 Bytes
0f69350 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import json
import logging
from typing import Dict, List, Optional
import torch
from fastapi import FastAPI, Request
from vllm import LLM, SamplingParams
from vllm.utils import random_uuid
from chat_template import format_chat
app = FastAPI()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Load the model function
def model_fn(model_dir):
# The model is already in the container, so we don't need to download it
model = LLM(
model=model_dir, # Load from local path
trust_remote_code=True,
dtype="fp8", # Explicitly specifying FP8 quantization
gpu_memory_utilization=0.9, # Optimal GPU usage
)
return model
# Global model variable
model = None
@app.on_event("startup")
async def startup_event():
global model
model = model_fn("/opt/ml/model") # Ensure the correct path to the model
# Chat completion endpoint
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
data = await request.json()
# Retrieve messages and format the prompt
messages = data.get("messages", [])
formatted_prompt = format_chat(messages)
# Build sampling parameters with flexibility
sampling_params = SamplingParams(
do_sample=data.get("do_sample", True),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
max_new_tokens=data.get("max_new_tokens", 512),
top_k=data.get("top_k", -1), # Support for top-k sampling
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
stop_token_ids=data.get("stop_token_ids", None),
skip_special_tokens=data.get("skip_special_tokens", True)
)
# Handle optional vLLM-specific guided parameters if present
guided_params = data.get("guided_params", None)
if guided_params:
sampling_params.guided_choice = guided_params.get("guided_choice")
sampling_params.guided_json = guided_params.get("guided_json")
sampling_params.guided_regex = guided_params.get("guided_regex")
# Generate output
outputs = model.generate(formatted_prompt, sampling_params)
generated_text = outputs[0].outputs[0].text
# Build response similar to OpenAI format
response = {
"id": f"chatcmpl-{random_uuid()}",
"object": "chat.completion",
"created": int(torch.cuda.current_timestamp()),
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(formatted_prompt),
"completion_tokens": len(generated_text),
"total_tokens": len(formatted_prompt) + len(generated_text)
}
}
return response
except Exception as e:
logger.exception("Exception during prediction")
return {"error": str(e), "details": repr(e)}
# Health check endpoint
@app.get("/ping")
def ping():
logger.info("Ping request received")
return {"status": "healthy"}
|