tstone87 commited on
Commit
af04f31
·
verified ·
1 Parent(s): cb79d6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -23
app.py CHANGED
@@ -5,7 +5,7 @@ import streamlit as st
5
  import PIL
6
  from ultralytics import YOLO
7
 
8
- # Ensure your model path points directly to the .pt file (not an HTML page)
9
  model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
10
 
11
  st.set_page_config(
@@ -29,7 +29,7 @@ with st.sidebar:
29
  progress_bar = st.progress(0)
30
 
31
  # --- MAIN PAGE TITLE AND IMAGES ---
32
- st.title("WildfireWatch: Detecting Wildfire using AI")
33
  col1, col2 = st.columns(2)
34
  with col1:
35
  st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True)
@@ -66,19 +66,29 @@ except Exception as ex:
66
  st.error(f"Unable to load model. Check the specified path: {model_path}")
67
  st.error(ex)
68
 
69
- # --- SESSION STATE FOR PROCESSED FRAMES ---
70
  if "processed_frames" not in st.session_state:
71
  st.session_state["processed_frames"] = []
72
 
73
- # We'll keep the detection results for each frame (if you want them)
74
  if "frame_detections" not in st.session_state:
75
  st.session_state["frame_detections"] = []
76
 
77
- # --- WHEN USER CLICKS DETECT ---
78
- if st.sidebar.button("Let's Detect Wildfire"):
 
 
 
 
 
 
79
  if not source_file:
80
  st.warning("No file uploaded!")
81
  elif file_type == 'image':
 
 
 
 
82
  # IMAGE DETECTION
83
  res = model.predict(uploaded_image, conf=confidence)
84
  boxes = res[0].boxes
@@ -89,10 +99,11 @@ if st.sidebar.button("Let's Detect Wildfire"):
89
  for box in boxes:
90
  st.write(box.xywh)
91
  else:
92
- # VIDEO DETECTION
93
- # Clear previous frames from session_state
94
  st.session_state["processed_frames"] = []
95
  st.session_state["frame_detections"] = []
 
 
96
 
97
  processed_frames = st.session_state["processed_frames"]
98
  frame_detections = st.session_state["frame_detections"]
@@ -131,8 +142,7 @@ if st.sidebar.button("Let's Detect Wildfire"):
131
  res_plotted = res[0].plot()[:, :, ::-1]
132
 
133
  processed_frames.append(res_plotted)
134
- # If you want to store bounding boxes for each frame:
135
- frame_detections.append(res[0].boxes)
136
 
137
  # Update progress
138
  if total_frames > 0:
@@ -149,7 +159,7 @@ if st.sidebar.button("Let's Detect Wildfire"):
149
  progress_text.text("Video processing complete!")
150
  progress_bar.progress(100)
151
 
152
- # Create shortened video from processed frames
153
  if processed_frames:
154
  temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
155
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
@@ -158,33 +168,36 @@ if st.sidebar.button("Let's Detect Wildfire"):
158
  out.write(frame)
159
  out.release()
160
 
161
- st.success("Shortened video created successfully!")
162
  with open(temp_video_file.name, 'rb') as video_file:
163
- st.download_button(
164
- label="Download Shortened Video",
165
- data=video_file.read(),
166
- file_name="shortened_video.mp4",
167
- mime="video/mp4"
168
- )
169
 
 
170
  else:
171
  st.error("No frames were processed from the video.")
172
 
173
- # --- DISPLAY THE PROCESSED FRAMES AFTER DETECTION ---
 
 
 
 
 
 
 
 
 
174
  if st.session_state["processed_frames"]:
175
  st.markdown("### Browse Detected Frames")
176
  num_frames = len(st.session_state["processed_frames"])
177
 
178
  if num_frames == 1:
179
- # Only one frame was processed
180
  st.image(st.session_state["processed_frames"][0], caption="Frame 0", use_column_width=True)
181
- # If you want to show bounding boxes:
182
  if st.session_state["frame_detections"]:
183
  with st.expander("Detection Results for Frame 0"):
184
  for box in st.session_state["frame_detections"][0]:
185
  st.write(box.xywh)
186
  else:
187
- # Multiple frames
188
  frame_idx = st.slider(
189
  "Select Frame",
190
  min_value=0,
@@ -195,7 +208,8 @@ if st.session_state["processed_frames"]:
195
  st.image(st.session_state["processed_frames"][frame_idx],
196
  caption=f"Frame {frame_idx}",
197
  use_column_width=True)
198
- # If you want to show bounding boxes:
 
199
  if st.session_state["frame_detections"]:
200
  with st.expander(f"Detection Results for Frame {frame_idx}"):
201
  for box in st.session_state["frame_detections"][frame_idx]:
 
5
  import PIL
6
  from ultralytics import YOLO
7
 
8
+ # Ensure your model path points directly to the .pt file
9
  model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
10
 
11
  st.set_page_config(
 
29
  progress_bar = st.progress(0)
30
 
31
  # --- MAIN PAGE TITLE AND IMAGES ---
32
+ st.title("Fire Watch: Detecting fire using AI vision models")
33
  col1, col2 = st.columns(2)
34
  with col1:
35
  st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True)
 
66
  st.error(f"Unable to load model. Check the specified path: {model_path}")
67
  st.error(ex)
68
 
69
+ # --- SESSION STATE SETUP ---
70
  if "processed_frames" not in st.session_state:
71
  st.session_state["processed_frames"] = []
72
 
73
+ # If you want bounding box data per frame:
74
  if "frame_detections" not in st.session_state:
75
  st.session_state["frame_detections"] = []
76
 
77
+ # We'll store the shortened video data so the download button remains visible
78
+ if "shortened_video_data" not in st.session_state:
79
+ st.session_state["shortened_video_data"] = None
80
+ if "shortened_video_ready" not in st.session_state:
81
+ st.session_state["shortened_video_ready"] = False
82
+
83
+ # --- DETECT BUTTON ---
84
+ if st.sidebar.button("Let's Detect fire"):
85
  if not source_file:
86
  st.warning("No file uploaded!")
87
  elif file_type == 'image':
88
+ # Reset previous video data
89
+ st.session_state["shortened_video_ready"] = False
90
+ st.session_state["shortened_video_data"] = None
91
+
92
  # IMAGE DETECTION
93
  res = model.predict(uploaded_image, conf=confidence)
94
  boxes = res[0].boxes
 
99
  for box in boxes:
100
  st.write(box.xywh)
101
  else:
102
+ # Reset previous frames and video data
 
103
  st.session_state["processed_frames"] = []
104
  st.session_state["frame_detections"] = []
105
+ st.session_state["shortened_video_ready"] = False
106
+ st.session_state["shortened_video_data"] = None
107
 
108
  processed_frames = st.session_state["processed_frames"]
109
  frame_detections = st.session_state["frame_detections"]
 
142
  res_plotted = res[0].plot()[:, :, ::-1]
143
 
144
  processed_frames.append(res_plotted)
145
+ frame_detections.append(res[0].boxes) # optional
 
146
 
147
  # Update progress
148
  if total_frames > 0:
 
159
  progress_text.text("Video processing complete!")
160
  progress_bar.progress(100)
161
 
162
+ # Create shortened video
163
  if processed_frames:
164
  temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
165
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
168
  out.write(frame)
169
  out.release()
170
 
171
+ # Store the video data in session_state
172
  with open(temp_video_file.name, 'rb') as video_file:
173
+ st.session_state["shortened_video_data"] = video_file.read()
174
+ st.session_state["shortened_video_ready"] = True
 
 
 
 
175
 
176
+ st.success("Shortened video created successfully!")
177
  else:
178
  st.error("No frames were processed from the video.")
179
 
180
+ # --- SHOW THE DOWNLOAD BUTTON IF READY ---
181
+ if st.session_state["shortened_video_ready"] and st.session_state["shortened_video_data"]:
182
+ st.download_button(
183
+ label="Download Shortened Video",
184
+ data=st.session_state["shortened_video_data"],
185
+ file_name="shortened_video.mp4",
186
+ mime="video/mp4"
187
+ )
188
+
189
+ # --- DISPLAY PROCESSED FRAMES IF ANY ---
190
  if st.session_state["processed_frames"]:
191
  st.markdown("### Browse Detected Frames")
192
  num_frames = len(st.session_state["processed_frames"])
193
 
194
  if num_frames == 1:
 
195
  st.image(st.session_state["processed_frames"][0], caption="Frame 0", use_column_width=True)
 
196
  if st.session_state["frame_detections"]:
197
  with st.expander("Detection Results for Frame 0"):
198
  for box in st.session_state["frame_detections"][0]:
199
  st.write(box.xywh)
200
  else:
 
201
  frame_idx = st.slider(
202
  "Select Frame",
203
  min_value=0,
 
208
  st.image(st.session_state["processed_frames"][frame_idx],
209
  caption=f"Frame {frame_idx}",
210
  use_column_width=True)
211
+
212
+ # Optionally show bounding box data
213
  if st.session_state["frame_detections"]:
214
  with st.expander(f"Detection Results for Frame {frame_idx}"):
215
  for box in st.session_state["frame_detections"][frame_idx]: