File size: 1,800 Bytes
276d537
 
 
 
 
 
 
e678a10
276d537
 
e678a10
 
276d537
e678a10
276d537
e678a10
 
 
276d537
 
e678a10
 
 
 
 
 
276d537
 
 
 
 
 
 
 
e678a10
276d537
 
 
 
 
 
 
 
 
e678a10
276d537
 
 
 
 
 
 
 
 
 
e678a10
276d537
e678a10
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import torch
from flask import Flask

FEATURE_WEIGHTS = {"shape": 0.4, "color": 0.5, "texture": 0.1}
FINAL_SCORE_THRESHOLD = 0.5

# Create Flask app
app = Flask(__name__)

# Load models
print("=" * 50)
print("πŸš€ Initializing application and loading models...")

device_name = os.environ.get("device", "cpu")
device = torch.device(
    'cuda' if 'cuda' in device_name and torch.cuda.is_available() else 'cpu'
)
print(f"🧠 Using device: {device}")

from transformers import (
    AutoProcessor,
    AutoModelForZeroShotObjectDetection,
    AutoTokenizer,
    AutoModel
)
from segment_anything import SamPredictor, sam_model_registry

print("...Loading Grounding DINO model...")
gnd_model_id = "IDEA-Research/grounding-dino-tiny"
processor_gnd = AutoProcessor.from_pretrained(gnd_model_id)
model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)

print("...Loading Segment Anything (SAM) model...")
# IMPORTANT: The path is now relative to the root of the project
sam_checkpoint = "sam_vit_b_01ec64.pth"
sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam_model)

print("...Loading BGE model for text embeddings...")
bge_model_id = "BAAI/bge-small-en-v1.5"
tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
model_text = AutoModel.from_pretrained(bge_model_id).to(device)

# Store models in a dictionary to pass to logic functions
models = {
    "processor_gnd": processor_gnd,
    "model_gnd": model_gnd,
    "predictor": predictor,
    "tokenizer_text": tokenizer_text,
    "model_text": model_text,
    "device": device
}

print("βœ… All models loaded successfully.")
print("=" * 50)

# Import routes after app and models are defined to avoid circular imports
from pipeline import routes