|
import json |
|
import logging |
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
from fastapi import FastAPI, Request |
|
from vllm import LLM, SamplingParams |
|
from vllm.utils import random_uuid |
|
|
|
from chat_template import format_chat |
|
|
|
app = FastAPI() |
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
def model_fn(model_dir): |
|
|
|
model = LLM( |
|
model=model_dir, |
|
trust_remote_code=True, |
|
dtype="fp8", |
|
gpu_memory_utilization=0.9, |
|
) |
|
return model |
|
|
|
|
|
model = None |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global model |
|
model = model_fn("/opt/ml/model") |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
async def chat_completions(request: Request): |
|
try: |
|
data = await request.json() |
|
|
|
|
|
messages = data.get("messages", []) |
|
formatted_prompt = format_chat(messages) |
|
|
|
|
|
sampling_params = SamplingParams( |
|
do_sample=data.get("do_sample", True), |
|
temperature=data.get("temperature", 0.7), |
|
top_p=data.get("top_p", 0.9), |
|
max_new_tokens=data.get("max_new_tokens", 512), |
|
top_k=data.get("top_k", -1), |
|
repetition_penalty=data.get("repetition_penalty", 1.0), |
|
length_penalty=data.get("length_penalty", 1.0), |
|
stop_token_ids=data.get("stop_token_ids", None), |
|
skip_special_tokens=data.get("skip_special_tokens", True) |
|
) |
|
|
|
|
|
guided_params = data.get("guided_params", None) |
|
if guided_params: |
|
sampling_params.guided_choice = guided_params.get("guided_choice") |
|
sampling_params.guided_json = guided_params.get("guided_json") |
|
sampling_params.guided_regex = guided_params.get("guided_regex") |
|
|
|
|
|
outputs = model.generate(formatted_prompt, sampling_params) |
|
generated_text = outputs[0].outputs[0].text |
|
|
|
|
|
response = { |
|
"id": f"chatcmpl-{random_uuid()}", |
|
"object": "chat.completion", |
|
"created": int(torch.cuda.current_timestamp()), |
|
"model": "qwen-72b", |
|
"choices": [{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": generated_text |
|
}, |
|
"finish_reason": "stop" |
|
}], |
|
"usage": { |
|
"prompt_tokens": len(formatted_prompt), |
|
"completion_tokens": len(generated_text), |
|
"total_tokens": len(formatted_prompt) + len(generated_text) |
|
} |
|
} |
|
|
|
return response |
|
|
|
except Exception as e: |
|
logger.exception("Exception during prediction") |
|
return {"error": str(e), "details": repr(e)} |
|
|
|
|
|
@app.get("/ping") |
|
def ping(): |
|
logger.info("Ping request received") |
|
return {"status": "healthy"} |
|
|