Qwen2.5-72B-Instruct-FP8 / inference.py
FINGU-AI's picture
Upload inference.py
0f69350 verified
raw
history blame
3.41 kB
import json
import logging
from typing import Dict, List, Optional
import torch
from fastapi import FastAPI, Request
from vllm import LLM, SamplingParams
from vllm.utils import random_uuid
from chat_template import format_chat
app = FastAPI()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Load the model function
def model_fn(model_dir):
# The model is already in the container, so we don't need to download it
model = LLM(
model=model_dir, # Load from local path
trust_remote_code=True,
dtype="fp8", # Explicitly specifying FP8 quantization
gpu_memory_utilization=0.9, # Optimal GPU usage
)
return model
# Global model variable
model = None
@app.on_event("startup")
async def startup_event():
global model
model = model_fn("/opt/ml/model") # Ensure the correct path to the model
# Chat completion endpoint
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
data = await request.json()
# Retrieve messages and format the prompt
messages = data.get("messages", [])
formatted_prompt = format_chat(messages)
# Build sampling parameters with flexibility
sampling_params = SamplingParams(
do_sample=data.get("do_sample", True),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
max_new_tokens=data.get("max_new_tokens", 512),
top_k=data.get("top_k", -1), # Support for top-k sampling
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
stop_token_ids=data.get("stop_token_ids", None),
skip_special_tokens=data.get("skip_special_tokens", True)
)
# Handle optional vLLM-specific guided parameters if present
guided_params = data.get("guided_params", None)
if guided_params:
sampling_params.guided_choice = guided_params.get("guided_choice")
sampling_params.guided_json = guided_params.get("guided_json")
sampling_params.guided_regex = guided_params.get("guided_regex")
# Generate output
outputs = model.generate(formatted_prompt, sampling_params)
generated_text = outputs[0].outputs[0].text
# Build response similar to OpenAI format
response = {
"id": f"chatcmpl-{random_uuid()}",
"object": "chat.completion",
"created": int(torch.cuda.current_timestamp()),
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(formatted_prompt),
"completion_tokens": len(generated_text),
"total_tokens": len(formatted_prompt) + len(generated_text)
}
}
return response
except Exception as e:
logger.exception("Exception during prediction")
return {"error": str(e), "details": repr(e)}
# Health check endpoint
@app.get("/ping")
def ping():
logger.info("Ping request received")
return {"status": "healthy"}