Nick Vandal commited on
Commit
d6fd1f0
1 Parent(s): 2ea0ee2

update to support lora, added new examples, updated prompt

Browse files
Files changed (4) hide show
  1. LLaVA +1 -1
  2. app.py +12 -5
  3. examples/carpus.jpg +0 -0
  4. examples/lateral_wrist.jpg +0 -0
LLaVA CHANGED
@@ -1 +1 @@
1
- Subproject commit a54459c2dd993c4fd2f571cd16bf73cb8dbdcb00
 
1
+ Subproject commit 30e73a40fe42f392dac3ad9466b3d62e1a40ad07
app.py CHANGED
@@ -25,7 +25,7 @@ def start_controller():
25
  return subprocess.Popen(controller_command)
26
 
27
 
28
- def start_worker(model_path: str, bits=16, revision='main', port=21002):
29
  print(f"Starting the model worker for the model {model_path}")
30
  model_name = model_path.strip("/").split("/")[-1]
31
  assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
@@ -49,8 +49,13 @@ def start_worker(model_path: str, bits=16, revision='main', port=21002):
49
  model_name,
50
  "--use-flash-attn",
51
  "--revision",
52
- revision
53
  ]
 
 
 
 
 
54
  if bits != 16:
55
  worker_command += [f"--load-{bits}bit"]
56
  print(worker_command)
@@ -84,6 +89,7 @@ Set the environment variable `model` to change the model:
84
  print(f"args: {gws.args}")
85
 
86
  model_paths = os.getenv("model", "nvandal/LLaVA-Med-v1.5-7b")
 
87
  revisions = os.getenv("revision", "main")
88
  bits = int(os.getenv("bits", 4))
89
  concurrency_count = int(os.getenv("concurrency_count", 5))
@@ -93,11 +99,12 @@ Set the environment variable `model` to change the model:
93
 
94
  model_paths = model_paths.split(';')
95
  revisions = revisions.split(';')
 
96
  assert(len(model_paths)==len(revisions))
97
  worker_proc = [None]*len(model_paths)
98
- for i, (model_path, revision) in enumerate(zip(model_paths,revisions)):
99
- print(model_path, revision)
100
- worker_proc[i] = start_worker(model_path, bits=bits, revision=revision, port=str(start_worker_port+i))
101
 
102
  # Wait for worker and controller to start
103
  time.sleep(10)
 
25
  return subprocess.Popen(controller_command)
26
 
27
 
28
+ def start_worker(model_path: str, bits=16, revision='main', model_base = None, port=21002):
29
  print(f"Starting the model worker for the model {model_path}")
30
  model_name = model_path.strip("/").split("/")[-1]
31
  assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
 
49
  model_name,
50
  "--use-flash-attn",
51
  "--revision",
52
+ revision,
53
  ]
54
+ if model_base:
55
+ worker_command += [
56
+ "--model-base",
57
+ model_base
58
+ ]
59
  if bits != 16:
60
  worker_command += [f"--load-{bits}bit"]
61
  print(worker_command)
 
89
  print(f"args: {gws.args}")
90
 
91
  model_paths = os.getenv("model", "nvandal/LLaVA-Med-v1.5-7b")
92
+ model_base = os.getenv("model_base", '')
93
  revisions = os.getenv("revision", "main")
94
  bits = int(os.getenv("bits", 4))
95
  concurrency_count = int(os.getenv("concurrency_count", 5))
 
99
 
100
  model_paths = model_paths.split(';')
101
  revisions = revisions.split(';')
102
+ model_base = model_base.split(';')
103
  assert(len(model_paths)==len(revisions))
104
  worker_proc = [None]*len(model_paths)
105
+ for i, (model_path, revision, model_base) in enumerate(zip(model_paths,revisions,model_base)):
106
+ print(model_path, revision, model_base)
107
+ worker_proc[i] = start_worker(model_path, bits=bits, revision=revision, model_base=model_base, port=str(start_worker_port+i))
108
 
109
  # Wait for worker and controller to start
110
  time.sleep(10)
examples/carpus.jpg ADDED
examples/lateral_wrist.jpg ADDED