optimize
Browse files- .DS_Store +0 -0
- configs/.DS_Store +0 -0
- controller.log +8 -0
- model_worker_ad9563.log +0 -17
- star-vector-dev/.DS_Store +0 -0
- star-vector-dev/.gitattributes +35 -0
- star-vector-dev/.gitignore +181 -0
- start.sh +0 -1
- starvector/.DS_Store +0 -0
- starvector/__pycache__/__init__.cpython-311.pyc +0 -0
- starvector/serve/.DS_Store +0 -0
- starvector/serve/__pycache__/__init__.cpython-311.pyc +0 -0
- starvector/serve/__pycache__/constants.cpython-311.pyc +0 -0
- starvector/serve/__pycache__/conversation.cpython-311.pyc +0 -0
- starvector/serve/__pycache__/util.cpython-311.pyc +0 -0
- starvector/serve/controller.py +293 -0
- starvector/serve/gradio_demo_with_updated_gradio.py +432 -0
- starvector/serve/gradio_web_server.py +562 -0
- starvector/serve/model_worker.py +269 -0
- starvector/serve/vllm_api_gradio/__pycache__/controller.cpython-311.pyc +0 -0
- starvector/serve/vllm_api_gradio/__pycache__/gradio_web_server.cpython-311.pyc +0 -0
- starvector/serve/vllm_api_gradio/__pycache__/model_worker.cpython-311.pyc +0 -0
- starvector/serve/vllm_api_gradio/gradio_web_server.py +17 -8
- starvector/serve/vllm_api_gradio/model_worker.py +3 -1
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
configs/.DS_Store
CHANGED
Binary files a/configs/.DS_Store and b/configs/.DS_Store differ
|
|
controller.log
CHANGED
@@ -29,3 +29,11 @@
|
|
29 |
2025-03-23 15:04:32 | ERROR | stderr | [32mINFO[0m: Waiting for application startup.
|
30 |
2025-03-23 15:04:32 | ERROR | stderr | [32mINFO[0m: Application startup complete.
|
31 |
2025-03-23 15:04:32 | ERROR | stderr | [32mINFO[0m: Uvicorn running on [1mhttp://0.0.0.0:10000[0m (Press CTRL+C to quit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
2025-03-23 15:04:32 | ERROR | stderr | [32mINFO[0m: Waiting for application startup.
|
30 |
2025-03-23 15:04:32 | ERROR | stderr | [32mINFO[0m: Application startup complete.
|
31 |
2025-03-23 15:04:32 | ERROR | stderr | [32mINFO[0m: Uvicorn running on [1mhttp://0.0.0.0:10000[0m (Press CTRL+C to quit)
|
32 |
+
2025-03-24 14:06:11 | INFO | controller | args: Namespace(host='0.0.0.0', port=10000, dispatch_method='shortest_queue')
|
33 |
+
2025-03-24 14:06:11 | INFO | controller | Init controller
|
34 |
+
2025-03-24 14:06:11 | ERROR | stderr | [32mINFO[0m: Started server process [[36m95537[0m]
|
35 |
+
2025-03-24 14:06:11 | ERROR | stderr | [32mINFO[0m: Waiting for application startup.
|
36 |
+
2025-03-24 14:06:11 | ERROR | stderr | [32mINFO[0m: Application startup complete.
|
37 |
+
2025-03-24 14:06:11 | ERROR | stderr | [31mERROR[0m: [Errno 48] error while attempting to bind on address ('0.0.0.0', 10000): address already in use
|
38 |
+
2025-03-24 14:06:11 | ERROR | stderr | [32mINFO[0m: Waiting for application shutdown.
|
39 |
+
2025-03-24 14:06:11 | ERROR | stderr | [32mINFO[0m: Application shutdown complete.
|
model_worker_ad9563.log
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
2025-03-23 15:01:04 | INFO | model_worker | args: Namespace(host='0.0.0.0', port=40000, worker_address='http://localhost:40000', controller_address='http://localhost:10000', model_name='/home/agent_h/data/starvector-1b-im2svg', multi_modal=False, limit_model_concurrency=5, stream_interval=1, no_register=False, openai_api_key='EMPTY', vllm_base_url='http://localhost:8000')
|
2 |
-
2025-03-23 15:01:04 | INFO | model_worker | Loading the model /home/agent_h/data/starvector-1b-im2svg on worker ad9563 ...
|
3 |
-
2025-03-23 15:01:04 | INFO | model_worker | Register to controller
|
4 |
-
2025-03-23 15:01:04 | ERROR | stderr | [32mINFO[0m: Started server process [[36m48407[0m]
|
5 |
-
2025-03-23 15:01:04 | ERROR | stderr | [32mINFO[0m: Waiting for application startup.
|
6 |
-
2025-03-23 15:01:04 | ERROR | stderr | [32mINFO[0m: Application startup complete.
|
7 |
-
2025-03-23 15:01:04 | ERROR | stderr | [32mINFO[0m: Uvicorn running on [1mhttp://0.0.0.0:40000[0m (Press CTRL+C to quit)
|
8 |
-
2025-03-23 15:01:19 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
|
9 |
-
2025-03-23 15:01:34 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
|
10 |
-
2025-03-23 15:01:49 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
|
11 |
-
2025-03-23 15:02:04 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
|
12 |
-
2025-03-23 15:02:19 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
|
13 |
-
2025-03-23 15:02:34 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
|
14 |
-
2025-03-23 15:02:45 | ERROR | stderr | [32mINFO[0m: Shutting down
|
15 |
-
2025-03-23 15:02:45 | ERROR | stderr | [32mINFO[0m: Waiting for application shutdown.
|
16 |
-
2025-03-23 15:02:45 | ERROR | stderr | [32mINFO[0m: Application shutdown complete.
|
17 |
-
2025-03-23 15:02:45 | ERROR | stderr | [32mINFO[0m: Finished server process [[36m48407[0m]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
star-vector-dev/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
star-vector-dev/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
star-vector-dev/.gitignore
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
# Other
|
163 |
+
*vscode*
|
164 |
+
*egg*
|
165 |
+
*nfs*
|
166 |
+
*conv.json*
|
167 |
+
*rebuttal*
|
168 |
+
*.log*
|
169 |
+
*remove_files*
|
170 |
+
*wandb*
|
171 |
+
*tmp*
|
172 |
+
*vscode*
|
173 |
+
*.csv
|
174 |
+
*avoid_samples*
|
175 |
+
*logs*
|
176 |
+
*results*
|
177 |
+
*.pickle
|
178 |
+
*.pkl
|
179 |
+
*internal*
|
180 |
+
*test.png*
|
181 |
+
assets/reward_assets
|
start.sh
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
|
3 |
bash -c "$SSH_TUNNEL_CMD_1" &
|
4 |
|
5 |
-
echo "SSH tunnel started, PID: $SSH_PID"
|
6 |
python -m starvector.serve.vllm_api_gradio.controller --host 0.0.0.0 --port 10000 &
|
7 |
python -m starvector.serve.vllm_api_gradio.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-name /home/agent_h/data/starvector-1b-im2svg --vllm-base-url http://localhost:8000 &
|
8 |
python -m starvector.serve.vllm_api_gradio.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 7860
|
|
|
2 |
|
3 |
bash -c "$SSH_TUNNEL_CMD_1" &
|
4 |
|
|
|
5 |
python -m starvector.serve.vllm_api_gradio.controller --host 0.0.0.0 --port 10000 &
|
6 |
python -m starvector.serve.vllm_api_gradio.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-name /home/agent_h/data/starvector-1b-im2svg --vllm-base-url http://localhost:8000 &
|
7 |
python -m starvector.serve.vllm_api_gradio.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 7860
|
starvector/.DS_Store
CHANGED
Binary files a/starvector/.DS_Store and b/starvector/.DS_Store differ
|
|
starvector/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/starvector/__pycache__/__init__.cpython-311.pyc and b/starvector/__pycache__/__init__.cpython-311.pyc differ
|
|
starvector/serve/.DS_Store
CHANGED
Binary files a/starvector/serve/.DS_Store and b/starvector/serve/.DS_Store differ
|
|
starvector/serve/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/starvector/serve/__pycache__/__init__.cpython-311.pyc and b/starvector/serve/__pycache__/__init__.cpython-311.pyc differ
|
|
starvector/serve/__pycache__/constants.cpython-311.pyc
CHANGED
Binary files a/starvector/serve/__pycache__/constants.cpython-311.pyc and b/starvector/serve/__pycache__/constants.cpython-311.pyc differ
|
|
starvector/serve/__pycache__/conversation.cpython-311.pyc
CHANGED
Binary files a/starvector/serve/__pycache__/conversation.cpython-311.pyc and b/starvector/serve/__pycache__/conversation.cpython-311.pyc differ
|
|
starvector/serve/__pycache__/util.cpython-311.pyc
CHANGED
Binary files a/starvector/serve/__pycache__/util.cpython-311.pyc and b/starvector/serve/__pycache__/util.cpython-311.pyc differ
|
|
starvector/serve/controller.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import asyncio
|
7 |
+
import dataclasses
|
8 |
+
from enum import Enum, auto
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import time
|
12 |
+
from typing import List, Union
|
13 |
+
import threading
|
14 |
+
|
15 |
+
from fastapi import FastAPI, Request
|
16 |
+
from fastapi.responses import StreamingResponse
|
17 |
+
import numpy as np
|
18 |
+
import requests
|
19 |
+
import uvicorn
|
20 |
+
|
21 |
+
from starvector.serve.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
+
from starvector.serve.util import build_logger, server_error_msg
|
23 |
+
|
24 |
+
logger = build_logger("controller", "controller.log")
|
25 |
+
|
26 |
+
class DispatchMethod(Enum):
|
27 |
+
LOTTERY = auto()
|
28 |
+
SHORTEST_QUEUE = auto()
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def from_str(cls, name):
|
32 |
+
if name == "lottery":
|
33 |
+
return cls.LOTTERY
|
34 |
+
elif name == "shortest_queue":
|
35 |
+
return cls.SHORTEST_QUEUE
|
36 |
+
else:
|
37 |
+
raise ValueError(f"Invalid dispatch method")
|
38 |
+
|
39 |
+
|
40 |
+
@dataclasses.dataclass
|
41 |
+
class WorkerInfo:
|
42 |
+
model_names: List[str]
|
43 |
+
speed: int
|
44 |
+
queue_length: int
|
45 |
+
check_heart_beat: bool
|
46 |
+
last_heart_beat: str
|
47 |
+
|
48 |
+
|
49 |
+
def heart_beat_controller(controller):
|
50 |
+
while True:
|
51 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
52 |
+
controller.remove_stable_workers_by_expiration()
|
53 |
+
|
54 |
+
|
55 |
+
class Controller:
|
56 |
+
def __init__(self, dispatch_method: str):
|
57 |
+
# Dict[str -> WorkerInfo]
|
58 |
+
self.worker_info = {}
|
59 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
60 |
+
|
61 |
+
self.heart_beat_thread = threading.Thread(
|
62 |
+
target=heart_beat_controller, args=(self,))
|
63 |
+
self.heart_beat_thread.start()
|
64 |
+
|
65 |
+
logger.info("Init controller")
|
66 |
+
|
67 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
68 |
+
worker_status: dict):
|
69 |
+
if worker_name not in self.worker_info:
|
70 |
+
logger.info(f"Register a new worker: {worker_name}")
|
71 |
+
else:
|
72 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
73 |
+
|
74 |
+
if not worker_status:
|
75 |
+
worker_status = self.get_worker_status(worker_name)
|
76 |
+
if not worker_status:
|
77 |
+
return False
|
78 |
+
|
79 |
+
self.worker_info[worker_name] = WorkerInfo(
|
80 |
+
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
81 |
+
check_heart_beat, time.time())
|
82 |
+
|
83 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
84 |
+
return True
|
85 |
+
|
86 |
+
def get_worker_status(self, worker_name: str):
|
87 |
+
try:
|
88 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
89 |
+
except requests.exceptions.RequestException as e:
|
90 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
91 |
+
return None
|
92 |
+
|
93 |
+
if r.status_code != 200:
|
94 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
95 |
+
return None
|
96 |
+
|
97 |
+
return r.json()
|
98 |
+
|
99 |
+
def remove_worker(self, worker_name: str):
|
100 |
+
del self.worker_info[worker_name]
|
101 |
+
|
102 |
+
def refresh_all_workers(self):
|
103 |
+
old_info = dict(self.worker_info)
|
104 |
+
self.worker_info = {}
|
105 |
+
|
106 |
+
for w_name, w_info in old_info.items():
|
107 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
108 |
+
logger.info(f"Remove stale worker: {w_name}")
|
109 |
+
|
110 |
+
def list_models(self):
|
111 |
+
model_names = set()
|
112 |
+
|
113 |
+
for w_name, w_info in self.worker_info.items():
|
114 |
+
model_names.update(w_info.model_names)
|
115 |
+
|
116 |
+
return list(model_names)
|
117 |
+
|
118 |
+
def get_worker_address(self, model_name: str):
|
119 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
120 |
+
worker_names = []
|
121 |
+
worker_speeds = []
|
122 |
+
for w_name, w_info in self.worker_info.items():
|
123 |
+
if model_name in w_info.model_names:
|
124 |
+
worker_names.append(w_name)
|
125 |
+
worker_speeds.append(w_info.speed)
|
126 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
127 |
+
norm = np.sum(worker_speeds)
|
128 |
+
if norm < 1e-4:
|
129 |
+
return ""
|
130 |
+
worker_speeds = worker_speeds / norm
|
131 |
+
if True: # Directly return address
|
132 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
133 |
+
p=worker_speeds)
|
134 |
+
worker_name = worker_names[pt]
|
135 |
+
return worker_name
|
136 |
+
|
137 |
+
# Check status before returning
|
138 |
+
while True:
|
139 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
140 |
+
p=worker_speeds)
|
141 |
+
worker_name = worker_names[pt]
|
142 |
+
|
143 |
+
if self.get_worker_status(worker_name):
|
144 |
+
break
|
145 |
+
else:
|
146 |
+
self.remove_worker(worker_name)
|
147 |
+
worker_speeds[pt] = 0
|
148 |
+
norm = np.sum(worker_speeds)
|
149 |
+
if norm < 1e-4:
|
150 |
+
return ""
|
151 |
+
worker_speeds = worker_speeds / norm
|
152 |
+
continue
|
153 |
+
return worker_name
|
154 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
155 |
+
worker_names = []
|
156 |
+
worker_qlen = []
|
157 |
+
for w_name, w_info in self.worker_info.items():
|
158 |
+
if model_name in w_info.model_names:
|
159 |
+
worker_names.append(w_name)
|
160 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
161 |
+
if len(worker_names) == 0:
|
162 |
+
return ""
|
163 |
+
min_index = np.argmin(worker_qlen)
|
164 |
+
w_name = worker_names[min_index]
|
165 |
+
self.worker_info[w_name].queue_length += 1
|
166 |
+
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
167 |
+
return w_name
|
168 |
+
else:
|
169 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
170 |
+
|
171 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
172 |
+
if worker_name not in self.worker_info:
|
173 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
174 |
+
return False
|
175 |
+
|
176 |
+
self.worker_info[worker_name].queue_length = queue_length
|
177 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
178 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
179 |
+
return True
|
180 |
+
|
181 |
+
def remove_stable_workers_by_expiration(self):
|
182 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
183 |
+
to_delete = []
|
184 |
+
for worker_name, w_info in self.worker_info.items():
|
185 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
186 |
+
to_delete.append(worker_name)
|
187 |
+
|
188 |
+
for worker_name in to_delete:
|
189 |
+
self.remove_worker(worker_name)
|
190 |
+
|
191 |
+
def worker_api_generate_stream(self, params):
|
192 |
+
worker_addr = self.get_worker_address(params["model"])
|
193 |
+
if not worker_addr:
|
194 |
+
logger.info(f"no worker: {params['model']}")
|
195 |
+
ret = {
|
196 |
+
"text": server_error_msg,
|
197 |
+
"error_code": 2,
|
198 |
+
}
|
199 |
+
yield json.dumps(ret).encode() + b"\0"
|
200 |
+
|
201 |
+
try:
|
202 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
203 |
+
json=params, stream=True, timeout=5)
|
204 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
205 |
+
if chunk:
|
206 |
+
yield chunk + b"\0"
|
207 |
+
except requests.exceptions.RequestException as e:
|
208 |
+
logger.info(f"worker timeout: {worker_addr}")
|
209 |
+
ret = {
|
210 |
+
"text": server_error_msg,
|
211 |
+
"error_code": 3,
|
212 |
+
}
|
213 |
+
yield json.dumps(ret).encode() + b"\0"
|
214 |
+
|
215 |
+
|
216 |
+
# Let the controller act as a worker to achieve hierarchical
|
217 |
+
# management. This can be used to connect isolated sub networks.
|
218 |
+
def worker_api_get_status(self):
|
219 |
+
model_names = set()
|
220 |
+
speed = 0
|
221 |
+
queue_length = 0
|
222 |
+
|
223 |
+
for w_name in self.worker_info:
|
224 |
+
worker_status = self.get_worker_status(w_name)
|
225 |
+
if worker_status is not None:
|
226 |
+
model_names.update(worker_status["model_names"])
|
227 |
+
speed += worker_status["speed"]
|
228 |
+
queue_length += worker_status["queue_length"]
|
229 |
+
|
230 |
+
return {
|
231 |
+
"model_names": list(model_names),
|
232 |
+
"speed": speed,
|
233 |
+
"queue_length": queue_length,
|
234 |
+
}
|
235 |
+
|
236 |
+
|
237 |
+
app = FastAPI()
|
238 |
+
|
239 |
+
@app.post("/register_worker")
|
240 |
+
async def register_worker(request: Request):
|
241 |
+
data = await request.json()
|
242 |
+
controller.register_worker(
|
243 |
+
data["worker_name"], data["check_heart_beat"],
|
244 |
+
data.get("worker_status", None))
|
245 |
+
|
246 |
+
@app.post("/refresh_all_workers")
|
247 |
+
async def refresh_all_workers():
|
248 |
+
models = controller.refresh_all_workers()
|
249 |
+
|
250 |
+
|
251 |
+
@app.post("/list_models")
|
252 |
+
async def list_models():
|
253 |
+
models = controller.list_models()
|
254 |
+
return {"models": models}
|
255 |
+
|
256 |
+
|
257 |
+
@app.post("/get_worker_address")
|
258 |
+
async def get_worker_address(request: Request):
|
259 |
+
data = await request.json()
|
260 |
+
addr = controller.get_worker_address(data["model"])
|
261 |
+
return {"address": addr}
|
262 |
+
|
263 |
+
@app.post("/receive_heart_beat")
|
264 |
+
async def receive_heart_beat(request: Request):
|
265 |
+
data = await request.json()
|
266 |
+
exist = controller.receive_heart_beat(
|
267 |
+
data["worker_name"], data["queue_length"])
|
268 |
+
return {"exist": exist}
|
269 |
+
|
270 |
+
|
271 |
+
@app.post("/worker_generate_stream")
|
272 |
+
async def worker_api_generate_stream(request: Request):
|
273 |
+
params = await request.json()
|
274 |
+
generator = controller.worker_api_generate_stream(params)
|
275 |
+
return StreamingResponse(generator)
|
276 |
+
|
277 |
+
|
278 |
+
@app.post("/worker_get_status")
|
279 |
+
async def worker_api_get_status(request: Request):
|
280 |
+
return controller.worker_api_get_status()
|
281 |
+
|
282 |
+
|
283 |
+
if __name__ == "__main__":
|
284 |
+
parser = argparse.ArgumentParser()
|
285 |
+
parser.add_argument("--host", type=str, default="localhost")
|
286 |
+
parser.add_argument("--port", type=int, default=21001)
|
287 |
+
parser.add_argument("--dispatch-method", type=str, choices=[
|
288 |
+
"lottery", "shortest_queue"], default="shortest_queue")
|
289 |
+
args = parser.parse_args()
|
290 |
+
logger.info(f"args: {args}")
|
291 |
+
|
292 |
+
controller = Controller(args.dispatch_method)
|
293 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
starvector/serve/gradio_demo_with_updated_gradio.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
from starvector.serve.conversation import default_conversation
|
9 |
+
from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
|
10 |
+
from starvector.serve.util import (build_logger, server_error_msg)
|
11 |
+
|
12 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
13 |
+
headers = {"User-Agent": "StarVector Client"}
|
14 |
+
|
15 |
+
no_change_btn = gr.Button()
|
16 |
+
enable_btn = gr.Button(interactive=True)
|
17 |
+
disable_btn = gr.Button(interactive=False)
|
18 |
+
|
19 |
+
priority = {
|
20 |
+
"starvector-1.4b": "aaaaaaa",
|
21 |
+
}
|
22 |
+
|
23 |
+
def get_conv_log_filename():
|
24 |
+
t = datetime.datetime.now()
|
25 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
26 |
+
return name
|
27 |
+
|
28 |
+
def get_model_list():
|
29 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
30 |
+
assert ret.status_code == 200
|
31 |
+
ret = requests.post(args.controller_url + "/list_models")
|
32 |
+
models = ret.json()["models"]
|
33 |
+
models.sort(key=lambda x: priority.get(x, x))
|
34 |
+
logger.info(f"Models: {models}")
|
35 |
+
return models
|
36 |
+
|
37 |
+
get_window_url_params = """
|
38 |
+
function() {
|
39 |
+
const params = new URLSearchParams(window.location.search);
|
40 |
+
url_params = Object.fromEntries(params);
|
41 |
+
console.log(url_params);
|
42 |
+
return url_params;
|
43 |
+
}
|
44 |
+
"""
|
45 |
+
|
46 |
+
def load_demo(url_params, request: gr.Request):
|
47 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
48 |
+
|
49 |
+
dropdown_update = gr.Dropdown(visible=True)
|
50 |
+
if "model" in url_params:
|
51 |
+
model = url_params["model"]
|
52 |
+
if model in models:
|
53 |
+
dropdown_update = gr.Dropdown(value=model, visible=True)
|
54 |
+
|
55 |
+
state = default_conversation.copy()
|
56 |
+
return state, dropdown_update
|
57 |
+
|
58 |
+
|
59 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
60 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
61 |
+
models = get_model_list()
|
62 |
+
state = default_conversation.copy()
|
63 |
+
dropdown_update = gr.Dropdown(
|
64 |
+
choices=models,
|
65 |
+
value=models[0] if len(models) > 0 else ""
|
66 |
+
)
|
67 |
+
return state, dropdown_update
|
68 |
+
|
69 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
70 |
+
with open(get_conv_log_filename(), "a") as fout:
|
71 |
+
data = {
|
72 |
+
"tstamp": round(time.time(), 4),
|
73 |
+
"type": vote_type,
|
74 |
+
"model": model_selector,
|
75 |
+
"state": state.dict(),
|
76 |
+
"ip": request.client.host,
|
77 |
+
}
|
78 |
+
fout.write(json.dumps(data) + "\n")
|
79 |
+
|
80 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
81 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
82 |
+
vote_last_response(state, "upvote", model_selector, request)
|
83 |
+
return ("",) + (disable_btn,) * 3
|
84 |
+
|
85 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
86 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
87 |
+
vote_last_response(state, "downvote", model_selector, request)
|
88 |
+
return ("",) + (disable_btn,) * 3
|
89 |
+
|
90 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
91 |
+
logger.info(f"flag. ip: {request.client.host}")
|
92 |
+
vote_last_response(state, "flag", model_selector, request)
|
93 |
+
return ("",) + (disable_btn,) * 3
|
94 |
+
|
95 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
96 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
97 |
+
state.messages[-1][-1] = None
|
98 |
+
prev_human_msg = state.messages[-2]
|
99 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
100 |
+
prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
|
101 |
+
state.skip_next = False
|
102 |
+
return (state, None, None, None) + (disable_btn,) * 6
|
103 |
+
|
104 |
+
def clear_history(request: gr.Request):
|
105 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
106 |
+
state = default_conversation.copy()
|
107 |
+
return (state, None, None) + (disable_btn,) * 6
|
108 |
+
|
109 |
+
def send_image(state, image, image_process_mode, request: gr.Request):
|
110 |
+
logger.info(f"send_image. ip: {request.client.host}.")
|
111 |
+
state.stop_sampling = False
|
112 |
+
if image is None:
|
113 |
+
state.skip_next = True
|
114 |
+
return (state, None, None, image) + (no_change_btn,) * 6
|
115 |
+
|
116 |
+
if image is not None:
|
117 |
+
text = (image, image_process_mode)
|
118 |
+
state.append_message(state.roles[0], text)
|
119 |
+
state.append_message(state.roles[1], "β")
|
120 |
+
state.skip_next = False
|
121 |
+
msg = state.to_gradio_svg_code()[0][1]
|
122 |
+
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 6
|
123 |
+
|
124 |
+
def stop_sampling(state, image, request: gr.Request):
|
125 |
+
logger.info(f"stop_sampling. ip: {request.client.host}")
|
126 |
+
state.stop_sampling = True
|
127 |
+
return (state, None, None, image) + (disable_btn,) * 6
|
128 |
+
|
129 |
+
def http_bot(state, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
|
130 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
131 |
+
start_tstamp = time.time()
|
132 |
+
model_name = model_selector
|
133 |
+
|
134 |
+
if state.skip_next:
|
135 |
+
# This generate call is skipped due to invalid inputs
|
136 |
+
yield (state, None, None) + (no_change_btn,) * 6
|
137 |
+
return
|
138 |
+
|
139 |
+
# Query worker address
|
140 |
+
controller_url = args.controller_url
|
141 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
142 |
+
json={"model": model_name})
|
143 |
+
worker_addr = ret.json()["address"]
|
144 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
145 |
+
|
146 |
+
# No available worker
|
147 |
+
if worker_addr == "":
|
148 |
+
state.messages[-1][-1] = server_error_msg
|
149 |
+
yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
150 |
+
return
|
151 |
+
|
152 |
+
# Construct prompt
|
153 |
+
prompt = state.get_prompt()
|
154 |
+
|
155 |
+
# Make requests
|
156 |
+
pload = {
|
157 |
+
"model": model_name,
|
158 |
+
"prompt": prompt,
|
159 |
+
"num_beams": int(num_beams),
|
160 |
+
"temperature": float(temperature),
|
161 |
+
"len_penalty": float(len_penalty),
|
162 |
+
"top_p": float(top_p),
|
163 |
+
"max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
|
164 |
+
}
|
165 |
+
logger.info(f"==== request ====\n{pload}")
|
166 |
+
|
167 |
+
pload['images'] = state.get_images()
|
168 |
+
|
169 |
+
state.messages[-1][-1] = "β"
|
170 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn)
|
171 |
+
|
172 |
+
try:
|
173 |
+
# Stream output
|
174 |
+
if state.stop_sampling:
|
175 |
+
state.messages[1][-1] = "β"
|
176 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
177 |
+
return
|
178 |
+
|
179 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
180 |
+
headers=headers, json=pload, stream=True, timeout=100)
|
181 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
182 |
+
if chunk:
|
183 |
+
data = json.loads(chunk.decode())
|
184 |
+
if data["error_code"] == 0:
|
185 |
+
# output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered
|
186 |
+
output = data["text"].strip()
|
187 |
+
state.messages[-1][-1] = output + "β"
|
188 |
+
st = state.to_gradio_svg_code()
|
189 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn)
|
190 |
+
else:
|
191 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
192 |
+
state.messages[-1][-1] = output
|
193 |
+
|
194 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
195 |
+
return
|
196 |
+
time.sleep(0.03)
|
197 |
+
except requests.exceptions.RequestException as e:
|
198 |
+
state.messages[-1][-1] = server_error_msg
|
199 |
+
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
200 |
+
return
|
201 |
+
|
202 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 6
|
203 |
+
|
204 |
+
finish_tstamp = time.time()
|
205 |
+
logger.info(f"{output}")
|
206 |
+
|
207 |
+
with open(get_conv_log_filename(), "a") as fout:
|
208 |
+
data = {
|
209 |
+
"tstamp": round(finish_tstamp, 4),
|
210 |
+
"type": "chat",
|
211 |
+
"model": model_name,
|
212 |
+
"start": round(start_tstamp, 4),
|
213 |
+
"finish": round(finish_tstamp, 4),
|
214 |
+
"svg": state.messages[-1][-1],
|
215 |
+
"ip": request.client.host,
|
216 |
+
}
|
217 |
+
fout.write(json.dumps(data) + "\n")
|
218 |
+
|
219 |
+
title_markdown = ("""
|
220 |
+
# π« StarVector: Generating Scalable Vector Graphics Code from Images and Text
|
221 |
+
[[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | π [[StarVector](https://arxiv.org/abs/2312.11556)]
|
222 |
+
""")
|
223 |
+
|
224 |
+
sub_title_markdown = (""" Throw an image and vectorize it! The model expects vector-like images to generate the corresponding svg code.""")
|
225 |
+
tos_markdown = ("""
|
226 |
+
### Terms of use
|
227 |
+
By using this service, users are required to agree to the following terms:
|
228 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
229 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
230 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
231 |
+
""")
|
232 |
+
|
233 |
+
|
234 |
+
learn_more_markdown = ("""
|
235 |
+
### License
|
236 |
+
The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
|
237 |
+
""")
|
238 |
+
|
239 |
+
block_css = """
|
240 |
+
|
241 |
+
#buttons button {
|
242 |
+
min-width: min(120px,100%);
|
243 |
+
}
|
244 |
+
|
245 |
+
.gradio-container{
|
246 |
+
max-width: 1200px!important
|
247 |
+
}
|
248 |
+
|
249 |
+
#svg_render{
|
250 |
+
padding: 20px !important;
|
251 |
+
}
|
252 |
+
|
253 |
+
#svg_code{
|
254 |
+
height: 200px !important;
|
255 |
+
overflow: scroll !important;
|
256 |
+
white-space: unset !important;
|
257 |
+
flex-shrink: unset !important;
|
258 |
+
}
|
259 |
+
|
260 |
+
|
261 |
+
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
|
262 |
+
*{transition: width 0.5s ease, flex-grow 0.5s ease}
|
263 |
+
"""
|
264 |
+
|
265 |
+
def build_demo(embed_mode, concurrency_count=10):
|
266 |
+
with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
|
267 |
+
state = gr.State()
|
268 |
+
if not embed_mode:
|
269 |
+
gr.Markdown(title_markdown)
|
270 |
+
gr.Markdown(sub_title_markdown)
|
271 |
+
with gr.Row():
|
272 |
+
with gr.Column(scale=3):
|
273 |
+
with gr.Row(elem_id="model_selector_row"):
|
274 |
+
model_selector = gr.Dropdown(
|
275 |
+
choices=models,
|
276 |
+
value=models[0] if len(models) > 0 else "",
|
277 |
+
interactive=True,
|
278 |
+
show_label=False,
|
279 |
+
container=False)
|
280 |
+
imagebox = gr.Image(type="pil")
|
281 |
+
image_process_mode = gr.Radio(
|
282 |
+
["Resize", "Pad", "Default"],
|
283 |
+
value="Pad",
|
284 |
+
label="Preprocess for non-square image", visible=False)
|
285 |
+
|
286 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
287 |
+
gr.Examples(examples=[
|
288 |
+
[f"{cur_dir}/examples/sample-4.png"],
|
289 |
+
[f"{cur_dir}/examples/sample-7.png"],
|
290 |
+
[f"{cur_dir}/examples/sample-16.png"],
|
291 |
+
[f"{cur_dir}/examples/sample-17.png"],
|
292 |
+
[f"{cur_dir}/examples/sample-18.png"],
|
293 |
+
[f"{cur_dir}/examples/sample-0.png"],
|
294 |
+
[f"{cur_dir}/examples/sample-1.png"],
|
295 |
+
[f"{cur_dir}/examples/sample-6.png"],
|
296 |
+
], inputs=[imagebox])
|
297 |
+
|
298 |
+
with gr.Column(scale=1, min_width=50):
|
299 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
300 |
+
|
301 |
+
with gr.Accordion("Parameters", open=True) as parameter_row:
|
302 |
+
num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
|
303 |
+
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.05, interactive=True, label="Temperature",)
|
304 |
+
len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
|
305 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, interactive=True, label="Top P",)
|
306 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2000, step=64, interactive=True, label="Max output tokens",)
|
307 |
+
|
308 |
+
with gr.Column(scale=8):
|
309 |
+
with gr.Row():
|
310 |
+
svg_code = gr.Code(label="SVG Code", elem_id='svg_code', min_width=200, interactive=False, lines=5)
|
311 |
+
with gr.Row():
|
312 |
+
gr.Image(width=50, height=256, label="Rendered SVG", elem_id='svg_render')
|
313 |
+
with gr.Row(elem_id="buttons") as button_row:
|
314 |
+
upvote_btn = gr.Button(value="π Upvote", interactive=False)
|
315 |
+
downvote_btn = gr.Button(value="π Downvote", interactive=False)
|
316 |
+
flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
|
317 |
+
stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False, visible=False)
|
318 |
+
regenerate_btn = gr.Button(value="π Regenerate", interactive=False, visible=False)
|
319 |
+
clear_btn = gr.Button(value="ποΈ Clear", interactive=False)
|
320 |
+
|
321 |
+
if not embed_mode:
|
322 |
+
gr.Markdown(tos_markdown)
|
323 |
+
gr.Markdown(learn_more_markdown)
|
324 |
+
url_params = gr.JSON(visible=False)
|
325 |
+
|
326 |
+
# Register listeners
|
327 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn]
|
328 |
+
upvote_btn.click(
|
329 |
+
upvote_last_response,
|
330 |
+
[state, model_selector],
|
331 |
+
[upvote_btn, downvote_btn, flag_btn],
|
332 |
+
queue=False
|
333 |
+
)
|
334 |
+
downvote_btn.click(
|
335 |
+
downvote_last_response,
|
336 |
+
[state, model_selector],
|
337 |
+
[upvote_btn, downvote_btn, flag_btn],
|
338 |
+
queue=False
|
339 |
+
)
|
340 |
+
flag_btn.click(
|
341 |
+
flag_last_response,
|
342 |
+
[state, model_selector],
|
343 |
+
[upvote_btn, downvote_btn, flag_btn],
|
344 |
+
queue=False
|
345 |
+
)
|
346 |
+
|
347 |
+
regenerate_btn.click(
|
348 |
+
regenerate,
|
349 |
+
[state, image_process_mode],
|
350 |
+
[state, svg_code, svg_render, imagebox] + btn_list,
|
351 |
+
queue=False
|
352 |
+
).then(
|
353 |
+
http_bot,
|
354 |
+
[state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
355 |
+
[state, svg_code, svg_render] + btn_list,
|
356 |
+
concurrency_limit=concurrency_count
|
357 |
+
)
|
358 |
+
|
359 |
+
submit_btn.click(
|
360 |
+
send_image,
|
361 |
+
[state, imagebox, image_process_mode],
|
362 |
+
[state, svg_code, svg_render, imagebox] + btn_list,
|
363 |
+
queue=False
|
364 |
+
).then(
|
365 |
+
http_bot,
|
366 |
+
[state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
367 |
+
[state, svg_code, svg_render] + btn_list,
|
368 |
+
concurrency_limit=concurrency_count
|
369 |
+
)
|
370 |
+
|
371 |
+
clear_btn.click(
|
372 |
+
clear_history,
|
373 |
+
None,
|
374 |
+
[state, svg_code, svg_render] + btn_list,
|
375 |
+
queue=False
|
376 |
+
)
|
377 |
+
|
378 |
+
stop_btn.click(
|
379 |
+
stop_sampling,
|
380 |
+
[state, imagebox],
|
381 |
+
[state, imagebox] + btn_list,
|
382 |
+
queue=False
|
383 |
+
).then(
|
384 |
+
clear_history,
|
385 |
+
None,
|
386 |
+
[state, svg_code, svg_render] + btn_list,
|
387 |
+
queue=False
|
388 |
+
)
|
389 |
+
|
390 |
+
if args.model_list_mode == "once":
|
391 |
+
demo.load(
|
392 |
+
load_demo,
|
393 |
+
[url_params],
|
394 |
+
[state, model_selector],
|
395 |
+
_js=get_window_url_params,
|
396 |
+
)
|
397 |
+
elif args.model_list_mode == "reload":
|
398 |
+
demo.load(
|
399 |
+
load_demo_refresh_model_list,
|
400 |
+
None,
|
401 |
+
[state, model_selector],
|
402 |
+
queue=False
|
403 |
+
)
|
404 |
+
else:
|
405 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
406 |
+
|
407 |
+
return demo
|
408 |
+
|
409 |
+
if __name__ == "__main__":
|
410 |
+
parser = argparse.ArgumentParser()
|
411 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
412 |
+
parser.add_argument("--port", type=int)
|
413 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
414 |
+
parser.add_argument("--concurrency-count", type=int, default=15)
|
415 |
+
parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
|
416 |
+
parser.add_argument("--share", action="store_true")
|
417 |
+
parser.add_argument("--moderate", action="store_true")
|
418 |
+
parser.add_argument("--embed", action="store_true")
|
419 |
+
args = parser.parse_args()
|
420 |
+
logger.info(f"args: {args}")
|
421 |
+
|
422 |
+
models = get_model_list()
|
423 |
+
|
424 |
+
logger.info(args)
|
425 |
+
demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
|
426 |
+
demo.queue(
|
427 |
+
api_open=False
|
428 |
+
).launch(
|
429 |
+
server_name=args.host,
|
430 |
+
server_port=args.port,
|
431 |
+
share=args.share
|
432 |
+
)
|
starvector/serve/gradio_web_server.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
from starvector.serve.conversation import default_conversation
|
9 |
+
from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
|
10 |
+
from starvector.serve.util import (build_logger, server_error_msg)
|
11 |
+
|
12 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
13 |
+
headers = {"User-Agent": "StarVector Client"}
|
14 |
+
|
15 |
+
no_change_btn = gr.Button.update()
|
16 |
+
enable_btn = gr.Button.update(interactive=True)
|
17 |
+
disable_btn = gr.Button.update(interactive=False)
|
18 |
+
|
19 |
+
priority = {
|
20 |
+
"starvector-1b-im2svg": "aaaaaaa",
|
21 |
+
}
|
22 |
+
|
23 |
+
def get_conv_log_filename():
|
24 |
+
t = datetime.datetime.now()
|
25 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
26 |
+
return name
|
27 |
+
|
28 |
+
def get_model_list():
|
29 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
30 |
+
assert ret.status_code == 200
|
31 |
+
ret = requests.post(args.controller_url + "/list_models")
|
32 |
+
models = ret.json()["models"]
|
33 |
+
models.sort(key=lambda x: priority.get(x, x))
|
34 |
+
logger.info(f"Models: {models}")
|
35 |
+
return models
|
36 |
+
|
37 |
+
def load_demo(url_params, request: gr.Request):
|
38 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
39 |
+
|
40 |
+
dropdown_update = gr.Dropdown.update(visible=True)
|
41 |
+
if "model" in url_params:
|
42 |
+
model = url_params["model"]
|
43 |
+
if model in models:
|
44 |
+
dropdown_update = gr.Dropdown.update(
|
45 |
+
value=model, visible=True)
|
46 |
+
|
47 |
+
state = default_conversation.copy()
|
48 |
+
return state, dropdown_update
|
49 |
+
|
50 |
+
mapping_model_task = {
|
51 |
+
'Image2SVG': 'im2svg',
|
52 |
+
'Text2SVG': 'text2svg'
|
53 |
+
}
|
54 |
+
|
55 |
+
def get_models_dropdown_from_task(task):
|
56 |
+
models = get_model_list()
|
57 |
+
models = [model for model in models if mapping_model_task[task] in model]
|
58 |
+
dropdown_update = gr.Dropdown.update(
|
59 |
+
choices=models,
|
60 |
+
value=models[0] if len(models) > 0 else ""
|
61 |
+
)
|
62 |
+
return dropdown_update
|
63 |
+
|
64 |
+
|
65 |
+
def load_demo_refresh_model_list(task, request: gr.Request):
|
66 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
67 |
+
dropdown_update = get_models_dropdown_from_task(task)
|
68 |
+
state = default_conversation.copy()
|
69 |
+
return state, dropdown_update
|
70 |
+
|
71 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
72 |
+
with open(get_conv_log_filename(), "a") as fout:
|
73 |
+
data = {
|
74 |
+
"tstamp": round(time.time(), 4),
|
75 |
+
"type": vote_type,
|
76 |
+
"model": model_selector,
|
77 |
+
"state": state.dict(),
|
78 |
+
"ip": request.client.host,
|
79 |
+
}
|
80 |
+
fout.write(json.dumps(data) + "\n")
|
81 |
+
|
82 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
83 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
84 |
+
vote_last_response(state, "upvote", model_selector, request)
|
85 |
+
return ("",) + (disable_btn,) * 7
|
86 |
+
|
87 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
88 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
89 |
+
vote_last_response(state, "downvote", model_selector, request)
|
90 |
+
return ("",) + (disable_btn,) * 7
|
91 |
+
|
92 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
93 |
+
logger.info(f"flag. ip: {request.client.host}")
|
94 |
+
vote_last_response(state, "flag", model_selector, request)
|
95 |
+
return ("",) + (disable_btn,) * 7
|
96 |
+
|
97 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
98 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
99 |
+
state.messages[-1][-1] = None
|
100 |
+
prev_human_msg = state.messages[-2]
|
101 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
102 |
+
prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
|
103 |
+
state.skip_next = False
|
104 |
+
return (state, None, None, None) + (disable_btn,) * 7
|
105 |
+
|
106 |
+
def clear_history(request: gr.Request):
|
107 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
108 |
+
state = default_conversation.copy()
|
109 |
+
return (state, None, None) + (disable_btn,) * 7
|
110 |
+
|
111 |
+
def send_data(state, image, image_process_mode, text_caption, task, request: gr.Request):
|
112 |
+
logger.info(f"send_data. ip: {request.client.host}.")
|
113 |
+
if task == 'Image2SVG':
|
114 |
+
if image is None:
|
115 |
+
state.skip_next = True
|
116 |
+
return (state, None, None, image) + (no_change_btn,) * 7
|
117 |
+
|
118 |
+
if image is not None:
|
119 |
+
image_message = (image, image_process_mode)
|
120 |
+
state.append_message(state.roles[0], image_message)
|
121 |
+
state.append_message(state.roles[1], "β")
|
122 |
+
state.skip_next = False
|
123 |
+
msg = state.to_gradio_svg_code()[0][1]
|
124 |
+
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
|
125 |
+
else:
|
126 |
+
if text_caption is None:
|
127 |
+
state.skip_next = True
|
128 |
+
return (state, None, None, image) + (no_change_btn,) * 7
|
129 |
+
|
130 |
+
state.append_message(state.roles[0], text_caption)
|
131 |
+
state.append_message(state.roles[1], "β")
|
132 |
+
state.skip_next = False
|
133 |
+
msg = state.to_gradio_svg_code()[0][1]
|
134 |
+
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
|
135 |
+
|
136 |
+
def download_files(state, request: gr.Request):
|
137 |
+
logger.info(f"download_files. ip: {request.client.host}")
|
138 |
+
svg_str, image = state.download_files()
|
139 |
+
|
140 |
+
# TODO: Figure out how to download the SVG in the users browser, idk how to do it now
|
141 |
+
|
142 |
+
def update_task(task):
|
143 |
+
dropdown_update = get_models_dropdown_from_task(task)
|
144 |
+
|
145 |
+
if task == "Text2SVG":
|
146 |
+
return 1.0, 0.9, 0.95, dropdown_update
|
147 |
+
else:
|
148 |
+
return 0.6, 0.9, 0.95, dropdown_update
|
149 |
+
|
150 |
+
|
151 |
+
def stop_sampling(state, image, request: gr.Request):
|
152 |
+
logger.info(f"stop_sampling. ip: {request.client.host}")
|
153 |
+
state.stop_sampling = True
|
154 |
+
return (state, None, None, image) + (disable_btn,) * 7
|
155 |
+
|
156 |
+
def http_bot(state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
|
157 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
158 |
+
start_tstamp = time.time()
|
159 |
+
model_name = model_selector
|
160 |
+
|
161 |
+
if state.skip_next:
|
162 |
+
# This generate call is skipped due to invalid inputs
|
163 |
+
yield (state, None, None) + (no_change_btn,) * 7
|
164 |
+
return
|
165 |
+
|
166 |
+
# Query worker address
|
167 |
+
controller_url = args.controller_url
|
168 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
169 |
+
json={"model": model_name})
|
170 |
+
worker_addr = ret.json()["address"]
|
171 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
172 |
+
|
173 |
+
# No available worker
|
174 |
+
if worker_addr == "":
|
175 |
+
state.messages[-1][-1] = server_error_msg
|
176 |
+
yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
177 |
+
return
|
178 |
+
|
179 |
+
# Construct prompt
|
180 |
+
if task_selector == "Image2SVG":
|
181 |
+
prompt = state.get_image_prompt()
|
182 |
+
else:
|
183 |
+
prompt = text_caption
|
184 |
+
|
185 |
+
# Make requests
|
186 |
+
pload = {
|
187 |
+
"model": model_name,
|
188 |
+
"prompt": prompt,
|
189 |
+
"num_beams": int(num_beams),
|
190 |
+
"temperature": float(temperature),
|
191 |
+
"len_penalty": float(len_penalty),
|
192 |
+
"top_p": float(top_p),
|
193 |
+
"max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
|
194 |
+
}
|
195 |
+
logger.info(f"==== request ====\n{pload}")
|
196 |
+
|
197 |
+
pload['images'] = state.get_images()
|
198 |
+
|
199 |
+
state.messages[-1][-1] = "β"
|
200 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
201 |
+
|
202 |
+
try:
|
203 |
+
# Stream output
|
204 |
+
if state.stop_sampling:
|
205 |
+
state.messages[1][-1] = "β"
|
206 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, enable_btn)
|
207 |
+
return
|
208 |
+
|
209 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
210 |
+
headers=headers, json=pload, stream=True, timeout=10)
|
211 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
212 |
+
if chunk:
|
213 |
+
data = json.loads(chunk.decode())
|
214 |
+
if data["error_code"] == 0:
|
215 |
+
# output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered
|
216 |
+
output = data["text"].strip()
|
217 |
+
state.messages[-1][-1] = output + "β"
|
218 |
+
st = state.to_gradio_svg_code()
|
219 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn)
|
220 |
+
else:
|
221 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
222 |
+
state.messages[-1][-1] = output
|
223 |
+
|
224 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
225 |
+
return
|
226 |
+
time.sleep(0.03)
|
227 |
+
except requests.exceptions.RequestException as e:
|
228 |
+
state.messages[-1][-1] = server_error_msg
|
229 |
+
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
230 |
+
return
|
231 |
+
|
232 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7
|
233 |
+
|
234 |
+
finish_tstamp = time.time()
|
235 |
+
logger.info(f"{output}")
|
236 |
+
|
237 |
+
with open(get_conv_log_filename(), "a") as fout:
|
238 |
+
data = {
|
239 |
+
"tstamp": round(finish_tstamp, 4),
|
240 |
+
"type": "chat",
|
241 |
+
"model": model_name,
|
242 |
+
"start": round(start_tstamp, 4),
|
243 |
+
"finish": round(finish_tstamp, 4),
|
244 |
+
"svg": state.messages[-1][-1],
|
245 |
+
"ip": request.client.host,
|
246 |
+
}
|
247 |
+
fout.write(json.dumps(data) + "\n")
|
248 |
+
|
249 |
+
title_markdown = ("""
|
250 |
+
# π« StarVector: Generating Scalable Vector Graphics Code from Images and Text
|
251 |
+
|
252 |
+
[[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | π [[StarVector](https://arxiv.org/abs/2312.11556)]""")
|
253 |
+
|
254 |
+
sub_title_markdown = ("""**How does it work?** Select the task you want to perform, and the model will be automatically set. For **Text2SVG**, introduce a prompt in Text Caption. For **Image2SVG**, select an image and vectorize it. \
|
255 |
+
**Note**: The current model works on vector-like images like icons and or vector-like designs.""")
|
256 |
+
tos_markdown = ("""
|
257 |
+
### Terms of use
|
258 |
+
By using this service, users are required to agree to the following terms:
|
259 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
260 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
261 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
262 |
+
""")
|
263 |
+
|
264 |
+
learn_more_markdown = ("""
|
265 |
+
### License
|
266 |
+
The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
|
267 |
+
""")
|
268 |
+
|
269 |
+
block_css = """
|
270 |
+
|
271 |
+
#buttons button {
|
272 |
+
min-width: min(120px,100%);
|
273 |
+
}
|
274 |
+
|
275 |
+
.gradio-container{
|
276 |
+
max-width: 1200px!important
|
277 |
+
}
|
278 |
+
|
279 |
+
.ΝΌ1 .cm-content {
|
280 |
+
white-space: unset !important;
|
281 |
+
flex-shrink: unset !important;
|
282 |
+
}
|
283 |
+
|
284 |
+
.ΝΌ2p .cm-scroller {
|
285 |
+
max-height: 200px;
|
286 |
+
overflow: scroll;
|
287 |
+
}
|
288 |
+
|
289 |
+
#svg_render{
|
290 |
+
padding: 20px !important;
|
291 |
+
}
|
292 |
+
|
293 |
+
#submit_btn{
|
294 |
+
max-height: 40px;
|
295 |
+
}
|
296 |
+
|
297 |
+
.selector{
|
298 |
+
max-height: 100px;
|
299 |
+
}
|
300 |
+
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
|
301 |
+
*{transition: width 0.5s ease, flex-grow 0.5s ease}
|
302 |
+
"""
|
303 |
+
def build_demo(embed_mode):
|
304 |
+
svg_render = gr.Image(label="Rendered SVG", elem_id='svg_render', height=300)
|
305 |
+
svg_code = gr.Code(label="SVG Code", elem_id='svg_code', interactive=True, lines=5)
|
306 |
+
|
307 |
+
with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
|
308 |
+
state = gr.State()
|
309 |
+
if not embed_mode:
|
310 |
+
gr.Markdown(title_markdown)
|
311 |
+
gr.Markdown(sub_title_markdown)
|
312 |
+
with gr.Row():
|
313 |
+
with gr.Column(scale=4):
|
314 |
+
task_selector = gr.Dropdown(
|
315 |
+
choices=["Image2SVG", "Text2SVG"],
|
316 |
+
value="Image2SVG",
|
317 |
+
label="Task",
|
318 |
+
interactive=True,
|
319 |
+
show_label=True,
|
320 |
+
container=True,
|
321 |
+
elem_id="task_selector",
|
322 |
+
elem_classes=["selector"],
|
323 |
+
)
|
324 |
+
model_selector = gr.Dropdown(
|
325 |
+
choices=models,
|
326 |
+
value=models[0] if len(models) > 0 else "",
|
327 |
+
label="Model",
|
328 |
+
interactive=True,
|
329 |
+
show_label=True,
|
330 |
+
container=True,
|
331 |
+
elem_classes=["selector"],
|
332 |
+
)
|
333 |
+
|
334 |
+
imagebox = gr.Image(type="pil", visible=True, elem_id="imagebox")
|
335 |
+
image_process_mode = gr.Radio(
|
336 |
+
["Resize", "Pad", "Default"],
|
337 |
+
value="Pad",
|
338 |
+
label="Preprocess for non-square image", visible=False)
|
339 |
+
|
340 |
+
# Text input
|
341 |
+
text_caption = gr.Textbox(label="Text Caption", visible=True, value="The icon of a yellow star", elem_id="text_caption")
|
342 |
+
|
343 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
344 |
+
gr.Examples(examples=[
|
345 |
+
[f"{cur_dir}/examples/sample-4.png"],
|
346 |
+
[f"{cur_dir}/examples/sample-7.png"],
|
347 |
+
[f"{cur_dir}/examples/sample-16.png"],
|
348 |
+
[f"{cur_dir}/examples/sample-17.png"],
|
349 |
+
[f"{cur_dir}/examples/sample-18.png"],
|
350 |
+
[f"{cur_dir}/examples/sample-0.png"],
|
351 |
+
[f"{cur_dir}/examples/sample-1.png"],
|
352 |
+
[f"{cur_dir}/examples/sample-6.png"],
|
353 |
+
], inputs=[imagebox], elem_id="examples")
|
354 |
+
|
355 |
+
submit_btn = gr.Button(value="Send", variant="primary", elem_id="submit_btn", interactive=True)
|
356 |
+
|
357 |
+
with gr.Accordion("Parameters", open=False):
|
358 |
+
num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
|
359 |
+
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.9, step=0.05, interactive=True, label="Temperature",)
|
360 |
+
len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
|
361 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
|
362 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=1024, step=64, interactive=True, label="Max output tokens",)
|
363 |
+
|
364 |
+
with gr.Column(scale=9):
|
365 |
+
with gr.Row():
|
366 |
+
svg_code.render()
|
367 |
+
with gr.Row():
|
368 |
+
svg_render.render()
|
369 |
+
|
370 |
+
with gr.Row(elem_id="buttons") as button_row:
|
371 |
+
upvote_btn = gr.Button(value="π Upvote", interactive=False)
|
372 |
+
downvote_btn = gr.Button(value="π Downvote", interactive=False)
|
373 |
+
flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
|
374 |
+
stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False, visible=False)
|
375 |
+
regenerate_btn = gr.Button(value="π Regenerate", interactive=False, visible=False)
|
376 |
+
clear_btn = gr.Button(value="ποΈ Clear", interactive=False)
|
377 |
+
download_btn = gr.Button(value="Download SVG", interactive=False, visible=False)
|
378 |
+
|
379 |
+
if not embed_mode:
|
380 |
+
gr.Markdown(tos_markdown)
|
381 |
+
gr.Markdown(learn_more_markdown)
|
382 |
+
url_params = gr.JSON(visible=False)
|
383 |
+
|
384 |
+
# Register listeners
|
385 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn, download_btn]
|
386 |
+
upvote_btn.click(
|
387 |
+
upvote_last_response,
|
388 |
+
[state, model_selector],
|
389 |
+
[upvote_btn, downvote_btn, flag_btn],
|
390 |
+
queue=False
|
391 |
+
)
|
392 |
+
downvote_btn.click(
|
393 |
+
downvote_last_response,
|
394 |
+
[state, model_selector],
|
395 |
+
[upvote_btn, downvote_btn, flag_btn],
|
396 |
+
queue=False
|
397 |
+
)
|
398 |
+
flag_btn.click(
|
399 |
+
flag_last_response,
|
400 |
+
[state, model_selector],
|
401 |
+
[upvote_btn, downvote_btn, flag_btn],
|
402 |
+
queue=False
|
403 |
+
)
|
404 |
+
|
405 |
+
regenerate_btn.click(
|
406 |
+
regenerate,
|
407 |
+
[state, image_process_mode],
|
408 |
+
[state, svg_code, svg_render, imagebox] + btn_list,
|
409 |
+
queue=False
|
410 |
+
).then(
|
411 |
+
http_bot,
|
412 |
+
[state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
413 |
+
[state, svg_code, svg_render] + btn_list)
|
414 |
+
|
415 |
+
submit_btn.click(
|
416 |
+
send_data,
|
417 |
+
[state, imagebox, image_process_mode, text_caption, task_selector],
|
418 |
+
[state, svg_code, svg_render, imagebox] + btn_list,
|
419 |
+
queue=False
|
420 |
+
).then(
|
421 |
+
http_bot,
|
422 |
+
[state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
423 |
+
[state, svg_code, svg_render] + btn_list
|
424 |
+
)
|
425 |
+
|
426 |
+
clear_btn.click(
|
427 |
+
clear_history,
|
428 |
+
None,
|
429 |
+
[state, svg_code, svg_render] + btn_list,
|
430 |
+
queue=False
|
431 |
+
)
|
432 |
+
|
433 |
+
stop_btn.click(
|
434 |
+
stop_sampling,
|
435 |
+
[state, imagebox],
|
436 |
+
[state, imagebox] + btn_list,
|
437 |
+
queue=False
|
438 |
+
).then(
|
439 |
+
clear_history,
|
440 |
+
None,
|
441 |
+
[state, svg_code, svg_render] + btn_list,
|
442 |
+
queue=False
|
443 |
+
)
|
444 |
+
|
445 |
+
download_btn.click(
|
446 |
+
download_files,
|
447 |
+
[state],
|
448 |
+
None,
|
449 |
+
queue=False
|
450 |
+
)
|
451 |
+
task_selector.change(
|
452 |
+
update_task,
|
453 |
+
inputs=[task_selector],
|
454 |
+
outputs=[len_penalty, temperature, top_p, model_selector],
|
455 |
+
queue=False,
|
456 |
+
_js="""
|
457 |
+
function(task) {
|
458 |
+
var imageBoxElement = document.getElementById("imagebox");
|
459 |
+
var textCaptionElement = document.getElementById("text_caption");
|
460 |
+
var examplesElement = document.getElementById("examples");
|
461 |
+
if (task === "Text2SVG") {
|
462 |
+
imageBoxElement.style.display = "none";
|
463 |
+
textCaptionElement.style.display = "block";
|
464 |
+
examplesElement.style.display = "none";
|
465 |
+
} else if (task === "Image2SVG") {
|
466 |
+
imageBoxElement.style.display = "block";
|
467 |
+
textCaptionElement.style.display = "none";
|
468 |
+
examplesElement.style.display = "block";
|
469 |
+
}
|
470 |
+
return task;
|
471 |
+
}
|
472 |
+
"""
|
473 |
+
)
|
474 |
+
|
475 |
+
if args.model_list_mode == "once":
|
476 |
+
demo.load(
|
477 |
+
load_demo,
|
478 |
+
[url_params, task_selector],
|
479 |
+
[state, model_selector],
|
480 |
+
_js="""
|
481 |
+
function() {
|
482 |
+
const params = new URLSearchParams(window.location.search);
|
483 |
+
url_params = Object.fromEntries(params);
|
484 |
+
console.log(url_params);
|
485 |
+
return url_params;
|
486 |
+
|
487 |
+
}
|
488 |
+
""",
|
489 |
+
queue=False
|
490 |
+
)
|
491 |
+
elif args.model_list_mode == "reload":
|
492 |
+
demo.load(
|
493 |
+
load_demo_refresh_model_list,
|
494 |
+
[task_selector],
|
495 |
+
[state, model_selector],
|
496 |
+
_js="""
|
497 |
+
function(task) {
|
498 |
+
var textCaptionElement = document.getElementById("text_caption");
|
499 |
+
var autoScrollBottom = true;
|
500 |
+
textCaptionElement.style.display = "none";
|
501 |
+
function updateScroll(){
|
502 |
+
if (autoScrollBottom) {
|
503 |
+
var element = document.getElementsByClassName("cm-scroller")[0];
|
504 |
+
element.scrollTop = element.scrollHeight;
|
505 |
+
}
|
506 |
+
}
|
507 |
+
function handleScroll() {
|
508 |
+
var element = document.getElementsByClassName("cm-scroller")[0];
|
509 |
+
//if (element.scrollHeight - element.scrollTop === element.clientHeight) {
|
510 |
+
if (element.scrollHeight - (element.scrollTop + element.clientHeight) < 0.2*(element.scrollTop)) {
|
511 |
+
// User has scrolled to the bottom, enable auto-scrolling
|
512 |
+
autoScrollBottom = true;
|
513 |
+
console.log("bottom");
|
514 |
+
} else {
|
515 |
+
console.log("not bottom");
|
516 |
+
// User has scrolled away from the bottom, disable auto-scrolling
|
517 |
+
autoScrollBottom = false;
|
518 |
+
}
|
519 |
+
}
|
520 |
+
setInterval(updateScroll,500);
|
521 |
+
var element = document.getElementsByClassName("cm-scroller")[0];
|
522 |
+
element.addEventListener("scroll", handleScroll);
|
523 |
+
|
524 |
+
return task;
|
525 |
+
}
|
526 |
+
|
527 |
+
""",
|
528 |
+
queue=False,
|
529 |
+
)
|
530 |
+
|
531 |
+
else:
|
532 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
533 |
+
|
534 |
+
return demo
|
535 |
+
|
536 |
+
if __name__ == "__main__":
|
537 |
+
|
538 |
+
parser = argparse.ArgumentParser()
|
539 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
540 |
+
parser.add_argument("--port", type=int)
|
541 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
542 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
543 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
544 |
+
choices=["once", "reload"])
|
545 |
+
parser.add_argument("--share", action="store_true")
|
546 |
+
parser.add_argument("--moderate", action="store_true")
|
547 |
+
parser.add_argument("--embed", action="store_true")
|
548 |
+
args = parser.parse_args()
|
549 |
+
logger.info(f"args: {args}")
|
550 |
+
|
551 |
+
models = get_model_list()
|
552 |
+
|
553 |
+
logger.info(args)
|
554 |
+
demo = build_demo(args.embed)
|
555 |
+
demo.queue(
|
556 |
+
concurrency_count=args.concurrency_count,
|
557 |
+
api_open=False
|
558 |
+
).launch(
|
559 |
+
server_name=args.host,
|
560 |
+
server_port=args.port,
|
561 |
+
share=args.share
|
562 |
+
)
|
starvector/serve/model_worker.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
import uuid
|
10 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
11 |
+
from fastapi.responses import StreamingResponse
|
12 |
+
import requests
|
13 |
+
import torch
|
14 |
+
import uvicorn
|
15 |
+
from functools import partial
|
16 |
+
from starvector.serve.constants import WORKER_HEART_BEAT_INTERVAL, CLIP_QUERY_LENGTH
|
17 |
+
from starvector.serve.util import (build_logger, server_error_msg,
|
18 |
+
pretty_print_semaphore)
|
19 |
+
from starvector.model.builder import load_pretrained_model
|
20 |
+
from starvector.serve.util import process_images, load_image_from_base64
|
21 |
+
from threading import Thread
|
22 |
+
from transformers import TextIteratorStreamer
|
23 |
+
|
24 |
+
GB = 1 << 30
|
25 |
+
|
26 |
+
worker_id = str(uuid.uuid4())[:6]
|
27 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
28 |
+
global_counter = 0
|
29 |
+
model_semaphore = None
|
30 |
+
|
31 |
+
def heart_beat_worker(controller):
|
32 |
+
while True:
|
33 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
34 |
+
controller.send_heart_beat()
|
35 |
+
|
36 |
+
class ModelWorker:
|
37 |
+
def __init__(self, controller_addr, worker_addr,
|
38 |
+
worker_id, no_register,
|
39 |
+
model_path, model_base, model_name,
|
40 |
+
load_8bit, load_4bit, device):
|
41 |
+
self.controller_addr = controller_addr
|
42 |
+
self.worker_addr = worker_addr
|
43 |
+
self.worker_id = worker_id
|
44 |
+
if model_path.endswith("/"):
|
45 |
+
model_path = model_path[:-1]
|
46 |
+
if model_name is None:
|
47 |
+
model_paths = model_path.split("/")
|
48 |
+
if model_paths[-1].startswith('checkpoint-'):
|
49 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
50 |
+
else:
|
51 |
+
self.model_name = model_paths[-1]
|
52 |
+
else:
|
53 |
+
self.model_name = model_name
|
54 |
+
|
55 |
+
if "text2svg" in self.model_name.lower():
|
56 |
+
self.task = "Text2SVG"
|
57 |
+
elif "im2svg" in self.model_name.lower():
|
58 |
+
self.task = "Image2SVG"
|
59 |
+
|
60 |
+
self.device = device
|
61 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
62 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
63 |
+
model_path, device=self.device, load_in_8bit=load_8bit, load_in_4bit=load_4bit)
|
64 |
+
self.model.to(torch.bfloat16)
|
65 |
+
self.is_multimodal = 'starvector' in self.model_name.lower()
|
66 |
+
|
67 |
+
if not no_register:
|
68 |
+
self.register_to_controller()
|
69 |
+
self.heart_beat_thread = threading.Thread(
|
70 |
+
target=heart_beat_worker, args=(self,))
|
71 |
+
self.heart_beat_thread.start()
|
72 |
+
|
73 |
+
def register_to_controller(self):
|
74 |
+
logger.info("Register to controller")
|
75 |
+
|
76 |
+
url = self.controller_addr + "/register_worker"
|
77 |
+
data = {
|
78 |
+
"worker_name": self.worker_addr,
|
79 |
+
"check_heart_beat": True,
|
80 |
+
"worker_status": self.get_status()
|
81 |
+
}
|
82 |
+
r = requests.post(url, json=data)
|
83 |
+
assert r.status_code == 200
|
84 |
+
|
85 |
+
def send_heart_beat(self):
|
86 |
+
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
87 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
88 |
+
f"global_counter: {global_counter}")
|
89 |
+
|
90 |
+
url = self.controller_addr + "/receive_heart_beat"
|
91 |
+
|
92 |
+
while True:
|
93 |
+
try:
|
94 |
+
ret = requests.post(url, json={
|
95 |
+
"worker_name": self.worker_addr,
|
96 |
+
"queue_length": self.get_queue_length()}, timeout=5)
|
97 |
+
exist = ret.json()["exist"]
|
98 |
+
break
|
99 |
+
except requests.exceptions.RequestException as e:
|
100 |
+
logger.error(f"heart beat error: {e}")
|
101 |
+
time.sleep(5)
|
102 |
+
|
103 |
+
if not exist:
|
104 |
+
self.register_to_controller()
|
105 |
+
|
106 |
+
def get_queue_length(self):
|
107 |
+
if model_semaphore is None:
|
108 |
+
return 0
|
109 |
+
else:
|
110 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
111 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
112 |
+
|
113 |
+
def get_status(self):
|
114 |
+
return {
|
115 |
+
"model_names": [self.model_name],
|
116 |
+
"speed": 1,
|
117 |
+
"queue_length": self.get_queue_length(),
|
118 |
+
}
|
119 |
+
|
120 |
+
@torch.inference_mode()
|
121 |
+
def generate_stream(self, params):
|
122 |
+
tokenizer, model, image_processor, task = self.tokenizer, self.model, self.image_processor, self.task
|
123 |
+
|
124 |
+
num_beams = int(params.get("num_beams", 1))
|
125 |
+
temperature = float(params.get("temperature", 1.0))
|
126 |
+
len_penalty = float(params.get("len_penalty", 1.0))
|
127 |
+
top_p = float(params.get("top_p", 1.0))
|
128 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 8192)
|
129 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15)
|
130 |
+
prompt = params["prompt"]
|
131 |
+
|
132 |
+
if task == "Image2SVG":
|
133 |
+
images = params.get("images", None)
|
134 |
+
for b64_image in images:
|
135 |
+
if b64_image is not None and self.is_multimodal:
|
136 |
+
image = load_image_from_base64(b64_image)
|
137 |
+
image = process_images(image, image_processor)
|
138 |
+
image = image.to(self.model.device, dtype=torch.float16)
|
139 |
+
else:
|
140 |
+
image = None
|
141 |
+
|
142 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192)
|
143 |
+
max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH)
|
144 |
+
pre_pend = prompt
|
145 |
+
batch = {}
|
146 |
+
batch["image"] = image
|
147 |
+
generate_method = model.model.generate_im2svg
|
148 |
+
else:
|
149 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 128)), 8192)
|
150 |
+
pre_pend = ""
|
151 |
+
batch = {}
|
152 |
+
batch['caption'] = [prompt]
|
153 |
+
# White PIL image
|
154 |
+
batch['image'] = torch.zeros((3, 256, 256), dtype=torch.float16).to(self.model.device)
|
155 |
+
generate_method = model.model.generate_text2svg
|
156 |
+
|
157 |
+
if max_new_tokens < 1:
|
158 |
+
yield json.dumps({"text": prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
159 |
+
return
|
160 |
+
|
161 |
+
thread = Thread(target=generate_method, kwargs=dict(
|
162 |
+
batch=batch,
|
163 |
+
prompt=prompt,
|
164 |
+
use_nucleus_sampling=True,
|
165 |
+
num_beams=num_beams,
|
166 |
+
temperature=temperature,
|
167 |
+
length_penalty=len_penalty,
|
168 |
+
top_p=top_p,
|
169 |
+
max_length=max_new_tokens,
|
170 |
+
streamer=streamer,
|
171 |
+
))
|
172 |
+
thread.start()
|
173 |
+
|
174 |
+
generated_text = pre_pend
|
175 |
+
for new_text in streamer:
|
176 |
+
if new_text == " ":
|
177 |
+
continue
|
178 |
+
generated_text += new_text
|
179 |
+
# if generated_text.endswith(stop_str):
|
180 |
+
# generated_text = generated_text[:-len(stop_str)]
|
181 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
182 |
+
|
183 |
+
def generate_stream_gate(self, params):
|
184 |
+
try:
|
185 |
+
for x in self.generate_stream(params):
|
186 |
+
yield x
|
187 |
+
except ValueError as e:
|
188 |
+
print("Caught ValueError:", e)
|
189 |
+
ret = {
|
190 |
+
"text": server_error_msg,
|
191 |
+
"error_code": 1,
|
192 |
+
}
|
193 |
+
yield json.dumps(ret).encode() + b"\0"
|
194 |
+
except torch.cuda.CudaError as e:
|
195 |
+
print("Caught torch.cuda.CudaError:", e)
|
196 |
+
ret = {
|
197 |
+
"text": server_error_msg,
|
198 |
+
"error_code": 1,
|
199 |
+
}
|
200 |
+
yield json.dumps(ret).encode() + b"\0"
|
201 |
+
except Exception as e:
|
202 |
+
print("Caught Unknown Error", e)
|
203 |
+
ret = {
|
204 |
+
"text": server_error_msg,
|
205 |
+
"error_code": 1,
|
206 |
+
}
|
207 |
+
yield json.dumps(ret).encode() + b"\0"
|
208 |
+
|
209 |
+
app = FastAPI()
|
210 |
+
|
211 |
+
def release_model_semaphore(fn=None):
|
212 |
+
model_semaphore.release()
|
213 |
+
if fn is not None:
|
214 |
+
fn()
|
215 |
+
|
216 |
+
@app.post("/worker_generate_stream")
|
217 |
+
async def generate_stream(request: Request):
|
218 |
+
global model_semaphore, global_counter
|
219 |
+
global_counter += 1
|
220 |
+
params = await request.json()
|
221 |
+
|
222 |
+
if model_semaphore is None:
|
223 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
224 |
+
await model_semaphore.acquire()
|
225 |
+
worker.send_heart_beat()
|
226 |
+
generator = worker.generate_stream_gate(params)
|
227 |
+
background_tasks = BackgroundTasks()
|
228 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
229 |
+
return StreamingResponse(generator, background=background_tasks)
|
230 |
+
|
231 |
+
@app.post("/worker_get_status")
|
232 |
+
async def get_status(request: Request):
|
233 |
+
return worker.get_status()
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
parser = argparse.ArgumentParser()
|
237 |
+
parser.add_argument("--host", type=str, default="localhost")
|
238 |
+
parser.add_argument("--port", type=int, default=21002)
|
239 |
+
parser.add_argument("--worker-address", type=str,
|
240 |
+
default="http://localhost:21002")
|
241 |
+
parser.add_argument("--controller-address", type=str,
|
242 |
+
default="http://localhost:21001")
|
243 |
+
parser.add_argument("--model-path", type=str, default="joanrodai/starvector-1.4b")
|
244 |
+
parser.add_argument("--model-base", type=str, default=None)
|
245 |
+
parser.add_argument("--model-name", type=str)
|
246 |
+
parser.add_argument("--device", type=str, default="cuda")
|
247 |
+
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.")
|
248 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
249 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
250 |
+
parser.add_argument("--no-register", action="store_true")
|
251 |
+
parser.add_argument("--load-8bit", action="store_true")
|
252 |
+
parser.add_argument("--load-4bit", action="store_true")
|
253 |
+
args = parser.parse_args()
|
254 |
+
logger.info(f"args: {args}")
|
255 |
+
|
256 |
+
if args.multi_modal:
|
257 |
+
logger.warning("Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.")
|
258 |
+
|
259 |
+
worker = ModelWorker(args.controller_address,
|
260 |
+
args.worker_address,
|
261 |
+
worker_id,
|
262 |
+
args.no_register,
|
263 |
+
args.model_path,
|
264 |
+
args.model_base,
|
265 |
+
args.model_name,
|
266 |
+
args.load_8bit,
|
267 |
+
args.load_4bit,
|
268 |
+
args.device)
|
269 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
starvector/serve/vllm_api_gradio/__pycache__/controller.cpython-311.pyc
CHANGED
Binary files a/starvector/serve/vllm_api_gradio/__pycache__/controller.cpython-311.pyc and b/starvector/serve/vllm_api_gradio/__pycache__/controller.cpython-311.pyc differ
|
|
starvector/serve/vllm_api_gradio/__pycache__/gradio_web_server.cpython-311.pyc
CHANGED
Binary files a/starvector/serve/vllm_api_gradio/__pycache__/gradio_web_server.cpython-311.pyc and b/starvector/serve/vllm_api_gradio/__pycache__/gradio_web_server.cpython-311.pyc differ
|
|
starvector/serve/vllm_api_gradio/__pycache__/model_worker.cpython-311.pyc
CHANGED
Binary files a/starvector/serve/vllm_api_gradio/__pycache__/model_worker.cpython-311.pyc and b/starvector/serve/vllm_api_gradio/__pycache__/model_worker.cpython-311.pyc differ
|
|
starvector/serve/vllm_api_gradio/gradio_web_server.py
CHANGED
@@ -204,7 +204,6 @@ def http_bot(state, task_selector, text_caption, model_selector, num_beams, temp
|
|
204 |
|
205 |
state.messages[-1][-1] = "β"
|
206 |
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
207 |
-
|
208 |
try:
|
209 |
# Stream output
|
210 |
if state.stop_sampling:
|
@@ -214,23 +213,33 @@ def http_bot(state, task_selector, text_caption, model_selector, num_beams, temp
|
|
214 |
|
215 |
response = requests.post(worker_addr + "/worker_generate_stream",
|
216 |
headers=headers, json=pload, stream=True, timeout=10)
|
|
|
|
|
|
|
|
|
217 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
218 |
if chunk:
|
219 |
data = json.loads(chunk.decode())
|
220 |
if data["error_code"] == 0:
|
221 |
-
# output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered
|
222 |
output = data["text"].strip()
|
223 |
state.messages[-1][-1] = output + "β"
|
224 |
-
|
225 |
-
#
|
226 |
-
|
|
|
|
|
|
|
|
|
227 |
else:
|
|
|
228 |
output = data["text"] + f" (error_code: {data['error_code']})"
|
229 |
state.messages[-1][-1] = output
|
230 |
st = state.to_gradio_svg_code()
|
231 |
-
|
232 |
-
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
233 |
return
|
|
|
|
|
|
|
234 |
except requests.exceptions.RequestException as e:
|
235 |
state.messages[-1][-1] = server_error_msg
|
236 |
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
@@ -576,7 +585,7 @@ def build_demo(embed_mode):
|
|
576 |
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.05, interactive=True, label="Temperature",)
|
577 |
len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, interactive=True, label="Length Penalty",)
|
578 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
|
579 |
-
max_output_tokens = gr.Slider(minimum=0, maximum=
|
580 |
|
581 |
with gr.Column(scale=9):
|
582 |
with gr.Row():
|
|
|
204 |
|
205 |
state.messages[-1][-1] = "β"
|
206 |
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
|
|
207 |
try:
|
208 |
# Stream output
|
209 |
if state.stop_sampling:
|
|
|
213 |
|
214 |
response = requests.post(worker_addr + "/worker_generate_stream",
|
215 |
headers=headers, json=pload, stream=True, timeout=10)
|
216 |
+
|
217 |
+
update_interval = 2 # seconds
|
218 |
+
last_update_time = time.time()
|
219 |
+
|
220 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
221 |
if chunk:
|
222 |
data = json.loads(chunk.decode())
|
223 |
if data["error_code"] == 0:
|
|
|
224 |
output = data["text"].strip()
|
225 |
state.messages[-1][-1] = output + "β"
|
226 |
+
|
227 |
+
# Only update if sufficient time has passed
|
228 |
+
current_time = time.time()
|
229 |
+
if current_time - last_update_time >= update_interval:
|
230 |
+
st = state.to_gradio_svg_code()
|
231 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn,) * 7
|
232 |
+
last_update_time = current_time
|
233 |
else:
|
234 |
+
# handle errors and yield immediately if needed
|
235 |
output = data["text"] + f" (error_code: {data['error_code']})"
|
236 |
state.messages[-1][-1] = output
|
237 |
st = state.to_gradio_svg_code()
|
238 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn,) * 7
|
|
|
239 |
return
|
240 |
+
# Final yield to ensure the last state is rendered
|
241 |
+
st = state.to_gradio_svg_code()
|
242 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (enable_btn,) * 7
|
243 |
except requests.exceptions.RequestException as e:
|
244 |
state.messages[-1][-1] = server_error_msg
|
245 |
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
|
|
585 |
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.05, interactive=True, label="Temperature",)
|
586 |
len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, interactive=True, label="Length Penalty",)
|
587 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
|
588 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2048, step=64, interactive=True, label="Max output tokens",)
|
589 |
|
590 |
with gr.Column(scale=9):
|
591 |
with gr.Row():
|
starvector/serve/vllm_api_gradio/model_worker.py
CHANGED
@@ -117,7 +117,7 @@ class ModelWorker:
|
|
117 |
temperature = float(params.get("temperature", 1.0))
|
118 |
len_penalty = float(params.get("len_penalty", 1.0))
|
119 |
top_p = float(params.get("top_p", 1.0))
|
120 |
-
max_context_length =
|
121 |
|
122 |
# prompt = params["prompt"]
|
123 |
prompt = "<svg "
|
@@ -132,6 +132,8 @@ class ModelWorker:
|
|
132 |
|
133 |
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192)
|
134 |
max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH)
|
|
|
|
|
135 |
|
136 |
# Use the chat completions endpoint
|
137 |
vllm_endpoint = f"{self.vllm_base_url}/v1/chat/completions"
|
|
|
117 |
temperature = float(params.get("temperature", 1.0))
|
118 |
len_penalty = float(params.get("len_penalty", 1.0))
|
119 |
top_p = float(params.get("top_p", 1.0))
|
120 |
+
max_context_length = 8192
|
121 |
|
122 |
# prompt = params["prompt"]
|
123 |
prompt = "<svg "
|
|
|
132 |
|
133 |
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192)
|
134 |
max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH)
|
135 |
+
# log max new token
|
136 |
+
logger.info(f"max_new_tokens: {max_new_tokens}")
|
137 |
|
138 |
# Use the chat completions endpoint
|
139 |
vllm_endpoint = f"{self.vllm_base_url}/v1/chat/completions"
|