Update app.py
Browse files
app.py
CHANGED
@@ -1,19 +1,16 @@
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image, ImageDraw
|
3 |
import torch
|
4 |
-
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
5 |
from transformers.image_transforms import center_to_corners_format
|
6 |
from transformers.models.owlvit.modeling_owlvit import box_iou
|
7 |
from functools import partial
|
8 |
|
9 |
-
# from utils import iou
|
10 |
|
11 |
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
|
12 |
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
13 |
|
14 |
-
from transformers.models.owlvit.modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput
|
15 |
-
|
16 |
-
|
17 |
|
18 |
|
19 |
|
@@ -69,8 +66,6 @@ def class_predictor(
|
|
69 |
|
70 |
|
71 |
|
72 |
-
|
73 |
-
|
74 |
def get_max_iou_indice(target_pred_boxes, query_box, target_sizes):
|
75 |
boxes = center_to_corners_format(target_pred_boxes)
|
76 |
img_h, img_w = target_sizes.unbind(1)
|
@@ -109,12 +104,6 @@ def box_guided_detection(
|
|
109 |
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
110 |
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
111 |
|
112 |
-
# batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape
|
113 |
-
# query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
114 |
-
# # Get top class embedding and best box index for each query image in batch
|
115 |
-
# query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map)
|
116 |
-
|
117 |
-
# Predict object boxes
|
118 |
target_pred_boxes = self.box_predictor(image_feats, feature_map)
|
119 |
|
120 |
# Get MAX IOU box corresponding embedding
|
@@ -124,9 +113,6 @@ def box_guided_detection(
|
|
124 |
(pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_indice=query_indice)
|
125 |
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
if not return_dict:
|
131 |
output = (
|
132 |
feature_map,
|
@@ -179,7 +165,7 @@ def threshold_change(xmin, ymin, xmax, ymax, image, threshold, nms):
|
|
179 |
labels = list(zip(boxes, scores))
|
180 |
labels.append((manul_box, "manual"))
|
181 |
|
182 |
-
cnt = len(boxes)
|
183 |
|
184 |
return (image, labels), cnt
|
185 |
|
@@ -198,7 +184,7 @@ def one_shot_detect(xmin, ymin, xmax, ymax, image, threshold, nms):
|
|
198 |
labels = list(zip(boxes, scores))
|
199 |
labels.append((manul_box, "manual"))
|
200 |
|
201 |
-
cnt = len(boxes)
|
202 |
|
203 |
return (image, labels), cnt
|
204 |
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image, ImageDraw
|
3 |
import torch
|
4 |
+
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
5 |
from transformers.image_transforms import center_to_corners_format
|
6 |
from transformers.models.owlvit.modeling_owlvit import box_iou
|
7 |
from functools import partial
|
8 |
|
|
|
9 |
|
10 |
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
|
11 |
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
12 |
|
13 |
+
from transformers.models.owlvit.modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput
|
|
|
|
|
14 |
|
15 |
|
16 |
|
|
|
66 |
|
67 |
|
68 |
|
|
|
|
|
69 |
def get_max_iou_indice(target_pred_boxes, query_box, target_sizes):
|
70 |
boxes = center_to_corners_format(target_pred_boxes)
|
71 |
img_h, img_w = target_sizes.unbind(1)
|
|
|
104 |
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
105 |
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
target_pred_boxes = self.box_predictor(image_feats, feature_map)
|
108 |
|
109 |
# Get MAX IOU box corresponding embedding
|
|
|
113 |
(pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_indice=query_indice)
|
114 |
|
115 |
|
|
|
|
|
|
|
116 |
if not return_dict:
|
117 |
output = (
|
118 |
feature_map,
|
|
|
165 |
labels = list(zip(boxes, scores))
|
166 |
labels.append((manul_box, "manual"))
|
167 |
|
168 |
+
cnt = len(boxes)
|
169 |
|
170 |
return (image, labels), cnt
|
171 |
|
|
|
184 |
labels = list(zip(boxes, scores))
|
185 |
labels.append((manul_box, "manual"))
|
186 |
|
187 |
+
cnt = len(boxes)
|
188 |
|
189 |
return (image, labels), cnt
|
190 |
|