Imran1 commited on
Commit
e222084
1 Parent(s): 54088e7

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +63 -0
inference.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ # Install VLLM using subprocess
5
+ subprocess.check_call([
6
+ f"{sys.executable}", "-m", "pip", "install",
7
+ "vllm @ https://github.com/vllm-project/vllm/releases/download/v0.6.1.post1/vllm-0.6.1.post1+cu118-cp310-cp310-manylinux1_x86_64.whl"
8
+ ])
9
+
10
+ import json
11
+ import logging
12
+ import os
13
+ from vllm import LLM, SamplingParams
14
+
15
+ logger = logging.getLogger()
16
+ logger.setLevel(logging.INFO)
17
+
18
+
19
+ def model_fn(model_dir, context= None):
20
+ model = LLM(
21
+ model=model_dir,
22
+ trust_remote_code=True,
23
+ gpu_memory_utilization=0.9,
24
+ tensor_parallel_size=4 # Set tensor parallelism to 4 GPUs
25
+ )
26
+ return model
27
+ def predict_fn(data, model , context= None):
28
+ try:
29
+ input_text = data.pop("inputs", data)
30
+ parameters = data.pop("parameters", {})
31
+
32
+ # Qwen 2.5 chat template
33
+ chat_template = f"<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n"
34
+
35
+ sampling_params = SamplingParams(
36
+ temperature=parameters.get("temperature", 0.7),
37
+ top_p=parameters.get("top_p", 0.9),
38
+ max_new_tokens=parameters.get("max_new_tokens", 512),
39
+ do_sample=True,
40
+ stop_tokens=["<|im_end|>", "<|im_start|>"] # Add stop tokens for Qwen
41
+ )
42
+
43
+ outputs = model.generate(chat_template, sampling_params)
44
+ generated_text = outputs[0].outputs[0].text
45
+
46
+ # Remove any trailing stop tokens if they were generated
47
+ for stop_token in ["<|im_end|>", "<|im_start|>"]:
48
+ if generated_text.endswith(stop_token):
49
+ generated_text = generated_text[:-len(stop_token)].strip()
50
+
51
+ return {"generated_text": generated_text}
52
+ except Exception as e:
53
+ logger.error(f"Exception during prediction: {e}")
54
+ return {"error": str(e)}
55
+
56
+ def input_fn(request_body, request_content_type,context= None):
57
+ if request_content_type == "application/json":
58
+ return json.loads(request_body)
59
+ else:
60
+ raise ValueError(f"Unsupported content type: {request_content_type}")
61
+
62
+ def output_fn(prediction, accept , context= None) :
63
+ return json.dumps(prediction)