sohamnk commited on
Commit
ee57fa4
·
verified ·
1 Parent(s): 8b6199a

Update pipeline/__init__.py

Browse files
Files changed (1) hide show
  1. pipeline/__init__.py +11 -3
pipeline/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
  from flask import Flask
 
4
 
5
  FEATURE_WEIGHTS = {"shape": 0.4, "color": 0.5, "texture": 0.1}
6
  FINAL_SCORE_THRESHOLD = 0.5
@@ -24,7 +25,6 @@ processor_gnd = AutoProcessor.from_pretrained(gnd_model_id)
24
  model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
25
 
26
  print("...Loading Segment Anything (SAM) model...")
27
- # IMPORTANT: The path is now relative to the root of the project
28
  sam_checkpoint = "sam_vit_b_01ec64.pth"
29
  sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
30
  predictor = SamPredictor(sam_model)
@@ -34,7 +34,7 @@ bge_model_id = "BAAI/bge-small-en-v1.5"
34
  tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
35
  model_text = AutoModel.from_pretrained(bge_model_id).to(device)
36
 
37
- # Store models in a dictionary to pass to logic functions
38
  models = {
39
  "processor_gnd": processor_gnd,
40
  "model_gnd": model_gnd,
@@ -44,8 +44,16 @@ models = {
44
  "device": device
45
  }
46
 
 
 
 
 
 
 
 
 
 
47
  print("✅ All models loaded successfully.")
48
  print("="*50)
49
 
50
- # Import routes after app and models are defined to avoid circular imports
51
  from pipeline import routes
 
1
  import os
2
  import torch
3
  from flask import Flask
4
+ from sentence_transformers.cross_encoder import CrossEncoder
5
 
6
  FEATURE_WEIGHTS = {"shape": 0.4, "color": 0.5, "texture": 0.1}
7
  FINAL_SCORE_THRESHOLD = 0.5
 
25
  model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
26
 
27
  print("...Loading Segment Anything (SAM) model...")
 
28
  sam_checkpoint = "sam_vit_b_01ec64.pth"
29
  sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
30
  predictor = SamPredictor(sam_model)
 
34
  tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
35
  model_text = AutoModel.from_pretrained(bge_model_id).to(device)
36
 
37
+
38
  models = {
39
  "processor_gnd": processor_gnd,
40
  "model_gnd": model_gnd,
 
44
  "device": device
45
  }
46
 
47
+
48
+ print("...Loading Cross-Encoder model for re-ranking...")
49
+
50
+ cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
51
+
52
+
53
+ models["cross_encoder"] = cross_encoder_model
54
+
55
+
56
  print("✅ All models loaded successfully.")
57
  print("="*50)
58
 
 
59
  from pipeline import routes