techysanoj commited on
Commit
65f769b
·
1 Parent(s): b2f7aa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -34
app.py CHANGED
@@ -1,34 +1,14 @@
1
  import gradio as gr
2
  import cv2
3
  import torch
4
- from torchvision import transforms
5
  from PIL import Image
 
6
 
7
- # Load the pre-trained object detection model (replace with your own model)
8
- # For example, using a torchvision model for demonstration purposes
9
- model = torch.hub.load('pytorch/vision:v0.10.0', 'fasterrcnn_resnet50_fpn', pretrained=True)
10
  model.eval()
11
 
12
- # Define the transformations for the input image
13
- transform = transforms.Compose([
14
- transforms.ToTensor(),
15
- ])
16
-
17
- # Function to perform object detection on an image
18
- def detect_objects(image):
19
- # Convert image to tensor
20
- input_tensor = transform(image).unsqueeze(0)
21
-
22
- # Perform object detection
23
- with torch.no_grad():
24
- predictions = model(input_tensor)
25
-
26
- # Extract bounding boxes and labels from predictions
27
- boxes = predictions[0]['boxes'].numpy()
28
- labels = predictions[0]['labels'].numpy()
29
-
30
- return boxes, labels
31
-
32
  # Function for live object detection from the camera
33
  def live_object_detection():
34
  # Open a connection to the camera (replace with your own camera setup)
@@ -41,14 +21,21 @@ def live_object_detection():
41
  # Convert the frame to PIL Image
42
  frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
43
 
44
- # Perform object detection
45
- boxes, labels = detect_objects(frame_pil)
 
 
 
 
 
 
46
 
47
  # Draw bounding boxes on the frame
48
- for box, label in zip(boxes, labels):
49
- box = [int(coord) for coord in box]
50
  cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
51
- cv2.putText(frame, f"Label: {label}", (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
 
52
 
53
  # Display the resulting frame
54
  cv2.imshow('Object Detection', frame)
@@ -63,11 +50,8 @@ def live_object_detection():
63
 
64
  # Define the Gradio interface
65
  iface = gr.Interface(
66
- fn=[detect_objects, live_object_detection],
67
- inputs=[
68
- gr.Image(type="pil", label="Upload a photo for object detection"),
69
- "webcam",
70
- ],
71
  outputs="image",
72
  live=True,
73
  )
 
1
  import gradio as gr
2
  import cv2
3
  import torch
 
4
  from PIL import Image
5
+ from transformers import DetrImageProcessor, DetrForObjectDetection
6
 
7
+ # Load the pre-trained DETR model
8
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
9
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
10
  model.eval()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Function for live object detection from the camera
13
  def live_object_detection():
14
  # Open a connection to the camera (replace with your own camera setup)
 
21
  # Convert the frame to PIL Image
22
  frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
23
 
24
+ # Process the frame with the DETR model
25
+ inputs = processor(images=frame_pil, return_tensors="pt")
26
+ outputs = model(**inputs)
27
+
28
+ # convert outputs (bounding boxes and class logits) to COCO API
29
+ # let's only keep detections with score > 0.9
30
+ target_sizes = torch.tensor([frame_pil.size[::-1]])
31
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
 
33
  # Draw bounding boxes on the frame
34
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
35
+ box = [int(round(i)) for i in box.tolist()]
36
  cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
37
+ cv2.putText(frame, f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}",
38
+ (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
39
 
40
  # Display the resulting frame
41
  cv2.imshow('Object Detection', frame)
 
50
 
51
  # Define the Gradio interface
52
  iface = gr.Interface(
53
+ fn=live_object_detection,
54
+ inputs="webcam",
 
 
 
55
  outputs="image",
56
  live=True,
57
  )