File size: 6,052 Bytes
eea1116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54088e7
eea1116
 
 
 
 
2744dc1
eea1116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
import fcntl  # For file locking
import os  # For file operations
import time  # For sleep function

# Set the max_split_size globally at the start
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

# Print to verify the environment variable is correctly set
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")

# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None

# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
    """
    Format chat messages using Qwen's chat template.
    """
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def model_fn(model_dir, context=None):
    global model, tokenizer

    # Path to lock file for ensuring single loading
    lock_file = "/tmp/model_load.lock"
    # Path to in-progress file indicating model loading is happening
    in_progress_file = "/tmp/model_loading_in_progress"

    if model is not None and tokenizer is not None:
        print("Model and tokenizer already loaded, skipping reload.")
        return model, tokenizer

    # Attempt to acquire the lock
    with open(lock_file, 'w') as lock:
        print("Attempting to acquire model load lock...")
        fcntl.flock(lock, fcntl.LOCK_EX)  # Exclusive lock

        try:
            # Check if another worker is in the process of loading
            if os.path.exists(in_progress_file):
                print("Another worker is currently loading the model, waiting...")

                # Poll the in-progress flag until the other worker finishes loading
                while os.path.exists(in_progress_file):
                    time.sleep(5)  # Wait for 5 seconds before checking again

                print("Loading complete by another worker, skipping reload.")
                return model, tokenizer

            # If no one is loading, start loading the model and set the in-progress flag
            print("No one is loading, proceeding to load the model.")
            with open(in_progress_file, 'w') as f:
                f.write("loading")

            # Loading the model and tokenizer
            if model is None or tokenizer is None:
                print("Loading the model and tokenizer...")

                offload_dir = "/tmp/offload_dir"
                os.makedirs(offload_dir, exist_ok=True)

                # Load and dispatch model across 4 GPUs using tensor parallelism
                model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
                model = load_checkpoint_and_dispatch(
                    model,
                    model_dir,
                    device_map="auto",  # Automatically map layers across GPUs
                    offload_folder=offload_dir,  # Offload parts to disk if needed
                    max_memory = {i: "15GiB" for i in range(torch.cuda.device_count())},  # Example for reducing usage per GPU
                    no_split_module_classes=["QwenForCausalLM"]  # Ensure model is split across the GPUs
                )

                # Load the tokenizer
                tokenizer = AutoTokenizer.from_pretrained(model_dir)

        except Exception as e:
            print(f"Error loading model and tokenizer: {e}")
            raise

        finally:
            # Remove the in-progress flag once the loading is complete
            if os.path.exists(in_progress_file):
                os.remove(in_progress_file)

            # Release the lock
            fcntl.flock(lock, fcntl.LOCK_UN)

    return model, tokenizer

# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer, context=None):
    """
    Generate predictions for the input data.
    """
    try:
        model, tokenizer = model_and_tokenizer
        data = json.loads(input_data)

        # Format the prompt using Qwen's chat template
        messages = data.get("messages", [])
        formatted_prompt = format_chat(messages, tokenizer)

        # Tokenize the input
        inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0")  # Send input to GPU 0 for generation

        # Generate output
        outputs = model.generate(
            inputs['input_ids'],
            max_new_tokens=data.get("max_new_tokens", 512),
            temperature=data.get("temperature", 0.7),
            top_p=data.get("top_p", 0.9),
            repetition_penalty=data.get("repetition_penalty", 1.0),
            length_penalty=data.get("length_penalty", 1.0),
            do_sample=True
        )

        # Decode the output
        generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        # Build response
        response = {
            "id": "chatcmpl-uuid",
            "object": "chat.completion",
            "model": "qwen-72b",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": generated_text
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(inputs['input_ids'][0]),
                "completion_tokens": len(outputs[0]),
                "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
            }
        }
        return response

    except Exception as e:
        return {"error": str(e), "details": repr(e)}

# Define input format for SageMaker
def input_fn(serialized_input_data, content_type, context=None):
    """
    Prepare the input data for inference.
    """
    return serialized_input_data

# Define output format for SageMaker
def output_fn(prediction_output, accept, context=None):
    """
    Convert the model output to a JSON response.
    """
    return json.dumps(prediction_output)