File size: 4,094 Bytes
46282cc
 
 
 
 
 
b95dd01
46282cc
 
cd043e1
46282cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd043e1
46282cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d068e0
46282cc
 
 
 
 
b95dd01
 
 
 
 
 
 
 
 
46282cc
 
 
 
 
 
 
 
 
 
cd043e1
46282cc
 
 
 
4d068e0
46282cc
 
 
 
 
 
 
 
 
 
4d068e0
 
46282cc
 
 
 
 
 
4d068e0
46282cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys
import os
import argparse
import time
import subprocess

import gradio_web_server as gws

# Execute the pip install command with additional options
# subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])


def start_controller():
    print("Starting the controller")
    controller_command = [
        sys.executable,
        "-m",
        "llava.serve.controller",
        "--host",
        "0.0.0.0",
        "--port",
        "10000",
    ]
    print(controller_command)
    return subprocess.Popen(controller_command)


def start_worker(model_path: str, model_name: str, bits=16, device=0):
    print(f"Starting the model worker for the model {model_path}")
    # model_name = model_path.strip("/").split("/")[-1]
    device = f"cuda:{device}" if isinstance(device, int) else device
    assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
    if bits != 16:
        model_name += f"-{bits}bit"
    worker_command = [
        sys.executable,
        "-m",
        "llava.serve.model_worker",
        "--host",
        "0.0.0.0",
        "--controller",
        "http://localhost:10000",
        "--model-path",
        model_path,
        "--model-name",
        model_name,
        "--use-flash-attn",
        '--device',
        device
    ]
    if bits != 16:
        worker_command += [f"--load-{bits}bit"]
    print(worker_command)
    return subprocess.Popen(worker_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int)
    parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
    parser.add_argument("--concurrency-count", type=int, default=5)
    parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
    parser.add_argument("--share", action="store_true")
    parser.add_argument("--moderate", action="store_true")
    parser.add_argument("--embed", action="store_true")
    gws.args = parser.parse_args()
    gws.models = []

    gws.title_markdown += """

ONLY WORKS WITH GPU!

Set the environment variable `model` to change the model:
['AIML-TUDA/LlavaGuard-7B'](https://huggingface.co/AIML-TUDA/LlavaGuard-7B),
['AIML-TUDA/LlavaGuard-13B'](https://huggingface.co/AIML-TUDA/LlavaGuard-13B),
['AIML-TUDA/LlavaGuard-34B'](https://huggingface.co/AIML-TUDA/LlavaGuard-34B),
"""
    # set_up_env_and_token(read=True)
    print(f"args: {gws.args}")
    # set the huggingface login token
    controller_proc = start_controller()
    concurrency_count = int(os.getenv("concurrency_count", 5))
    api_key = os.getenv("token")
    if api_key:
        cmd = f"huggingface-cli login --token {api_key} --add-to-git-credential"
        os.system(cmd)
    else:
        if '/workspace' not in sys.path:
            sys.path.append('/workspace')
        from llavaguard.hf_utils import set_up_env_and_token
        set_up_env_and_token(read=True, write=False)

    models = [
        'LukasHug/LlavaGuard-7B-hf',
        'LukasHug/LlavaGuard-13B-hf',
        'LukasHug/LlavaGuard-34B-hf',]
    bits = int(os.getenv("bits", 16))
    model = os.getenv("model", models[0])
    available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
    model_path, model_name = model, model.split("/")[-1]

    # worker_proc = start_worker(model_path, model_name, bits=bits)


    # Wait for worker and controller to start
    time.sleep(10)

    exit_status = 0
    try:
        demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
        demo.queue(
            status_update_rate=10,
            api_open=False
        ).launch(
            server_name=gws.args.host,
            server_port=gws.args.port,
            share=gws.args.share
        )

    except Exception as e:
        print(e)
        exit_status = 1
    finally:
        worker_proc.kill()
        controller_proc.kill()

        sys.exit(exit_status)