sadimanna commited on
Commit
ea7464f
·
1 Parent(s): 823a42e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -151
app.py CHANGED
@@ -1,166 +1,36 @@
1
- """Object detection demo with MobileNet SSD.
2
- This model and code are based on
3
- https://github.com/robmarkcole/object-detection-app
4
- """
5
 
6
  import logging
7
  import queue
8
- from pathlib import Path
9
- from typing import List, NamedTuple
10
 
11
- import av
12
- import cv2
13
- import numpy as np
14
  import streamlit as st
15
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
16
 
17
- from sample_utils.download import download_file
18
  from sample_utils.turn import get_ice_servers
19
 
20
- HERE = Path(__file__).parent
21
- ROOT = HERE
22
-
23
  logger = logging.getLogger(__name__)
24
 
25
 
26
- MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
27
- MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
28
- PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
29
- PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
30
-
31
- CLASSES = [
32
- "background",
33
- "aeroplane",
34
- "bicycle",
35
- "bird",
36
- "boat",
37
- "bottle",
38
- "bus",
39
- "car",
40
- "cat",
41
- "chair",
42
- "cow",
43
- "diningtable",
44
- "dog",
45
- "horse",
46
- "motorbike",
47
- "person",
48
- "pottedplant",
49
- "sheep",
50
- "sofa",
51
- "train",
52
- "tvmonitor",
53
- ]
54
-
55
-
56
- class Detection(NamedTuple):
57
- class_id: int
58
- label: str
59
- score: float
60
- box: np.ndarray
61
-
62
-
63
- @st.cache_resource # type: ignore
64
- def generate_label_colors():
65
- return np.random.uniform(0, 255, size=(len(CLASSES), 3))
66
-
67
-
68
- COLORS = generate_label_colors()
69
-
70
- download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
71
- download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
72
-
73
-
74
- # Session-specific caching
75
- cache_key = "object_detection_dnn"
76
- if cache_key in st.session_state:
77
- net = st.session_state[cache_key]
78
- else:
79
- net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
80
- st.session_state[cache_key] = net
81
-
82
- score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
83
-
84
- # NOTE: The callback will be called in another thread,
85
- # so use a queue here for thread-safety to pass the data
86
- # from inside to outside the callback.
87
- # TODO: A general-purpose shared state object may be more useful.
88
- result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
89
-
90
-
91
- def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
92
- image = frame.to_ndarray(format="bgr24")
93
-
94
- # Run inference
95
- # blob = cv2.dnn.blobFromImage(
96
- # cv2.resize(image, (128, 128)), 0.007843, (128, 128), 127.5
97
- # )
98
- # net.setInput(blob)
99
- # output = net.forward()
100
-
101
- # h, w = image.shape[:2]
102
-
103
- # # Convert the output array into a structured form.
104
- # output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
105
- # output = output[output[:, 2] >= score_threshold]
106
- # detections = [
107
- # Detection(
108
- # class_id=int(detection[1]),
109
- # label=CLASSES[int(detection[1])],
110
- # score=float(detection[2]),
111
- # box=(detection[3:7] * np.array([w, h, w, h])),
112
- # )
113
- # for detection in output
114
- # ]
115
-
116
- # # Render bounding boxes and captions
117
- # for detection in detections:
118
- # caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
119
- # color = COLORS[detection.class_id]
120
- # xmin, ymin, xmax, ymax = detection.box.astype("int")
121
-
122
- # cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
123
- # cv2.putText(
124
- # image,
125
- # caption,
126
- # (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
127
- # cv2.FONT_HERSHEY_SIMPLEX,
128
- # 0.5,
129
- # color,
130
- # 2,
131
- # )
132
-
133
- result_queue.put([]) #detections)
134
-
135
- return av.VideoFrame.from_ndarray(image, format="bgr24")
136
-
137
-
138
  webrtc_ctx = webrtc_streamer(
139
- key="object-detection",
140
- mode=WebRtcMode.SENDRECV,
141
- rtc_configuration={
142
- "iceServers": get_ice_servers(),
143
- "iceTransportPolicy": "relay",
144
- },
145
- video_frame_callback=video_frame_callback,
146
- media_stream_constraints={"video": True, "audio": False},
147
- async_processing=True,
148
  )
149
 
150
- if st.checkbox("Show the detected labels", value=True):
151
- if webrtc_ctx.state.playing:
152
- labels_placeholder = st.empty()
153
- # NOTE: The video transformation with object detection and
154
- # this loop displaying the result labels are running
155
- # in different threads asynchronously.
156
- # Then the rendered video frames and the labels displayed here
157
- # are not strictly synchronized.
158
- while True:
159
- result = result_queue.get()
160
- labels_placeholder.table(result)
161
-
162
- st.markdown(
163
- "This demo uses a model and code from "
164
- "https://github.com/robmarkcole/object-detection-app. "
165
- "Many thanks to the project."
166
- )
 
1
+ """A sample to use WebRTC in sendonly mode to transfer frames
2
+ from the browser to the server and to render frames via `st.image`."""
 
 
3
 
4
  import logging
5
  import queue
 
 
6
 
 
 
 
7
  import streamlit as st
8
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
9
 
 
10
  from sample_utils.turn import get_ice_servers
11
 
 
 
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  webrtc_ctx = webrtc_streamer(
16
+ key="video-sendonly",
17
+ mode=WebRtcMode.SENDONLY,
18
+ rtc_configuration={"iceServers": get_ice_servers()},
19
+ media_stream_constraints={"video": True},
 
 
 
 
 
20
  )
21
 
22
+ image_place = st.empty()
23
+
24
+ while True:
25
+ if webrtc_ctx.video_receiver:
26
+ try:
27
+ video_frame = webrtc_ctx.video_receiver.get_frame(timeout=1)
28
+ except queue.Empty:
29
+ logger.warning("Queue is empty. Abort.")
30
+ break
31
+
32
+ img_rgb = video_frame.to_ndarray(format="rgb24")
33
+ image_place.image(img_rgb)
34
+ else:
35
+ logger.warning("AudioReciver is not set. Abort.")
36
+ break