|
import json |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from typing import List, Dict |
|
from accelerate import load_checkpoint_and_dispatch |
|
import fcntl |
|
import os |
|
import time |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
|
|
|
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}") |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
|
|
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str: |
|
""" |
|
Format chat messages using Qwen's chat template. |
|
""" |
|
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
def model_fn(model_dir, context=None): |
|
global model, tokenizer |
|
|
|
|
|
lock_file = "/tmp/model_load.lock" |
|
|
|
in_progress_file = "/tmp/model_loading_in_progress" |
|
|
|
if model is not None and tokenizer is not None: |
|
print("Model and tokenizer already loaded, skipping reload.") |
|
return model, tokenizer |
|
|
|
|
|
with open(lock_file, 'w') as lock: |
|
print("Attempting to acquire model load lock...") |
|
fcntl.flock(lock, fcntl.LOCK_EX) |
|
|
|
try: |
|
|
|
if os.path.exists(in_progress_file): |
|
print("Another worker is currently loading the model, waiting...") |
|
|
|
|
|
while os.path.exists(in_progress_file): |
|
time.sleep(5) |
|
|
|
print("Loading complete by another worker, skipping reload.") |
|
return model, tokenizer |
|
|
|
|
|
print("No one is loading, proceeding to load the model.") |
|
with open(in_progress_file, 'w') as f: |
|
f.write("loading") |
|
|
|
|
|
if model is None or tokenizer is None: |
|
print("Loading the model and tokenizer...") |
|
|
|
offload_dir = "/tmp/offload_dir" |
|
os.makedirs(offload_dir, exist_ok=True) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16) |
|
model = load_checkpoint_and_dispatch( |
|
model, |
|
model_dir, |
|
device_map="auto", |
|
offload_folder=offload_dir, |
|
max_memory = {i: "15GiB" for i in range(torch.cuda.device_count())} |
|
no_split_module_classes=["QwenForCausalLM"] |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
|
except Exception as e: |
|
print(f"Error loading model and tokenizer: {e}") |
|
raise |
|
|
|
finally: |
|
|
|
if os.path.exists(in_progress_file): |
|
os.remove(in_progress_file) |
|
|
|
|
|
fcntl.flock(lock, fcntl.LOCK_UN) |
|
|
|
return model, tokenizer |
|
|
|
|
|
def predict_fn(input_data, model_and_tokenizer, context=None): |
|
""" |
|
Generate predictions for the input data. |
|
""" |
|
try: |
|
model, tokenizer = model_and_tokenizer |
|
data = json.loads(input_data) |
|
|
|
|
|
messages = data.get("messages", []) |
|
formatted_prompt = format_chat(messages, tokenizer) |
|
|
|
|
|
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") |
|
|
|
|
|
outputs = model.generate( |
|
inputs['input_ids'], |
|
max_new_tokens=data.get("max_new_tokens", 512), |
|
temperature=data.get("temperature", 0.7), |
|
top_p=data.get("top_p", 0.9), |
|
repetition_penalty=data.get("repetition_penalty", 1.0), |
|
length_penalty=data.get("length_penalty", 1.0), |
|
do_sample=True |
|
) |
|
|
|
|
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
|
|
|
response = { |
|
"id": "chatcmpl-uuid", |
|
"object": "chat.completion", |
|
"model": "qwen-72b", |
|
"choices": [{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": generated_text |
|
}, |
|
"finish_reason": "stop" |
|
}], |
|
"usage": { |
|
"prompt_tokens": len(inputs['input_ids'][0]), |
|
"completion_tokens": len(outputs[0]), |
|
"total_tokens": len(inputs['input_ids'][0]) + len(outputs[0]) |
|
} |
|
} |
|
return response |
|
|
|
except Exception as e: |
|
return {"error": str(e), "details": repr(e)} |
|
|
|
|
|
def input_fn(serialized_input_data, content_type, context=None): |
|
""" |
|
Prepare the input data for inference. |
|
""" |
|
return serialized_input_data |
|
|
|
|
|
def output_fn(prediction_output, accept, context=None): |
|
""" |
|
Convert the model output to a JSON response. |
|
""" |
|
return json.dumps(prediction_output) |