KingNish commited on
Commit
18e3525
·
verified ·
1 Parent(s): bf334ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -2
app.py CHANGED
@@ -1,8 +1,15 @@
 
1
  import gradio as gr
2
  import numpy as np
3
  import os
4
  import torch
5
  import random
 
 
 
 
 
 
6
 
7
  from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
8
  from PIL import Image
@@ -17,10 +24,23 @@ from modeling.bagel import (
17
  SiglipVisionConfig, SiglipVisionModel
18
  )
19
  from modeling.qwen2 import Qwen2Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  # Model Initialization
23
- model_path = "/path/to/BAGEL-7B-MoT/weights" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
24
 
25
  llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
26
  llm_config.qk_norm = True
@@ -502,4 +522,4 @@ with gr.Blocks() as demo:
502
  </div>
503
  """)
504
 
505
- demo.launch(share=True)
 
1
+ import spaces
2
  import gradio as gr
3
  import numpy as np
4
  import os
5
  import torch
6
  import random
7
+ import subprocess
8
+ subprocess.run(
9
+ "pip install flash-attn --no-build-isolation",
10
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
+ shell=True,
12
+ )
13
 
14
  from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
  from PIL import Image
 
24
  SiglipVisionConfig, SiglipVisionModel
25
  )
26
  from modeling.qwen2 import Qwen2Tokenizer
27
+ from huggingface_hub import snapshot_download
28
+
29
+ save_dir = "./model"
30
+ repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
31
+ cache_dir = save_dir + "/cache"
32
+
33
+ snapshot_download(cache_dir=cache_dir,
34
+ local_dir=save_dir,
35
+ repo_id=repo_id,
36
+ local_dir_use_symlinks=False,
37
+ resume_download=True,
38
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
39
+ )
40
 
41
 
42
  # Model Initialization
43
+ model_path = "./model"
44
 
45
  llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
46
  llm_config.qk_norm = True
 
522
  </div>
523
  """)
524
 
525
+ demo.launch()