tstone87 commited on
Commit
137fb06
·
verified ·
1 Parent(s): 9c54f39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -27
app.py CHANGED
@@ -5,7 +5,11 @@ import streamlit as st
5
  import PIL
6
  from ultralytics import YOLO
7
 
8
- # Required libraries: streamlit, opencv-python-headless, ultralytics, Pillow
 
 
 
 
9
 
10
  model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
11
 
@@ -18,14 +22,14 @@ st.set_page_config(
18
 
19
  with st.sidebar:
20
  st.header("IMAGE/VIDEO UPLOAD")
21
- source_file = st.file_uploader("Choose an image or video...",
22
- type=("jpg", "jpeg", "png", "bmp", "webp", "mp4"))
23
  confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
24
  video_option = st.selectbox("Select Video Shortening Option",
25
  ["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"])
26
  progress_text = st.empty()
27
  progress_bar = st.progress(0)
28
- slider_container = st.empty() # For dynamic slider widget
 
29
 
30
  st.title("WildfireWatch: Detecting Wildfire using AI")
31
  col1, col2 = st.columns(2)
@@ -35,11 +39,13 @@ with col2:
35
  st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True)
36
 
37
  st.markdown("""
38
- Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas. Early detection is critical. WildfireWatch leverages YOLOv8 for real‐time fire and smoke detection in images and videos.
 
39
  """)
40
  st.markdown("---")
41
  st.header("Fire Detection:")
42
 
 
43
  col1, col2 = st.columns(2)
44
  if source_file:
45
  if source_file.type.split('/')[0] == 'image':
@@ -52,22 +58,30 @@ if source_file:
52
  else:
53
  st.info("Please upload an image or video file to begin.")
54
 
 
55
  try:
56
  model = YOLO(model_path)
57
  except Exception as ex:
58
  st.error(f"Unable to load model. Check the specified path: {model_path}")
59
  st.error(ex)
60
 
61
- viewer_slot = st.empty() # Container for currently viewed frame
 
 
62
 
63
- # Initialize session state for slider value if not already set.
64
  if "slider_value" not in st.session_state:
65
- st.session_state.slider_value = 0
66
 
 
 
 
 
67
  if st.sidebar.button("Let's Detect Wildfire"):
68
  if not source_file:
69
  st.warning("No file uploaded!")
70
  elif source_file.type.split('/')[0] == 'image':
 
71
  res = model.predict(uploaded_image, conf=confidence)
72
  boxes = res[0].boxes
73
  res_plotted = res[0].plot()[:, :, ::-1]
@@ -77,8 +91,11 @@ if st.sidebar.button("Let's Detect Wildfire"):
77
  for box in boxes:
78
  st.write(box.xywh)
79
  else:
80
- processed_frames = []
 
81
  frame_count = 0
 
 
82
  orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
83
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
84
  width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -109,7 +126,9 @@ if st.sidebar.button("Let's Detect Wildfire"):
109
  res = model.predict(image, conf=confidence)
110
  res_plotted = res[0].plot()[:, :, ::-1]
111
  processed_frames.append(res_plotted)
 
112
 
 
113
  if total_frames > 0:
114
  progress_pct = int((frame_count / total_frames) * 100)
115
  progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)")
@@ -117,30 +136,44 @@ if st.sidebar.button("Let's Detect Wildfire"):
117
  else:
118
  progress_text.text(f"Processing frame {frame_count}")
119
 
120
- # Only update the slider if we have frames.
121
- if processed_frames:
122
- slider_container.empty() # Clear previous slider
123
- # Use stored slider_value if within bounds, otherwise default to last frame.
124
- default_val = st.session_state.slider_value if st.session_state.slider_value < len(processed_frames) else len(processed_frames)-1
125
- # Create slider with a unique key based on number of processed frames.
126
- slider_val = slider_container.slider(
127
- "Frame Viewer",
128
- min_value=0,
129
- max_value=len(processed_frames)-1,
130
- value=default_val,
131
- step=1,
132
- key=f"frame_slider_{len(processed_frames)}"
133
- )
134
- st.session_state.slider_value = slider_val
135
-
136
- if slider_val == len(processed_frames)-1:
137
- viewer_slot.image(processed_frames[-1], caption=f"Frame {len(processed_frames)-1}", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  frame_count += 1
139
  success, image = vidcap.read()
140
 
141
  progress_text.text("Video processing complete!")
142
  progress_bar.progress(100)
143
 
 
144
  if processed_frames:
145
  temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
146
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
5
  import PIL
6
  from ultralytics import YOLO
7
 
8
+ # Required libraries:
9
+ # streamlit
10
+ # opencv-python-headless
11
+ # ultralytics
12
+ # Pillow
13
 
14
  model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
15
 
 
22
 
23
  with st.sidebar:
24
  st.header("IMAGE/VIDEO UPLOAD")
25
+ source_file = st.file_uploader("Choose an image or video...", type=("jpg", "jpeg", "png", "bmp", "webp", "mp4"))
 
26
  confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
27
  video_option = st.selectbox("Select Video Shortening Option",
28
  ["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"])
29
  progress_text = st.empty()
30
  progress_bar = st.progress(0)
31
+ # A container where our dynamic slider (frame viewer) will be placed.
32
+ slider_container = st.empty()
33
 
34
  st.title("WildfireWatch: Detecting Wildfire using AI")
35
  col1, col2 = st.columns(2)
 
39
  st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True)
40
 
41
  st.markdown("""
42
+ Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas. Early detection is critical.
43
+ WildfireWatch leverages YOLOv8 for real‐time fire and smoke detection in images and videos.
44
  """)
45
  st.markdown("---")
46
  st.header("Fire Detection:")
47
 
48
+ # Left column for the uploaded file, right for detection results.
49
  col1, col2 = st.columns(2)
50
  if source_file:
51
  if source_file.type.split('/')[0] == 'image':
 
58
  else:
59
  st.info("Please upload an image or video file to begin.")
60
 
61
+ # Attempt to load the model.
62
  try:
63
  model = YOLO(model_path)
64
  except Exception as ex:
65
  st.error(f"Unable to load model. Check the specified path: {model_path}")
66
  st.error(ex)
67
 
68
+ # We'll store processed frames persistently in session_state.
69
+ if "processed_frames" not in st.session_state:
70
+ st.session_state["processed_frames"] = []
71
 
72
+ # Also store the last slider value (if the user manually changes it).
73
  if "slider_value" not in st.session_state:
74
+ st.session_state["slider_value"] = 0
75
 
76
+ # Container to display the currently viewed frame.
77
+ viewer_slot = st.empty()
78
+
79
+ # --- Processing and Viewer Update ---
80
  if st.sidebar.button("Let's Detect Wildfire"):
81
  if not source_file:
82
  st.warning("No file uploaded!")
83
  elif source_file.type.split('/')[0] == 'image':
84
+ # Process image input.
85
  res = model.predict(uploaded_image, conf=confidence)
86
  boxes = res[0].boxes
87
  res_plotted = res[0].plot()[:, :, ::-1]
 
91
  for box in boxes:
92
  st.write(box.xywh)
93
  else:
94
+ # For video input, process frames.
95
+ processed_frames = st.session_state["processed_frames"]
96
  frame_count = 0
97
+
98
+ # Video properties.
99
  orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
100
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
101
  width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
126
  res = model.predict(image, conf=confidence)
127
  res_plotted = res[0].plot()[:, :, ::-1]
128
  processed_frames.append(res_plotted)
129
+ st.session_state["processed_frames"] = processed_frames
130
 
131
+ # Update progress info.
132
  if total_frames > 0:
133
  progress_pct = int((frame_count / total_frames) * 100)
134
  progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)")
 
136
  else:
137
  progress_text.text(f"Processing frame {frame_count}")
138
 
139
+ # --- Update the frame viewer slider dynamically ---
140
+ # Retrieve user's last slider selection.
141
+ last_slider = st.session_state.slider_value
142
+ # If the user was at the end, default to the new end.
143
+ if last_slider >= len(processed_frames):
144
+ default_val = len(processed_frames) - 1
145
+ else:
146
+ default_val = last_slider
147
+
148
+ # Clear the slider container and recreate the slider.
149
+ slider_container.empty()
150
+ # Use a dynamic key to avoid duplicate key errors.
151
+ slider_key = f"frame_slider_{len(processed_frames)}"
152
+ slider_val = slider_container.slider("Frame Viewer",
153
+ min_value=0,
154
+ max_value=len(processed_frames) - 1,
155
+ value=default_val,
156
+ step=1,
157
+ key=slider_key)
158
+ st.session_state.slider_value = slider_val
159
+
160
+ # If the slider is at the most recent frame, update the viewer.
161
+ if slider_val == len(processed_frames) - 1:
162
+ viewer_slot.image(processed_frames[-1],
163
+ caption=f"Frame {len(processed_frames) - 1}",
164
+ use_column_width=True)
165
+ else:
166
+ # Otherwise, show the frame corresponding to the slider.
167
+ viewer_slot.image(processed_frames[slider_val],
168
+ caption=f"Frame {slider_val}",
169
+ use_column_width=True)
170
  frame_count += 1
171
  success, image = vidcap.read()
172
 
173
  progress_text.text("Video processing complete!")
174
  progress_bar.progress(100)
175
 
176
+ # --- Video Download Section ---
177
  if processed_frames:
178
  temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
179
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')