import json import logging from typing import Dict, List, Optional import torch from fastapi import FastAPI, Request from vllm import LLM, SamplingParams from vllm.utils import random_uuid from chat_template import format_chat app = FastAPI() logger = logging.getLogger() logger.setLevel(logging.INFO) # Load the model function def model_fn(model_dir): # The model is already in the container, so we don't need to download it model = LLM( model=model_dir, # Load from local path trust_remote_code=True, dtype="fp8", # Explicitly specifying FP8 quantization gpu_memory_utilization=0.9, # Optimal GPU usage ) return model # Global model variable model = None @app.on_event("startup") async def startup_event(): global model model = model_fn("/opt/ml/model") # Ensure the correct path to the model # Chat completion endpoint @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: data = await request.json() # Retrieve messages and format the prompt messages = data.get("messages", []) formatted_prompt = format_chat(messages) # Build sampling parameters with flexibility sampling_params = SamplingParams( do_sample=data.get("do_sample", True), temperature=data.get("temperature", 0.7), top_p=data.get("top_p", 0.9), max_new_tokens=data.get("max_new_tokens", 512), top_k=data.get("top_k", -1), # Support for top-k sampling repetition_penalty=data.get("repetition_penalty", 1.0), length_penalty=data.get("length_penalty", 1.0), stop_token_ids=data.get("stop_token_ids", None), skip_special_tokens=data.get("skip_special_tokens", True) ) # Handle optional vLLM-specific guided parameters if present guided_params = data.get("guided_params", None) if guided_params: sampling_params.guided_choice = guided_params.get("guided_choice") sampling_params.guided_json = guided_params.get("guided_json") sampling_params.guided_regex = guided_params.get("guided_regex") # Generate output outputs = model.generate(formatted_prompt, sampling_params) generated_text = outputs[0].outputs[0].text # Build response similar to OpenAI format response = { "id": f"chatcmpl-{random_uuid()}", "object": "chat.completion", "created": int(torch.cuda.current_timestamp()), "model": "qwen-72b", "choices": [{ "index": 0, "message": { "role": "assistant", "content": generated_text }, "finish_reason": "stop" }], "usage": { "prompt_tokens": len(formatted_prompt), "completion_tokens": len(generated_text), "total_tokens": len(formatted_prompt) + len(generated_text) } } return response except Exception as e: logger.exception("Exception during prediction") return {"error": str(e), "details": repr(e)} # Health check endpoint @app.get("/ping") def ping(): logger.info("Ping request received") return {"status": "healthy"}