OmniSVG commited on
Commit
1948259
·
verified ·
1 Parent(s): 1245532

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -12,7 +12,7 @@ import glob
12
 
13
 
14
  from decoder import SketchDecoder
15
- from transformers import AutoTokenizer, AutoProcessor
16
  from qwen_vl_utils import process_vision_info
17
  from tokenizer import SVGTokenizer
18
 
@@ -52,8 +52,9 @@ def load_models():
52
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
53
 
54
  sketch_decoder = SketchDecoder()
55
-
56
- sketch_weight_path = "https://huggingface.co/OmniSVG/OmniSVG/resolve/main/OmniSVG-3B.bin"
 
57
  sketch_decoder.load_state_dict(torch.load(sketch_weight_path))
58
  sketch_decoder = sketch_decoder.to(device).eval()
59
 
 
12
 
13
 
14
  from decoder import SketchDecoder
15
+ from transformers import AutoTokenizer, AutoProcessor, Qwen2_5_VLForConditionalGeneration
16
  from qwen_vl_utils import process_vision_info
17
  from tokenizer import SVGTokenizer
18
 
 
52
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
53
 
54
  sketch_decoder = SketchDecoder()
55
+
56
+ sketch_weight_path = Qwen2_5_VLForConditionalGeneration.from_pretrained("OmniSVG/OmniSVG")
57
+ #sketch_weight_path = "https://huggingface.co/OmniSVG/OmniSVG/resolve/main/OmniSVG-3B.bin"
58
  sketch_decoder.load_state_dict(torch.load(sketch_weight_path))
59
  sketch_decoder = sketch_decoder.to(device).eval()
60