sadimanna commited on
Commit
ed8f39e
·
1 Parent(s): 4fb7cec

updated app.py

Browse files
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Emotion Detection:
3
+ Model from: https://github.com/onnx/models/blob/main/vision/body_analysis/emotion_ferplus/model/emotion-ferplus-8.onnx
4
+ Model name: emotion-ferplus-8.onnx
5
+ """
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import time
10
+ import os
11
+
12
+ from cv2 import dnn
13
+ from math import ceil
14
+
15
+ import logging
16
+ import queue
17
+ from pathlib import Path
18
+ from typing import List, NamedTuple
19
+
20
+ import av
21
+ import streamlit as st
22
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
23
+
24
+ from sample_utils.download import download_file
25
+ from sample_utils.turn import get_ice_servers
26
+
27
+ HERE = Path(__file__).parent
28
+ ROOT = HERE.parent
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ ONNX_MODEL_URL = "https://github.com/spmallick/learnopencv/raw/master/Face-Emotion-Recognition/emotion-ferplus-8.onnx" # noqa: E501
33
+ ONNX_MODEL_LOCAL_PATH = ROOT / "./emotion-ferplus-8.onnx"
34
+ CAFFE_MODEL_URL = "https://github.com/spmallick/learnopencv/raw/master/Face-Emotion-Recognition/RFB-320/RFB-320.caffemodel" # noqa: E501
35
+ CAFFE_MODEL_LOCAL_PATH = ROOT / "./RFB-320/RFB-320.caffemodel"
36
+ PROTOTXT_URL = "https://github.com/spmallick/learnopencv/raw/master/Facial-Emotion-Recognition/RFB-320/RFB-320.prototxt" # noqa: E501
37
+ PROTOTXT_LOCAL_PATH = ROOT / "./RFB-320/RFB-320.prototxt.txt"
38
+
39
+ download_file(ONNX_MODEL_URL, ONNX_MODEL_LOCAL_PATH, expected_size=None)
40
+ download_file(CAFFE_MODEL_URL, CAFFE_MODEL_LOCAL_PATH, expected_size=None)
41
+ download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=None)
42
+
43
+ # Session-specific caching
44
+ onnx_cache_key = "onnx_model"
45
+ caffe_cache_key = "caffe_model"
46
+
47
+ if onnx_cache_key in st.session_state and caffe_cache_key in st.session_state:
48
+ model = st.session_state[onnx_cache_key]
49
+ net = st.session_state[caffe_cache_key]
50
+ else:
51
+ model = cv2.dnn.readNetFromONNX(str(ONNX_MODEL_LOCAL_PATH))
52
+ net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(CAFFE_MODEL_LOCAL_PATH))
53
+ st.session_state[onnx_cache_key] = model
54
+ st.session_state[caffe_cache_key] = net
55
+
56
+ image_mean = np.array([127, 127, 127])
57
+ image_std = 128.0
58
+ iou_threshold = 0.3
59
+ center_variance = 0.1
60
+ size_variance = 0.2
61
+ min_boxes = [
62
+ [10.0, 16.0, 24.0],
63
+ [32.0, 48.0],
64
+ [64.0, 96.0],
65
+ [128.0, 192.0, 256.0]
66
+ ]
67
+ strides = [8.0, 16.0, 32.0, 64.0]
68
+ threshold = 0.5
69
+
70
+ emotion_dict = {
71
+ 0: 'neutral',
72
+ 1: 'happiness',
73
+ 2: 'surprise',
74
+ 3: 'sadness',
75
+ 4: 'anger',
76
+ 5: 'disgust',
77
+ 6: 'fear'
78
+ }
79
+
80
+ def define_img_size(image_size):
81
+ shrinkage_list = []
82
+ feature_map_w_h_list = []
83
+ for size in image_size:
84
+ feature_map = [int(ceil(size / stride)) for stride in strides]
85
+ feature_map_w_h_list.append(feature_map)
86
+
87
+ for i in range(0, len(image_size)):
88
+ shrinkage_list.append(strides)
89
+ priors = generate_priors(
90
+ feature_map_w_h_list, shrinkage_list, image_size, min_boxes
91
+ )
92
+ return priors
93
+
94
+
95
+ def generate_priors(
96
+ feature_map_list, shrinkage_list, image_size, min_boxes
97
+ ):
98
+ priors = []
99
+ for index in range(0, len(feature_map_list[0])):
100
+ scale_w = image_size[0] / shrinkage_list[0][index]
101
+ scale_h = image_size[1] / shrinkage_list[1][index]
102
+ for j in range(0, feature_map_list[1][index]):
103
+ for i in range(0, feature_map_list[0][index]):
104
+ x_center = (i + 0.5) / scale_w
105
+ y_center = (j + 0.5) / scale_h
106
+
107
+ for min_box in min_boxes[index]:
108
+ w = min_box / image_size[0]
109
+ h = min_box / image_size[1]
110
+ priors.append([
111
+ x_center,
112
+ y_center,
113
+ w,
114
+ h
115
+ ])
116
+ print("priors nums:{}".format(len(priors)))
117
+ return np.clip(priors, 0.0, 1.0)
118
+
119
+
120
+ def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
121
+ scores = box_scores[:, -1]
122
+ boxes = box_scores[:, :-1]
123
+ picked = []
124
+ indexes = np.argsort(scores)
125
+ indexes = indexes[-candidate_size:]
126
+ while len(indexes) > 0:
127
+ current = indexes[-1]
128
+ picked.append(current)
129
+ if 0 < top_k == len(picked) or len(indexes) == 1:
130
+ break
131
+ current_box = boxes[current, :]
132
+ indexes = indexes[:-1]
133
+ rest_boxes = boxes[indexes, :]
134
+ iou = iou_of(
135
+ rest_boxes,
136
+ np.expand_dims(current_box, axis=0),
137
+ )
138
+ indexes = indexes[iou <= iou_threshold]
139
+ return box_scores[picked, :]
140
+
141
+
142
+ def area_of(left_top, right_bottom):
143
+ hw = np.clip(right_bottom - left_top, 0.0, None)
144
+ return hw[..., 0] * hw[..., 1]
145
+
146
+
147
+ def iou_of(boxes0, boxes1, eps=1e-5):
148
+ overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
149
+ overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
150
+
151
+ overlap_area = area_of(overlap_left_top, overlap_right_bottom)
152
+ area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
153
+ area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
154
+ return overlap_area / (area0 + area1 - overlap_area + eps)
155
+
156
+
157
+ def predict(
158
+ width,
159
+ height,
160
+ confidences,
161
+ boxes,
162
+ prob_threshold,
163
+ iou_threshold=0.3,
164
+ top_k=-1
165
+ ):
166
+ boxes = boxes[0]
167
+ confidences = confidences[0]
168
+ picked_box_probs = []
169
+ picked_labels = []
170
+ for class_index in range(1, confidences.shape[1]):
171
+ probs = confidences[:, class_index]
172
+ mask = probs > prob_threshold
173
+ probs = probs[mask]
174
+ if probs.shape[0] == 0:
175
+ continue
176
+ subset_boxes = boxes[mask, :]
177
+ box_probs = np.concatenate(
178
+ [subset_boxes, probs.reshape(-1, 1)], axis=1
179
+ )
180
+ box_probs = hard_nms(box_probs,
181
+ iou_threshold=iou_threshold,
182
+ top_k=top_k,
183
+ )
184
+ picked_box_probs.append(box_probs)
185
+ picked_labels.extend([class_index] * box_probs.shape[0])
186
+ if not picked_box_probs:
187
+ return np.array([]), np.array([]), np.array([])
188
+ picked_box_probs = np.concatenate(picked_box_probs)
189
+ picked_box_probs[:, 0] *= width
190
+ picked_box_probs[:, 1] *= height
191
+ picked_box_probs[:, 2] *= width
192
+ picked_box_probs[:, 3] *= height
193
+ return (
194
+ picked_box_probs[:, :4].astype(np.int32),
195
+ np.array(picked_labels),
196
+ picked_box_probs[:, 4]
197
+ )
198
+
199
+
200
+ def convert_locations_to_boxes(locations, priors, center_variance,
201
+ size_variance):
202
+ if len(priors.shape) + 1 == len(locations.shape):
203
+ priors = np.expand_dims(priors, 0)
204
+ return np.concatenate([
205
+ locations[..., :2] * center_variance * priors[..., 2:] + priors[..., :2],
206
+ np.exp(locations[..., 2:] * size_variance) * priors[..., 2:]
207
+ ], axis=len(locations.shape) - 1)
208
+
209
+
210
+ def center_form_to_corner_form(locations):
211
+ return np.concatenate(
212
+ [locations[..., :2] - locations[..., 2:] / 2,
213
+ locations[..., :2] + locations[..., 2:] / 2],
214
+ len(locations.shape) - 1
215
+ )
216
+
217
+
218
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
219
+
220
+ frame = frame.to_ndarray(format="bgr24")
221
+
222
+ input_size = [320, 240]
223
+ width = input_size[0]
224
+ height = input_size[1]
225
+ priors = define_img_size(input_size)
226
+
227
+ img_ori = frame
228
+ #print("frame size: ", frame.shape)
229
+ rect = cv2.resize(img_ori, (width, height))
230
+ rect = cv2.cvtColor(rect, cv2.COLOR_BGR2RGB)
231
+ net.setInput(dnn.blobFromImage(
232
+ rect, 1 / image_std, (width, height), 127)
233
+ )
234
+ start_time = time.time()
235
+ boxes, scores = net.forward(["boxes", "scores"])
236
+ boxes = np.expand_dims(np.reshape(boxes, (-1, 4)), axis=0)
237
+ scores = np.expand_dims(np.reshape(scores, (-1, 2)), axis=0)
238
+ boxes = convert_locations_to_boxes(
239
+ boxes, priors, center_variance, size_variance
240
+ )
241
+ boxes = center_form_to_corner_form(boxes)
242
+ boxes, labels, probs = predict(
243
+ img_ori.shape[1],
244
+ img_ori.shape[0],
245
+ scores,
246
+ boxes,
247
+ threshold
248
+ )
249
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
250
+ for (x1, y1, x2, y2) in boxes:
251
+ w = x2 - x1
252
+ h = y2 - y1
253
+ cv2.rectangle(frame, (x1,y1), (x2, y2), (255,0,0), 2)
254
+ resize_frame = cv2.resize(
255
+ gray[y1:y1 + h, x1:x1 + w], (64, 64)
256
+ )
257
+ resize_frame = resize_frame.reshape(1, 1, 64, 64)
258
+ model.setInput(resize_frame)
259
+ output = model.forward()
260
+ end_time = time.time()
261
+ fps = 1 / (end_time - start_time)
262
+ print(f"FPS: {fps:.1f}")
263
+ pred = emotion_dict[list(output[0]).index(max(output[0]))]
264
+ cv2.rectangle(
265
+ img_ori,
266
+ (x1, y1),
267
+ (x2, y2),
268
+ (215, 5, 247),
269
+ 2,
270
+ lineType=cv2.LINE_AA
271
+ )
272
+ cv2.putText(
273
+ frame,
274
+ pred,
275
+ (x1, y1-10),
276
+ cv2.FONT_HERSHEY_SIMPLEX,
277
+ 0.8,
278
+ (215, 5, 247),
279
+ 2,
280
+ lineType=cv2.LINE_AA
281
+ )
282
+
283
+ return av.VideoFrame.from_ndarray(frame, format="bgr24")
284
+
285
+ if __name__ == "__main__":
286
+ webrtc_ctx = webrtc_streamer(
287
+ key="object-detection",
288
+ mode=WebRtcMode.SENDRECV,
289
+ rtc_configuration={
290
+ "iceServers": get_ice_servers(),
291
+ "iceTransportPolicy": "relay",
292
+ },
293
+ video_frame_callback=video_frame_callback,
294
+ media_stream_constraints={"video": True, "audio": False},
295
+ async_processing=True,
296
+ )
297
+
298
+
299
+ st.markdown(
300
+ "This demo uses a model and code from "
301
+ "https://github.com/spmallick/learnopencv. "
302
+ "Many thanks to the project."
303
+ )
sample_utils/__init__.py ADDED
File without changes
sample_utils/download.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from pathlib import Path
3
+
4
+ import streamlit as st
5
+
6
+
7
+ # This code is based on https://github.com/streamlit/demo-self-driving/blob/230245391f2dda0cb464008195a470751c01770b/streamlit_app.py#L48 # noqa: E501
8
+ def download_file(url, download_to: Path, expected_size=None):
9
+ # Don't download the file twice.
10
+ # (If possible, verify the download using the file length.)
11
+ if download_to.exists():
12
+ if expected_size:
13
+ if download_to.stat().st_size == expected_size:
14
+ return
15
+ else:
16
+ st.info(f"{url} is already downloaded.")
17
+ # if not st.button("Download again?"):
18
+ return
19
+
20
+ download_to.parent.mkdir(parents=True, exist_ok=True)
21
+
22
+ # These are handles to two visual elements to animate.
23
+ weights_warning, progress_bar = None, None
24
+ try:
25
+ weights_warning = st.warning("Downloading %s..." % url)
26
+ progress_bar = st.progress(0)
27
+ with open(download_to, "wb") as output_file:
28
+ with urllib.request.urlopen(url) as response:
29
+ length = int(response.info()["Content-Length"])
30
+ counter = 0.0
31
+ MEGABYTES = 2.0 ** 20.0
32
+ while True:
33
+ data = response.read(8192)
34
+ if not data:
35
+ break
36
+ counter += len(data)
37
+ output_file.write(data)
38
+
39
+ # We perform animation by overwriting the elements.
40
+ weights_warning.warning(
41
+ "Downloading %s... (%6.2f/%6.2f MB)"
42
+ % (url, counter / MEGABYTES, length / MEGABYTES)
43
+ )
44
+ progress_bar.progress(min(counter / length, 1.0))
45
+ # Finally, we remove these visual elements by calling .empty().
46
+ finally:
47
+ if weights_warning is not None:
48
+ weights_warning.empty()
49
+ if progress_bar is not None:
50
+ progress_bar.empty()
sample_utils/turn.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import streamlit as st
5
+ from twilio.base.exceptions import TwilioRestException
6
+ from twilio.rest import Client
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def get_ice_servers():
12
+ """Use Twilio's TURN server because Streamlit Community Cloud has changed
13
+ its infrastructure and WebRTC connection cannot be established without TURN server now. # noqa: E501
14
+ We considered Open Relay Project (https://www.metered.ca/tools/openrelay/) too,
15
+ but it is not stable and hardly works as some people reported like https://github.com/aiortc/aiortc/issues/832#issuecomment-1482420656 # noqa: E501
16
+ See https://github.com/whitphx/streamlit-webrtc/issues/1213
17
+ """
18
+
19
+ # Ref: https://www.twilio.com/docs/stun-turn/api
20
+ try:
21
+ account_sid = os.environ["TWILIO_ACCOUNT_SID"]
22
+ auth_token = os.environ["TWILIO_AUTH_TOKEN"]
23
+ except KeyError:
24
+ logger.warning(
25
+ "Twilio credentials are not set. Fallback to a free STUN server from Google." # noqa: E501
26
+ )
27
+ return [{"urls": ["stun:stun.l.google.com:19302"]}]
28
+
29
+ client = Client(account_sid, auth_token)
30
+
31
+ try:
32
+ token = client.tokens.create()
33
+ except TwilioRestException as e:
34
+ st.warning(
35
+ f"Error occurred while accessing Twilio API. Fallback to a free STUN server from Google. ({e})" # noqa: E501
36
+ )
37
+ return [{"urls": ["stun:stun.l.google.com:19302"]}]
38
+
39
+ return token.ice_servers