Spaces:
Running
Running
import os | |
import tempfile | |
import cv2 | |
import streamlit as st | |
import PIL | |
from ultralytics import YOLO | |
# Required libraries: streamlit, opencv-python-headless, ultralytics, Pillow | |
# Replace with your model URL or local file path | |
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt' | |
# Configure page layout for Hugging Face Spaces | |
st.set_page_config( | |
page_title="Fire Watch using AI vision models", | |
page_icon="🔥", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Sidebar: Upload file, select confidence and video shortening options. | |
with st.sidebar: | |
st.header("IMAGE/VIDEO UPLOAD") | |
source_file = st.file_uploader( | |
"Choose an image or video...", type=("jpg", "jpeg", "png", "bmp", "webp", "mp4")) | |
confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100 | |
video_option = st.selectbox( | |
"Select Video Shortening Option", | |
["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"] | |
) | |
progress_text = st.empty() | |
progress_bar = st.progress(0) | |
# Container for our dynamic slider (frame viewer) | |
slider_container = st.empty() | |
# Main page header and intro images | |
st.title("WildfireWatch: Detecting Wildfire using AI") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True) | |
with col2: | |
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True) | |
st.markdown(""" | |
Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas. Early detection is critical. WildfireWatch leverages YOLOv8 for real‐time fire and smoke detection in images and videos. | |
""") | |
st.markdown("---") | |
st.header("Fire Detection:") | |
# Create two columns for displaying the upload and results. | |
col1, col2 = st.columns(2) | |
if source_file: | |
if source_file.type.split('/')[0] == 'image': | |
uploaded_image = PIL.Image.open(source_file) | |
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) | |
else: | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(source_file.read()) | |
vidcap = cv2.VideoCapture(tfile.name) | |
else: | |
st.info("Please upload an image or video file to begin.") | |
# Load YOLO model | |
try: | |
model = YOLO(model_path) | |
except Exception as ex: | |
st.error(f"Unable to load model. Check the specified path: {model_path}") | |
st.error(ex) | |
# We'll use a session_state variable to remember the current slider value. | |
if "frame_slider" not in st.session_state: | |
st.session_state.frame_slider = 0 | |
# A container to display the currently viewed frame. | |
viewer_slot = st.empty() | |
# When the user clicks the detect button... | |
if st.sidebar.button("Let's Detect Wildfire"): | |
if not source_file: | |
st.warning("No file uploaded!") | |
elif source_file.type.split('/')[0] == 'image': | |
# Process image input. | |
res = model.predict(uploaded_image, conf=confidence) | |
boxes = res[0].boxes | |
res_plotted = res[0].plot()[:, :, ::-1] | |
with col2: | |
st.image(res_plotted, caption='Detected Image', use_column_width=True) | |
with st.expander("Detection Results"): | |
for box in boxes: | |
st.write(box.xywh) | |
else: | |
# Process video input. | |
processed_frames = [] | |
frame_count = 0 | |
# Get video properties. | |
orig_fps = vidcap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Determine sampling interval and output fps based on the option selected. | |
if video_option == "Original FPS": | |
sample_interval = 1 | |
output_fps = orig_fps | |
elif video_option == "1 fps": | |
sample_interval = int(orig_fps) if orig_fps > 0 else 1 | |
output_fps = 1 | |
elif video_option == "1 frame per 5 seconds": | |
sample_interval = int(orig_fps * 5) if orig_fps > 0 else 5 | |
output_fps = 1 | |
elif video_option == "1 frame per 10 seconds": | |
sample_interval = int(orig_fps * 10) if orig_fps > 0 else 10 | |
output_fps = 1 | |
elif video_option == "1 frame per 15 seconds": | |
sample_interval = int(orig_fps * 15) if orig_fps > 0 else 15 | |
output_fps = 1 | |
else: | |
sample_interval = 1 | |
output_fps = orig_fps | |
success, image = vidcap.read() | |
while success: | |
if frame_count % sample_interval == 0: | |
# Run detection on current frame. | |
res = model.predict(image, conf=confidence) | |
res_plotted = res[0].plot()[:, :, ::-1] | |
processed_frames.append(res_plotted) | |
# Update progress. | |
if total_frames > 0: | |
progress_pct = int((frame_count / total_frames) * 100) | |
progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)") | |
progress_bar.progress(min(100, progress_pct)) | |
else: | |
progress_text.text(f"Processing frame {frame_count}") | |
# Only update slider if we have at least one processed frame. | |
if len(processed_frames) > 0: | |
# Clear the previous slider widget. | |
slider_container.empty() | |
# Determine the current slider value. | |
curr_slider_val = st.session_state.get("frame_slider", len(processed_frames)-1) | |
# Ensure the slider value is within the new bounds. | |
if curr_slider_val > len(processed_frames)-1: | |
curr_slider_val = len(processed_frames)-1 | |
# Create a new slider. This slider's key is fixed because we cleared the container beforehand. | |
slider_val = slider_container.slider( | |
"Frame Viewer", | |
min_value=0, | |
max_value=len(processed_frames)-1, | |
value=curr_slider_val, | |
step=1, | |
key="frame_slider" | |
) | |
st.session_state.frame_slider = slider_val | |
# If the user is at the most recent frame, update the viewer. | |
if slider_val == len(processed_frames)-1: | |
viewer_slot.image(processed_frames[-1], caption=f"Frame {len(processed_frames)-1}", use_column_width=True) | |
frame_count += 1 | |
success, image = vidcap.read() | |
# Finalize progress. | |
progress_text.text("Video processing complete!") | |
progress_bar.progress(100) | |
# Create and provide the downloadable shortened video. | |
if processed_frames: | |
temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(temp_video_file.name, fourcc, output_fps, (width, height)) | |
for frame in processed_frames: | |
out.write(frame) | |
out.release() | |
st.success("Shortened video created successfully!") | |
with open(temp_video_file.name, 'rb') as video_file: | |
st.download_button( | |
label="Download Shortened Video", | |
data=video_file.read(), | |
file_name="shortened_video.mp4", | |
mime="video/mp4" | |
) | |
else: | |
st.error("No frames were processed from the video.") | |