File size: 4,131 Bytes
14f034e
 
 
 
 
 
 
 
d85229e
 
 
14f034e
 
 
 
d85229e
14f034e
 
 
 
 
 
 
a0ac53d
14f034e
 
 
6c420e0
14f034e
 
 
 
 
 
d85229e
14f034e
 
 
 
6c420e0
 
 
 
14f034e
 
 
 
 
 
7564980
6c420e0
 
14f034e
 
 
a0ac53d
14f034e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c7aced
 
 
 
1c1b839
 
 
 
 
1c7aced
 
14f034e
 
6c420e0
 
14f034e
 
 
 
6c420e0
 
 
 
 
 
 
 
 
14f034e
 
 
 
 
 
8fe0728
14f034e
 
 
 
 
 
 
 
 
 
 
 
 
6c420e0
 
14f034e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
import os
import argparse
import time
import subprocess

import llava.serve.gradio_web_server as gws

# Execute the pip install command with additional options
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])


def start_controller():
    print("Starting the controller")
    controller_command = [
        sys.executable,
        "-m",
        "llava.serve.controller",
        "--host",
        "0.0.0.0",
        "--port",
        "10000",
    ]
    print(controller_command)
    return subprocess.Popen(controller_command)


def start_worker(model_path: str, bits=16, revision='main', port=21002):
    print(f"Starting the model worker for the model {model_path}")
    model_name = model_path.strip("/").split("/")[-1]
    assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
    if bits != 16:
        model_name += f"-{bits}bit"
    worker_command = [
        sys.executable,
        "-m",
        "llava.serve.model_worker",
        "--host",
        "0.0.0.0",
        "--port",
        port,
        "--worker-address",
        f"http://127.0.0.1:{port}",
        "--controller",
        "http://localhost:10000",
        "--model-path",
        model_path,
        "--model-name",
        model_name,
        "--use-flash-attn",
        "--revision",
        revision
    ]
    if bits != 16:
        worker_command += [f"--load-{bits}bit"]
    print(worker_command)
    return subprocess.Popen(worker_command)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int)
    parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
    parser.add_argument("--concurrency-count", type=int, default=5)
    parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
    parser.add_argument("--share", action="store_true")
    parser.add_argument("--moderate", action="store_true")
    parser.add_argument("--embed", action="store_true")
    gws.args = parser.parse_args()
    gws.models = []

    gws.title_markdown += """

ONLY WORKS WITH GPU! By default, we load the model with 4-bit quantization to make it fit in smaller hardwares. Set the environment variable `bits` to control the quantization.

Set the environment variable `model` to change the model:
[`liuhaotian/llava-v1.6-mistral-7b`](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b),
[`liuhaotian/llava-v1.6-vicuna-7b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b),
[`liuhaotian/llava-v1.6-vicuna-13b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-13b),
[`liuhaotian/llava-v1.6-34b`](https://huggingface.co/liuhaotian/llava-v1.6-34b).
"""

    print(f"args: {gws.args}")

    model_paths = os.getenv("model", "nvandal/LLaVA-Med-v1.5-7b")
    revisions = os.getenv("revision", "main")
    bits = int(os.getenv("bits", 4))
    concurrency_count = int(os.getenv("concurrency_count", 5))

    controller_proc = start_controller()
    start_worker_port = 21002

    model_paths = model_paths.split(';')
    revisions = revisions.split(';')
    assert(len(model_paths)==len(revisions))
    worker_proc = [None]*len(model_paths)
    for i, (model_path, revision) in enumerate(zip(model_paths,revisions)):
        print(model_path, revision)
        worker_proc[i] = start_worker(model_path, bits=bits, revision=revision, port=str(start_worker_port+i))

    # Wait for worker and controller to start
    time.sleep(10)

    exit_status = 0
    try:
        demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
        demo.queue(
            status_update_rate=10,
            api_open=False
        ).launch(
            server_name=gws.args.host,
            server_port=gws.args.port,
            share=gws.args.share
        )

    except Exception as e:
        print(e)
        exit_status = 1
    finally:
        for w in worker_proc:
            w.kill()
        controller_proc.kill()

        sys.exit(exit_status)