HaohuaLv commited on
Commit
293d766
·
1 Parent(s): a3597eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -18
app.py CHANGED
@@ -1,19 +1,16 @@
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
  import torch
4
- from transformers import OwlViTProcessor, OwlViTForObjectDetection, OwlViTModel, OwlViTImageProcessor
5
  from transformers.image_transforms import center_to_corners_format
6
  from transformers.models.owlvit.modeling_owlvit import box_iou
7
  from functools import partial
8
 
9
- # from utils import iou
10
 
11
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
12
  model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
13
 
14
- from transformers.models.owlvit.modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput, OwlViTClassPredictionHead
15
-
16
-
17
 
18
 
19
 
@@ -69,8 +66,6 @@ def class_predictor(
69
 
70
 
71
 
72
-
73
-
74
  def get_max_iou_indice(target_pred_boxes, query_box, target_sizes):
75
  boxes = center_to_corners_format(target_pred_boxes)
76
  img_h, img_w = target_sizes.unbind(1)
@@ -109,12 +104,6 @@ def box_guided_detection(
109
  batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
110
  image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
111
 
112
- # batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape
113
- # query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim))
114
- # # Get top class embedding and best box index for each query image in batch
115
- # query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map)
116
-
117
- # Predict object boxes
118
  target_pred_boxes = self.box_predictor(image_feats, feature_map)
119
 
120
  # Get MAX IOU box corresponding embedding
@@ -124,9 +113,6 @@ def box_guided_detection(
124
  (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_indice=query_indice)
125
 
126
 
127
-
128
-
129
-
130
  if not return_dict:
131
  output = (
132
  feature_map,
@@ -179,7 +165,7 @@ def threshold_change(xmin, ymin, xmax, ymax, image, threshold, nms):
179
  labels = list(zip(boxes, scores))
180
  labels.append((manul_box, "manual"))
181
 
182
- cnt = len(boxes) - 1
183
 
184
  return (image, labels), cnt
185
 
@@ -198,7 +184,7 @@ def one_shot_detect(xmin, ymin, xmax, ymax, image, threshold, nms):
198
  labels = list(zip(boxes, scores))
199
  labels.append((manul_box, "manual"))
200
 
201
- cnt = len(boxes) - 1
202
 
203
  return (image, labels), cnt
204
 
 
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
  import torch
4
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
5
  from transformers.image_transforms import center_to_corners_format
6
  from transformers.models.owlvit.modeling_owlvit import box_iou
7
  from functools import partial
8
 
 
9
 
10
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
11
  model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
12
 
13
+ from transformers.models.owlvit.modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput
 
 
14
 
15
 
16
 
 
66
 
67
 
68
 
 
 
69
  def get_max_iou_indice(target_pred_boxes, query_box, target_sizes):
70
  boxes = center_to_corners_format(target_pred_boxes)
71
  img_h, img_w = target_sizes.unbind(1)
 
104
  batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
105
  image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
106
 
 
 
 
 
 
 
107
  target_pred_boxes = self.box_predictor(image_feats, feature_map)
108
 
109
  # Get MAX IOU box corresponding embedding
 
113
  (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_indice=query_indice)
114
 
115
 
 
 
 
116
  if not return_dict:
117
  output = (
118
  feature_map,
 
165
  labels = list(zip(boxes, scores))
166
  labels.append((manul_box, "manual"))
167
 
168
+ cnt = len(boxes)
169
 
170
  return (image, labels), cnt
171
 
 
184
  labels = list(zip(boxes, scores))
185
  labels.append((manul_box, "manual"))
186
 
187
+ cnt = len(boxes)
188
 
189
  return (image, labels), cnt
190