Spaces:
Running
on
Zero
Running
on
Zero
David Day
commited on
Setup ZeroGPU
Browse files- README.md +1 -1
- app.py +14 -27
- model_builder.py +1 -1
- model_worker.py +5 -1
- requirements.txt +2 -4
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.16.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
app.py
CHANGED
@@ -332,7 +332,7 @@ This is the demo for Dr-LLaVA. So far it could only be used for H&E stained Bone
|
|
332 |
</ul>
|
333 |
""")
|
334 |
# Replace 'path_to_image' with the path to your image file
|
335 |
-
gr.Image(value="https://
|
336 |
width=600, interactive=False, type="pil")
|
337 |
with gr.Column(scale=3):
|
338 |
with gr.Row(elem_id="model_selector_row"):
|
@@ -497,7 +497,7 @@ def start_worker():
|
|
497 |
]
|
498 |
return subprocess.Popen(worker_command)
|
499 |
|
500 |
-
|
501 |
parser = argparse.ArgumentParser()
|
502 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
503 |
parser.add_argument("--port", type=int)
|
@@ -510,35 +510,22 @@ def get_args():
|
|
510 |
parser.add_argument("--moderate", action="store_true")
|
511 |
parser.add_argument("--embed", action="store_true")
|
512 |
args = parser.parse_args()
|
513 |
-
|
514 |
-
return args
|
515 |
-
|
516 |
-
|
517 |
-
def start_demo(args):
|
518 |
-
demo = build_demo(args.embed)
|
519 |
-
demo.queue(
|
520 |
-
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
521 |
-
).launch(server_name=args.host, server_port=args.port, share=args.share)
|
522 |
-
|
523 |
-
|
524 |
-
if __name__ == "__main__":
|
525 |
-
args = get_args()
|
526 |
logger.info(f"args: {args}")
|
527 |
|
528 |
controller_proc = start_controller()
|
529 |
worker_proc = start_worker()
|
530 |
|
531 |
# Wait for worker and controller to start
|
532 |
-
time.sleep(
|
533 |
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
|
|
332 |
</ul>
|
333 |
""")
|
334 |
# Replace 'path_to_image' with the path to your image file
|
335 |
+
gr.Image(value="https://davidday.tw/wp-content/uploads/2024/08/Dr-LLa-VA-Fig-1.jpg",
|
336 |
width=600, interactive=False, type="pil")
|
337 |
with gr.Column(scale=3):
|
338 |
with gr.Row(elem_id="model_selector_row"):
|
|
|
497 |
]
|
498 |
return subprocess.Popen(worker_command)
|
499 |
|
500 |
+
if __name__ == "__main__":
|
501 |
parser = argparse.ArgumentParser()
|
502 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
503 |
parser.add_argument("--port", type=int)
|
|
|
510 |
parser.add_argument("--moderate", action="store_true")
|
511 |
parser.add_argument("--embed", action="store_true")
|
512 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
logger.info(f"args: {args}")
|
514 |
|
515 |
controller_proc = start_controller()
|
516 |
worker_proc = start_worker()
|
517 |
|
518 |
# Wait for worker and controller to start
|
519 |
+
time.sleep(60)
|
520 |
|
521 |
+
models = get_model_list()
|
522 |
+
|
523 |
+
logger.info(args)
|
524 |
+
demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
|
525 |
+
demo.queue(
|
526 |
+
api_open=False
|
527 |
+
).launch(
|
528 |
+
server_name=args.host,
|
529 |
+
server_port=args.port,
|
530 |
+
share=args.share
|
531 |
+
)
|
model_builder.py
CHANGED
@@ -23,7 +23,7 @@ from llava.model import *
|
|
23 |
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
|
26 |
-
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="
|
27 |
kwargs = {"device_map": device_map}
|
28 |
|
29 |
if load_8bit:
|
|
|
23 |
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="cpu", load_bf16=False):
|
27 |
kwargs = {"device_map": device_map}
|
28 |
|
29 |
if load_8bit:
|
model_worker.py
CHANGED
@@ -14,6 +14,7 @@ import requests
|
|
14 |
import torch
|
15 |
import uvicorn
|
16 |
from functools import partial
|
|
|
17 |
|
18 |
from peft import PeftModel
|
19 |
|
@@ -72,6 +73,8 @@ class ModelWorker:
|
|
72 |
self.model = PeftModel.from_pretrained(
|
73 |
self.model,
|
74 |
lora_path,
|
|
|
|
|
75 |
)
|
76 |
|
77 |
if not no_register:
|
@@ -127,9 +130,10 @@ class ModelWorker:
|
|
127 |
"queue_length": self.get_queue_length(),
|
128 |
}
|
129 |
|
130 |
-
@
|
131 |
def generate_stream(self, params):
|
132 |
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
|
|
133 |
|
134 |
prompt = params["prompt"]
|
135 |
ori_prompt = prompt
|
|
|
14 |
import torch
|
15 |
import uvicorn
|
16 |
from functools import partial
|
17 |
+
import spaces
|
18 |
|
19 |
from peft import PeftModel
|
20 |
|
|
|
73 |
self.model = PeftModel.from_pretrained(
|
74 |
self.model,
|
75 |
lora_path,
|
76 |
+
torch_device='cpu',
|
77 |
+
device_map="cpu",
|
78 |
)
|
79 |
|
80 |
if not no_register:
|
|
|
130 |
"queue_length": self.get_queue_length(),
|
131 |
}
|
132 |
|
133 |
+
@spaces.GPU
|
134 |
def generate_stream(self, params):
|
135 |
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
136 |
+
logger.info(f'Model devices: {self.model.device}')
|
137 |
|
138 |
prompt = params["prompt"]
|
139 |
ori_prompt = prompt
|
requirements.txt
CHANGED
@@ -2,9 +2,7 @@
|
|
2 |
tokenizers>=0.12.1
|
3 |
torch==2.0.1
|
4 |
torchvision==0.15.2
|
5 |
-
|
6 |
-
pydantic<2.0.0
|
7 |
-
peft==0.4.0
|
8 |
transformers==4.31.0
|
9 |
accelerate==0.21.0
|
10 |
bitsandbytes==0.41.0
|
@@ -12,5 +10,5 @@ sentencepiece==0.1.99
|
|
12 |
einops==0.6.1
|
13 |
einops-exts==0.0.4
|
14 |
timm==0.6.13
|
15 |
-
|
16 |
scipy
|
|
|
2 |
tokenizers>=0.12.1
|
3 |
torch==2.0.1
|
4 |
torchvision==0.15.2
|
5 |
+
peft
|
|
|
|
|
6 |
transformers==4.31.0
|
7 |
accelerate==0.21.0
|
8 |
bitsandbytes==0.41.0
|
|
|
10 |
einops==0.6.1
|
11 |
einops-exts==0.0.4
|
12 |
timm==0.6.13
|
13 |
+
httpx==0.24.0
|
14 |
scipy
|