David Day commited on
Commit
9e1deca
·
unverified ·
1 Parent(s): 43f2643

Setup ZeroGPU

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +14 -27
  3. model_builder.py +1 -1
  4. model_worker.py +5 -1
  5. requirements.txt +2 -4
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -332,7 +332,7 @@ This is the demo for Dr-LLaVA. So far it could only be used for H&E stained Bone
332
  </ul>
333
  """)
334
  # Replace 'path_to_image' with the path to your image file
335
- gr.Image(value="https://i.postimg.cc/tJzyq5Dh/Dr-LLa-VA-Fig-1.png",
336
  width=600, interactive=False, type="pil")
337
  with gr.Column(scale=3):
338
  with gr.Row(elem_id="model_selector_row"):
@@ -497,7 +497,7 @@ def start_worker():
497
  ]
498
  return subprocess.Popen(worker_command)
499
 
500
- def get_args():
501
  parser = argparse.ArgumentParser()
502
  parser.add_argument("--host", type=str, default="0.0.0.0")
503
  parser.add_argument("--port", type=int)
@@ -510,35 +510,22 @@ def get_args():
510
  parser.add_argument("--moderate", action="store_true")
511
  parser.add_argument("--embed", action="store_true")
512
  args = parser.parse_args()
513
-
514
- return args
515
-
516
-
517
- def start_demo(args):
518
- demo = build_demo(args.embed)
519
- demo.queue(
520
- concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
521
- ).launch(server_name=args.host, server_port=args.port, share=args.share)
522
-
523
-
524
- if __name__ == "__main__":
525
- args = get_args()
526
  logger.info(f"args: {args}")
527
 
528
  controller_proc = start_controller()
529
  worker_proc = start_worker()
530
 
531
  # Wait for worker and controller to start
532
- time.sleep(10)
533
 
534
- exit_status = 0
535
- try:
536
- start_demo(args)
537
- except Exception as e:
538
- print(e)
539
- exit_status = 1
540
- finally:
541
- worker_proc.kill()
542
- controller_proc.kill()
543
-
544
- sys.exit(exit_status)
 
332
  </ul>
333
  """)
334
  # Replace 'path_to_image' with the path to your image file
335
+ gr.Image(value="https://davidday.tw/wp-content/uploads/2024/08/Dr-LLa-VA-Fig-1.jpg",
336
  width=600, interactive=False, type="pil")
337
  with gr.Column(scale=3):
338
  with gr.Row(elem_id="model_selector_row"):
 
497
  ]
498
  return subprocess.Popen(worker_command)
499
 
500
+ if __name__ == "__main__":
501
  parser = argparse.ArgumentParser()
502
  parser.add_argument("--host", type=str, default="0.0.0.0")
503
  parser.add_argument("--port", type=int)
 
510
  parser.add_argument("--moderate", action="store_true")
511
  parser.add_argument("--embed", action="store_true")
512
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  logger.info(f"args: {args}")
514
 
515
  controller_proc = start_controller()
516
  worker_proc = start_worker()
517
 
518
  # Wait for worker and controller to start
519
+ time.sleep(60)
520
 
521
+ models = get_model_list()
522
+
523
+ logger.info(args)
524
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
525
+ demo.queue(
526
+ api_open=False
527
+ ).launch(
528
+ server_name=args.host,
529
+ server_port=args.port,
530
+ share=args.share
531
+ )
model_builder.py CHANGED
@@ -23,7 +23,7 @@ from llava.model import *
23
  from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", load_bf16=False):
27
  kwargs = {"device_map": device_map}
28
 
29
  if load_8bit:
 
23
  from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="cpu", load_bf16=False):
27
  kwargs = {"device_map": device_map}
28
 
29
  if load_8bit:
model_worker.py CHANGED
@@ -14,6 +14,7 @@ import requests
14
  import torch
15
  import uvicorn
16
  from functools import partial
 
17
 
18
  from peft import PeftModel
19
 
@@ -72,6 +73,8 @@ class ModelWorker:
72
  self.model = PeftModel.from_pretrained(
73
  self.model,
74
  lora_path,
 
 
75
  )
76
 
77
  if not no_register:
@@ -127,9 +130,10 @@ class ModelWorker:
127
  "queue_length": self.get_queue_length(),
128
  }
129
 
130
- @torch.inference_mode()
131
  def generate_stream(self, params):
132
  tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
 
133
 
134
  prompt = params["prompt"]
135
  ori_prompt = prompt
 
14
  import torch
15
  import uvicorn
16
  from functools import partial
17
+ import spaces
18
 
19
  from peft import PeftModel
20
 
 
73
  self.model = PeftModel.from_pretrained(
74
  self.model,
75
  lora_path,
76
+ torch_device='cpu',
77
+ device_map="cpu",
78
  )
79
 
80
  if not no_register:
 
130
  "queue_length": self.get_queue_length(),
131
  }
132
 
133
+ @spaces.GPU
134
  def generate_stream(self, params):
135
  tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
136
+ logger.info(f'Model devices: {self.model.device}')
137
 
138
  prompt = params["prompt"]
139
  ori_prompt = prompt
requirements.txt CHANGED
@@ -2,9 +2,7 @@
2
  tokenizers>=0.12.1
3
  torch==2.0.1
4
  torchvision==0.15.2
5
- deepspeed==0.9.5
6
- pydantic<2.0.0
7
- peft==0.4.0
8
  transformers==4.31.0
9
  accelerate==0.21.0
10
  bitsandbytes==0.41.0
@@ -12,5 +10,5 @@ sentencepiece==0.1.99
12
  einops==0.6.1
13
  einops-exts==0.0.4
14
  timm==0.6.13
15
- numpy<2
16
  scipy
 
2
  tokenizers>=0.12.1
3
  torch==2.0.1
4
  torchvision==0.15.2
5
+ peft
 
 
6
  transformers==4.31.0
7
  accelerate==0.21.0
8
  bitsandbytes==0.41.0
 
10
  einops==0.6.1
11
  einops-exts==0.0.4
12
  timm==0.6.13
13
+ httpx==0.24.0
14
  scipy