ccr-colorado / app.py
tstone87's picture
Update app.py
3bcd916 verified
raw
history blame
12.4 kB
import PIL
import cv2
import streamlit as st
from ultralytics import YOLO
import tempfile
import time
import requests
import numpy as np
import os
# Page Config
st.set_page_config(page_title="WildfireWatch", page_icon="🔥", layout="wide")
# CSS for layout stability and dark tab text
st.markdown(
"""
<style>
.stApp {
background-color: #f5f5f5;
color: #1a1a1a;
}
h1 {
color: #1a1a1a;
}
.stTabs > div > button {
background-color: #e0e0e0;
color: #333333;
font-weight: bold;
}
.stTabs > div > button:hover {
background-color: #d0d0d0;
color: #333333;
}
.stTabs > div > button[aria-selected="true"] {
background-color: #ffffff;
color: #333333;
}
.main .block-container {
max-height: 100vh;
overflow-y: auto;
}
.stImage > img {
max-height: 50vh;
object-fit: contain;
}
</style>
""",
unsafe_allow_html=True
)
# Load Model
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
try:
model = YOLO(model_path)
except Exception as ex:
st.error(f"Unable to load model. Check the specified path: {model_path}")
st.error(ex)
st.stop()
# Initialize Session State
if 'monitoring' not in st.session_state:
st.session_state.monitoring = False
if 'current_webcam_url' not in st.session_state:
st.session_state.current_webcam_url = None
# Header
st.title("WildfireWatch: Detecting Wildfire using AI")
st.markdown("""
Wildfires are a major environmental issue, causing substantial losses to ecosystems, human livelihoods, and potentially leading to loss of life. Early detection of wildfires can prevent these losses. Our application uses state-of-the-art YOLOv8 model for real-time wildfire and smoke detection.
""")
st.markdown("---")
# Tabs
tabs = st.tabs(["Upload", "Webcam"])
# Tab 1: Upload (Simplified with diagnostics)
with tabs[0]:
col1, col2 = st.columns(2)
with col1:
st.markdown("**Add Your File**")
st.write("Upload an image or video to scan for fire or smoke.")
source_file = st.file_uploader("", type=["jpg", "jpeg", "png", "mp4"], label_visibility="collapsed")
confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="upload_conf")
sampling_options = {"Every Frame": 0, "1 FPS": 1, "2 FPS": 2, "5 FPS": 5}
sampling_rate = st.selectbox("Analysis Rate", list(sampling_options.keys()), index=1, key="sampling_rate")
with col2:
frame_placeholder = st.empty()
status_placeholder = st.empty()
progress_placeholder = st.empty()
download_placeholder = st.empty()
if source_file:
st.write(f"File size: {source_file.size / 1024 / 1024:.2f} MB") # Diagnostic
if st.button("Detect Wildfire", key="upload_detect"):
file_type = source_file.type.split('/')[0]
if file_type == 'image':
uploaded_image = PIL.Image.open(source_file)
res = model.predict(uploaded_image, conf=confidence)
detected_image = res[0].plot()[:, :, ::-1]
frame_placeholder.image(detected_image, use_column_width=True)
status_placeholder.write(f"Objects detected: {len(res[0].boxes)}")
elif file_type == 'video':
try:
# Save input video
input_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
input_tfile.write(source_file.read())
input_tfile.close()
# Open video
vidcap = cv2.VideoCapture(input_tfile.name)
if not vidcap.isOpened():
status_placeholder.error("Failed to open video file.")
else:
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(vidcap.get(cv2.CAP_PROP_FPS)) or 30
frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Frame sampling
target_fps = sampling_options[sampling_rate]
frame_skip = 1 if target_fps == 0 else max(1, int(fps / target_fps))
# Output video
output_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='_detected.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_tfile.name, fourcc, fps, (frame_width, frame_height))
success, frame = vidcap.read()
frame_count = 0
processed_count = 0
last_detected_frame = None
while success:
if frame_count % frame_skip == 0:
res = model.predict(frame, conf=confidence)
detected_frame = res[0].plot()[:, :, ::-1]
last_detected_frame = detected_frame
frame_placeholder.image(detected_frame, use_column_width=True)
status_placeholder.write(f"Frame {frame_count}: Objects detected: {len(res[0].boxes)}")
processed_count += 1
elif last_detected_frame is not None:
frame_placeholder.image(last_detected_frame, use_column_width=True)
if last_detected_frame is not None:
out.write(last_detected_frame[:, :, ::-1])
# Progress
if total_frames > 0:
progress_percent = (frame_count + 1) / total_frames * 100
progress_placeholder.write(f"Progress: {progress_percent:.1f}% (Processed {processed_count} frames)")
else:
progress_placeholder.write(f"Progress: {frame_count} frames processed")
success, frame = vidcap.read()
frame_count += 1
time.sleep(0.05)
vidcap.release()
out.release()
os.unlink(input_tfile.name)
with open(output_tfile.name, 'rb') as f:
download_placeholder.download_button(
label="Download Analyzed Video",
data=f,
file_name="analyzed_video.mp4",
mime="video/mp4"
)
status_placeholder.write(f"Video processing complete. Processed {processed_count} of {frame_count} frames.")
except Exception as e:
status_placeholder.error(f"Error processing video: {str(e)}")
# Tab 2: Webcam (Unchanged)
with tabs[1]:
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("**Webcam Feed**")
st.write("Provide a webcam URL (image or video stream) to monitor for hazards.")
webcam_url = st.text_input("Webcam URL", "http://<your_webcam_ip>/current.jpg", label_visibility="collapsed")
confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="webcam_conf")
refresh_rate = st.slider("Refresh Rate (seconds)", 1, 60, 30, key="webcam_rate")
start = st.button("Begin Monitoring", key="webcam_start")
stop = st.button("Stop Monitoring", key="webcam_stop")
if start:
st.session_state.monitoring = True
st.session_state.current_webcam_url = webcam_url
if stop or (st.session_state.monitoring and webcam_url != st.session_state.current_webcam_url):
st.session_state.monitoring = False
st.session_state.current_webcam_url = None
with col2:
frame_placeholder = st.empty()
status_placeholder = st.empty()
timer_placeholder = st.empty()
if st.session_state.monitoring and st.session_state.current_webcam_url:
cap = cv2.VideoCapture(webcam_url)
is_video_stream = cap.isOpened()
if is_video_stream:
status_placeholder.write("Connected to video stream...")
while st.session_state.monitoring and cap.isOpened():
try:
ret, frame = cap.read()
if not ret:
status_placeholder.error("Video stream interrupted.")
break
if webcam_url != st.session_state.current_webcam_url:
status_placeholder.write("URL changed. Stopping video monitoring.")
break
res = model.predict(frame, conf=confidence)
detected_frame = res[0].plot()[:, :, ::-1]
frame_placeholder.image(detected_frame, use_column_width=True)
status_placeholder.write(f"Objects detected: {len(res[0].boxes)}")
time.sleep(0.1)
except Exception as e:
status_placeholder.error(f"Video error: {e}")
st.session_state.monitoring = False
break
cap.release()
else:
status_placeholder.write("Monitoring image-based webcam...")
while st.session_state.monitoring:
try:
start_time = time.time()
if webcam_url != st.session_state.current_webcam_url:
status_placeholder.write("URL changed. Stopping image monitoring.")
break
response = requests.get(webcam_url, timeout=5)
if response.status_code != 200:
status_placeholder.error(f"Fetch failed: HTTP {response.status_code}")
break
image_array = np.asarray(bytearray(response.content), dtype=np.uint8)
frame = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
if frame is None:
status_placeholder.error("Image decoding failed.")
break
res = model.predict(frame, conf=confidence)
detected_frame = res[0].plot()[:, :, ::-1]
frame_placeholder.image(detected_frame, use_column_width=True)
status_placeholder.write(f"Objects detected: {len(res[0].boxes)}")
elapsed = time.time() - start_time
remaining = max(0, refresh_rate - elapsed)
for i in range(int(remaining), -1, -1):
if not st.session_state.monitoring or webcam_url != st.session_state.current_webcam_url:
status_placeholder.write("Monitoring interrupted or URL changed.")
break
timer_placeholder.write(f"Next scan: {i}s")
time.sleep(1)
except Exception as e:
status_placeholder.error(f"Image fetch error: {e}")
st.session_state.monitoring = False
break
if not st.session_state.monitoring:
timer_placeholder.write("Monitoring stopped.")