Files changed (1) hide show
  1. inference.py +9 -2
inference.py CHANGED
@@ -5,12 +5,19 @@ import sys
5
  import torch
6
  from typing import List, Dict
7
 
8
- # Ensure vllm is installed
9
  try:
10
  import vllm
11
  except ImportError:
12
- subprocess.check_call([sys.executable, "-m", "pip", "install", "vllm"])
 
 
 
 
 
 
13
 
 
14
  # Import the necessary modules after installation
15
  from vllm import LLM, SamplingParams
16
  from vllm.utils import random_uuid
 
5
  import torch
6
  from typing import List, Dict
7
 
8
+ # Ensure vllm is installed and specify version to match CUDA compatibility
9
  try:
10
  import vllm
11
  except ImportError:
12
+ # Check CUDA version and install the correct vllm version
13
+ cuda_version = torch.version.cuda
14
+ if cuda_version == "11.8":
15
+ vllm_version = "v0.6.1.post1"
16
+ pip_cmd = f"pip install https://github.com/vllm-project/vllm/releases/download/{vllm_version}/vllm-{vllm_version}+cu118-cp310-cp310-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118"
17
+ else:
18
+ raise RuntimeError(f"Unsupported CUDA version: {cuda_version}")
19
 
20
+ subprocess.check_call([sys.executable, "-m", "pip", "install", pip_cmd])
21
  # Import the necessary modules after installation
22
  from vllm import LLM, SamplingParams
23
  from vllm.utils import random_uuid