codelion commited on
Commit
5482ab4
·
verified ·
1 Parent(s): c62b2e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from transformers import BlipProcessor, BlipForConditionalGeneration
7
+ from transformers import ViltProcessor, ViltForQuestionAnswering
8
+ import time
9
+ from io import BytesIO
10
+ import threading
11
+ import queue
12
+ from datetime import datetime
13
+
14
+ # Set page config to wide mode
15
+ st.set_page_config(layout="wide", page_title="Securade.ai Sentinel")
16
+
17
+ def initialize_state():
18
+ if 'initialized' not in st.session_state:
19
+ st.session_state.frame = None
20
+ st.session_state.captions = []
21
+ st.session_state.stop_event = threading.Event()
22
+ st.session_state.frame_queue = queue.Queue(maxsize=1)
23
+ st.session_state.caption_queue = queue.Queue(maxsize=10)
24
+ st.session_state.processor = None
25
+ st.session_state.thread = None
26
+ st.session_state.initialized = True
27
+
28
+ @st.cache_resource
29
+ def load_processor():
30
+ class VideoProcessor:
31
+ def __init__(self):
32
+ self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
33
+ self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
34
+ self.vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
35
+ self.vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
36
+
37
+ # Check for available devices
38
+ if torch.cuda.is_available():
39
+ self.device = "cuda"
40
+ elif torch.backends.mps.is_available():
41
+ self.device = "mps"
42
+ else:
43
+ self.device = "cpu"
44
+
45
+ self.caption_model.to(self.device)
46
+ self.vqa_model.to(self.device)
47
+
48
+ def generate_caption(self, image):
49
+ inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
50
+ output = self.caption_model.generate(**inputs, max_new_tokens=50)
51
+ return self.caption_processor.decode(output[0], skip_special_tokens=True)
52
+
53
+ def answer_question(self, image, question):
54
+ inputs = self.vqa_processor(image, question, return_tensors="pt").to(self.device)
55
+ outputs = self.vqa_model(**inputs)
56
+ logits = outputs.logits
57
+ idx = logits.argmax(-1).item()
58
+ return self.vqa_model.config.id2label[idx]
59
+
60
+ return VideoProcessor()
61
+
62
+ def get_video_source(source_type, source_path=None):
63
+ if source_type == "Webcam":
64
+ return cv2.VideoCapture(0)
65
+ elif source_type == "Video File":
66
+ return cv2.VideoCapture(source_path)
67
+ elif source_type == "RTSP Stream":
68
+ return cv2.VideoCapture(source_path)
69
+ return None
70
+
71
+ def process_video(stop_event, frame_queue, caption_queue, processor, source_type, source_path=None):
72
+ cap = get_video_source(source_type, source_path)
73
+ last_caption_time = time.time()
74
+
75
+ while not stop_event.is_set():
76
+ ret, frame = cap.read()
77
+ if not ret:
78
+ break
79
+
80
+ frame = cv2.resize(frame, (800, 600))
81
+ current_time = time.time()
82
+
83
+ # Generate caption every 3 seconds
84
+ if current_time - last_caption_time >= 3.0:
85
+ img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
86
+ caption = processor.generate_caption(img)
87
+ timestamp = datetime.now().strftime("%H:%M:%S")
88
+
89
+ try:
90
+ if caption_queue.full():
91
+ caption_queue.get_nowait()
92
+ caption_queue.put_nowait({'timestamp': timestamp, 'caption': caption})
93
+ last_caption_time = current_time
94
+ except queue.Full:
95
+ pass
96
+
97
+ try:
98
+ if frame_queue.full():
99
+ frame_queue.get_nowait()
100
+ frame_queue.put_nowait(frame)
101
+ except queue.Full:
102
+ pass
103
+
104
+ time.sleep(0.03)
105
+
106
+ cap.release()
107
+
108
+ def main():
109
+ initialize_state()
110
+
111
+ # Main title
112
+ st.title("Securade.ai Sentinel")
113
+
114
+ # Create three columns for layout
115
+ video_col, caption_col, qa_col = st.columns([0.4, 0.3, 0.3])
116
+
117
+ # Video column
118
+ with video_col:
119
+ st.subheader("Video Feed")
120
+
121
+ # Video source selection
122
+ source_type = st.selectbox(
123
+ "Select Video Source",
124
+ ["Webcam", "Video File", "RTSP Stream"]
125
+ )
126
+
127
+ source_path = None
128
+ if source_type == "Video File":
129
+ source_file = st.file_uploader("Choose a video file", type=['mp4', 'avi', 'mov'])
130
+ if source_file:
131
+ # Save the uploaded file temporarily
132
+ temp_file = BytesIO(source_file.read())
133
+ source_path = temp_file
134
+ elif source_type == "RTSP Stream":
135
+ source_path = st.text_input("Enter RTSP URL", placeholder="rtsp://your-camera-url")
136
+
137
+ start_stop = st.button("Start/Stop Surveillance")
138
+ video_placeholder = st.empty()
139
+
140
+ if start_stop:
141
+ if st.session_state.stop_event.is_set():
142
+ # Start surveillance
143
+ if st.session_state.processor is None:
144
+ st.session_state.processor = load_processor()
145
+ st.session_state.stop_event.clear()
146
+ st.session_state.thread = threading.Thread(
147
+ target=process_video,
148
+ args=(
149
+ st.session_state.stop_event,
150
+ st.session_state.frame_queue,
151
+ st.session_state.caption_queue,
152
+ st.session_state.processor,
153
+ source_type,
154
+ source_path
155
+ ),
156
+ daemon=True
157
+ )
158
+ st.session_state.thread.start()
159
+ else:
160
+ # Stop surveillance
161
+ st.session_state.stop_event.set()
162
+ if st.session_state.thread:
163
+ st.session_state.thread.join(timeout=1.0)
164
+ st.session_state.frame = None
165
+ video_placeholder.empty()
166
+
167
+ # Caption column
168
+ with caption_col:
169
+ st.subheader("Scene Analysis")
170
+ caption_placeholder = st.empty()
171
+
172
+ # Q&A column
173
+ with qa_col:
174
+ st.subheader("Visual Q&A")
175
+ question = st.text_input("Ask a question about the scene:")
176
+ ask_button = st.button("Ask")
177
+ answer_placeholder = st.empty()
178
+
179
+ if ask_button and question and st.session_state.frame is not None:
180
+ img = Image.fromarray(cv2.cvtColor(st.session_state.frame, cv2.COLOR_BGR2RGB))
181
+ answer = st.session_state.processor.answer_question(img, question)
182
+ answer_placeholder.markdown(f"**Answer:** {answer}")
183
+
184
+ # Update loop
185
+ if not st.session_state.stop_event.is_set():
186
+ placeholder = st.empty()
187
+ while True:
188
+ try:
189
+ # Update video frame
190
+ frame = st.session_state.frame_queue.get_nowait()
191
+ st.session_state.frame = frame
192
+ video_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
193
+
194
+ # Update captions
195
+ while not st.session_state.caption_queue.empty():
196
+ new_caption = st.session_state.caption_queue.get_nowait()
197
+ st.session_state.captions.append(new_caption)
198
+ st.session_state.captions = st.session_state.captions[-5:] # Keep last 5 captions
199
+
200
+ if st.session_state.captions:
201
+ caption_text = "\n\n".join([
202
+ f"**[{cap['timestamp']}]** {cap['caption']}"
203
+ for cap in reversed(st.session_state.captions)
204
+ ])
205
+ caption_placeholder.markdown(caption_text)
206
+
207
+ except queue.Empty:
208
+ time.sleep(0.01)
209
+ continue
210
+
211
+ if __name__ == "__main__":
212
+ main()