ccr-colorado / app.py
tstone87's picture
Update app.py
ebb9e28 verified
raw
history blame
7.83 kB
import os
import tempfile
import cv2
import streamlit as st
import PIL
from ultralytics import YOLO
# Required libraries: streamlit, opencv-python-headless, ultralytics, Pillow
# Replace with your model URL or local file path
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
# Configure page layout for Hugging Face Spaces
st.set_page_config(
page_title="Fire Watch using AI vision models",
page_icon="🔥",
layout="wide",
initial_sidebar_state="expanded"
)
# Sidebar: Upload file, select confidence and video shortening options.
with st.sidebar:
st.header("IMAGE/VIDEO UPLOAD")
source_file = st.file_uploader(
"Choose an image or video...", type=("jpg", "jpeg", "png", "bmp", "webp", "mp4"))
confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
video_option = st.selectbox(
"Select Video Shortening Option",
["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"]
)
progress_text = st.empty()
progress_bar = st.progress(0)
# Container for our dynamic slider (frame viewer)
slider_container = st.empty()
# Main page header and intro images
st.title("WildfireWatch: Detecting Wildfire using AI")
col1, col2 = st.columns(2)
with col1:
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True)
with col2:
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True)
st.markdown("""
Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas. Early detection is critical. WildfireWatch leverages YOLOv8 for real‐time fire and smoke detection in images and videos.
""")
st.markdown("---")
st.header("Fire Detection:")
# Create two columns for displaying the upload and results.
col1, col2 = st.columns(2)
if source_file:
if source_file.type.split('/')[0] == 'image':
uploaded_image = PIL.Image.open(source_file)
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
else:
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(source_file.read())
vidcap = cv2.VideoCapture(tfile.name)
else:
st.info("Please upload an image or video file to begin.")
# Load YOLO model
try:
model = YOLO(model_path)
except Exception as ex:
st.error(f"Unable to load model. Check the specified path: {model_path}")
st.error(ex)
# We'll use a session_state variable to remember the current slider value.
if "frame_slider" not in st.session_state:
st.session_state.frame_slider = 0
# A container to display the currently viewed frame.
viewer_slot = st.empty()
# When the user clicks the detect button...
if st.sidebar.button("Let's Detect Wildfire"):
if not source_file:
st.warning("No file uploaded!")
elif source_file.type.split('/')[0] == 'image':
# Process image input.
res = model.predict(uploaded_image, conf=confidence)
boxes = res[0].boxes
res_plotted = res[0].plot()[:, :, ::-1]
with col2:
st.image(res_plotted, caption='Detected Image', use_column_width=True)
with st.expander("Detection Results"):
for box in boxes:
st.write(box.xywh)
else:
# Process video input.
processed_frames = []
frame_count = 0
# Get video properties.
orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Determine sampling interval and output fps based on the option selected.
if video_option == "Original FPS":
sample_interval = 1
output_fps = orig_fps
elif video_option == "1 fps":
sample_interval = int(orig_fps) if orig_fps > 0 else 1
output_fps = 1
elif video_option == "1 frame per 5 seconds":
sample_interval = int(orig_fps * 5) if orig_fps > 0 else 5
output_fps = 1
elif video_option == "1 frame per 10 seconds":
sample_interval = int(orig_fps * 10) if orig_fps > 0 else 10
output_fps = 1
elif video_option == "1 frame per 15 seconds":
sample_interval = int(orig_fps * 15) if orig_fps > 0 else 15
output_fps = 1
else:
sample_interval = 1
output_fps = orig_fps
success, image = vidcap.read()
while success:
if frame_count % sample_interval == 0:
# Run detection on current frame.
res = model.predict(image, conf=confidence)
res_plotted = res[0].plot()[:, :, ::-1]
processed_frames.append(res_plotted)
# Update progress.
if total_frames > 0:
progress_pct = int((frame_count / total_frames) * 100)
progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)")
progress_bar.progress(min(100, progress_pct))
else:
progress_text.text(f"Processing frame {frame_count}")
# Only update slider if we have at least one processed frame.
if len(processed_frames) > 0:
# Clear the previous slider widget.
slider_container.empty()
# Determine the current slider value.
curr_slider_val = st.session_state.get("frame_slider", len(processed_frames)-1)
# Ensure the slider value is within the new bounds.
if curr_slider_val > len(processed_frames)-1:
curr_slider_val = len(processed_frames)-1
# Create a new slider. This slider's key is fixed because we cleared the container beforehand.
slider_val = slider_container.slider(
"Frame Viewer",
min_value=0,
max_value=len(processed_frames)-1,
value=curr_slider_val,
step=1,
key="frame_slider"
)
st.session_state.frame_slider = slider_val
# If the user is at the most recent frame, update the viewer.
if slider_val == len(processed_frames)-1:
viewer_slot.image(processed_frames[-1], caption=f"Frame {len(processed_frames)-1}", use_column_width=True)
frame_count += 1
success, image = vidcap.read()
# Finalize progress.
progress_text.text("Video processing complete!")
progress_bar.progress(100)
# Create and provide the downloadable shortened video.
if processed_frames:
temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video_file.name, fourcc, output_fps, (width, height))
for frame in processed_frames:
out.write(frame)
out.release()
st.success("Shortened video created successfully!")
with open(temp_video_file.name, 'rb') as video_file:
st.download_button(
label="Download Shortened Video",
data=video_file.read(),
file_name="shortened_video.mp4",
mime="video/mp4"
)
else:
st.error("No frames were processed from the video.")