banao-tech commited on
Commit
52027db
·
verified ·
1 Parent(s): a5fb61e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -21
main.py CHANGED
@@ -23,37 +23,38 @@ from utils import (
23
  )
24
  import torch
25
 
26
- #yolo_model = get_yolo_model(model_path='best.pt')
27
- #caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="icon_caption_florence")
28
 
29
  from ultralytics import YOLO
30
 
31
- #if not os.path.exists("NewParser/best.pt"):
32
- #os.makedirs("NewParser/best.pt")
 
 
 
 
 
33
 
34
- #try:
35
- #yolo_model = YOLO("best.pt").to("cuda")
36
- #except:
37
- #yolo_model = YOLO("best.pt")
38
  from transformers import AutoProcessor, AutoModelForCausalLM
39
- from transformers import AutoProcessor, Blip2ForConditionalGeneration
40
- # Correctly load the processor and model for Blip-2
 
 
 
41
  try:
42
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
43
- model = Blip2ForConditionalGeneration.from_pretrained(
44
  "microsoft/OmniParser",
45
- torch_dtype=torch.float16, # Assuming you're using a GPU
46
- trust_remote_code=True
47
  ).to("cuda")
48
- except Exception as e:
49
- print(f"Error loading caption model: {e}")
50
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
51
- model = Blip2ForConditionalGeneration.from_pretrained(
52
  "microsoft/OmniParser",
53
  torch_dtype=torch.float16,
54
- trust_remote_code=True
55
- ).to("cpu") # Fallback to CPU if CUDA fails
56
-
57
  print("finish loading model!!!")
58
 
59
  app = FastAPI()
 
23
  )
24
  import torch
25
 
26
+ # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt')
27
+ # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence")
28
 
29
  from ultralytics import YOLO
30
 
31
+ # if not os.path.exists("/data/icon_detect"):
32
+ # os.makedirs("/data/icon_detect")
33
+
34
+ try:
35
+ yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda")
36
+ except:
37
+ yolo_model = YOLO("weights/icon_detect/best.pt")
38
 
 
 
 
 
39
  from transformers import AutoProcessor, AutoModelForCausalLM
40
+
41
+ processor = AutoProcessor.from_pretrained(
42
+ "microsoft/Florence-2-base", trust_remote_code=True
43
+ )
44
+
45
  try:
46
+ model = AutoModelForCausalLM.from_pretrained(
 
47
  "microsoft/OmniParser",
48
+ torch_dtype=torch.float16,
49
+ trust_remote_code=True,
50
  ).to("cuda")
51
+ except:
52
+ model = AutoModelForCausalLM.from_pretrained(
 
 
53
  "microsoft/OmniParser",
54
  torch_dtype=torch.float16,
55
+ trust_remote_code=True,
56
+ )
57
+ caption_model_processor = {"processor": processor, "model": model}
58
  print("finish loading model!!!")
59
 
60
  app = FastAPI()