Nick Vandal commited on
Commit
6c420e0
1 Parent(s): 0c371b7

added mupliple models and revisions

Browse files
Files changed (2) hide show
  1. LLaVA +1 -1
  2. app.py +20 -4
LLaVA CHANGED
@@ -1 +1 @@
1
- Subproject commit 3e83206ab58f79936da2742c85c93dfd3890451c
 
1
+ Subproject commit 3c2f6ba15ed0477f4149fd582d2b640e19da2a57
app.py CHANGED
@@ -25,7 +25,7 @@ def start_controller():
25
  return subprocess.Popen(controller_command)
26
 
27
 
28
- def start_worker(model_path: str, bits=16):
29
  print(f"Starting the model worker for the model {model_path}")
30
  model_name = model_path.strip("/").split("/")[-1]
31
  assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
@@ -37,6 +37,10 @@ def start_worker(model_path: str, bits=16):
37
  "llava.serve.model_worker",
38
  "--host",
39
  "0.0.0.0",
 
 
 
 
40
  "--controller",
41
  "http://localhost:10000",
42
  "--model-path",
@@ -44,6 +48,8 @@ def start_worker(model_path: str, bits=16):
44
  "--model-name",
45
  model_name,
46
  "--use-flash-attn",
 
 
47
  ]
48
  if bits != 16:
49
  worker_command += [f"--load-{bits}bit"]
@@ -77,12 +83,21 @@ Set the environment variable `model` to change the model:
77
 
78
  print(f"args: {gws.args}")
79
 
80
- model_path = os.getenv("model", "liuhaotian/llava-v1.6-mistral-7b")
 
81
  bits = int(os.getenv("bits", 4))
82
  concurrency_count = int(os.getenv("concurrency_count", 5))
83
 
84
  controller_proc = start_controller()
85
- worker_proc = start_worker(model_path, bits=bits)
 
 
 
 
 
 
 
 
86
 
87
  # Wait for worker and controller to start
88
  time.sleep(10)
@@ -103,7 +118,8 @@ Set the environment variable `model` to change the model:
103
  print(e)
104
  exit_status = 1
105
  finally:
106
- worker_proc.kill()
 
107
  controller_proc.kill()
108
 
109
  sys.exit(exit_status)
 
25
  return subprocess.Popen(controller_command)
26
 
27
 
28
+ def start_worker(model_path: str, bits=16, revision='main', port=21002):
29
  print(f"Starting the model worker for the model {model_path}")
30
  model_name = model_path.strip("/").split("/")[-1]
31
  assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
 
37
  "llava.serve.model_worker",
38
  "--host",
39
  "0.0.0.0",
40
+ "--port",
41
+ port,
42
+ "--worker-address",
43
+ f"http://127.0.0.1:{port}",
44
  "--controller",
45
  "http://localhost:10000",
46
  "--model-path",
 
48
  "--model-name",
49
  model_name,
50
  "--use-flash-attn",
51
+ "--revision",
52
+ revision
53
  ]
54
  if bits != 16:
55
  worker_command += [f"--load-{bits}bit"]
 
83
 
84
  print(f"args: {gws.args}")
85
 
86
+ model_paths = os.getenv("model", "nvandal/LLaVA-Med-v1.5-7b")
87
+ revisions = os.getenv("revision", "main")
88
  bits = int(os.getenv("bits", 4))
89
  concurrency_count = int(os.getenv("concurrency_count", 5))
90
 
91
  controller_proc = start_controller()
92
+ start_worker_port = 21002
93
+
94
+ model_paths = model_paths.split(';')
95
+ revisions = revisions.split(';')
96
+ assert(len(model_paths)==len(revisions))
97
+ worker_proc = [None]*len(model_paths)
98
+ for i, (model_path, revision) in enumerate(zip(model_paths,revisions)):
99
+ print(model_path, revision)
100
+ worker_proc[i] = start_worker(model_path, bits=bits, revision=revision, port=str(start_worker_port+i))
101
 
102
  # Wait for worker and controller to start
103
  time.sleep(10)
 
118
  print(e)
119
  exit_status = 1
120
  finally:
121
+ for w in worker_proc:
122
+ w.kill()
123
  controller_proc.kill()
124
 
125
  sys.exit(exit_status)