import os import tempfile import cv2 import streamlit as st import PIL from ultralytics import YOLO # Required libraries: streamlit, opencv-python-headless, ultralytics, Pillow # Replace with your model URL or local file path model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt' # Configure page layout for Hugging Face Spaces st.set_page_config( page_title="Fire Watch using AI vision models", page_icon="šŸ”„", layout="wide", initial_sidebar_state="expanded" ) # Sidebar: Upload file, select confidence and video shortening options. with st.sidebar: st.header("IMAGE/VIDEO UPLOAD") source_file = st.file_uploader( "Choose an image or video...", type=("jpg", "jpeg", "png", "bmp", "webp", "mp4")) confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100 video_option = st.selectbox( "Select Video Shortening Option", ["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"] ) progress_text = st.empty() progress_bar = st.progress(0) # Container for our dynamic slider (frame viewer) slider_container = st.empty() # Main page header and intro images st.title("WildfireWatch: Detecting Wildfire using AI") col1, col2 = st.columns(2) with col1: st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True) with col2: st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True) st.markdown(""" Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas. Early detection is critical. WildfireWatch leverages YOLOv8 for real‐time fire and smoke detection in images and videos. """) st.markdown("---") st.header("Fire Detection:") # Create two columns for displaying the upload and results. col1, col2 = st.columns(2) if source_file: if source_file.type.split('/')[0] == 'image': uploaded_image = PIL.Image.open(source_file) st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) else: tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(source_file.read()) vidcap = cv2.VideoCapture(tfile.name) else: st.info("Please upload an image or video file to begin.") # Load YOLO model try: model = YOLO(model_path) except Exception as ex: st.error(f"Unable to load model. Check the specified path: {model_path}") st.error(ex) # We'll use a session_state variable to remember the current slider value. if "frame_slider" not in st.session_state: st.session_state.frame_slider = 0 # A container to display the currently viewed frame. viewer_slot = st.empty() # When the user clicks the detect button... if st.sidebar.button("Let's Detect Wildfire"): if not source_file: st.warning("No file uploaded!") elif source_file.type.split('/')[0] == 'image': # Process image input. res = model.predict(uploaded_image, conf=confidence) boxes = res[0].boxes res_plotted = res[0].plot()[:, :, ::-1] with col2: st.image(res_plotted, caption='Detected Image', use_column_width=True) with st.expander("Detection Results"): for box in boxes: st.write(box.xywh) else: # Process video input. processed_frames = [] frame_count = 0 # Get video properties. orig_fps = vidcap.get(cv2.CAP_PROP_FPS) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Determine sampling interval and output fps based on the option selected. if video_option == "Original FPS": sample_interval = 1 output_fps = orig_fps elif video_option == "1 fps": sample_interval = int(orig_fps) if orig_fps > 0 else 1 output_fps = 1 elif video_option == "1 frame per 5 seconds": sample_interval = int(orig_fps * 5) if orig_fps > 0 else 5 output_fps = 1 elif video_option == "1 frame per 10 seconds": sample_interval = int(orig_fps * 10) if orig_fps > 0 else 10 output_fps = 1 elif video_option == "1 frame per 15 seconds": sample_interval = int(orig_fps * 15) if orig_fps > 0 else 15 output_fps = 1 else: sample_interval = 1 output_fps = orig_fps success, image = vidcap.read() while success: if frame_count % sample_interval == 0: # Run detection on current frame. res = model.predict(image, conf=confidence) res_plotted = res[0].plot()[:, :, ::-1] processed_frames.append(res_plotted) # Update progress. if total_frames > 0: progress_pct = int((frame_count / total_frames) * 100) progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)") progress_bar.progress(min(100, progress_pct)) else: progress_text.text(f"Processing frame {frame_count}") # Only update slider if we have at least one processed frame. if len(processed_frames) > 0: # Clear the previous slider widget. slider_container.empty() # Determine the current slider value. curr_slider_val = st.session_state.get("frame_slider", len(processed_frames)-1) # Ensure the slider value is within the new bounds. if curr_slider_val > len(processed_frames)-1: curr_slider_val = len(processed_frames)-1 # Create a new slider. This slider's key is fixed because we cleared the container beforehand. slider_val = slider_container.slider( "Frame Viewer", min_value=0, max_value=len(processed_frames)-1, value=curr_slider_val, step=1, key="frame_slider" ) st.session_state.frame_slider = slider_val # If the user is at the most recent frame, update the viewer. if slider_val == len(processed_frames)-1: viewer_slot.image(processed_frames[-1], caption=f"Frame {len(processed_frames)-1}", use_column_width=True) frame_count += 1 success, image = vidcap.read() # Finalize progress. progress_text.text("Video processing complete!") progress_bar.progress(100) # Create and provide the downloadable shortened video. if processed_frames: temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_video_file.name, fourcc, output_fps, (width, height)) for frame in processed_frames: out.write(frame) out.release() st.success("Shortened video created successfully!") with open(temp_video_file.name, 'rb') as video_file: st.download_button( label="Download Shortened Video", data=video_file.read(), file_name="shortened_video.mp4", mime="video/mp4" ) else: st.error("No frames were processed from the video.")