File size: 1,800 Bytes
276d537 e678a10 276d537 e678a10 276d537 e678a10 276d537 e678a10 276d537 e678a10 276d537 e678a10 276d537 e678a10 276d537 e678a10 276d537 e678a10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
import torch
from flask import Flask
FEATURE_WEIGHTS = {"shape": 0.4, "color": 0.5, "texture": 0.1}
FINAL_SCORE_THRESHOLD = 0.5
# Create Flask app
app = Flask(__name__)
# Load models
print("=" * 50)
print("π Initializing application and loading models...")
device_name = os.environ.get("device", "cpu")
device = torch.device(
'cuda' if 'cuda' in device_name and torch.cuda.is_available() else 'cpu'
)
print(f"π§ Using device: {device}")
from transformers import (
AutoProcessor,
AutoModelForZeroShotObjectDetection,
AutoTokenizer,
AutoModel
)
from segment_anything import SamPredictor, sam_model_registry
print("...Loading Grounding DINO model...")
gnd_model_id = "IDEA-Research/grounding-dino-tiny"
processor_gnd = AutoProcessor.from_pretrained(gnd_model_id)
model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
print("...Loading Segment Anything (SAM) model...")
# IMPORTANT: The path is now relative to the root of the project
sam_checkpoint = "sam_vit_b_01ec64.pth"
sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam_model)
print("...Loading BGE model for text embeddings...")
bge_model_id = "BAAI/bge-small-en-v1.5"
tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
model_text = AutoModel.from_pretrained(bge_model_id).to(device)
# Store models in a dictionary to pass to logic functions
models = {
"processor_gnd": processor_gnd,
"model_gnd": model_gnd,
"predictor": predictor,
"tokenizer_text": tokenizer_text,
"model_text": model_text,
"device": device
}
print("β
All models loaded successfully.")
print("=" * 50)
# Import routes after app and models are defined to avoid circular imports
from pipeline import routes
|