|
import os |
|
import json |
|
import subprocess |
|
import sys |
|
import torch |
|
from typing import List, Dict |
|
|
|
|
|
try: |
|
import vllm |
|
except ImportError: |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "vllm"]) |
|
|
|
|
|
from vllm import LLM, SamplingParams |
|
from vllm.utils import random_uuid |
|
|
|
|
|
def format_chat(messages: List[Dict[str, str]]) -> str: |
|
""" |
|
Format chat messages using Qwen's chat template |
|
""" |
|
formatted_text = "" |
|
for message in messages: |
|
role = message["role"] |
|
content = message["content"] |
|
|
|
if role == "system": |
|
formatted_text += f"<|im_start|>system\n{content}<|im_end|>\n" |
|
elif role == "user": |
|
formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n" |
|
elif role == "assistant": |
|
formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n" |
|
|
|
|
|
formatted_text += "<|im_start|>assistant\n" |
|
|
|
return formatted_text |
|
|
|
|
|
def model_fn(model_dir): |
|
|
|
model = LLM( |
|
model=model_dir, |
|
trust_remote_code=True, |
|
gpu_memory_utilization=0.9 |
|
) |
|
return model |
|
|
|
|
|
def predict_fn(input_data, model): |
|
try: |
|
data = json.loads(input_data) |
|
|
|
|
|
messages = data.get("messages", []) |
|
formatted_prompt = format_chat(messages) |
|
|
|
|
|
sampling_params = SamplingParams( |
|
temperature=data.get("temperature", 0.7), |
|
top_p=data.get("top_p", 0.9), |
|
max_new_tokens=data.get("max_new_tokens", 512), |
|
top_k=data.get("top_k", -1), |
|
repetition_penalty=data.get("repetition_penalty", 1.0), |
|
length_penalty=data.get("length_penalty", 1.0), |
|
stop_token_ids=data.get("stop_token_ids", None), |
|
skip_special_tokens=data.get("skip_special_tokens", True) |
|
) |
|
|
|
|
|
outputs = model.generate(formatted_prompt, sampling_params) |
|
generated_text = outputs[0].outputs[0].text |
|
|
|
|
|
response = { |
|
"id": f"chatcmpl-{random_uuid()}", |
|
"object": "chat.completion", |
|
"created": int(torch.cuda.current_timestamp()), |
|
"model": "qwen-72b", |
|
"choices": [{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": generated_text |
|
}, |
|
"finish_reason": "stop" |
|
}], |
|
"usage": { |
|
"prompt_tokens": len(formatted_prompt), |
|
"completion_tokens": len(generated_text), |
|
"total_tokens": len(formatted_prompt) + len(generated_text) |
|
} |
|
} |
|
return response |
|
except Exception as e: |
|
return {"error": str(e), "details": repr(e)} |
|
|
|
|
|
def input_fn(serialized_input_data, content_type): |
|
return serialized_input_data |
|
|
|
def output_fn(prediction_output, accept): |
|
return json.dumps(prediction_output) |
|
|