File size: 6,052 Bytes
eea1116 54088e7 eea1116 2744dc1 eea1116 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
import fcntl # For file locking
import os # For file operations
import time # For sleep function
# Set the max_split_size globally at the start
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Print to verify the environment variable is correctly set
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")
# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None
# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
"""
Format chat messages using Qwen's chat template.
"""
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def model_fn(model_dir, context=None):
global model, tokenizer
# Path to lock file for ensuring single loading
lock_file = "/tmp/model_load.lock"
# Path to in-progress file indicating model loading is happening
in_progress_file = "/tmp/model_loading_in_progress"
if model is not None and tokenizer is not None:
print("Model and tokenizer already loaded, skipping reload.")
return model, tokenizer
# Attempt to acquire the lock
with open(lock_file, 'w') as lock:
print("Attempting to acquire model load lock...")
fcntl.flock(lock, fcntl.LOCK_EX) # Exclusive lock
try:
# Check if another worker is in the process of loading
if os.path.exists(in_progress_file):
print("Another worker is currently loading the model, waiting...")
# Poll the in-progress flag until the other worker finishes loading
while os.path.exists(in_progress_file):
time.sleep(5) # Wait for 5 seconds before checking again
print("Loading complete by another worker, skipping reload.")
return model, tokenizer
# If no one is loading, start loading the model and set the in-progress flag
print("No one is loading, proceeding to load the model.")
with open(in_progress_file, 'w') as f:
f.write("loading")
# Loading the model and tokenizer
if model is None or tokenizer is None:
print("Loading the model and tokenizer...")
offload_dir = "/tmp/offload_dir"
os.makedirs(offload_dir, exist_ok=True)
# Load and dispatch model across 4 GPUs using tensor parallelism
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
model = load_checkpoint_and_dispatch(
model,
model_dir,
device_map="auto", # Automatically map layers across GPUs
offload_folder=offload_dir, # Offload parts to disk if needed
max_memory = {i: "15GiB" for i in range(torch.cuda.device_count())}, # Example for reducing usage per GPU
no_split_module_classes=["QwenForCausalLM"] # Ensure model is split across the GPUs
)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
except Exception as e:
print(f"Error loading model and tokenizer: {e}")
raise
finally:
# Remove the in-progress flag once the loading is complete
if os.path.exists(in_progress_file):
os.remove(in_progress_file)
# Release the lock
fcntl.flock(lock, fcntl.LOCK_UN)
return model, tokenizer
# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer, context=None):
"""
Generate predictions for the input data.
"""
try:
model, tokenizer = model_and_tokenizer
data = json.loads(input_data)
# Format the prompt using Qwen's chat template
messages = data.get("messages", [])
formatted_prompt = format_chat(messages, tokenizer)
# Tokenize the input
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
# Generate output
outputs = model.generate(
inputs['input_ids'],
max_new_tokens=data.get("max_new_tokens", 512),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
do_sample=True
)
# Decode the output
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Build response
response = {
"id": "chatcmpl-uuid",
"object": "chat.completion",
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(inputs['input_ids'][0]),
"completion_tokens": len(outputs[0]),
"total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
# Define input format for SageMaker
def input_fn(serialized_input_data, content_type, context=None):
"""
Prepare the input data for inference.
"""
return serialized_input_data
# Define output format for SageMaker
def output_fn(prediction_output, accept, context=None):
"""
Convert the model output to a JSON response.
"""
return json.dumps(prediction_output) |