Gabolozano commited on
Commit
b842d10
·
verified ·
1 Parent(s): 4ce9fc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -10,8 +10,15 @@ config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
10
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
11
  image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
 
13
- # Initialize the pipeline, adjust the confidence_threshold if possible
14
- od_pipe = pipeline(task='object-detection', model=model, image_processor=image_processor, confidence_threshold=0.5)
 
 
 
 
 
 
 
15
 
16
  def draw_detections(image, detections):
17
  # Convert PIL image to a numpy array
@@ -32,19 +39,19 @@ def draw_detections(image, detections):
32
  # Draw rectangles and text with a larger font
33
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
34
  label_text = f'{label} {score:.2f}'
35
- # Increase the font size and text thickness
36
- cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
37
 
38
  # Convert BGR to RGB for displaying
39
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
40
  final_pil_image = Image.fromarray(final_image)
41
  return final_pil_image
42
 
43
- def get_pipeline_prediction(pil_image):
 
 
 
44
  try:
45
- # Ensure PIL image is passed correctly
46
- if isinstance(pil_image, np.ndarray):
47
- pil_image = Image.fromarray(pil_image.astype('uint8'), 'RGB')
48
  pipeline_output = od_pipe(pil_image)
49
  processed_image = draw_detections(pil_image, pipeline_output)
50
  return processed_image, pipeline_output
@@ -57,13 +64,14 @@ with gr.Blocks() as demo:
57
  with gr.Row():
58
  with gr.Column():
59
  inp_image = gr.Image(label="Input image")
 
 
60
  btn_run = gr.Button('Run Detection')
61
  with gr.Column():
62
  with gr.Tab("Annotated Image"):
63
  out_image = gr.Image()
64
  with gr.Tab("Detection Results"):
65
  out_json = gr.JSON()
66
-
67
- btn_run.click(get_pipeline_prediction, inputs=inp_image, outputs=[out_image, out_json])
68
 
69
  demo.launch()
 
10
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
11
  image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
 
13
+ def load_model(threshold):
14
+ # Since changing threshold at runtime for models isn't typically supported directly by the transformers pipeline,
15
+ # we reinitialize the model with the desired configuration when needed.
16
+ config = DetrConfig.from_pretrained("facebook/detr-resnet-50", num_labels=91, threshold=threshold)
17
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
18
+ image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
19
+ return pipeline(task='object-detection', model=model, image_processor=image_processor)
20
+
21
+ od_pipe = load_model(0.5) # Default threshold
22
 
23
  def draw_detections(image, detections):
24
  # Convert PIL image to a numpy array
 
39
  # Draw rectangles and text with a larger font
40
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
41
  label_text = f'{label} {score:.2f}'
42
+ cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
 
43
 
44
  # Convert BGR to RGB for displaying
45
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
46
  final_pil_image = Image.fromarray(final_image)
47
  return final_pil_image
48
 
49
+ def get_pipeline_prediction(threshold, pil_image):
50
+ global od_pipe
51
+ if od_pipe.config.threshold != threshold:
52
+ od_pipe = load_model(threshold)
53
  try:
54
+ pil_image = Image.fromarray(np.array(pil_image))
 
 
55
  pipeline_output = od_pipe(pil_image)
56
  processed_image = draw_detections(pil_image, pipeline_output)
57
  return processed_image, pipeline_output
 
64
  with gr.Row():
65
  with gr.Column():
66
  inp_image = gr.Image(label="Input image")
67
+ slider = gr.Slider(minimum=0, maximum=1, step=0.05, label="Adjust Detection Sensitivity", value=0.5)
68
+ gr.Markdown("Adjust the slider to change the detection sensitivity.")
69
  btn_run = gr.Button('Run Detection')
70
  with gr.Column():
71
  with gr.Tab("Annotated Image"):
72
  out_image = gr.Image()
73
  with gr.Tab("Detection Results"):
74
  out_json = gr.JSON()
75
+ btn_run.click(get_pipeline_prediction, inputs=[slider, inp_image], outputs=[out_image, out_json])
 
76
 
77
  demo.launch()