tstone87 commited on
Commit
9c54f39
·
verified ·
1 Parent(s): cd07504

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -38
app.py CHANGED
@@ -7,10 +7,8 @@ from ultralytics import YOLO
7
 
8
  # Required libraries: streamlit, opencv-python-headless, ultralytics, Pillow
9
 
10
- # Replace with your model URL or local file path
11
  model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
12
 
13
- # Configure page layout for Hugging Face Spaces
14
  st.set_page_config(
15
  page_title="Fire Watch using AI vision models",
16
  page_icon="🔥",
@@ -18,22 +16,17 @@ st.set_page_config(
18
  initial_sidebar_state="expanded"
19
  )
20
 
21
- # Sidebar: Upload file, select confidence and video shortening options.
22
  with st.sidebar:
23
  st.header("IMAGE/VIDEO UPLOAD")
24
- source_file = st.file_uploader(
25
- "Choose an image or video...", type=("jpg", "jpeg", "png", "bmp", "webp", "mp4"))
26
  confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
27
- video_option = st.selectbox(
28
- "Select Video Shortening Option",
29
- ["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"]
30
- )
31
  progress_text = st.empty()
32
  progress_bar = st.progress(0)
33
- # Container for our dynamic slider (frame viewer)
34
- slider_container = st.empty()
35
 
36
- # Main page header and intro images
37
  st.title("WildfireWatch: Detecting Wildfire using AI")
38
  col1, col2 = st.columns(2)
39
  with col1:
@@ -47,7 +40,6 @@ Fires in Colorado present a serious challenge, threatening urban communities, hi
47
  st.markdown("---")
48
  st.header("Fire Detection:")
49
 
50
- # Create two columns for displaying the upload and results.
51
  col1, col2 = st.columns(2)
52
  if source_file:
53
  if source_file.type.split('/')[0] == 'image':
@@ -60,22 +52,22 @@ if source_file:
60
  else:
61
  st.info("Please upload an image or video file to begin.")
62
 
63
- # Load YOLO model
64
  try:
65
  model = YOLO(model_path)
66
  except Exception as ex:
67
  st.error(f"Unable to load model. Check the specified path: {model_path}")
68
  st.error(ex)
69
 
70
- # A container to display the currently viewed frame.
71
- viewer_slot = st.empty()
 
 
 
72
 
73
- # When the user clicks the detect button...
74
  if st.sidebar.button("Let's Detect Wildfire"):
75
  if not source_file:
76
  st.warning("No file uploaded!")
77
  elif source_file.type.split('/')[0] == 'image':
78
- # Process image input.
79
  res = model.predict(uploaded_image, conf=confidence)
80
  boxes = res[0].boxes
81
  res_plotted = res[0].plot()[:, :, ::-1]
@@ -85,17 +77,13 @@ if st.sidebar.button("Let's Detect Wildfire"):
85
  for box in boxes:
86
  st.write(box.xywh)
87
  else:
88
- # Process video input.
89
  processed_frames = []
90
  frame_count = 0
91
-
92
- # Get video properties.
93
  orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
94
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
95
  width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
96
  height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
97
 
98
- # Determine sampling interval and output fps based on the option selected.
99
  if video_option == "Original FPS":
100
  sample_interval = 1
101
  output_fps = orig_fps
@@ -118,12 +106,10 @@ if st.sidebar.button("Let's Detect Wildfire"):
118
  success, image = vidcap.read()
119
  while success:
120
  if frame_count % sample_interval == 0:
121
- # Run detection on current frame.
122
  res = model.predict(image, conf=confidence)
123
  res_plotted = res[0].plot()[:, :, ::-1]
124
  processed_frames.append(res_plotted)
125
 
126
- # Update progress.
127
  if total_frames > 0:
128
  progress_pct = int((frame_count / total_frames) * 100)
129
  progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)")
@@ -131,35 +117,30 @@ if st.sidebar.button("Let's Detect Wildfire"):
131
  else:
132
  progress_text.text(f"Processing frame {frame_count}")
133
 
134
- # Update slider only if we have at least one processed frame.
135
- if len(processed_frames) > 0:
136
- # Clear previous slider widget.
137
- slider_container.empty()
138
- # Use the slider widget's own value from session_state, or default to the last frame.
139
- default_val = st.session_state.get("frame_slider", len(processed_frames)-1)
140
- # Ensure default value is within bounds.
141
- if default_val > len(processed_frames)-1:
142
- default_val = len(processed_frames)-1
143
- # Create the slider. The key "frame_slider" will automatically update st.session_state.
144
  slider_val = slider_container.slider(
145
  "Frame Viewer",
146
  min_value=0,
147
  max_value=len(processed_frames)-1,
148
  value=default_val,
149
  step=1,
150
- key="frame_slider"
151
  )
152
- # If the user is at the most recent frame, update the viewer.
 
153
  if slider_val == len(processed_frames)-1:
154
  viewer_slot.image(processed_frames[-1], caption=f"Frame {len(processed_frames)-1}", use_column_width=True)
155
  frame_count += 1
156
  success, image = vidcap.read()
157
 
158
- # Finalize progress.
159
  progress_text.text("Video processing complete!")
160
  progress_bar.progress(100)
161
 
162
- # Create and provide the downloadable shortened video.
163
  if processed_frames:
164
  temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
165
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
7
 
8
  # Required libraries: streamlit, opencv-python-headless, ultralytics, Pillow
9
 
 
10
  model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
11
 
 
12
  st.set_page_config(
13
  page_title="Fire Watch using AI vision models",
14
  page_icon="🔥",
 
16
  initial_sidebar_state="expanded"
17
  )
18
 
 
19
  with st.sidebar:
20
  st.header("IMAGE/VIDEO UPLOAD")
21
+ source_file = st.file_uploader("Choose an image or video...",
22
+ type=("jpg", "jpeg", "png", "bmp", "webp", "mp4"))
23
  confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
24
+ video_option = st.selectbox("Select Video Shortening Option",
25
+ ["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"])
 
 
26
  progress_text = st.empty()
27
  progress_bar = st.progress(0)
28
+ slider_container = st.empty() # For dynamic slider widget
 
29
 
 
30
  st.title("WildfireWatch: Detecting Wildfire using AI")
31
  col1, col2 = st.columns(2)
32
  with col1:
 
40
  st.markdown("---")
41
  st.header("Fire Detection:")
42
 
 
43
  col1, col2 = st.columns(2)
44
  if source_file:
45
  if source_file.type.split('/')[0] == 'image':
 
52
  else:
53
  st.info("Please upload an image or video file to begin.")
54
 
 
55
  try:
56
  model = YOLO(model_path)
57
  except Exception as ex:
58
  st.error(f"Unable to load model. Check the specified path: {model_path}")
59
  st.error(ex)
60
 
61
+ viewer_slot = st.empty() # Container for currently viewed frame
62
+
63
+ # Initialize session state for slider value if not already set.
64
+ if "slider_value" not in st.session_state:
65
+ st.session_state.slider_value = 0
66
 
 
67
  if st.sidebar.button("Let's Detect Wildfire"):
68
  if not source_file:
69
  st.warning("No file uploaded!")
70
  elif source_file.type.split('/')[0] == 'image':
 
71
  res = model.predict(uploaded_image, conf=confidence)
72
  boxes = res[0].boxes
73
  res_plotted = res[0].plot()[:, :, ::-1]
 
77
  for box in boxes:
78
  st.write(box.xywh)
79
  else:
 
80
  processed_frames = []
81
  frame_count = 0
 
 
82
  orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
83
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
84
  width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
85
  height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
86
 
 
87
  if video_option == "Original FPS":
88
  sample_interval = 1
89
  output_fps = orig_fps
 
106
  success, image = vidcap.read()
107
  while success:
108
  if frame_count % sample_interval == 0:
 
109
  res = model.predict(image, conf=confidence)
110
  res_plotted = res[0].plot()[:, :, ::-1]
111
  processed_frames.append(res_plotted)
112
 
 
113
  if total_frames > 0:
114
  progress_pct = int((frame_count / total_frames) * 100)
115
  progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)")
 
117
  else:
118
  progress_text.text(f"Processing frame {frame_count}")
119
 
120
+ # Only update the slider if we have frames.
121
+ if processed_frames:
122
+ slider_container.empty() # Clear previous slider
123
+ # Use stored slider_value if within bounds, otherwise default to last frame.
124
+ default_val = st.session_state.slider_value if st.session_state.slider_value < len(processed_frames) else len(processed_frames)-1
125
+ # Create slider with a unique key based on number of processed frames.
 
 
 
 
126
  slider_val = slider_container.slider(
127
  "Frame Viewer",
128
  min_value=0,
129
  max_value=len(processed_frames)-1,
130
  value=default_val,
131
  step=1,
132
+ key=f"frame_slider_{len(processed_frames)}"
133
  )
134
+ st.session_state.slider_value = slider_val
135
+
136
  if slider_val == len(processed_frames)-1:
137
  viewer_slot.image(processed_frames[-1], caption=f"Frame {len(processed_frames)-1}", use_column_width=True)
138
  frame_count += 1
139
  success, image = vidcap.read()
140
 
 
141
  progress_text.text("Video processing complete!")
142
  progress_bar.progress(100)
143
 
 
144
  if processed_frames:
145
  temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
146
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')