tstone87 commited on
Commit
cb79d6c
·
verified ·
1 Parent(s): 137fb06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -63
app.py CHANGED
@@ -5,12 +5,7 @@ import streamlit as st
5
  import PIL
6
  from ultralytics import YOLO
7
 
8
- # Required libraries:
9
- # streamlit
10
- # opencv-python-headless
11
- # ultralytics
12
- # Pillow
13
-
14
  model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
15
 
16
  st.set_page_config(
@@ -20,17 +15,20 @@ st.set_page_config(
20
  initial_sidebar_state="expanded"
21
  )
22
 
 
23
  with st.sidebar:
24
  st.header("IMAGE/VIDEO UPLOAD")
25
- source_file = st.file_uploader("Choose an image or video...", type=("jpg", "jpeg", "png", "bmp", "webp", "mp4"))
 
26
  confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
27
- video_option = st.selectbox("Select Video Shortening Option",
28
- ["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"])
 
 
29
  progress_text = st.empty()
30
  progress_bar = st.progress(0)
31
- # A container where our dynamic slider (frame viewer) will be placed.
32
- slider_container = st.empty()
33
 
 
34
  st.title("WildfireWatch: Detecting Wildfire using AI")
35
  col1, col2 = st.columns(2)
36
  with col1:
@@ -39,49 +37,49 @@ with col2:
39
  st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True)
40
 
41
  st.markdown("""
42
- Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas. Early detection is critical.
43
- WildfireWatch leverages YOLOv8 for realtime fire and smoke detection in images and videos.
 
44
  """)
45
  st.markdown("---")
46
  st.header("Fire Detection:")
47
 
48
- # Left column for the uploaded file, right for detection results.
49
  col1, col2 = st.columns(2)
50
  if source_file:
51
- if source_file.type.split('/')[0] == 'image':
 
52
  uploaded_image = PIL.Image.open(source_file)
53
  st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
54
  else:
 
55
  tfile = tempfile.NamedTemporaryFile(delete=False)
56
  tfile.write(source_file.read())
57
  vidcap = cv2.VideoCapture(tfile.name)
58
  else:
59
  st.info("Please upload an image or video file to begin.")
60
 
61
- # Attempt to load the model.
62
  try:
63
  model = YOLO(model_path)
64
  except Exception as ex:
65
  st.error(f"Unable to load model. Check the specified path: {model_path}")
66
  st.error(ex)
67
 
68
- # We'll store processed frames persistently in session_state.
69
  if "processed_frames" not in st.session_state:
70
  st.session_state["processed_frames"] = []
71
 
72
- # Also store the last slider value (if the user manually changes it).
73
- if "slider_value" not in st.session_state:
74
- st.session_state["slider_value"] = 0
75
-
76
- # Container to display the currently viewed frame.
77
- viewer_slot = st.empty()
78
 
79
- # --- Processing and Viewer Update ---
80
  if st.sidebar.button("Let's Detect Wildfire"):
81
  if not source_file:
82
  st.warning("No file uploaded!")
83
- elif source_file.type.split('/')[0] == 'image':
84
- # Process image input.
85
  res = model.predict(uploaded_image, conf=confidence)
86
  boxes = res[0].boxes
87
  res_plotted = res[0].plot()[:, :, ::-1]
@@ -91,16 +89,21 @@ if st.sidebar.button("Let's Detect Wildfire"):
91
  for box in boxes:
92
  st.write(box.xywh)
93
  else:
94
- # For video input, process frames.
 
 
 
 
95
  processed_frames = st.session_state["processed_frames"]
96
- frame_count = 0
97
 
98
- # Video properties.
99
  orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
100
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
101
  width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
102
  height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
103
 
 
104
  if video_option == "Original FPS":
105
  sample_interval = 1
106
  output_fps = orig_fps
@@ -123,12 +126,15 @@ if st.sidebar.button("Let's Detect Wildfire"):
123
  success, image = vidcap.read()
124
  while success:
125
  if frame_count % sample_interval == 0:
 
126
  res = model.predict(image, conf=confidence)
127
  res_plotted = res[0].plot()[:, :, ::-1]
 
128
  processed_frames.append(res_plotted)
129
- st.session_state["processed_frames"] = processed_frames
 
130
 
131
- # Update progress info.
132
  if total_frames > 0:
133
  progress_pct = int((frame_count / total_frames) * 100)
134
  progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)")
@@ -136,44 +142,14 @@ if st.sidebar.button("Let's Detect Wildfire"):
136
  else:
137
  progress_text.text(f"Processing frame {frame_count}")
138
 
139
- # --- Update the frame viewer slider dynamically ---
140
- # Retrieve user's last slider selection.
141
- last_slider = st.session_state.slider_value
142
- # If the user was at the end, default to the new end.
143
- if last_slider >= len(processed_frames):
144
- default_val = len(processed_frames) - 1
145
- else:
146
- default_val = last_slider
147
-
148
- # Clear the slider container and recreate the slider.
149
- slider_container.empty()
150
- # Use a dynamic key to avoid duplicate key errors.
151
- slider_key = f"frame_slider_{len(processed_frames)}"
152
- slider_val = slider_container.slider("Frame Viewer",
153
- min_value=0,
154
- max_value=len(processed_frames) - 1,
155
- value=default_val,
156
- step=1,
157
- key=slider_key)
158
- st.session_state.slider_value = slider_val
159
-
160
- # If the slider is at the most recent frame, update the viewer.
161
- if slider_val == len(processed_frames) - 1:
162
- viewer_slot.image(processed_frames[-1],
163
- caption=f"Frame {len(processed_frames) - 1}",
164
- use_column_width=True)
165
- else:
166
- # Otherwise, show the frame corresponding to the slider.
167
- viewer_slot.image(processed_frames[slider_val],
168
- caption=f"Frame {slider_val}",
169
- use_column_width=True)
170
  frame_count += 1
171
  success, image = vidcap.read()
172
 
 
173
  progress_text.text("Video processing complete!")
174
  progress_bar.progress(100)
175
 
176
- # --- Video Download Section ---
177
  if processed_frames:
178
  temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
179
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
@@ -190,5 +166,37 @@ if st.sidebar.button("Let's Detect Wildfire"):
190
  file_name="shortened_video.mp4",
191
  mime="video/mp4"
192
  )
 
193
  else:
194
  st.error("No frames were processed from the video.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import PIL
6
  from ultralytics import YOLO
7
 
8
+ # Ensure your model path points directly to the .pt file (not an HTML page)
 
 
 
 
 
9
  model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
10
 
11
  st.set_page_config(
 
15
  initial_sidebar_state="expanded"
16
  )
17
 
18
+ # --- SIDEBAR ---
19
  with st.sidebar:
20
  st.header("IMAGE/VIDEO UPLOAD")
21
+ source_file = st.file_uploader("Choose an image or video...",
22
+ type=("jpg", "jpeg", "png", "bmp", "webp", "mp4"))
23
  confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
24
+ video_option = st.selectbox(
25
+ "Select Video Shortening Option",
26
+ ["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"]
27
+ )
28
  progress_text = st.empty()
29
  progress_bar = st.progress(0)
 
 
30
 
31
+ # --- MAIN PAGE TITLE AND IMAGES ---
32
  st.title("WildfireWatch: Detecting Wildfire using AI")
33
  col1, col2 = st.columns(2)
34
  with col1:
 
37
  st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True)
38
 
39
  st.markdown("""
40
+ Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas.
41
+ Early detection is critical. WildfireWatch leverages YOLOv8 for real-time fire and smoke detection
42
+ in images and videos.
43
  """)
44
  st.markdown("---")
45
  st.header("Fire Detection:")
46
 
47
+ # --- DISPLAY UPLOADED FILE ---
48
  col1, col2 = st.columns(2)
49
  if source_file:
50
+ file_type = source_file.type.split('/')[0]
51
+ if file_type == 'image':
52
  uploaded_image = PIL.Image.open(source_file)
53
  st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
54
  else:
55
+ # Temporarily store the uploaded video
56
  tfile = tempfile.NamedTemporaryFile(delete=False)
57
  tfile.write(source_file.read())
58
  vidcap = cv2.VideoCapture(tfile.name)
59
  else:
60
  st.info("Please upload an image or video file to begin.")
61
 
62
+ # --- LOAD YOLO MODEL ---
63
  try:
64
  model = YOLO(model_path)
65
  except Exception as ex:
66
  st.error(f"Unable to load model. Check the specified path: {model_path}")
67
  st.error(ex)
68
 
69
+ # --- SESSION STATE FOR PROCESSED FRAMES ---
70
  if "processed_frames" not in st.session_state:
71
  st.session_state["processed_frames"] = []
72
 
73
+ # We'll keep the detection results for each frame (if you want them)
74
+ if "frame_detections" not in st.session_state:
75
+ st.session_state["frame_detections"] = []
 
 
 
76
 
77
+ # --- WHEN USER CLICKS DETECT ---
78
  if st.sidebar.button("Let's Detect Wildfire"):
79
  if not source_file:
80
  st.warning("No file uploaded!")
81
+ elif file_type == 'image':
82
+ # IMAGE DETECTION
83
  res = model.predict(uploaded_image, conf=confidence)
84
  boxes = res[0].boxes
85
  res_plotted = res[0].plot()[:, :, ::-1]
 
89
  for box in boxes:
90
  st.write(box.xywh)
91
  else:
92
+ # VIDEO DETECTION
93
+ # Clear previous frames from session_state
94
+ st.session_state["processed_frames"] = []
95
+ st.session_state["frame_detections"] = []
96
+
97
  processed_frames = st.session_state["processed_frames"]
98
+ frame_detections = st.session_state["frame_detections"]
99
 
100
+ frame_count = 0
101
  orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
102
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
103
  width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
104
  height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
105
 
106
+ # Determine sampling interval
107
  if video_option == "Original FPS":
108
  sample_interval = 1
109
  output_fps = orig_fps
 
126
  success, image = vidcap.read()
127
  while success:
128
  if frame_count % sample_interval == 0:
129
+ # Run detection
130
  res = model.predict(image, conf=confidence)
131
  res_plotted = res[0].plot()[:, :, ::-1]
132
+
133
  processed_frames.append(res_plotted)
134
+ # If you want to store bounding boxes for each frame:
135
+ frame_detections.append(res[0].boxes)
136
 
137
+ # Update progress
138
  if total_frames > 0:
139
  progress_pct = int((frame_count / total_frames) * 100)
140
  progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)")
 
142
  else:
143
  progress_text.text(f"Processing frame {frame_count}")
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  frame_count += 1
146
  success, image = vidcap.read()
147
 
148
+ # Processing complete
149
  progress_text.text("Video processing complete!")
150
  progress_bar.progress(100)
151
 
152
+ # Create shortened video from processed frames
153
  if processed_frames:
154
  temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
155
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
166
  file_name="shortened_video.mp4",
167
  mime="video/mp4"
168
  )
169
+
170
  else:
171
  st.error("No frames were processed from the video.")
172
+
173
+ # --- DISPLAY THE PROCESSED FRAMES AFTER DETECTION ---
174
+ if st.session_state["processed_frames"]:
175
+ st.markdown("### Browse Detected Frames")
176
+ num_frames = len(st.session_state["processed_frames"])
177
+
178
+ if num_frames == 1:
179
+ # Only one frame was processed
180
+ st.image(st.session_state["processed_frames"][0], caption="Frame 0", use_column_width=True)
181
+ # If you want to show bounding boxes:
182
+ if st.session_state["frame_detections"]:
183
+ with st.expander("Detection Results for Frame 0"):
184
+ for box in st.session_state["frame_detections"][0]:
185
+ st.write(box.xywh)
186
+ else:
187
+ # Multiple frames
188
+ frame_idx = st.slider(
189
+ "Select Frame",
190
+ min_value=0,
191
+ max_value=num_frames - 1,
192
+ value=0,
193
+ step=1
194
+ )
195
+ st.image(st.session_state["processed_frames"][frame_idx],
196
+ caption=f"Frame {frame_idx}",
197
+ use_column_width=True)
198
+ # If you want to show bounding boxes:
199
+ if st.session_state["frame_detections"]:
200
+ with st.expander(f"Detection Results for Frame {frame_idx}"):
201
+ for box in st.session_state["frame_detections"][frame_idx]:
202
+ st.write(box.xywh)