Files changed (1) hide show
  1. inference.py +80 -0
inference.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Dict, List, Optional
4
+
5
+ import torch
6
+ from fastapi import FastAPI, Request
7
+ from vllm import LLM, SamplingParams
8
+ from vllm.utils import random_uuid
9
+
10
+ from chat_template import format_chat
11
+
12
+ app = FastAPI()
13
+ logger = logging.getLogger()
14
+ logger.setLevel(logging.INFO)
15
+
16
+ def model_fn(model_dir):
17
+ # The model is already in the container, so we don't need to download it
18
+ model = LLM(
19
+ model="/opt/ml/model", # This is now a local path
20
+ trust_remote_code=True,
21
+ dtype="float16",
22
+ gpu_memory_utilization=0.9,
23
+ )
24
+ return model
25
+
26
+ # Global model variable
27
+ model = None
28
+
29
+ @app.on_event("startup")
30
+ async def startup_event():
31
+ global model
32
+ model = model_fn("/opt/ml/model")
33
+
34
+ @app.post("/v1/chat/completions")
35
+ async def chat_completions(request: Request):
36
+ try:
37
+ data = await request.json()
38
+
39
+ messages = data.get("messages", [])
40
+ formatted_prompt = format_chat(messages)
41
+
42
+ sampling_params = SamplingParams(
43
+ do_sample=data.get("do_sample", True),
44
+ temperature=data.get("temperature", 0.7),
45
+ top_p=data.get("top_p", 0.9),
46
+ max_new_tokens=data.get("max_new_tokens", 512),
47
+ )
48
+
49
+ outputs = model.generate(formatted_prompt, sampling_params)
50
+ generated_text = outputs[0].outputs[0].text
51
+
52
+ response = {
53
+ "id": f"chatcmpl-{random_uuid()}",
54
+ "object": "chat.completion",
55
+ "created": int(torch.cuda.current_timestamp()),
56
+ "model": "qwen-72b",
57
+ "choices": [{
58
+ "index": 0,
59
+ "message": {
60
+ "role": "assistant",
61
+ "content": generated_text
62
+ },
63
+ "finish_reason": "stop"
64
+ }],
65
+ "usage": {
66
+ "prompt_tokens": len(formatted_prompt),
67
+ "completion_tokens": len(generated_text),
68
+ "total_tokens": len(formatted_prompt) + len(generated_text)
69
+ }
70
+ }
71
+
72
+ return response
73
+
74
+ except Exception as e:
75
+ logger.error(f"Exception during prediction: {e}")
76
+ return {"error": str(e)}
77
+
78
+ @app.get("/ping")
79
+ def ping():
80
+ return {"status": "healthy"}