hz2475 commited on
Commit
8d3de58
Β·
1 Parent(s): 105ac3b
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
configs/.DS_Store CHANGED
Binary files a/configs/.DS_Store and b/configs/.DS_Store differ
 
controller.log CHANGED
@@ -29,3 +29,11 @@
29
  2025-03-23 15:04:32 | ERROR | stderr | INFO: Waiting for application startup.
30
  2025-03-23 15:04:32 | ERROR | stderr | INFO: Application startup complete.
31
  2025-03-23 15:04:32 | ERROR | stderr | INFO: Uvicorn running on http://0.0.0.0:10000 (Press CTRL+C to quit)
 
 
 
 
 
 
 
 
 
29
  2025-03-23 15:04:32 | ERROR | stderr | INFO: Waiting for application startup.
30
  2025-03-23 15:04:32 | ERROR | stderr | INFO: Application startup complete.
31
  2025-03-23 15:04:32 | ERROR | stderr | INFO: Uvicorn running on http://0.0.0.0:10000 (Press CTRL+C to quit)
32
+ 2025-03-24 14:06:11 | INFO | controller | args: Namespace(host='0.0.0.0', port=10000, dispatch_method='shortest_queue')
33
+ 2025-03-24 14:06:11 | INFO | controller | Init controller
34
+ 2025-03-24 14:06:11 | ERROR | stderr | INFO: Started server process [95537]
35
+ 2025-03-24 14:06:11 | ERROR | stderr | INFO: Waiting for application startup.
36
+ 2025-03-24 14:06:11 | ERROR | stderr | INFO: Application startup complete.
37
+ 2025-03-24 14:06:11 | ERROR | stderr | ERROR: [Errno 48] error while attempting to bind on address ('0.0.0.0', 10000): address already in use
38
+ 2025-03-24 14:06:11 | ERROR | stderr | INFO: Waiting for application shutdown.
39
+ 2025-03-24 14:06:11 | ERROR | stderr | INFO: Application shutdown complete.
model_worker_ad9563.log DELETED
@@ -1,17 +0,0 @@
1
- 2025-03-23 15:01:04 | INFO | model_worker | args: Namespace(host='0.0.0.0', port=40000, worker_address='http://localhost:40000', controller_address='http://localhost:10000', model_name='/home/agent_h/data/starvector-1b-im2svg', multi_modal=False, limit_model_concurrency=5, stream_interval=1, no_register=False, openai_api_key='EMPTY', vllm_base_url='http://localhost:8000')
2
- 2025-03-23 15:01:04 | INFO | model_worker | Loading the model /home/agent_h/data/starvector-1b-im2svg on worker ad9563 ...
3
- 2025-03-23 15:01:04 | INFO | model_worker | Register to controller
4
- 2025-03-23 15:01:04 | ERROR | stderr | INFO: Started server process [48407]
5
- 2025-03-23 15:01:04 | ERROR | stderr | INFO: Waiting for application startup.
6
- 2025-03-23 15:01:04 | ERROR | stderr | INFO: Application startup complete.
7
- 2025-03-23 15:01:04 | ERROR | stderr | INFO: Uvicorn running on http://0.0.0.0:40000 (Press CTRL+C to quit)
8
- 2025-03-23 15:01:19 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
9
- 2025-03-23 15:01:34 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
10
- 2025-03-23 15:01:49 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
11
- 2025-03-23 15:02:04 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
12
- 2025-03-23 15:02:19 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
13
- 2025-03-23 15:02:34 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0
14
- 2025-03-23 15:02:45 | ERROR | stderr | INFO: Shutting down
15
- 2025-03-23 15:02:45 | ERROR | stderr | INFO: Waiting for application shutdown.
16
- 2025-03-23 15:02:45 | ERROR | stderr | INFO: Application shutdown complete.
17
- 2025-03-23 15:02:45 | ERROR | stderr | INFO: Finished server process [48407]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
star-vector-dev/.DS_Store ADDED
Binary file (6.15 kB). View file
 
star-vector-dev/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
star-vector-dev/.gitignore ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # Other
163
+ *vscode*
164
+ *egg*
165
+ *nfs*
166
+ *conv.json*
167
+ *rebuttal*
168
+ *.log*
169
+ *remove_files*
170
+ *wandb*
171
+ *tmp*
172
+ *vscode*
173
+ *.csv
174
+ *avoid_samples*
175
+ *logs*
176
+ *results*
177
+ *.pickle
178
+ *.pkl
179
+ *internal*
180
+ *test.png*
181
+ assets/reward_assets
start.sh CHANGED
@@ -2,7 +2,6 @@
2
 
3
  bash -c "$SSH_TUNNEL_CMD_1" &
4
 
5
- echo "SSH tunnel started, PID: $SSH_PID"
6
  python -m starvector.serve.vllm_api_gradio.controller --host 0.0.0.0 --port 10000 &
7
  python -m starvector.serve.vllm_api_gradio.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-name /home/agent_h/data/starvector-1b-im2svg --vllm-base-url http://localhost:8000 &
8
  python -m starvector.serve.vllm_api_gradio.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 7860
 
2
 
3
  bash -c "$SSH_TUNNEL_CMD_1" &
4
 
 
5
  python -m starvector.serve.vllm_api_gradio.controller --host 0.0.0.0 --port 10000 &
6
  python -m starvector.serve.vllm_api_gradio.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-name /home/agent_h/data/starvector-1b-im2svg --vllm-base-url http://localhost:8000 &
7
  python -m starvector.serve.vllm_api_gradio.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 7860
starvector/.DS_Store CHANGED
Binary files a/starvector/.DS_Store and b/starvector/.DS_Store differ
 
starvector/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/starvector/__pycache__/__init__.cpython-311.pyc and b/starvector/__pycache__/__init__.cpython-311.pyc differ
 
starvector/serve/.DS_Store CHANGED
Binary files a/starvector/serve/.DS_Store and b/starvector/serve/.DS_Store differ
 
starvector/serve/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/starvector/serve/__pycache__/__init__.cpython-311.pyc and b/starvector/serve/__pycache__/__init__.cpython-311.pyc differ
 
starvector/serve/__pycache__/constants.cpython-311.pyc CHANGED
Binary files a/starvector/serve/__pycache__/constants.cpython-311.pyc and b/starvector/serve/__pycache__/constants.cpython-311.pyc differ
 
starvector/serve/__pycache__/conversation.cpython-311.pyc CHANGED
Binary files a/starvector/serve/__pycache__/conversation.cpython-311.pyc and b/starvector/serve/__pycache__/conversation.cpython-311.pyc differ
 
starvector/serve/__pycache__/util.cpython-311.pyc CHANGED
Binary files a/starvector/serve/__pycache__/util.cpython-311.pyc and b/starvector/serve/__pycache__/util.cpython-311.pyc differ
 
starvector/serve/controller.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from starvector.serve.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from starvector.serve.util import build_logger, server_error_msg
23
+
24
+ logger = build_logger("controller", "controller.log")
25
+
26
+ class DispatchMethod(Enum):
27
+ LOTTERY = auto()
28
+ SHORTEST_QUEUE = auto()
29
+
30
+ @classmethod
31
+ def from_str(cls, name):
32
+ if name == "lottery":
33
+ return cls.LOTTERY
34
+ elif name == "shortest_queue":
35
+ return cls.SHORTEST_QUEUE
36
+ else:
37
+ raise ValueError(f"Invalid dispatch method")
38
+
39
+
40
+ @dataclasses.dataclass
41
+ class WorkerInfo:
42
+ model_names: List[str]
43
+ speed: int
44
+ queue_length: int
45
+ check_heart_beat: bool
46
+ last_heart_beat: str
47
+
48
+
49
+ def heart_beat_controller(controller):
50
+ while True:
51
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
52
+ controller.remove_stable_workers_by_expiration()
53
+
54
+
55
+ class Controller:
56
+ def __init__(self, dispatch_method: str):
57
+ # Dict[str -> WorkerInfo]
58
+ self.worker_info = {}
59
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
60
+
61
+ self.heart_beat_thread = threading.Thread(
62
+ target=heart_beat_controller, args=(self,))
63
+ self.heart_beat_thread.start()
64
+
65
+ logger.info("Init controller")
66
+
67
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
68
+ worker_status: dict):
69
+ if worker_name not in self.worker_info:
70
+ logger.info(f"Register a new worker: {worker_name}")
71
+ else:
72
+ logger.info(f"Register an existing worker: {worker_name}")
73
+
74
+ if not worker_status:
75
+ worker_status = self.get_worker_status(worker_name)
76
+ if not worker_status:
77
+ return False
78
+
79
+ self.worker_info[worker_name] = WorkerInfo(
80
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
81
+ check_heart_beat, time.time())
82
+
83
+ logger.info(f"Register done: {worker_name}, {worker_status}")
84
+ return True
85
+
86
+ def get_worker_status(self, worker_name: str):
87
+ try:
88
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
89
+ except requests.exceptions.RequestException as e:
90
+ logger.error(f"Get status fails: {worker_name}, {e}")
91
+ return None
92
+
93
+ if r.status_code != 200:
94
+ logger.error(f"Get status fails: {worker_name}, {r}")
95
+ return None
96
+
97
+ return r.json()
98
+
99
+ def remove_worker(self, worker_name: str):
100
+ del self.worker_info[worker_name]
101
+
102
+ def refresh_all_workers(self):
103
+ old_info = dict(self.worker_info)
104
+ self.worker_info = {}
105
+
106
+ for w_name, w_info in old_info.items():
107
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
108
+ logger.info(f"Remove stale worker: {w_name}")
109
+
110
+ def list_models(self):
111
+ model_names = set()
112
+
113
+ for w_name, w_info in self.worker_info.items():
114
+ model_names.update(w_info.model_names)
115
+
116
+ return list(model_names)
117
+
118
+ def get_worker_address(self, model_name: str):
119
+ if self.dispatch_method == DispatchMethod.LOTTERY:
120
+ worker_names = []
121
+ worker_speeds = []
122
+ for w_name, w_info in self.worker_info.items():
123
+ if model_name in w_info.model_names:
124
+ worker_names.append(w_name)
125
+ worker_speeds.append(w_info.speed)
126
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
127
+ norm = np.sum(worker_speeds)
128
+ if norm < 1e-4:
129
+ return ""
130
+ worker_speeds = worker_speeds / norm
131
+ if True: # Directly return address
132
+ pt = np.random.choice(np.arange(len(worker_names)),
133
+ p=worker_speeds)
134
+ worker_name = worker_names[pt]
135
+ return worker_name
136
+
137
+ # Check status before returning
138
+ while True:
139
+ pt = np.random.choice(np.arange(len(worker_names)),
140
+ p=worker_speeds)
141
+ worker_name = worker_names[pt]
142
+
143
+ if self.get_worker_status(worker_name):
144
+ break
145
+ else:
146
+ self.remove_worker(worker_name)
147
+ worker_speeds[pt] = 0
148
+ norm = np.sum(worker_speeds)
149
+ if norm < 1e-4:
150
+ return ""
151
+ worker_speeds = worker_speeds / norm
152
+ continue
153
+ return worker_name
154
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
155
+ worker_names = []
156
+ worker_qlen = []
157
+ for w_name, w_info in self.worker_info.items():
158
+ if model_name in w_info.model_names:
159
+ worker_names.append(w_name)
160
+ worker_qlen.append(w_info.queue_length / w_info.speed)
161
+ if len(worker_names) == 0:
162
+ return ""
163
+ min_index = np.argmin(worker_qlen)
164
+ w_name = worker_names[min_index]
165
+ self.worker_info[w_name].queue_length += 1
166
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
167
+ return w_name
168
+ else:
169
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
170
+
171
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
172
+ if worker_name not in self.worker_info:
173
+ logger.info(f"Receive unknown heart beat. {worker_name}")
174
+ return False
175
+
176
+ self.worker_info[worker_name].queue_length = queue_length
177
+ self.worker_info[worker_name].last_heart_beat = time.time()
178
+ logger.info(f"Receive heart beat. {worker_name}")
179
+ return True
180
+
181
+ def remove_stable_workers_by_expiration(self):
182
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
183
+ to_delete = []
184
+ for worker_name, w_info in self.worker_info.items():
185
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
186
+ to_delete.append(worker_name)
187
+
188
+ for worker_name in to_delete:
189
+ self.remove_worker(worker_name)
190
+
191
+ def worker_api_generate_stream(self, params):
192
+ worker_addr = self.get_worker_address(params["model"])
193
+ if not worker_addr:
194
+ logger.info(f"no worker: {params['model']}")
195
+ ret = {
196
+ "text": server_error_msg,
197
+ "error_code": 2,
198
+ }
199
+ yield json.dumps(ret).encode() + b"\0"
200
+
201
+ try:
202
+ response = requests.post(worker_addr + "/worker_generate_stream",
203
+ json=params, stream=True, timeout=5)
204
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
205
+ if chunk:
206
+ yield chunk + b"\0"
207
+ except requests.exceptions.RequestException as e:
208
+ logger.info(f"worker timeout: {worker_addr}")
209
+ ret = {
210
+ "text": server_error_msg,
211
+ "error_code": 3,
212
+ }
213
+ yield json.dumps(ret).encode() + b"\0"
214
+
215
+
216
+ # Let the controller act as a worker to achieve hierarchical
217
+ # management. This can be used to connect isolated sub networks.
218
+ def worker_api_get_status(self):
219
+ model_names = set()
220
+ speed = 0
221
+ queue_length = 0
222
+
223
+ for w_name in self.worker_info:
224
+ worker_status = self.get_worker_status(w_name)
225
+ if worker_status is not None:
226
+ model_names.update(worker_status["model_names"])
227
+ speed += worker_status["speed"]
228
+ queue_length += worker_status["queue_length"]
229
+
230
+ return {
231
+ "model_names": list(model_names),
232
+ "speed": speed,
233
+ "queue_length": queue_length,
234
+ }
235
+
236
+
237
+ app = FastAPI()
238
+
239
+ @app.post("/register_worker")
240
+ async def register_worker(request: Request):
241
+ data = await request.json()
242
+ controller.register_worker(
243
+ data["worker_name"], data["check_heart_beat"],
244
+ data.get("worker_status", None))
245
+
246
+ @app.post("/refresh_all_workers")
247
+ async def refresh_all_workers():
248
+ models = controller.refresh_all_workers()
249
+
250
+
251
+ @app.post("/list_models")
252
+ async def list_models():
253
+ models = controller.list_models()
254
+ return {"models": models}
255
+
256
+
257
+ @app.post("/get_worker_address")
258
+ async def get_worker_address(request: Request):
259
+ data = await request.json()
260
+ addr = controller.get_worker_address(data["model"])
261
+ return {"address": addr}
262
+
263
+ @app.post("/receive_heart_beat")
264
+ async def receive_heart_beat(request: Request):
265
+ data = await request.json()
266
+ exist = controller.receive_heart_beat(
267
+ data["worker_name"], data["queue_length"])
268
+ return {"exist": exist}
269
+
270
+
271
+ @app.post("/worker_generate_stream")
272
+ async def worker_api_generate_stream(request: Request):
273
+ params = await request.json()
274
+ generator = controller.worker_api_generate_stream(params)
275
+ return StreamingResponse(generator)
276
+
277
+
278
+ @app.post("/worker_get_status")
279
+ async def worker_api_get_status(request: Request):
280
+ return controller.worker_api_get_status()
281
+
282
+
283
+ if __name__ == "__main__":
284
+ parser = argparse.ArgumentParser()
285
+ parser.add_argument("--host", type=str, default="localhost")
286
+ parser.add_argument("--port", type=int, default=21001)
287
+ parser.add_argument("--dispatch-method", type=str, choices=[
288
+ "lottery", "shortest_queue"], default="shortest_queue")
289
+ args = parser.parse_args()
290
+ logger.info(f"args: {args}")
291
+
292
+ controller = Controller(args.dispatch_method)
293
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
starvector/serve/gradio_demo_with_updated_gradio.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import gradio as gr
7
+ import requests
8
+ from starvector.serve.conversation import default_conversation
9
+ from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
10
+ from starvector.serve.util import (build_logger, server_error_msg)
11
+
12
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
13
+ headers = {"User-Agent": "StarVector Client"}
14
+
15
+ no_change_btn = gr.Button()
16
+ enable_btn = gr.Button(interactive=True)
17
+ disable_btn = gr.Button(interactive=False)
18
+
19
+ priority = {
20
+ "starvector-1.4b": "aaaaaaa",
21
+ }
22
+
23
+ def get_conv_log_filename():
24
+ t = datetime.datetime.now()
25
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
26
+ return name
27
+
28
+ def get_model_list():
29
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
30
+ assert ret.status_code == 200
31
+ ret = requests.post(args.controller_url + "/list_models")
32
+ models = ret.json()["models"]
33
+ models.sort(key=lambda x: priority.get(x, x))
34
+ logger.info(f"Models: {models}")
35
+ return models
36
+
37
+ get_window_url_params = """
38
+ function() {
39
+ const params = new URLSearchParams(window.location.search);
40
+ url_params = Object.fromEntries(params);
41
+ console.log(url_params);
42
+ return url_params;
43
+ }
44
+ """
45
+
46
+ def load_demo(url_params, request: gr.Request):
47
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
48
+
49
+ dropdown_update = gr.Dropdown(visible=True)
50
+ if "model" in url_params:
51
+ model = url_params["model"]
52
+ if model in models:
53
+ dropdown_update = gr.Dropdown(value=model, visible=True)
54
+
55
+ state = default_conversation.copy()
56
+ return state, dropdown_update
57
+
58
+
59
+ def load_demo_refresh_model_list(request: gr.Request):
60
+ logger.info(f"load_demo. ip: {request.client.host}")
61
+ models = get_model_list()
62
+ state = default_conversation.copy()
63
+ dropdown_update = gr.Dropdown(
64
+ choices=models,
65
+ value=models[0] if len(models) > 0 else ""
66
+ )
67
+ return state, dropdown_update
68
+
69
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
70
+ with open(get_conv_log_filename(), "a") as fout:
71
+ data = {
72
+ "tstamp": round(time.time(), 4),
73
+ "type": vote_type,
74
+ "model": model_selector,
75
+ "state": state.dict(),
76
+ "ip": request.client.host,
77
+ }
78
+ fout.write(json.dumps(data) + "\n")
79
+
80
+ def upvote_last_response(state, model_selector, request: gr.Request):
81
+ logger.info(f"upvote. ip: {request.client.host}")
82
+ vote_last_response(state, "upvote", model_selector, request)
83
+ return ("",) + (disable_btn,) * 3
84
+
85
+ def downvote_last_response(state, model_selector, request: gr.Request):
86
+ logger.info(f"downvote. ip: {request.client.host}")
87
+ vote_last_response(state, "downvote", model_selector, request)
88
+ return ("",) + (disable_btn,) * 3
89
+
90
+ def flag_last_response(state, model_selector, request: gr.Request):
91
+ logger.info(f"flag. ip: {request.client.host}")
92
+ vote_last_response(state, "flag", model_selector, request)
93
+ return ("",) + (disable_btn,) * 3
94
+
95
+ def regenerate(state, image_process_mode, request: gr.Request):
96
+ logger.info(f"regenerate. ip: {request.client.host}")
97
+ state.messages[-1][-1] = None
98
+ prev_human_msg = state.messages[-2]
99
+ if type(prev_human_msg[1]) in (tuple, list):
100
+ prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
101
+ state.skip_next = False
102
+ return (state, None, None, None) + (disable_btn,) * 6
103
+
104
+ def clear_history(request: gr.Request):
105
+ logger.info(f"clear_history. ip: {request.client.host}")
106
+ state = default_conversation.copy()
107
+ return (state, None, None) + (disable_btn,) * 6
108
+
109
+ def send_image(state, image, image_process_mode, request: gr.Request):
110
+ logger.info(f"send_image. ip: {request.client.host}.")
111
+ state.stop_sampling = False
112
+ if image is None:
113
+ state.skip_next = True
114
+ return (state, None, None, image) + (no_change_btn,) * 6
115
+
116
+ if image is not None:
117
+ text = (image, image_process_mode)
118
+ state.append_message(state.roles[0], text)
119
+ state.append_message(state.roles[1], "β–Œ")
120
+ state.skip_next = False
121
+ msg = state.to_gradio_svg_code()[0][1]
122
+ return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 6
123
+
124
+ def stop_sampling(state, image, request: gr.Request):
125
+ logger.info(f"stop_sampling. ip: {request.client.host}")
126
+ state.stop_sampling = True
127
+ return (state, None, None, image) + (disable_btn,) * 6
128
+
129
+ def http_bot(state, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
130
+ logger.info(f"http_bot. ip: {request.client.host}")
131
+ start_tstamp = time.time()
132
+ model_name = model_selector
133
+
134
+ if state.skip_next:
135
+ # This generate call is skipped due to invalid inputs
136
+ yield (state, None, None) + (no_change_btn,) * 6
137
+ return
138
+
139
+ # Query worker address
140
+ controller_url = args.controller_url
141
+ ret = requests.post(controller_url + "/get_worker_address",
142
+ json={"model": model_name})
143
+ worker_addr = ret.json()["address"]
144
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
145
+
146
+ # No available worker
147
+ if worker_addr == "":
148
+ state.messages[-1][-1] = server_error_msg
149
+ yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
150
+ return
151
+
152
+ # Construct prompt
153
+ prompt = state.get_prompt()
154
+
155
+ # Make requests
156
+ pload = {
157
+ "model": model_name,
158
+ "prompt": prompt,
159
+ "num_beams": int(num_beams),
160
+ "temperature": float(temperature),
161
+ "len_penalty": float(len_penalty),
162
+ "top_p": float(top_p),
163
+ "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
164
+ }
165
+ logger.info(f"==== request ====\n{pload}")
166
+
167
+ pload['images'] = state.get_images()
168
+
169
+ state.messages[-1][-1] = "β–Œ"
170
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn)
171
+
172
+ try:
173
+ # Stream output
174
+ if state.stop_sampling:
175
+ state.messages[1][-1] = "β–Œ"
176
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
177
+ return
178
+
179
+ response = requests.post(worker_addr + "/worker_generate_stream",
180
+ headers=headers, json=pload, stream=True, timeout=100)
181
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
182
+ if chunk:
183
+ data = json.loads(chunk.decode())
184
+ if data["error_code"] == 0:
185
+ # output = data["text"].strip().replace('<', '&lt;').replace('>', '&gt;') # trick to avoid the SVG getting rendered
186
+ output = data["text"].strip()
187
+ state.messages[-1][-1] = output + "β–Œ"
188
+ st = state.to_gradio_svg_code()
189
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn)
190
+ else:
191
+ output = data["text"] + f" (error_code: {data['error_code']})"
192
+ state.messages[-1][-1] = output
193
+
194
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
195
+ return
196
+ time.sleep(0.03)
197
+ except requests.exceptions.RequestException as e:
198
+ state.messages[-1][-1] = server_error_msg
199
+ yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
200
+ return
201
+
202
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 6
203
+
204
+ finish_tstamp = time.time()
205
+ logger.info(f"{output}")
206
+
207
+ with open(get_conv_log_filename(), "a") as fout:
208
+ data = {
209
+ "tstamp": round(finish_tstamp, 4),
210
+ "type": "chat",
211
+ "model": model_name,
212
+ "start": round(start_tstamp, 4),
213
+ "finish": round(finish_tstamp, 4),
214
+ "svg": state.messages[-1][-1],
215
+ "ip": request.client.host,
216
+ }
217
+ fout.write(json.dumps(data) + "\n")
218
+
219
+ title_markdown = ("""
220
+ # πŸ’« StarVector: Generating Scalable Vector Graphics Code from Images and Text
221
+ [[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | πŸ“š [[StarVector](https://arxiv.org/abs/2312.11556)]
222
+ """)
223
+
224
+ sub_title_markdown = (""" Throw an image and vectorize it! The model expects vector-like images to generate the corresponding svg code.""")
225
+ tos_markdown = ("""
226
+ ### Terms of use
227
+ By using this service, users are required to agree to the following terms:
228
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
229
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
230
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
231
+ """)
232
+
233
+
234
+ learn_more_markdown = ("""
235
+ ### License
236
+ The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
237
+ """)
238
+
239
+ block_css = """
240
+
241
+ #buttons button {
242
+ min-width: min(120px,100%);
243
+ }
244
+
245
+ .gradio-container{
246
+ max-width: 1200px!important
247
+ }
248
+
249
+ #svg_render{
250
+ padding: 20px !important;
251
+ }
252
+
253
+ #svg_code{
254
+ height: 200px !important;
255
+ overflow: scroll !important;
256
+ white-space: unset !important;
257
+ flex-shrink: unset !important;
258
+ }
259
+
260
+
261
+ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
262
+ *{transition: width 0.5s ease, flex-grow 0.5s ease}
263
+ """
264
+
265
+ def build_demo(embed_mode, concurrency_count=10):
266
+ with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
267
+ state = gr.State()
268
+ if not embed_mode:
269
+ gr.Markdown(title_markdown)
270
+ gr.Markdown(sub_title_markdown)
271
+ with gr.Row():
272
+ with gr.Column(scale=3):
273
+ with gr.Row(elem_id="model_selector_row"):
274
+ model_selector = gr.Dropdown(
275
+ choices=models,
276
+ value=models[0] if len(models) > 0 else "",
277
+ interactive=True,
278
+ show_label=False,
279
+ container=False)
280
+ imagebox = gr.Image(type="pil")
281
+ image_process_mode = gr.Radio(
282
+ ["Resize", "Pad", "Default"],
283
+ value="Pad",
284
+ label="Preprocess for non-square image", visible=False)
285
+
286
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
287
+ gr.Examples(examples=[
288
+ [f"{cur_dir}/examples/sample-4.png"],
289
+ [f"{cur_dir}/examples/sample-7.png"],
290
+ [f"{cur_dir}/examples/sample-16.png"],
291
+ [f"{cur_dir}/examples/sample-17.png"],
292
+ [f"{cur_dir}/examples/sample-18.png"],
293
+ [f"{cur_dir}/examples/sample-0.png"],
294
+ [f"{cur_dir}/examples/sample-1.png"],
295
+ [f"{cur_dir}/examples/sample-6.png"],
296
+ ], inputs=[imagebox])
297
+
298
+ with gr.Column(scale=1, min_width=50):
299
+ submit_btn = gr.Button(value="Send", variant="primary")
300
+
301
+ with gr.Accordion("Parameters", open=True) as parameter_row:
302
+ num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
303
+ temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.05, interactive=True, label="Temperature",)
304
+ len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
305
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, interactive=True, label="Top P",)
306
+ max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2000, step=64, interactive=True, label="Max output tokens",)
307
+
308
+ with gr.Column(scale=8):
309
+ with gr.Row():
310
+ svg_code = gr.Code(label="SVG Code", elem_id='svg_code', min_width=200, interactive=False, lines=5)
311
+ with gr.Row():
312
+ gr.Image(width=50, height=256, label="Rendered SVG", elem_id='svg_render')
313
+ with gr.Row(elem_id="buttons") as button_row:
314
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
315
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
316
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
317
+ stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False)
318
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False, visible=False)
319
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
320
+
321
+ if not embed_mode:
322
+ gr.Markdown(tos_markdown)
323
+ gr.Markdown(learn_more_markdown)
324
+ url_params = gr.JSON(visible=False)
325
+
326
+ # Register listeners
327
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn]
328
+ upvote_btn.click(
329
+ upvote_last_response,
330
+ [state, model_selector],
331
+ [upvote_btn, downvote_btn, flag_btn],
332
+ queue=False
333
+ )
334
+ downvote_btn.click(
335
+ downvote_last_response,
336
+ [state, model_selector],
337
+ [upvote_btn, downvote_btn, flag_btn],
338
+ queue=False
339
+ )
340
+ flag_btn.click(
341
+ flag_last_response,
342
+ [state, model_selector],
343
+ [upvote_btn, downvote_btn, flag_btn],
344
+ queue=False
345
+ )
346
+
347
+ regenerate_btn.click(
348
+ regenerate,
349
+ [state, image_process_mode],
350
+ [state, svg_code, svg_render, imagebox] + btn_list,
351
+ queue=False
352
+ ).then(
353
+ http_bot,
354
+ [state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
355
+ [state, svg_code, svg_render] + btn_list,
356
+ concurrency_limit=concurrency_count
357
+ )
358
+
359
+ submit_btn.click(
360
+ send_image,
361
+ [state, imagebox, image_process_mode],
362
+ [state, svg_code, svg_render, imagebox] + btn_list,
363
+ queue=False
364
+ ).then(
365
+ http_bot,
366
+ [state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
367
+ [state, svg_code, svg_render] + btn_list,
368
+ concurrency_limit=concurrency_count
369
+ )
370
+
371
+ clear_btn.click(
372
+ clear_history,
373
+ None,
374
+ [state, svg_code, svg_render] + btn_list,
375
+ queue=False
376
+ )
377
+
378
+ stop_btn.click(
379
+ stop_sampling,
380
+ [state, imagebox],
381
+ [state, imagebox] + btn_list,
382
+ queue=False
383
+ ).then(
384
+ clear_history,
385
+ None,
386
+ [state, svg_code, svg_render] + btn_list,
387
+ queue=False
388
+ )
389
+
390
+ if args.model_list_mode == "once":
391
+ demo.load(
392
+ load_demo,
393
+ [url_params],
394
+ [state, model_selector],
395
+ _js=get_window_url_params,
396
+ )
397
+ elif args.model_list_mode == "reload":
398
+ demo.load(
399
+ load_demo_refresh_model_list,
400
+ None,
401
+ [state, model_selector],
402
+ queue=False
403
+ )
404
+ else:
405
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
406
+
407
+ return demo
408
+
409
+ if __name__ == "__main__":
410
+ parser = argparse.ArgumentParser()
411
+ parser.add_argument("--host", type=str, default="0.0.0.0")
412
+ parser.add_argument("--port", type=int)
413
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
414
+ parser.add_argument("--concurrency-count", type=int, default=15)
415
+ parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
416
+ parser.add_argument("--share", action="store_true")
417
+ parser.add_argument("--moderate", action="store_true")
418
+ parser.add_argument("--embed", action="store_true")
419
+ args = parser.parse_args()
420
+ logger.info(f"args: {args}")
421
+
422
+ models = get_model_list()
423
+
424
+ logger.info(args)
425
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
426
+ demo.queue(
427
+ api_open=False
428
+ ).launch(
429
+ server_name=args.host,
430
+ server_port=args.port,
431
+ share=args.share
432
+ )
starvector/serve/gradio_web_server.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import gradio as gr
7
+ import requests
8
+ from starvector.serve.conversation import default_conversation
9
+ from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
10
+ from starvector.serve.util import (build_logger, server_error_msg)
11
+
12
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
13
+ headers = {"User-Agent": "StarVector Client"}
14
+
15
+ no_change_btn = gr.Button.update()
16
+ enable_btn = gr.Button.update(interactive=True)
17
+ disable_btn = gr.Button.update(interactive=False)
18
+
19
+ priority = {
20
+ "starvector-1b-im2svg": "aaaaaaa",
21
+ }
22
+
23
+ def get_conv_log_filename():
24
+ t = datetime.datetime.now()
25
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
26
+ return name
27
+
28
+ def get_model_list():
29
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
30
+ assert ret.status_code == 200
31
+ ret = requests.post(args.controller_url + "/list_models")
32
+ models = ret.json()["models"]
33
+ models.sort(key=lambda x: priority.get(x, x))
34
+ logger.info(f"Models: {models}")
35
+ return models
36
+
37
+ def load_demo(url_params, request: gr.Request):
38
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
39
+
40
+ dropdown_update = gr.Dropdown.update(visible=True)
41
+ if "model" in url_params:
42
+ model = url_params["model"]
43
+ if model in models:
44
+ dropdown_update = gr.Dropdown.update(
45
+ value=model, visible=True)
46
+
47
+ state = default_conversation.copy()
48
+ return state, dropdown_update
49
+
50
+ mapping_model_task = {
51
+ 'Image2SVG': 'im2svg',
52
+ 'Text2SVG': 'text2svg'
53
+ }
54
+
55
+ def get_models_dropdown_from_task(task):
56
+ models = get_model_list()
57
+ models = [model for model in models if mapping_model_task[task] in model]
58
+ dropdown_update = gr.Dropdown.update(
59
+ choices=models,
60
+ value=models[0] if len(models) > 0 else ""
61
+ )
62
+ return dropdown_update
63
+
64
+
65
+ def load_demo_refresh_model_list(task, request: gr.Request):
66
+ logger.info(f"load_demo. ip: {request.client.host}")
67
+ dropdown_update = get_models_dropdown_from_task(task)
68
+ state = default_conversation.copy()
69
+ return state, dropdown_update
70
+
71
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
72
+ with open(get_conv_log_filename(), "a") as fout:
73
+ data = {
74
+ "tstamp": round(time.time(), 4),
75
+ "type": vote_type,
76
+ "model": model_selector,
77
+ "state": state.dict(),
78
+ "ip": request.client.host,
79
+ }
80
+ fout.write(json.dumps(data) + "\n")
81
+
82
+ def upvote_last_response(state, model_selector, request: gr.Request):
83
+ logger.info(f"upvote. ip: {request.client.host}")
84
+ vote_last_response(state, "upvote", model_selector, request)
85
+ return ("",) + (disable_btn,) * 7
86
+
87
+ def downvote_last_response(state, model_selector, request: gr.Request):
88
+ logger.info(f"downvote. ip: {request.client.host}")
89
+ vote_last_response(state, "downvote", model_selector, request)
90
+ return ("",) + (disable_btn,) * 7
91
+
92
+ def flag_last_response(state, model_selector, request: gr.Request):
93
+ logger.info(f"flag. ip: {request.client.host}")
94
+ vote_last_response(state, "flag", model_selector, request)
95
+ return ("",) + (disable_btn,) * 7
96
+
97
+ def regenerate(state, image_process_mode, request: gr.Request):
98
+ logger.info(f"regenerate. ip: {request.client.host}")
99
+ state.messages[-1][-1] = None
100
+ prev_human_msg = state.messages[-2]
101
+ if type(prev_human_msg[1]) in (tuple, list):
102
+ prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
103
+ state.skip_next = False
104
+ return (state, None, None, None) + (disable_btn,) * 7
105
+
106
+ def clear_history(request: gr.Request):
107
+ logger.info(f"clear_history. ip: {request.client.host}")
108
+ state = default_conversation.copy()
109
+ return (state, None, None) + (disable_btn,) * 7
110
+
111
+ def send_data(state, image, image_process_mode, text_caption, task, request: gr.Request):
112
+ logger.info(f"send_data. ip: {request.client.host}.")
113
+ if task == 'Image2SVG':
114
+ if image is None:
115
+ state.skip_next = True
116
+ return (state, None, None, image) + (no_change_btn,) * 7
117
+
118
+ if image is not None:
119
+ image_message = (image, image_process_mode)
120
+ state.append_message(state.roles[0], image_message)
121
+ state.append_message(state.roles[1], "β–Œ")
122
+ state.skip_next = False
123
+ msg = state.to_gradio_svg_code()[0][1]
124
+ return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
125
+ else:
126
+ if text_caption is None:
127
+ state.skip_next = True
128
+ return (state, None, None, image) + (no_change_btn,) * 7
129
+
130
+ state.append_message(state.roles[0], text_caption)
131
+ state.append_message(state.roles[1], "β–Œ")
132
+ state.skip_next = False
133
+ msg = state.to_gradio_svg_code()[0][1]
134
+ return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
135
+
136
+ def download_files(state, request: gr.Request):
137
+ logger.info(f"download_files. ip: {request.client.host}")
138
+ svg_str, image = state.download_files()
139
+
140
+ # TODO: Figure out how to download the SVG in the users browser, idk how to do it now
141
+
142
+ def update_task(task):
143
+ dropdown_update = get_models_dropdown_from_task(task)
144
+
145
+ if task == "Text2SVG":
146
+ return 1.0, 0.9, 0.95, dropdown_update
147
+ else:
148
+ return 0.6, 0.9, 0.95, dropdown_update
149
+
150
+
151
+ def stop_sampling(state, image, request: gr.Request):
152
+ logger.info(f"stop_sampling. ip: {request.client.host}")
153
+ state.stop_sampling = True
154
+ return (state, None, None, image) + (disable_btn,) * 7
155
+
156
+ def http_bot(state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
157
+ logger.info(f"http_bot. ip: {request.client.host}")
158
+ start_tstamp = time.time()
159
+ model_name = model_selector
160
+
161
+ if state.skip_next:
162
+ # This generate call is skipped due to invalid inputs
163
+ yield (state, None, None) + (no_change_btn,) * 7
164
+ return
165
+
166
+ # Query worker address
167
+ controller_url = args.controller_url
168
+ ret = requests.post(controller_url + "/get_worker_address",
169
+ json={"model": model_name})
170
+ worker_addr = ret.json()["address"]
171
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
172
+
173
+ # No available worker
174
+ if worker_addr == "":
175
+ state.messages[-1][-1] = server_error_msg
176
+ yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
177
+ return
178
+
179
+ # Construct prompt
180
+ if task_selector == "Image2SVG":
181
+ prompt = state.get_image_prompt()
182
+ else:
183
+ prompt = text_caption
184
+
185
+ # Make requests
186
+ pload = {
187
+ "model": model_name,
188
+ "prompt": prompt,
189
+ "num_beams": int(num_beams),
190
+ "temperature": float(temperature),
191
+ "len_penalty": float(len_penalty),
192
+ "top_p": float(top_p),
193
+ "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
194
+ }
195
+ logger.info(f"==== request ====\n{pload}")
196
+
197
+ pload['images'] = state.get_images()
198
+
199
+ state.messages[-1][-1] = "β–Œ"
200
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
201
+
202
+ try:
203
+ # Stream output
204
+ if state.stop_sampling:
205
+ state.messages[1][-1] = "β–Œ"
206
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, enable_btn)
207
+ return
208
+
209
+ response = requests.post(worker_addr + "/worker_generate_stream",
210
+ headers=headers, json=pload, stream=True, timeout=10)
211
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
212
+ if chunk:
213
+ data = json.loads(chunk.decode())
214
+ if data["error_code"] == 0:
215
+ # output = data["text"].strip().replace('<', '&lt;').replace('>', '&gt;') # trick to avoid the SVG getting rendered
216
+ output = data["text"].strip()
217
+ state.messages[-1][-1] = output + "β–Œ"
218
+ st = state.to_gradio_svg_code()
219
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn)
220
+ else:
221
+ output = data["text"] + f" (error_code: {data['error_code']})"
222
+ state.messages[-1][-1] = output
223
+
224
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
225
+ return
226
+ time.sleep(0.03)
227
+ except requests.exceptions.RequestException as e:
228
+ state.messages[-1][-1] = server_error_msg
229
+ yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
230
+ return
231
+
232
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7
233
+
234
+ finish_tstamp = time.time()
235
+ logger.info(f"{output}")
236
+
237
+ with open(get_conv_log_filename(), "a") as fout:
238
+ data = {
239
+ "tstamp": round(finish_tstamp, 4),
240
+ "type": "chat",
241
+ "model": model_name,
242
+ "start": round(start_tstamp, 4),
243
+ "finish": round(finish_tstamp, 4),
244
+ "svg": state.messages[-1][-1],
245
+ "ip": request.client.host,
246
+ }
247
+ fout.write(json.dumps(data) + "\n")
248
+
249
+ title_markdown = ("""
250
+ # πŸ’« StarVector: Generating Scalable Vector Graphics Code from Images and Text
251
+
252
+ [[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | πŸ“š [[StarVector](https://arxiv.org/abs/2312.11556)]""")
253
+
254
+ sub_title_markdown = ("""**How does it work?** Select the task you want to perform, and the model will be automatically set. For **Text2SVG**, introduce a prompt in Text Caption. For **Image2SVG**, select an image and vectorize it. \
255
+ **Note**: The current model works on vector-like images like icons and or vector-like designs.""")
256
+ tos_markdown = ("""
257
+ ### Terms of use
258
+ By using this service, users are required to agree to the following terms:
259
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
260
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
261
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
262
+ """)
263
+
264
+ learn_more_markdown = ("""
265
+ ### License
266
+ The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
267
+ """)
268
+
269
+ block_css = """
270
+
271
+ #buttons button {
272
+ min-width: min(120px,100%);
273
+ }
274
+
275
+ .gradio-container{
276
+ max-width: 1200px!important
277
+ }
278
+
279
+ .ΝΌ1 .cm-content {
280
+ white-space: unset !important;
281
+ flex-shrink: unset !important;
282
+ }
283
+
284
+ .ΝΌ2p .cm-scroller {
285
+ max-height: 200px;
286
+ overflow: scroll;
287
+ }
288
+
289
+ #svg_render{
290
+ padding: 20px !important;
291
+ }
292
+
293
+ #submit_btn{
294
+ max-height: 40px;
295
+ }
296
+
297
+ .selector{
298
+ max-height: 100px;
299
+ }
300
+ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
301
+ *{transition: width 0.5s ease, flex-grow 0.5s ease}
302
+ """
303
+ def build_demo(embed_mode):
304
+ svg_render = gr.Image(label="Rendered SVG", elem_id='svg_render', height=300)
305
+ svg_code = gr.Code(label="SVG Code", elem_id='svg_code', interactive=True, lines=5)
306
+
307
+ with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
308
+ state = gr.State()
309
+ if not embed_mode:
310
+ gr.Markdown(title_markdown)
311
+ gr.Markdown(sub_title_markdown)
312
+ with gr.Row():
313
+ with gr.Column(scale=4):
314
+ task_selector = gr.Dropdown(
315
+ choices=["Image2SVG", "Text2SVG"],
316
+ value="Image2SVG",
317
+ label="Task",
318
+ interactive=True,
319
+ show_label=True,
320
+ container=True,
321
+ elem_id="task_selector",
322
+ elem_classes=["selector"],
323
+ )
324
+ model_selector = gr.Dropdown(
325
+ choices=models,
326
+ value=models[0] if len(models) > 0 else "",
327
+ label="Model",
328
+ interactive=True,
329
+ show_label=True,
330
+ container=True,
331
+ elem_classes=["selector"],
332
+ )
333
+
334
+ imagebox = gr.Image(type="pil", visible=True, elem_id="imagebox")
335
+ image_process_mode = gr.Radio(
336
+ ["Resize", "Pad", "Default"],
337
+ value="Pad",
338
+ label="Preprocess for non-square image", visible=False)
339
+
340
+ # Text input
341
+ text_caption = gr.Textbox(label="Text Caption", visible=True, value="The icon of a yellow star", elem_id="text_caption")
342
+
343
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
344
+ gr.Examples(examples=[
345
+ [f"{cur_dir}/examples/sample-4.png"],
346
+ [f"{cur_dir}/examples/sample-7.png"],
347
+ [f"{cur_dir}/examples/sample-16.png"],
348
+ [f"{cur_dir}/examples/sample-17.png"],
349
+ [f"{cur_dir}/examples/sample-18.png"],
350
+ [f"{cur_dir}/examples/sample-0.png"],
351
+ [f"{cur_dir}/examples/sample-1.png"],
352
+ [f"{cur_dir}/examples/sample-6.png"],
353
+ ], inputs=[imagebox], elem_id="examples")
354
+
355
+ submit_btn = gr.Button(value="Send", variant="primary", elem_id="submit_btn", interactive=True)
356
+
357
+ with gr.Accordion("Parameters", open=False):
358
+ num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
359
+ temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.9, step=0.05, interactive=True, label="Temperature",)
360
+ len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
361
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
362
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=1024, step=64, interactive=True, label="Max output tokens",)
363
+
364
+ with gr.Column(scale=9):
365
+ with gr.Row():
366
+ svg_code.render()
367
+ with gr.Row():
368
+ svg_render.render()
369
+
370
+ with gr.Row(elem_id="buttons") as button_row:
371
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
372
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
373
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
374
+ stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False)
375
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False, visible=False)
376
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
377
+ download_btn = gr.Button(value="Download SVG", interactive=False, visible=False)
378
+
379
+ if not embed_mode:
380
+ gr.Markdown(tos_markdown)
381
+ gr.Markdown(learn_more_markdown)
382
+ url_params = gr.JSON(visible=False)
383
+
384
+ # Register listeners
385
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn, download_btn]
386
+ upvote_btn.click(
387
+ upvote_last_response,
388
+ [state, model_selector],
389
+ [upvote_btn, downvote_btn, flag_btn],
390
+ queue=False
391
+ )
392
+ downvote_btn.click(
393
+ downvote_last_response,
394
+ [state, model_selector],
395
+ [upvote_btn, downvote_btn, flag_btn],
396
+ queue=False
397
+ )
398
+ flag_btn.click(
399
+ flag_last_response,
400
+ [state, model_selector],
401
+ [upvote_btn, downvote_btn, flag_btn],
402
+ queue=False
403
+ )
404
+
405
+ regenerate_btn.click(
406
+ regenerate,
407
+ [state, image_process_mode],
408
+ [state, svg_code, svg_render, imagebox] + btn_list,
409
+ queue=False
410
+ ).then(
411
+ http_bot,
412
+ [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
413
+ [state, svg_code, svg_render] + btn_list)
414
+
415
+ submit_btn.click(
416
+ send_data,
417
+ [state, imagebox, image_process_mode, text_caption, task_selector],
418
+ [state, svg_code, svg_render, imagebox] + btn_list,
419
+ queue=False
420
+ ).then(
421
+ http_bot,
422
+ [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
423
+ [state, svg_code, svg_render] + btn_list
424
+ )
425
+
426
+ clear_btn.click(
427
+ clear_history,
428
+ None,
429
+ [state, svg_code, svg_render] + btn_list,
430
+ queue=False
431
+ )
432
+
433
+ stop_btn.click(
434
+ stop_sampling,
435
+ [state, imagebox],
436
+ [state, imagebox] + btn_list,
437
+ queue=False
438
+ ).then(
439
+ clear_history,
440
+ None,
441
+ [state, svg_code, svg_render] + btn_list,
442
+ queue=False
443
+ )
444
+
445
+ download_btn.click(
446
+ download_files,
447
+ [state],
448
+ None,
449
+ queue=False
450
+ )
451
+ task_selector.change(
452
+ update_task,
453
+ inputs=[task_selector],
454
+ outputs=[len_penalty, temperature, top_p, model_selector],
455
+ queue=False,
456
+ _js="""
457
+ function(task) {
458
+ var imageBoxElement = document.getElementById("imagebox");
459
+ var textCaptionElement = document.getElementById("text_caption");
460
+ var examplesElement = document.getElementById("examples");
461
+ if (task === "Text2SVG") {
462
+ imageBoxElement.style.display = "none";
463
+ textCaptionElement.style.display = "block";
464
+ examplesElement.style.display = "none";
465
+ } else if (task === "Image2SVG") {
466
+ imageBoxElement.style.display = "block";
467
+ textCaptionElement.style.display = "none";
468
+ examplesElement.style.display = "block";
469
+ }
470
+ return task;
471
+ }
472
+ """
473
+ )
474
+
475
+ if args.model_list_mode == "once":
476
+ demo.load(
477
+ load_demo,
478
+ [url_params, task_selector],
479
+ [state, model_selector],
480
+ _js="""
481
+ function() {
482
+ const params = new URLSearchParams(window.location.search);
483
+ url_params = Object.fromEntries(params);
484
+ console.log(url_params);
485
+ return url_params;
486
+
487
+ }
488
+ """,
489
+ queue=False
490
+ )
491
+ elif args.model_list_mode == "reload":
492
+ demo.load(
493
+ load_demo_refresh_model_list,
494
+ [task_selector],
495
+ [state, model_selector],
496
+ _js="""
497
+ function(task) {
498
+ var textCaptionElement = document.getElementById("text_caption");
499
+ var autoScrollBottom = true;
500
+ textCaptionElement.style.display = "none";
501
+ function updateScroll(){
502
+ if (autoScrollBottom) {
503
+ var element = document.getElementsByClassName("cm-scroller")[0];
504
+ element.scrollTop = element.scrollHeight;
505
+ }
506
+ }
507
+ function handleScroll() {
508
+ var element = document.getElementsByClassName("cm-scroller")[0];
509
+ //if (element.scrollHeight - element.scrollTop === element.clientHeight) {
510
+ if (element.scrollHeight - (element.scrollTop + element.clientHeight) < 0.2*(element.scrollTop)) {
511
+ // User has scrolled to the bottom, enable auto-scrolling
512
+ autoScrollBottom = true;
513
+ console.log("bottom");
514
+ } else {
515
+ console.log("not bottom");
516
+ // User has scrolled away from the bottom, disable auto-scrolling
517
+ autoScrollBottom = false;
518
+ }
519
+ }
520
+ setInterval(updateScroll,500);
521
+ var element = document.getElementsByClassName("cm-scroller")[0];
522
+ element.addEventListener("scroll", handleScroll);
523
+
524
+ return task;
525
+ }
526
+
527
+ """,
528
+ queue=False,
529
+ )
530
+
531
+ else:
532
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
533
+
534
+ return demo
535
+
536
+ if __name__ == "__main__":
537
+
538
+ parser = argparse.ArgumentParser()
539
+ parser.add_argument("--host", type=str, default="0.0.0.0")
540
+ parser.add_argument("--port", type=int)
541
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
542
+ parser.add_argument("--concurrency-count", type=int, default=10)
543
+ parser.add_argument("--model-list-mode", type=str, default="once",
544
+ choices=["once", "reload"])
545
+ parser.add_argument("--share", action="store_true")
546
+ parser.add_argument("--moderate", action="store_true")
547
+ parser.add_argument("--embed", action="store_true")
548
+ args = parser.parse_args()
549
+ logger.info(f"args: {args}")
550
+
551
+ models = get_model_list()
552
+
553
+ logger.info(args)
554
+ demo = build_demo(args.embed)
555
+ demo.queue(
556
+ concurrency_count=args.concurrency_count,
557
+ api_open=False
558
+ ).launch(
559
+ server_name=args.host,
560
+ server_port=args.port,
561
+ share=args.share
562
+ )
starvector/serve/model_worker.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+ from fastapi import FastAPI, Request, BackgroundTasks
11
+ from fastapi.responses import StreamingResponse
12
+ import requests
13
+ import torch
14
+ import uvicorn
15
+ from functools import partial
16
+ from starvector.serve.constants import WORKER_HEART_BEAT_INTERVAL, CLIP_QUERY_LENGTH
17
+ from starvector.serve.util import (build_logger, server_error_msg,
18
+ pretty_print_semaphore)
19
+ from starvector.model.builder import load_pretrained_model
20
+ from starvector.serve.util import process_images, load_image_from_base64
21
+ from threading import Thread
22
+ from transformers import TextIteratorStreamer
23
+
24
+ GB = 1 << 30
25
+
26
+ worker_id = str(uuid.uuid4())[:6]
27
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
28
+ global_counter = 0
29
+ model_semaphore = None
30
+
31
+ def heart_beat_worker(controller):
32
+ while True:
33
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
34
+ controller.send_heart_beat()
35
+
36
+ class ModelWorker:
37
+ def __init__(self, controller_addr, worker_addr,
38
+ worker_id, no_register,
39
+ model_path, model_base, model_name,
40
+ load_8bit, load_4bit, device):
41
+ self.controller_addr = controller_addr
42
+ self.worker_addr = worker_addr
43
+ self.worker_id = worker_id
44
+ if model_path.endswith("/"):
45
+ model_path = model_path[:-1]
46
+ if model_name is None:
47
+ model_paths = model_path.split("/")
48
+ if model_paths[-1].startswith('checkpoint-'):
49
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
50
+ else:
51
+ self.model_name = model_paths[-1]
52
+ else:
53
+ self.model_name = model_name
54
+
55
+ if "text2svg" in self.model_name.lower():
56
+ self.task = "Text2SVG"
57
+ elif "im2svg" in self.model_name.lower():
58
+ self.task = "Image2SVG"
59
+
60
+ self.device = device
61
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
62
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
63
+ model_path, device=self.device, load_in_8bit=load_8bit, load_in_4bit=load_4bit)
64
+ self.model.to(torch.bfloat16)
65
+ self.is_multimodal = 'starvector' in self.model_name.lower()
66
+
67
+ if not no_register:
68
+ self.register_to_controller()
69
+ self.heart_beat_thread = threading.Thread(
70
+ target=heart_beat_worker, args=(self,))
71
+ self.heart_beat_thread.start()
72
+
73
+ def register_to_controller(self):
74
+ logger.info("Register to controller")
75
+
76
+ url = self.controller_addr + "/register_worker"
77
+ data = {
78
+ "worker_name": self.worker_addr,
79
+ "check_heart_beat": True,
80
+ "worker_status": self.get_status()
81
+ }
82
+ r = requests.post(url, json=data)
83
+ assert r.status_code == 200
84
+
85
+ def send_heart_beat(self):
86
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
87
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
88
+ f"global_counter: {global_counter}")
89
+
90
+ url = self.controller_addr + "/receive_heart_beat"
91
+
92
+ while True:
93
+ try:
94
+ ret = requests.post(url, json={
95
+ "worker_name": self.worker_addr,
96
+ "queue_length": self.get_queue_length()}, timeout=5)
97
+ exist = ret.json()["exist"]
98
+ break
99
+ except requests.exceptions.RequestException as e:
100
+ logger.error(f"heart beat error: {e}")
101
+ time.sleep(5)
102
+
103
+ if not exist:
104
+ self.register_to_controller()
105
+
106
+ def get_queue_length(self):
107
+ if model_semaphore is None:
108
+ return 0
109
+ else:
110
+ return args.limit_model_concurrency - model_semaphore._value + (len(
111
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
112
+
113
+ def get_status(self):
114
+ return {
115
+ "model_names": [self.model_name],
116
+ "speed": 1,
117
+ "queue_length": self.get_queue_length(),
118
+ }
119
+
120
+ @torch.inference_mode()
121
+ def generate_stream(self, params):
122
+ tokenizer, model, image_processor, task = self.tokenizer, self.model, self.image_processor, self.task
123
+
124
+ num_beams = int(params.get("num_beams", 1))
125
+ temperature = float(params.get("temperature", 1.0))
126
+ len_penalty = float(params.get("len_penalty", 1.0))
127
+ top_p = float(params.get("top_p", 1.0))
128
+ max_context_length = getattr(model.config, 'max_position_embeddings', 8192)
129
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15)
130
+ prompt = params["prompt"]
131
+
132
+ if task == "Image2SVG":
133
+ images = params.get("images", None)
134
+ for b64_image in images:
135
+ if b64_image is not None and self.is_multimodal:
136
+ image = load_image_from_base64(b64_image)
137
+ image = process_images(image, image_processor)
138
+ image = image.to(self.model.device, dtype=torch.float16)
139
+ else:
140
+ image = None
141
+
142
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192)
143
+ max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH)
144
+ pre_pend = prompt
145
+ batch = {}
146
+ batch["image"] = image
147
+ generate_method = model.model.generate_im2svg
148
+ else:
149
+ max_new_tokens = min(int(params.get("max_new_tokens", 128)), 8192)
150
+ pre_pend = ""
151
+ batch = {}
152
+ batch['caption'] = [prompt]
153
+ # White PIL image
154
+ batch['image'] = torch.zeros((3, 256, 256), dtype=torch.float16).to(self.model.device)
155
+ generate_method = model.model.generate_text2svg
156
+
157
+ if max_new_tokens < 1:
158
+ yield json.dumps({"text": prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
159
+ return
160
+
161
+ thread = Thread(target=generate_method, kwargs=dict(
162
+ batch=batch,
163
+ prompt=prompt,
164
+ use_nucleus_sampling=True,
165
+ num_beams=num_beams,
166
+ temperature=temperature,
167
+ length_penalty=len_penalty,
168
+ top_p=top_p,
169
+ max_length=max_new_tokens,
170
+ streamer=streamer,
171
+ ))
172
+ thread.start()
173
+
174
+ generated_text = pre_pend
175
+ for new_text in streamer:
176
+ if new_text == " ":
177
+ continue
178
+ generated_text += new_text
179
+ # if generated_text.endswith(stop_str):
180
+ # generated_text = generated_text[:-len(stop_str)]
181
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
182
+
183
+ def generate_stream_gate(self, params):
184
+ try:
185
+ for x in self.generate_stream(params):
186
+ yield x
187
+ except ValueError as e:
188
+ print("Caught ValueError:", e)
189
+ ret = {
190
+ "text": server_error_msg,
191
+ "error_code": 1,
192
+ }
193
+ yield json.dumps(ret).encode() + b"\0"
194
+ except torch.cuda.CudaError as e:
195
+ print("Caught torch.cuda.CudaError:", e)
196
+ ret = {
197
+ "text": server_error_msg,
198
+ "error_code": 1,
199
+ }
200
+ yield json.dumps(ret).encode() + b"\0"
201
+ except Exception as e:
202
+ print("Caught Unknown Error", e)
203
+ ret = {
204
+ "text": server_error_msg,
205
+ "error_code": 1,
206
+ }
207
+ yield json.dumps(ret).encode() + b"\0"
208
+
209
+ app = FastAPI()
210
+
211
+ def release_model_semaphore(fn=None):
212
+ model_semaphore.release()
213
+ if fn is not None:
214
+ fn()
215
+
216
+ @app.post("/worker_generate_stream")
217
+ async def generate_stream(request: Request):
218
+ global model_semaphore, global_counter
219
+ global_counter += 1
220
+ params = await request.json()
221
+
222
+ if model_semaphore is None:
223
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
224
+ await model_semaphore.acquire()
225
+ worker.send_heart_beat()
226
+ generator = worker.generate_stream_gate(params)
227
+ background_tasks = BackgroundTasks()
228
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
229
+ return StreamingResponse(generator, background=background_tasks)
230
+
231
+ @app.post("/worker_get_status")
232
+ async def get_status(request: Request):
233
+ return worker.get_status()
234
+
235
+ if __name__ == "__main__":
236
+ parser = argparse.ArgumentParser()
237
+ parser.add_argument("--host", type=str, default="localhost")
238
+ parser.add_argument("--port", type=int, default=21002)
239
+ parser.add_argument("--worker-address", type=str,
240
+ default="http://localhost:21002")
241
+ parser.add_argument("--controller-address", type=str,
242
+ default="http://localhost:21001")
243
+ parser.add_argument("--model-path", type=str, default="joanrodai/starvector-1.4b")
244
+ parser.add_argument("--model-base", type=str, default=None)
245
+ parser.add_argument("--model-name", type=str)
246
+ parser.add_argument("--device", type=str, default="cuda")
247
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.")
248
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
249
+ parser.add_argument("--stream-interval", type=int, default=1)
250
+ parser.add_argument("--no-register", action="store_true")
251
+ parser.add_argument("--load-8bit", action="store_true")
252
+ parser.add_argument("--load-4bit", action="store_true")
253
+ args = parser.parse_args()
254
+ logger.info(f"args: {args}")
255
+
256
+ if args.multi_modal:
257
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.")
258
+
259
+ worker = ModelWorker(args.controller_address,
260
+ args.worker_address,
261
+ worker_id,
262
+ args.no_register,
263
+ args.model_path,
264
+ args.model_base,
265
+ args.model_name,
266
+ args.load_8bit,
267
+ args.load_4bit,
268
+ args.device)
269
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
starvector/serve/vllm_api_gradio/__pycache__/controller.cpython-311.pyc CHANGED
Binary files a/starvector/serve/vllm_api_gradio/__pycache__/controller.cpython-311.pyc and b/starvector/serve/vllm_api_gradio/__pycache__/controller.cpython-311.pyc differ
 
starvector/serve/vllm_api_gradio/__pycache__/gradio_web_server.cpython-311.pyc CHANGED
Binary files a/starvector/serve/vllm_api_gradio/__pycache__/gradio_web_server.cpython-311.pyc and b/starvector/serve/vllm_api_gradio/__pycache__/gradio_web_server.cpython-311.pyc differ
 
starvector/serve/vllm_api_gradio/__pycache__/model_worker.cpython-311.pyc CHANGED
Binary files a/starvector/serve/vllm_api_gradio/__pycache__/model_worker.cpython-311.pyc and b/starvector/serve/vllm_api_gradio/__pycache__/model_worker.cpython-311.pyc differ
 
starvector/serve/vllm_api_gradio/gradio_web_server.py CHANGED
@@ -204,7 +204,6 @@ def http_bot(state, task_selector, text_caption, model_selector, num_beams, temp
204
 
205
  state.messages[-1][-1] = "β–Œ"
206
  yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
207
-
208
  try:
209
  # Stream output
210
  if state.stop_sampling:
@@ -214,23 +213,33 @@ def http_bot(state, task_selector, text_caption, model_selector, num_beams, temp
214
 
215
  response = requests.post(worker_addr + "/worker_generate_stream",
216
  headers=headers, json=pload, stream=True, timeout=10)
 
 
 
 
217
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
218
  if chunk:
219
  data = json.loads(chunk.decode())
220
  if data["error_code"] == 0:
221
- # output = data["text"].strip().replace('<', '&lt;').replace('>', '&gt;') # trick to avoid the SVG getting rendered
222
  output = data["text"].strip()
223
  state.messages[-1][-1] = output + "β–Œ"
224
- st = state.to_gradio_svg_code()
225
- # Explicitly set the string value without HTML escaping
226
- yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn)
 
 
 
 
227
  else:
 
228
  output = data["text"] + f" (error_code: {data['error_code']})"
229
  state.messages[-1][-1] = output
230
  st = state.to_gradio_svg_code()
231
-
232
- yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
233
  return
 
 
 
234
  except requests.exceptions.RequestException as e:
235
  state.messages[-1][-1] = server_error_msg
236
  yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
@@ -576,7 +585,7 @@ def build_demo(embed_mode):
576
  temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.05, interactive=True, label="Temperature",)
577
  len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, interactive=True, label="Length Penalty",)
578
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
579
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=1024, step=64, interactive=True, label="Max output tokens",)
580
 
581
  with gr.Column(scale=9):
582
  with gr.Row():
 
204
 
205
  state.messages[-1][-1] = "β–Œ"
206
  yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
 
207
  try:
208
  # Stream output
209
  if state.stop_sampling:
 
213
 
214
  response = requests.post(worker_addr + "/worker_generate_stream",
215
  headers=headers, json=pload, stream=True, timeout=10)
216
+
217
+ update_interval = 2 # seconds
218
+ last_update_time = time.time()
219
+
220
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
221
  if chunk:
222
  data = json.loads(chunk.decode())
223
  if data["error_code"] == 0:
 
224
  output = data["text"].strip()
225
  state.messages[-1][-1] = output + "β–Œ"
226
+
227
+ # Only update if sufficient time has passed
228
+ current_time = time.time()
229
+ if current_time - last_update_time >= update_interval:
230
+ st = state.to_gradio_svg_code()
231
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn,) * 7
232
+ last_update_time = current_time
233
  else:
234
+ # handle errors and yield immediately if needed
235
  output = data["text"] + f" (error_code: {data['error_code']})"
236
  state.messages[-1][-1] = output
237
  st = state.to_gradio_svg_code()
238
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn,) * 7
 
239
  return
240
+ # Final yield to ensure the last state is rendered
241
+ st = state.to_gradio_svg_code()
242
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (enable_btn,) * 7
243
  except requests.exceptions.RequestException as e:
244
  state.messages[-1][-1] = server_error_msg
245
  yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
 
585
  temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.05, interactive=True, label="Temperature",)
586
  len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, interactive=True, label="Length Penalty",)
587
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
588
+ max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2048, step=64, interactive=True, label="Max output tokens",)
589
 
590
  with gr.Column(scale=9):
591
  with gr.Row():
starvector/serve/vllm_api_gradio/model_worker.py CHANGED
@@ -117,7 +117,7 @@ class ModelWorker:
117
  temperature = float(params.get("temperature", 1.0))
118
  len_penalty = float(params.get("len_penalty", 1.0))
119
  top_p = float(params.get("top_p", 1.0))
120
- max_context_length = 1000
121
 
122
  # prompt = params["prompt"]
123
  prompt = "<svg "
@@ -132,6 +132,8 @@ class ModelWorker:
132
 
133
  max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192)
134
  max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH)
 
 
135
 
136
  # Use the chat completions endpoint
137
  vllm_endpoint = f"{self.vllm_base_url}/v1/chat/completions"
 
117
  temperature = float(params.get("temperature", 1.0))
118
  len_penalty = float(params.get("len_penalty", 1.0))
119
  top_p = float(params.get("top_p", 1.0))
120
+ max_context_length = 8192
121
 
122
  # prompt = params["prompt"]
123
  prompt = "<svg "
 
132
 
133
  max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192)
134
  max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH)
135
+ # log max new token
136
+ logger.info(f"max_new_tokens: {max_new_tokens}")
137
 
138
  # Use the chat completions endpoint
139
  vllm_endpoint = f"{self.vllm_base_url}/v1/chat/completions"