File size: 3,816 Bytes
36cfeeb
d54b6e0
36cfeeb
 
d54b6e0
df9d248
 
d25660d
 
 
 
 
 
 
 
 
 
 
 
 
36cfeeb
 
 
 
 
d54b6e0
36cfeeb
d54b6e0
 
 
36cfeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d54b6e0
 
ff2fb92
36cfeeb
 
 
 
88fa6b0
36cfeeb
 
 
d54b6e0
 
36cfeeb
d54b6e0
 
 
 
 
36cfeeb
d54b6e0
36cfeeb
 
d54b6e0
 
36cfeeb
 
d54b6e0
 
36cfeeb
 
d54b6e0
 
36cfeeb
 
 
d54b6e0
 
 
36cfeeb
d54b6e0
36cfeeb
d54b6e0
 
 
 
 
 
 
 
 
 
36cfeeb
 
 
d54b6e0
 
 
 
 
 
36cfeeb
d54b6e0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import json
import subprocess
import sys
import torch
from typing import List, Dict

# Function to install vllm if not already installed
def install_vllm():
    try:
        import vllm
    except ImportError:
        subprocess.check_call([
            sys.executable, "-m", "pip", "install",
            "vllm @ https://github.com/vllm-project/vllm/releases/download/v0.6.1.post1/vllm-0.6.1.post1+cu118-cp310-cp310-manylinux1_x86_64.whl"
        ])
        import vllm

# Call the function to install vllm
install_vllm()

# Import the necessary modules after installation
from vllm import LLM, SamplingParams
from vllm.utils import random_uuid

# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]]) -> str:
    """
    Format chat messages using Qwen's chat template
    """
    formatted_text = ""
    for message in messages:
        role = message["role"]
        content = message["content"]
        
        if role == "system":
            formatted_text += f"<|im_start|>system\n{content}<|im_end|>\n"
        elif role == "user":
            formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n"
        elif role == "assistant":
            formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n"
    
    # Add the final assistant prompt
    formatted_text += "<|im_start|>assistant\n"
    
    return formatted_text

# Model loading function for SageMaker
def model_fn(model_dir, context=None):
    # Load the model with tensor parallelism
    model = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=8,  # Use 8 GPUs for parallelism
        gpu_memory_utilization=0.9  # Optimal GPU usage
    )
    return model

# Custom predict function for SageMaker
def predict_fn(input_data, model):
    try:
        data = json.loads(input_data)

        # Format the prompt using Qwen's chat template
        messages = data.get("messages", [])
        formatted_prompt = format_chat(messages)

        # Build sampling parameters (without do_sample to match OpenAI API)
        sampling_params = SamplingParams(
            temperature=data.get("temperature", 0.7),
            top_p=data.get("top_p", 0.9),
            max_new_tokens=data.get("max_new_tokens", 512),
            top_k=data.get("top_k", -1),  # Support for top-k sampling
            repetition_penalty=data.get("repetition_penalty", 1.0),
            length_penalty=data.get("length_penalty", 1.0),
            stop_token_ids=data.get("stop_token_ids", None),
            skip_special_tokens=data.get("skip_special_tokens", True)
        )

        # Generate output
        outputs = model.generate(formatted_prompt, sampling_params)
        generated_text = outputs[0].outputs[0].text

        # Build response
        response = {
            "id": f"chatcmpl-{random_uuid()}",
            "object": "chat.completion",
            "created": int(torch.cuda.current_timestamp()),
            "model": "qwen-72b",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": generated_text
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(formatted_prompt),
                "completion_tokens": len(generated_text),
                "total_tokens": len(formatted_prompt) + len(generated_text)
            }
        }
        return response
    except Exception as e:
        return {"error": str(e), "details": repr(e)}

# Define input and output formats for SageMaker
def input_fn(serialized_input_data, content_type):
    return serialized_input_data

def output_fn(prediction_output, accept):
    return json.dumps(prediction_output)