Spaces:
Sleeping
Sleeping
| import PIL | |
| import cv2 | |
| import streamlit as st | |
| from ultralytics import YOLO | |
| import tempfile | |
| import time | |
| import requests | |
| import numpy as np | |
| import os | |
| # Page Config | |
| st.set_page_config(page_title="WildfireWatch", page_icon="🔥", layout="wide") | |
| # CSS for layout stability and dark tab text | |
| st.markdown( | |
| """ | |
| <style> | |
| .stApp { | |
| background-color: #f5f5f5; | |
| color: #1a1a1a; | |
| } | |
| h1 { | |
| color: #1a1a1a; | |
| } | |
| .stTabs > div > button { | |
| background-color: #e0e0e0; | |
| color: #333333; | |
| font-weight: bold; | |
| } | |
| .stTabs > div > button:hover { | |
| background-color: #d0d0d0; | |
| color: #333333; | |
| } | |
| .stTabs > div > button[aria-selected="true"] { | |
| background-color: #ffffff; | |
| color: #333333; | |
| } | |
| .main .block-container { | |
| max-height: 100vh; | |
| overflow-y: auto; | |
| } | |
| .stImage > img { | |
| max-height: 50vh; | |
| object-fit: contain; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Load Model | |
| model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt' | |
| try: | |
| model = YOLO(model_path) | |
| except Exception as ex: | |
| st.error(f"Unable to load model. Check the specified path: {model_path}") | |
| st.error(ex) | |
| st.stop() | |
| # Initialize Session State | |
| if 'monitoring' not in st.session_state: | |
| st.session_state.monitoring = False | |
| if 'current_webcam_url' not in st.session_state: | |
| st.session_state.current_webcam_url = None | |
| # Header | |
| st.title("WildfireWatch: Detecting Wildfire using AI") | |
| st.markdown(""" | |
| Wildfires are a major environmental issue, causing substantial losses to ecosystems, human livelihoods, and potentially leading to loss of life. Early detection of wildfires can prevent these losses. Our application uses state-of-the-art YOLOv8 model for real-time wildfire and smoke detection. | |
| """) | |
| st.markdown("---") | |
| # Tabs | |
| tabs = st.tabs(["Upload", "Webcam"]) | |
| # Tab 1: Upload (Simplified with diagnostics) | |
| with tabs[0]: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("**Add Your File**") | |
| st.write("Upload an image or video to scan for fire or smoke.") | |
| source_file = st.file_uploader("", type=["jpg", "jpeg", "png", "mp4"], label_visibility="collapsed") | |
| confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="upload_conf") | |
| sampling_options = {"Every Frame": 0, "1 FPS": 1, "2 FPS": 2, "5 FPS": 5} | |
| sampling_rate = st.selectbox("Analysis Rate", list(sampling_options.keys()), index=1, key="sampling_rate") | |
| with col2: | |
| frame_placeholder = st.empty() | |
| status_placeholder = st.empty() | |
| progress_placeholder = st.empty() | |
| download_placeholder = st.empty() | |
| if source_file: | |
| st.write(f"File size: {source_file.size / 1024 / 1024:.2f} MB") # Diagnostic | |
| if st.button("Detect Wildfire", key="upload_detect"): | |
| file_type = source_file.type.split('/')[0] | |
| if file_type == 'image': | |
| uploaded_image = PIL.Image.open(source_file) | |
| res = model.predict(uploaded_image, conf=confidence) | |
| detected_image = res[0].plot()[:, :, ::-1] | |
| frame_placeholder.image(detected_image, use_column_width=True) | |
| status_placeholder.write(f"Objects detected: {len(res[0].boxes)}") | |
| elif file_type == 'video': | |
| try: | |
| # Save input video | |
| input_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
| input_tfile.write(source_file.read()) | |
| input_tfile.close() | |
| # Open video | |
| vidcap = cv2.VideoCapture(input_tfile.name) | |
| if not vidcap.isOpened(): | |
| status_placeholder.error("Failed to open video file.") | |
| else: | |
| total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = int(vidcap.get(cv2.CAP_PROP_FPS)) or 30 | |
| frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # Frame sampling | |
| target_fps = sampling_options[sampling_rate] | |
| frame_skip = 1 if target_fps == 0 else max(1, int(fps / target_fps)) | |
| # Output video | |
| output_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='_detected.mp4') | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_tfile.name, fourcc, fps, (frame_width, frame_height)) | |
| success, frame = vidcap.read() | |
| frame_count = 0 | |
| processed_count = 0 | |
| last_detected_frame = None | |
| while success: | |
| if frame_count % frame_skip == 0: | |
| res = model.predict(frame, conf=confidence) | |
| detected_frame = res[0].plot()[:, :, ::-1] | |
| last_detected_frame = detected_frame | |
| frame_placeholder.image(detected_frame, use_column_width=True) | |
| status_placeholder.write(f"Frame {frame_count}: Objects detected: {len(res[0].boxes)}") | |
| processed_count += 1 | |
| elif last_detected_frame is not None: | |
| frame_placeholder.image(last_detected_frame, use_column_width=True) | |
| if last_detected_frame is not None: | |
| out.write(last_detected_frame[:, :, ::-1]) | |
| # Progress | |
| if total_frames > 0: | |
| progress_percent = (frame_count + 1) / total_frames * 100 | |
| progress_placeholder.write(f"Progress: {progress_percent:.1f}% (Processed {processed_count} frames)") | |
| else: | |
| progress_placeholder.write(f"Progress: {frame_count} frames processed") | |
| success, frame = vidcap.read() | |
| frame_count += 1 | |
| time.sleep(0.05) | |
| vidcap.release() | |
| out.release() | |
| os.unlink(input_tfile.name) | |
| with open(output_tfile.name, 'rb') as f: | |
| download_placeholder.download_button( | |
| label="Download Analyzed Video", | |
| data=f, | |
| file_name="analyzed_video.mp4", | |
| mime="video/mp4" | |
| ) | |
| status_placeholder.write(f"Video processing complete. Processed {processed_count} of {frame_count} frames.") | |
| except Exception as e: | |
| status_placeholder.error(f"Error processing video: {str(e)}") | |
| # Tab 2: Webcam (Unchanged) | |
| with tabs[1]: | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.markdown("**Webcam Feed**") | |
| st.write("Provide a webcam URL (image or video stream) to monitor for hazards.") | |
| webcam_url = st.text_input("Webcam URL", "http://<your_webcam_ip>/current.jpg", label_visibility="collapsed") | |
| confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="webcam_conf") | |
| refresh_rate = st.slider("Refresh Rate (seconds)", 1, 60, 30, key="webcam_rate") | |
| start = st.button("Begin Monitoring", key="webcam_start") | |
| stop = st.button("Stop Monitoring", key="webcam_stop") | |
| if start: | |
| st.session_state.monitoring = True | |
| st.session_state.current_webcam_url = webcam_url | |
| if stop or (st.session_state.monitoring and webcam_url != st.session_state.current_webcam_url): | |
| st.session_state.monitoring = False | |
| st.session_state.current_webcam_url = None | |
| with col2: | |
| frame_placeholder = st.empty() | |
| status_placeholder = st.empty() | |
| timer_placeholder = st.empty() | |
| if st.session_state.monitoring and st.session_state.current_webcam_url: | |
| cap = cv2.VideoCapture(webcam_url) | |
| is_video_stream = cap.isOpened() | |
| if is_video_stream: | |
| status_placeholder.write("Connected to video stream...") | |
| while st.session_state.monitoring and cap.isOpened(): | |
| try: | |
| ret, frame = cap.read() | |
| if not ret: | |
| status_placeholder.error("Video stream interrupted.") | |
| break | |
| if webcam_url != st.session_state.current_webcam_url: | |
| status_placeholder.write("URL changed. Stopping video monitoring.") | |
| break | |
| res = model.predict(frame, conf=confidence) | |
| detected_frame = res[0].plot()[:, :, ::-1] | |
| frame_placeholder.image(detected_frame, use_column_width=True) | |
| status_placeholder.write(f"Objects detected: {len(res[0].boxes)}") | |
| time.sleep(0.1) | |
| except Exception as e: | |
| status_placeholder.error(f"Video error: {e}") | |
| st.session_state.monitoring = False | |
| break | |
| cap.release() | |
| else: | |
| status_placeholder.write("Monitoring image-based webcam...") | |
| while st.session_state.monitoring: | |
| try: | |
| start_time = time.time() | |
| if webcam_url != st.session_state.current_webcam_url: | |
| status_placeholder.write("URL changed. Stopping image monitoring.") | |
| break | |
| response = requests.get(webcam_url, timeout=5) | |
| if response.status_code != 200: | |
| status_placeholder.error(f"Fetch failed: HTTP {response.status_code}") | |
| break | |
| image_array = np.asarray(bytearray(response.content), dtype=np.uint8) | |
| frame = cv2.imdecode(image_array, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| status_placeholder.error("Image decoding failed.") | |
| break | |
| res = model.predict(frame, conf=confidence) | |
| detected_frame = res[0].plot()[:, :, ::-1] | |
| frame_placeholder.image(detected_frame, use_column_width=True) | |
| status_placeholder.write(f"Objects detected: {len(res[0].boxes)}") | |
| elapsed = time.time() - start_time | |
| remaining = max(0, refresh_rate - elapsed) | |
| for i in range(int(remaining), -1, -1): | |
| if not st.session_state.monitoring or webcam_url != st.session_state.current_webcam_url: | |
| status_placeholder.write("Monitoring interrupted or URL changed.") | |
| break | |
| timer_placeholder.write(f"Next scan: {i}s") | |
| time.sleep(1) | |
| except Exception as e: | |
| status_placeholder.error(f"Image fetch error: {e}") | |
| st.session_state.monitoring = False | |
| break | |
| if not st.session_state.monitoring: | |
| timer_placeholder.write("Monitoring stopped.") |