andrewkatumba commited on
Commit
6c899dd
·
verified ·
1 Parent(s): ee1bcff

Add plotters for bounding boxes

Browse files
Files changed (1) hide show
  1. app.py +42 -30
app.py CHANGED
@@ -2,7 +2,9 @@ from transformers import pipeline, SamModel, SamProcessor
2
  import torch
3
  import numpy as np
4
  import spaces
 
5
 
 
6
  checkpoint = "google/owlvit-base-patch16"
7
  detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
8
  sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
@@ -10,40 +12,50 @@ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
10
 
11
  @spaces.GPU
12
  def query(image, texts, threshold):
13
- texts = texts.split(",")
14
-
15
- predictions = detector(
16
- image,
17
- candidate_labels=texts,
18
- threshold=threshold
19
- )
20
-
21
- result_labels = []
22
- for pred in predictions:
23
-
24
- box = pred["box"]
25
- score = pred["score"]
26
- label = pred["label"]
27
- box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
28
- round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
29
-
30
- inputs = sam_processor(
 
 
 
 
31
  image,
32
- input_boxes=[[[box]]],
33
  return_tensors="pt"
34
  ).to("cuda")
35
 
36
- with torch.no_grad():
37
- outputs = sam_model(**inputs)
38
-
39
- mask = sam_processor.image_processor.post_process_masks(
40
- outputs.pred_masks.cpu(),
41
- inputs["original_sizes"].cpu(),
42
- inputs["reshaped_input_sizes"].cpu()
43
- )[0][0][0].numpy()
44
- mask = mask[np.newaxis, ...]
45
- result_labels.append((mask, label))
46
- return image, result_labels
 
 
 
 
 
 
47
 
48
  import gradio as gr
49
 
 
2
  import torch
3
  import numpy as np
4
  import spaces
5
+ from PIL import Image, ImageDraw
6
 
7
+ # Load models (unchanged)
8
  checkpoint = "google/owlvit-base-patch16"
9
  detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
10
  sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
 
12
 
13
  @spaces.GPU
14
  def query(image, texts, threshold):
15
+ texts = texts.split(",")
16
+
17
+ # --- Object Detection (unchanged) ---
18
+ predictions = detector(
19
+ image,
20
+ candidate_labels=texts,
21
+ threshold=threshold
22
+ )
23
+
24
+ result_labels = []
25
+ draw = ImageDraw.Draw(image) # Create a drawing object for the image
26
+
27
+ for pred in predictions:
28
+ box = pred["box"]
29
+ score = pred["score"]
30
+ label = pred["label"]
31
+
32
+ # Round box coordinates for display and SAM input (mostly unchanged)
33
+ box = [round(coord, 2) for coord in list(box.values())]
34
+
35
+ # --- Segmentation (unchanged) ---
36
+ inputs = sam_processor(
37
  image,
38
+ input_boxes=[[[box]]], # Note: SAM expects a nested list
39
  return_tensors="pt"
40
  ).to("cuda")
41
 
42
+ with torch.no_grad():
43
+ outputs = sam_model(**inputs)
44
+
45
+ mask = sam_processor.image_processor.post_process_masks(
46
+ outputs.pred_masks.cpu(),
47
+ inputs["original_sizes"].cpu(),
48
+ inputs["reshaped_input_sizes"].cpu()
49
+ )[0][0][0].numpy()
50
+ mask = mask[np.newaxis, ...]
51
+ result_labels.append((mask, label))
52
+
53
+ # --- Draw Bounding Box ---
54
+ draw.rectangle(box, outline="red", width=3) # Draw rectangle with a red outline
55
+ draw.text((box[0], box[1] - 10), label, fill="red") # Add label above the box
56
+
57
+ return image, result_labels # Return the modified image
58
+
59
 
60
  import gradio as gr
61