import os import json import subprocess import sys import torch from typing import List, Dict # Ensure vllm is installed try: import vllm except ImportError: subprocess.check_call([sys.executable, "-m", "pip", "install", "vllm"]) # Import the necessary modules after installation from vllm import LLM, SamplingParams from vllm.utils import random_uuid # Function to format chat messages using Qwen's chat template def format_chat(messages: List[Dict[str, str]]) -> str: """ Format chat messages using Qwen's chat template """ formatted_text = "" for message in messages: role = message["role"] content = message["content"] if role == "system": formatted_text += f"<|im_start|>system\n{content}<|im_end|>\n" elif role == "user": formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n" elif role == "assistant": formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n" # Add the final assistant prompt formatted_text += "<|im_start|>assistant\n" return formatted_text # Model loading function for SageMaker def model_fn(model_dir): # Load the quantized model from the model directory model = LLM( model=model_dir, trust_remote_code=True, gpu_memory_utilization=0.9 # Optimal GPU usage ) return model # Custom predict function for SageMaker def predict_fn(input_data, model): try: data = json.loads(input_data) # Format the prompt using Qwen's chat template messages = data.get("messages", []) formatted_prompt = format_chat(messages) # Build sampling parameters (without do_sample to match OpenAI API) sampling_params = SamplingParams( temperature=data.get("temperature", 0.7), top_p=data.get("top_p", 0.9), max_new_tokens=data.get("max_new_tokens", 512), top_k=data.get("top_k", -1), # Support for top-k sampling repetition_penalty=data.get("repetition_penalty", 1.0), length_penalty=data.get("length_penalty", 1.0), stop_token_ids=data.get("stop_token_ids", None), skip_special_tokens=data.get("skip_special_tokens", True) ) # Generate output outputs = model.generate(formatted_prompt, sampling_params) generated_text = outputs[0].outputs[0].text # Build response response = { "id": f"chatcmpl-{random_uuid()}", "object": "chat.completion", "created": int(torch.cuda.current_timestamp()), "model": "qwen-72b", "choices": [{ "index": 0, "message": { "role": "assistant", "content": generated_text }, "finish_reason": "stop" }], "usage": { "prompt_tokens": len(formatted_prompt), "completion_tokens": len(generated_text), "total_tokens": len(formatted_prompt) + len(generated_text) } } return response except Exception as e: return {"error": str(e), "details": repr(e)} # Define input and output formats for SageMaker def input_fn(serialized_input_data, content_type): return serialized_input_data def output_fn(prediction_output, accept): return json.dumps(prediction_output)