sadimanna commited on
Commit
348ad13
·
1 Parent(s): 40a433b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -40
app.py CHANGED
@@ -3,65 +3,166 @@ import os
3
  os.environ['TWILIO_ACCOUNT_SID'] = 'AC9f1e8d20a3c92c12340cf1cb543dfc45'
4
  os.environ['TWILIO_AUTH_TOKEN'] = '78b931e178545a8d22c33afae4c1b23c'
5
 
 
 
 
 
 
 
 
 
 
 
6
  import av
7
  import cv2
 
8
  import streamlit as st
9
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
10
 
 
11
  from sample_utils.turn import get_ice_servers
12
 
13
- _type = st.radio("Select transform type", ("noop", "cartoon", "edges", "rotate"))
14
-
15
-
16
- def callback(frame: av.VideoFrame) -> av.VideoFrame:
17
- img = frame.to_ndarray(format="bgr24")
18
-
19
- if _type == "noop":
20
- pass
21
- elif _type == "cartoon":
22
- # prepare color
23
- img_color = cv2.pyrDown(cv2.pyrDown(img))
24
- for _ in range(6):
25
- img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
26
- img_color = cv2.pyrUp(cv2.pyrUp(img_color))
27
-
28
- # prepare edges
29
- img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
30
- img_edges = cv2.adaptiveThreshold(
31
- cv2.medianBlur(img_edges, 7),
32
- 255,
33
- cv2.ADAPTIVE_THRESH_MEAN_C,
34
- cv2.THRESH_BINARY,
35
- 9,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  2,
37
  )
38
- img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB)
39
 
40
- # combine color and edges
41
- img = cv2.bitwise_and(img_color, img_edges)
42
- elif _type == "edges":
43
- # perform edge detection
44
- img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)
45
- elif _type == "rotate":
46
- # rotate image
47
- rows, cols, _ = img.shape
48
- M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1)
49
- img = cv2.warpAffine(img, M, (cols, rows))
50
 
51
- return av.VideoFrame.from_ndarray(img, format="bgr24")
52
 
53
 
54
- webrtc_streamer(
55
- key="opencv-filter",
56
  mode=WebRtcMode.SENDRECV,
57
  rtc_configuration={"iceServers": get_ice_servers()},
58
- video_frame_callback=callback,
59
  media_stream_constraints={"video": True, "audio": False},
60
  async_processing=True,
61
  )
62
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  st.markdown(
64
- "This demo is based on "
65
- "https://github.com/aiortc/aiortc/blob/2362e6d1f0c730a0f8c387bbea76546775ad2fe8/examples/server/server.py#L34. " # noqa: E501
66
  "Many thanks to the project."
67
  )
 
3
  os.environ['TWILIO_ACCOUNT_SID'] = 'AC9f1e8d20a3c92c12340cf1cb543dfc45'
4
  os.environ['TWILIO_AUTH_TOKEN'] = '78b931e178545a8d22c33afae4c1b23c'
5
 
6
+ """Object detection demo with MobileNet SSD.
7
+ This model and code are based on
8
+ https://github.com/robmarkcole/object-detection-app
9
+ """
10
+
11
+ import logging
12
+ import queue
13
+ from pathlib import Path
14
+ from typing import List, NamedTuple
15
+
16
  import av
17
  import cv2
18
+ import numpy as np
19
  import streamlit as st
20
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
21
 
22
+ from sample_utils.download import download_file
23
  from sample_utils.turn import get_ice_servers
24
 
25
+ HERE = Path(__file__).parent
26
+ ROOT = HERE.parent
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
32
+ MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
33
+ PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
34
+ PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
35
+
36
+ CLASSES = [
37
+ "background",
38
+ "aeroplane",
39
+ "bicycle",
40
+ "bird",
41
+ "boat",
42
+ "bottle",
43
+ "bus",
44
+ "car",
45
+ "cat",
46
+ "chair",
47
+ "cow",
48
+ "diningtable",
49
+ "dog",
50
+ "horse",
51
+ "motorbike",
52
+ "person",
53
+ "pottedplant",
54
+ "sheep",
55
+ "sofa",
56
+ "train",
57
+ "tvmonitor",
58
+ ]
59
+
60
+
61
+ class Detection(NamedTuple):
62
+ class_id: int
63
+ label: str
64
+ score: float
65
+ box: np.ndarray
66
+
67
+
68
+ @st.cache_resource # type: ignore
69
+ def generate_label_colors():
70
+ return np.random.uniform(0, 255, size=(len(CLASSES), 3))
71
+
72
+
73
+ COLORS = generate_label_colors()
74
+
75
+ download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
76
+ download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
77
+
78
+
79
+ # Session-specific caching
80
+ cache_key = "object_detection_dnn"
81
+ if cache_key in st.session_state:
82
+ net = st.session_state[cache_key]
83
+ else:
84
+ net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
85
+ st.session_state[cache_key] = net
86
+
87
+ score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
88
+
89
+ # NOTE: The callback will be called in another thread,
90
+ # so use a queue here for thread-safety to pass the data
91
+ # from inside to outside the callback.
92
+ # TODO: A general-purpose shared state object may be more useful.
93
+ result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
94
+
95
+
96
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
97
+ image = frame.to_ndarray(format="bgr24")
98
+
99
+ # Run inference
100
+ blob = cv2.dnn.blobFromImage(
101
+ cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
102
+ )
103
+ net.setInput(blob)
104
+ output = net.forward()
105
+
106
+ h, w = image.shape[:2]
107
+
108
+ # Convert the output array into a structured form.
109
+ output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
110
+ output = output[output[:, 2] >= score_threshold]
111
+ detections = [
112
+ Detection(
113
+ class_id=int(detection[1]),
114
+ label=CLASSES[int(detection[1])],
115
+ score=float(detection[2]),
116
+ box=(detection[3:7] * np.array([w, h, w, h])),
117
+ )
118
+ for detection in output
119
+ ]
120
+
121
+ # Render bounding boxes and captions
122
+ for detection in detections:
123
+ caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
124
+ color = COLORS[detection.class_id]
125
+ xmin, ymin, xmax, ymax = detection.box.astype("int")
126
+
127
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
128
+ cv2.putText(
129
+ image,
130
+ caption,
131
+ (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
132
+ cv2.FONT_HERSHEY_SIMPLEX,
133
+ 0.5,
134
+ color,
135
  2,
136
  )
 
137
 
138
+ result_queue.put(detections)
 
 
 
 
 
 
 
 
 
139
 
140
+ return av.VideoFrame.from_ndarray(image, format="bgr24")
141
 
142
 
143
+ webrtc_ctx = webrtc_streamer(
144
+ key="object-detection",
145
  mode=WebRtcMode.SENDRECV,
146
  rtc_configuration={"iceServers": get_ice_servers()},
147
+ video_frame_callback=video_frame_callback,
148
  media_stream_constraints={"video": True, "audio": False},
149
  async_processing=True,
150
  )
151
 
152
+ if st.checkbox("Show the detected labels", value=True):
153
+ if webrtc_ctx.state.playing:
154
+ labels_placeholder = st.empty()
155
+ # NOTE: The video transformation with object detection and
156
+ # this loop displaying the result labels are running
157
+ # in different threads asynchronously.
158
+ # Then the rendered video frames and the labels displayed here
159
+ # are not strictly synchronized.
160
+ while True:
161
+ result = result_queue.get()
162
+ labels_placeholder.table(result)
163
+
164
  st.markdown(
165
+ "This demo uses a model and code from "
166
+ "https://github.com/robmarkcole/object-detection-app. "
167
  "Many thanks to the project."
168
  )