Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,114 Bytes
46282cc 83baad4 b95dd01 46282cc cd043e1 46282cc 83baad4 46282cc 83baad4 46282cc cd043e1 46282cc 4d068e0 83baad4 46282cc 83baad4 46282cc b95dd01 46282cc 83baad4 46282cc ebf946b 46282cc 83baad4 46282cc a74784f 46282cc 83baad4 4d068e0 46282cc 4d068e0 46282cc a74784f 46282cc 4d068e0 83baad4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import sys
import os
import argparse
import time
import subprocess
import spaces
import gradio_web_server as gws
# Execute the pip install command with additional options
# subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
def start_controller():
print("Starting the controller")
controller_command = [
sys.executable,
"-m",
"llava.serve.controller",
"--host",
"0.0.0.0",
"--port",
"10000",
]
print(controller_command)
return subprocess.Popen(controller_command)
@spaces.GPU
def start_worker(model_path: str, model_name: str, bits=16, device=0):
print(f"Starting the model worker for the model {model_path}")
# model_name = model_path.strip("/").split("/")[-1]
device = f"cuda:{device}" if isinstance(device, int) else device
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
if bits != 16:
model_name += f"-{bits}bit"
worker_command = [
sys.executable,
"-m",
"llava.serve.model_worker",
"--host",
"0.0.0.0",
"--controller",
"http://localhost:10000",
"--model-path",
model_path,
"--model-name",
model_name,
"--use-flash-attn",
'--device',
device
]
if bits != 16:
worker_command += [f"--load-{bits}bit"]
print(worker_command)
return subprocess.Popen(worker_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
parser.add_argument("--concurrency-count", type=int, default=5)
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
gws.args = parser.parse_args()
gws.models = []
gws.title_markdown += """
ONLY WORKS WITH GPU!
Set the environment variable `model` to change the model:
['AIML-TUDA/LlavaGuard-7B'](https://huggingface.co/AIML-TUDA/LlavaGuard-7B),
['AIML-TUDA/LlavaGuard-13B'](https://huggingface.co/AIML-TUDA/LlavaGuard-13B),
['AIML-TUDA/LlavaGuard-34B'](https://huggingface.co/AIML-TUDA/LlavaGuard-34B),
"""
# set_up_env_and_token(read=True)
print(f"args: {gws.args}")
# set the huggingface login token
controller_proc = start_controller()
concurrency_count = int(os.getenv("concurrency_count", 5))
api_key = os.getenv("token")
if api_key:
cmd = f"huggingface-cli login --token {api_key} --add-to-git-credential"
os.system(cmd)
else:
if '/workspace' not in sys.path:
sys.path.append('/workspace')
from llavaguard.hf_utils import set_up_env_and_token
set_up_env_and_token(read=True, write=False)
models = [
'LukasHug/LlavaGuard-7B-hf',
'LukasHug/LlavaGuard-13B-hf',
'LukasHug/LlavaGuard-34B-hf',]
bits = int(os.getenv("bits", 16))
model = os.getenv("model", models[0])
available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
model_path, model_name = model, model.split("/")[1]
worker_proc = start_worker(model_path, model_name, bits=bits)
# Wait for worker and controller to start
time.sleep(50)
exit_status = 0
try:
demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
demo.queue(
status_update_rate=10,
api_open=False
).launch(
server_name=gws.args.host,
server_port=gws.args.port,
share=gws.args.share
)
except Exception as e:
print(e)
exit_status = 1
finally:
worker_proc.kill()
controller_proc.kill()
sys.exit(exit_status) |