Spaces:
Sleeping
Sleeping
import PIL | |
import cv2 | |
import streamlit as st | |
from ultralytics import YOLO | |
import tempfile | |
import time | |
import requests | |
import numpy as np | |
import os | |
# Page Config | |
st.set_page_config(page_title="WildfireWatch", page_icon="🔥", layout="wide") | |
# CSS for layout stability and dark tab text | |
st.markdown( | |
""" | |
<style> | |
.stApp { | |
background-color: #f5f5f5; | |
color: #1a1a1a; | |
} | |
h1 { | |
color: #1a1a1a; | |
} | |
.stTabs > div > button { | |
background-color: #e0e0e0; | |
color: #333333; | |
font-weight: bold; | |
} | |
.stTabs > div > button:hover { | |
background-color: #d0d0d0; | |
color: #333333; | |
} | |
.stTabs > div > button[aria-selected="true"] { | |
background-color: #ffffff; | |
color: #333333; | |
} | |
.main .block-container { | |
max-height: 100vh; | |
overflow-y: auto; | |
} | |
.stImage > img { | |
max-height: 50vh; | |
object-fit: contain; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# Load Model | |
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt' | |
try: | |
model = YOLO(model_path) | |
except Exception as ex: | |
st.error(f"Unable to load model. Check the specified path: {model_path}") | |
st.error(ex) | |
st.stop() | |
# Initialize Session State | |
if 'monitoring' not in st.session_state: | |
st.session_state.monitoring = False | |
if 'current_webcam_url' not in st.session_state: | |
st.session_state.current_webcam_url = None | |
# Header | |
st.title("WildfireWatch: Detecting Wildfire using AI") | |
st.markdown(""" | |
Wildfires are a major environmental issue, causing substantial losses to ecosystems, human livelihoods, and potentially leading to loss of life. Early detection of wildfires can prevent these losses. Our application uses state-of-the-art YOLOv8 model for real-time wildfire and smoke detection. | |
""") | |
st.markdown("---") | |
# Tabs | |
tabs = st.tabs(["Upload", "Webcam"]) | |
# Tab 1: Upload (Simplified with diagnostics) | |
with tabs[0]: | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown("**Add Your File**") | |
st.write("Upload an image or video to scan for fire or smoke.") | |
source_file = st.file_uploader("", type=["jpg", "jpeg", "png", "mp4"], label_visibility="collapsed") | |
confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="upload_conf") | |
sampling_options = {"Every Frame": 0, "1 FPS": 1, "2 FPS": 2, "5 FPS": 5} | |
sampling_rate = st.selectbox("Analysis Rate", list(sampling_options.keys()), index=1, key="sampling_rate") | |
with col2: | |
frame_placeholder = st.empty() | |
status_placeholder = st.empty() | |
progress_placeholder = st.empty() | |
download_placeholder = st.empty() | |
if source_file: | |
st.write(f"File size: {source_file.size / 1024 / 1024:.2f} MB") # Diagnostic | |
if st.button("Detect Wildfire", key="upload_detect"): | |
file_type = source_file.type.split('/')[0] | |
if file_type == 'image': | |
uploaded_image = PIL.Image.open(source_file) | |
res = model.predict(uploaded_image, conf=confidence) | |
detected_image = res[0].plot()[:, :, ::-1] | |
frame_placeholder.image(detected_image, use_column_width=True) | |
status_placeholder.write(f"Objects detected: {len(res[0].boxes)}") | |
elif file_type == 'video': | |
try: | |
# Save input video | |
input_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
input_tfile.write(source_file.read()) | |
input_tfile.close() | |
# Open video | |
vidcap = cv2.VideoCapture(input_tfile.name) | |
if not vidcap.isOpened(): | |
status_placeholder.error("Failed to open video file.") | |
else: | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = int(vidcap.get(cv2.CAP_PROP_FPS)) or 30 | |
frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Frame sampling | |
target_fps = sampling_options[sampling_rate] | |
frame_skip = 1 if target_fps == 0 else max(1, int(fps / target_fps)) | |
# Output video | |
output_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='_detected.mp4') | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_tfile.name, fourcc, fps, (frame_width, frame_height)) | |
success, frame = vidcap.read() | |
frame_count = 0 | |
processed_count = 0 | |
last_detected_frame = None | |
while success: | |
if frame_count % frame_skip == 0: | |
res = model.predict(frame, conf=confidence) | |
detected_frame = res[0].plot()[:, :, ::-1] | |
last_detected_frame = detected_frame | |
frame_placeholder.image(detected_frame, use_column_width=True) | |
status_placeholder.write(f"Frame {frame_count}: Objects detected: {len(res[0].boxes)}") | |
processed_count += 1 | |
elif last_detected_frame is not None: | |
frame_placeholder.image(last_detected_frame, use_column_width=True) | |
if last_detected_frame is not None: | |
out.write(last_detected_frame[:, :, ::-1]) | |
# Progress | |
if total_frames > 0: | |
progress_percent = (frame_count + 1) / total_frames * 100 | |
progress_placeholder.write(f"Progress: {progress_percent:.1f}% (Processed {processed_count} frames)") | |
else: | |
progress_placeholder.write(f"Progress: {frame_count} frames processed") | |
success, frame = vidcap.read() | |
frame_count += 1 | |
time.sleep(0.05) | |
vidcap.release() | |
out.release() | |
os.unlink(input_tfile.name) | |
with open(output_tfile.name, 'rb') as f: | |
download_placeholder.download_button( | |
label="Download Analyzed Video", | |
data=f, | |
file_name="analyzed_video.mp4", | |
mime="video/mp4" | |
) | |
status_placeholder.write(f"Video processing complete. Processed {processed_count} of {frame_count} frames.") | |
except Exception as e: | |
status_placeholder.error(f"Error processing video: {str(e)}") | |
# Tab 2: Webcam (Unchanged) | |
with tabs[1]: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.markdown("**Webcam Feed**") | |
st.write("Provide a webcam URL (image or video stream) to monitor for hazards.") | |
webcam_url = st.text_input("Webcam URL", "http://<your_webcam_ip>/current.jpg", label_visibility="collapsed") | |
confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="webcam_conf") | |
refresh_rate = st.slider("Refresh Rate (seconds)", 1, 60, 30, key="webcam_rate") | |
start = st.button("Begin Monitoring", key="webcam_start") | |
stop = st.button("Stop Monitoring", key="webcam_stop") | |
if start: | |
st.session_state.monitoring = True | |
st.session_state.current_webcam_url = webcam_url | |
if stop or (st.session_state.monitoring and webcam_url != st.session_state.current_webcam_url): | |
st.session_state.monitoring = False | |
st.session_state.current_webcam_url = None | |
with col2: | |
frame_placeholder = st.empty() | |
status_placeholder = st.empty() | |
timer_placeholder = st.empty() | |
if st.session_state.monitoring and st.session_state.current_webcam_url: | |
cap = cv2.VideoCapture(webcam_url) | |
is_video_stream = cap.isOpened() | |
if is_video_stream: | |
status_placeholder.write("Connected to video stream...") | |
while st.session_state.monitoring and cap.isOpened(): | |
try: | |
ret, frame = cap.read() | |
if not ret: | |
status_placeholder.error("Video stream interrupted.") | |
break | |
if webcam_url != st.session_state.current_webcam_url: | |
status_placeholder.write("URL changed. Stopping video monitoring.") | |
break | |
res = model.predict(frame, conf=confidence) | |
detected_frame = res[0].plot()[:, :, ::-1] | |
frame_placeholder.image(detected_frame, use_column_width=True) | |
status_placeholder.write(f"Objects detected: {len(res[0].boxes)}") | |
time.sleep(0.1) | |
except Exception as e: | |
status_placeholder.error(f"Video error: {e}") | |
st.session_state.monitoring = False | |
break | |
cap.release() | |
else: | |
status_placeholder.write("Monitoring image-based webcam...") | |
while st.session_state.monitoring: | |
try: | |
start_time = time.time() | |
if webcam_url != st.session_state.current_webcam_url: | |
status_placeholder.write("URL changed. Stopping image monitoring.") | |
break | |
response = requests.get(webcam_url, timeout=5) | |
if response.status_code != 200: | |
status_placeholder.error(f"Fetch failed: HTTP {response.status_code}") | |
break | |
image_array = np.asarray(bytearray(response.content), dtype=np.uint8) | |
frame = cv2.imdecode(image_array, cv2.IMREAD_COLOR) | |
if frame is None: | |
status_placeholder.error("Image decoding failed.") | |
break | |
res = model.predict(frame, conf=confidence) | |
detected_frame = res[0].plot()[:, :, ::-1] | |
frame_placeholder.image(detected_frame, use_column_width=True) | |
status_placeholder.write(f"Objects detected: {len(res[0].boxes)}") | |
elapsed = time.time() - start_time | |
remaining = max(0, refresh_rate - elapsed) | |
for i in range(int(remaining), -1, -1): | |
if not st.session_state.monitoring or webcam_url != st.session_state.current_webcam_url: | |
status_placeholder.write("Monitoring interrupted or URL changed.") | |
break | |
timer_placeholder.write(f"Next scan: {i}s") | |
time.sleep(1) | |
except Exception as e: | |
status_placeholder.error(f"Image fetch error: {e}") | |
st.session_state.monitoring = False | |
break | |
if not st.session_state.monitoring: | |
timer_placeholder.write("Monitoring stopped.") |