Files changed (1) hide show
  1. inference.py +9 -9
inference.py CHANGED
@@ -9,15 +9,15 @@ from typing import List, Dict
9
  try:
10
  import vllm
11
  except ImportError:
12
- # Check CUDA version and install the correct vllm version
13
- cuda_version = torch.version.cuda
14
- if cuda_version == "11.8":
15
- vllm_version = "v0.6.1.post1"
16
- pip_cmd = f"pip install https://github.com/vllm-project/vllm/releases/download/{vllm_version}/vllm-{vllm_version}+cu118-cp310-cp310-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118"
17
- else:
18
- raise RuntimeError(f"Unsupported CUDA version: {cuda_version}")
19
-
20
- subprocess.check_call([sys.executable, "-m", "pip", "install", pip_cmd])
21
  # Import the necessary modules after installation
22
  from vllm import LLM, SamplingParams
23
  from vllm.utils import random_uuid
 
9
  try:
10
  import vllm
11
  except ImportError:
12
+ # Install vllm with CUDA 11.8 support
13
+ vllm_version = "v0.6.1.post1"
14
+ pip_cmd = [
15
+ sys.executable,
16
+ "-m", "pip", "install",
17
+ f"https://github.com/vllm-project/vllm/releases/download/{vllm_version}/vllm-{vllm_version}+cu118-cp310-cp310-manylinux1_x86_64.whl",
18
+ "--extra-index-url", "https://download.pytorch.org/whl/cu118"
19
+ ]
20
+ subprocess.check_call(pip_cmd)
21
  # Import the necessary modules after installation
22
  from vllm import LLM, SamplingParams
23
  from vllm.utils import random_uuid