IZERE HIRWA Roger commited on
Commit
46331c0
Β·
1 Parent(s): cdcf202
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -21,6 +21,7 @@ from flask_cors import CORS
21
  import torch
22
  from groundingdino.util.inference import load_model, predict
23
  from segment_anything import sam_model_registry, SamPredictor
 
24
 
25
  # ─── Load models once ───────────────────────────────────────────────────────────
26
  device = torch.device("cpu")
@@ -39,20 +40,24 @@ app = Flask(__name__)
39
  CORS(app)
40
 
41
  def segment(image_pil: Image.Image, prompt: str):
42
- # Convert PIL image to numpy array and normalize
43
- image_np = np.array(image_pil).astype(np.float32) / 255.0 # Normalize to [0, 1]
44
-
45
- # Convert numpy array to torch tensor
46
- image_tensor = torch.tensor(image_np).permute(2, 0, 1).unsqueeze(0).to(device) # Convert to CHW format
 
 
 
 
47
 
48
  # Run GroundingDINO to get boxes for the prompt
49
  boxes, _, _ = predict(
50
  model=grounder,
51
- image=image_tensor, # Pass normalized tensor
52
  caption=prompt,
53
  box_threshold=0.3,
54
  text_threshold=0.25,
55
- device="cpu" # Explicitly set device to CPU
56
  )
57
  if boxes.size == 0:
58
  raise ValueError("No boxes found for prompt.")
 
21
  import torch
22
  from groundingdino.util.inference import load_model, predict
23
  from segment_anything import sam_model_registry, SamPredictor
24
+ import groundingdino.datasets.transforms as T
25
 
26
  # ─── Load models once ───────────────────────────────────────────────────────────
27
  device = torch.device("cpu")
 
40
  CORS(app)
41
 
42
  def segment(image_pil: Image.Image, prompt: str):
43
+ # Use the proper image preprocessing for GroundingDINO
44
+ transform = T.Compose([
45
+ T.RandomResize([800], max_size=1333),
46
+ T.ToTensor(),
47
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
48
+ ])
49
+
50
+ image_transformed, _ = transform(image_pil, None)
51
+ image_transformed = image_transformed.to(device)
52
 
53
  # Run GroundingDINO to get boxes for the prompt
54
  boxes, _, _ = predict(
55
  model=grounder,
56
+ image=image_transformed,
57
  caption=prompt,
58
  box_threshold=0.3,
59
  text_threshold=0.25,
60
+ device="cpu"
61
  )
62
  if boxes.size == 0:
63
  raise ValueError("No boxes found for prompt.")