File size: 3,816 Bytes
36cfeeb d54b6e0 36cfeeb d54b6e0 df9d248 d25660d 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 ff2fb92 36cfeeb 88fa6b0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 36cfeeb d54b6e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
import json
import subprocess
import sys
import torch
from typing import List, Dict
# Function to install vllm if not already installed
def install_vllm():
try:
import vllm
except ImportError:
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"vllm @ https://github.com/vllm-project/vllm/releases/download/v0.6.1.post1/vllm-0.6.1.post1+cu118-cp310-cp310-manylinux1_x86_64.whl"
])
import vllm
# Call the function to install vllm
install_vllm()
# Import the necessary modules after installation
from vllm import LLM, SamplingParams
from vllm.utils import random_uuid
# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]]) -> str:
"""
Format chat messages using Qwen's chat template
"""
formatted_text = ""
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
formatted_text += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n"
# Add the final assistant prompt
formatted_text += "<|im_start|>assistant\n"
return formatted_text
# Model loading function for SageMaker
def model_fn(model_dir, context=None):
# Load the model with tensor parallelism
model = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=8, # Use 8 GPUs for parallelism
gpu_memory_utilization=0.9 # Optimal GPU usage
)
return model
# Custom predict function for SageMaker
def predict_fn(input_data, model):
try:
data = json.loads(input_data)
# Format the prompt using Qwen's chat template
messages = data.get("messages", [])
formatted_prompt = format_chat(messages)
# Build sampling parameters (without do_sample to match OpenAI API)
sampling_params = SamplingParams(
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
max_new_tokens=data.get("max_new_tokens", 512),
top_k=data.get("top_k", -1), # Support for top-k sampling
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
stop_token_ids=data.get("stop_token_ids", None),
skip_special_tokens=data.get("skip_special_tokens", True)
)
# Generate output
outputs = model.generate(formatted_prompt, sampling_params)
generated_text = outputs[0].outputs[0].text
# Build response
response = {
"id": f"chatcmpl-{random_uuid()}",
"object": "chat.completion",
"created": int(torch.cuda.current_timestamp()),
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(formatted_prompt),
"completion_tokens": len(generated_text),
"total_tokens": len(formatted_prompt) + len(generated_text)
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
# Define input and output formats for SageMaker
def input_fn(serialized_input_data, content_type):
return serialized_input_data
def output_fn(prediction_output, accept):
return json.dumps(prediction_output)
|