File size: 3,831 Bytes
d54b6e0
 
df9d248
 
d54b6e0
df9d248
 
a6861fa
df9d248
 
 
53b4f42
 
 
 
 
 
 
 
 
df9d248
d54b6e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import json
import subprocess
import sys
import torch
from typing import List, Dict

# Ensure vllm is installed and specify version to match CUDA compatibility
try:
    import vllm
except ImportError:
    # Install vllm with CUDA 11.8 support
    vllm_version = "v0.6.1.post1"
    pip_cmd = [
        sys.executable, 
        "-m", "pip", "install", 
        f"https://github.com/vllm-project/vllm/releases/download/{vllm_version}/vllm-{vllm_version}+cu118-cp310-cp310-manylinux1_x86_64.whl",
        "--extra-index-url", "https://download.pytorch.org/whl/cu118"
    ]
    subprocess.check_call(pip_cmd)
# Import the necessary modules after installation
from vllm import LLM, SamplingParams
from vllm.utils import random_uuid

# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]]) -> str:
    """
    Format chat messages using Qwen's chat template
    """
    formatted_text = ""
    for message in messages:
        role = message["role"]
        content = message["content"]
        
        if role == "system":
            formatted_text += f"<|im_start|>system\n{content}<|im_end|>\n"
        elif role == "user":
            formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n"
        elif role == "assistant":
            formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n"
    
    # Add the final assistant prompt
    formatted_text += "<|im_start|>assistant\n"
    
    return formatted_text

# Model loading function for SageMaker
def model_fn(model_dir):
    # Load the quantized model from the model directory
    model = LLM(
        model=model_dir,
        trust_remote_code=True,
        gpu_memory_utilization=0.9  # Optimal GPU usage
    )
    return model

# Custom predict function for SageMaker
def predict_fn(input_data, model):
    try:
        data = json.loads(input_data)

        # Format the prompt using Qwen's chat template
        messages = data.get("messages", [])
        formatted_prompt = format_chat(messages)

        # Build sampling parameters (without do_sample to match OpenAI API)
        sampling_params = SamplingParams(
            temperature=data.get("temperature", 0.7),
            top_p=data.get("top_p", 0.9),
            max_new_tokens=data.get("max_new_tokens", 512),
            top_k=data.get("top_k", -1),  # Support for top-k sampling
            repetition_penalty=data.get("repetition_penalty", 1.0),
            length_penalty=data.get("length_penalty", 1.0),
            stop_token_ids=data.get("stop_token_ids", None),
            skip_special_tokens=data.get("skip_special_tokens", True)
        )

        # Generate output
        outputs = model.generate(formatted_prompt, sampling_params)
        generated_text = outputs[0].outputs[0].text

        # Build response
        response = {
            "id": f"chatcmpl-{random_uuid()}",
            "object": "chat.completion",
            "created": int(torch.cuda.current_timestamp()),
            "model": "qwen-72b",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": generated_text
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(formatted_prompt),
                "completion_tokens": len(generated_text),
                "total_tokens": len(formatted_prompt) + len(generated_text)
            }
        }
        return response
    except Exception as e:
        return {"error": str(e), "details": repr(e)}

# Define input and output formats for SageMaker
def input_fn(serialized_input_data, content_type):
    return serialized_input_data

def output_fn(prediction_output, accept):
    return json.dumps(prediction_output)