ccr-colorado / app.py
tstone87's picture
Update app.py
cb79d6c verified
raw
history blame
7.86 kB
import os
import tempfile
import cv2
import streamlit as st
import PIL
from ultralytics import YOLO
# Ensure your model path points directly to the .pt file (not an HTML page)
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
st.set_page_config(
page_title="Fire Watch using AI vision models",
page_icon="🔥",
layout="wide",
initial_sidebar_state="expanded"
)
# --- SIDEBAR ---
with st.sidebar:
st.header("IMAGE/VIDEO UPLOAD")
source_file = st.file_uploader("Choose an image or video...",
type=("jpg", "jpeg", "png", "bmp", "webp", "mp4"))
confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
video_option = st.selectbox(
"Select Video Shortening Option",
["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"]
)
progress_text = st.empty()
progress_bar = st.progress(0)
# --- MAIN PAGE TITLE AND IMAGES ---
st.title("WildfireWatch: Detecting Wildfire using AI")
col1, col2 = st.columns(2)
with col1:
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True)
with col2:
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True)
st.markdown("""
Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas.
Early detection is critical. WildfireWatch leverages YOLOv8 for real-time fire and smoke detection
in images and videos.
""")
st.markdown("---")
st.header("Fire Detection:")
# --- DISPLAY UPLOADED FILE ---
col1, col2 = st.columns(2)
if source_file:
file_type = source_file.type.split('/')[0]
if file_type == 'image':
uploaded_image = PIL.Image.open(source_file)
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
else:
# Temporarily store the uploaded video
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(source_file.read())
vidcap = cv2.VideoCapture(tfile.name)
else:
st.info("Please upload an image or video file to begin.")
# --- LOAD YOLO MODEL ---
try:
model = YOLO(model_path)
except Exception as ex:
st.error(f"Unable to load model. Check the specified path: {model_path}")
st.error(ex)
# --- SESSION STATE FOR PROCESSED FRAMES ---
if "processed_frames" not in st.session_state:
st.session_state["processed_frames"] = []
# We'll keep the detection results for each frame (if you want them)
if "frame_detections" not in st.session_state:
st.session_state["frame_detections"] = []
# --- WHEN USER CLICKS DETECT ---
if st.sidebar.button("Let's Detect Wildfire"):
if not source_file:
st.warning("No file uploaded!")
elif file_type == 'image':
# IMAGE DETECTION
res = model.predict(uploaded_image, conf=confidence)
boxes = res[0].boxes
res_plotted = res[0].plot()[:, :, ::-1]
with col2:
st.image(res_plotted, caption='Detected Image', use_column_width=True)
with st.expander("Detection Results"):
for box in boxes:
st.write(box.xywh)
else:
# VIDEO DETECTION
# Clear previous frames from session_state
st.session_state["processed_frames"] = []
st.session_state["frame_detections"] = []
processed_frames = st.session_state["processed_frames"]
frame_detections = st.session_state["frame_detections"]
frame_count = 0
orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Determine sampling interval
if video_option == "Original FPS":
sample_interval = 1
output_fps = orig_fps
elif video_option == "1 fps":
sample_interval = int(orig_fps) if orig_fps > 0 else 1
output_fps = 1
elif video_option == "1 frame per 5 seconds":
sample_interval = int(orig_fps * 5) if orig_fps > 0 else 5
output_fps = 1
elif video_option == "1 frame per 10 seconds":
sample_interval = int(orig_fps * 10) if orig_fps > 0 else 10
output_fps = 1
elif video_option == "1 frame per 15 seconds":
sample_interval = int(orig_fps * 15) if orig_fps > 0 else 15
output_fps = 1
else:
sample_interval = 1
output_fps = orig_fps
success, image = vidcap.read()
while success:
if frame_count % sample_interval == 0:
# Run detection
res = model.predict(image, conf=confidence)
res_plotted = res[0].plot()[:, :, ::-1]
processed_frames.append(res_plotted)
# If you want to store bounding boxes for each frame:
frame_detections.append(res[0].boxes)
# Update progress
if total_frames > 0:
progress_pct = int((frame_count / total_frames) * 100)
progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)")
progress_bar.progress(min(100, progress_pct))
else:
progress_text.text(f"Processing frame {frame_count}")
frame_count += 1
success, image = vidcap.read()
# Processing complete
progress_text.text("Video processing complete!")
progress_bar.progress(100)
# Create shortened video from processed frames
if processed_frames:
temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video_file.name, fourcc, output_fps, (width, height))
for frame in processed_frames:
out.write(frame)
out.release()
st.success("Shortened video created successfully!")
with open(temp_video_file.name, 'rb') as video_file:
st.download_button(
label="Download Shortened Video",
data=video_file.read(),
file_name="shortened_video.mp4",
mime="video/mp4"
)
else:
st.error("No frames were processed from the video.")
# --- DISPLAY THE PROCESSED FRAMES AFTER DETECTION ---
if st.session_state["processed_frames"]:
st.markdown("### Browse Detected Frames")
num_frames = len(st.session_state["processed_frames"])
if num_frames == 1:
# Only one frame was processed
st.image(st.session_state["processed_frames"][0], caption="Frame 0", use_column_width=True)
# If you want to show bounding boxes:
if st.session_state["frame_detections"]:
with st.expander("Detection Results for Frame 0"):
for box in st.session_state["frame_detections"][0]:
st.write(box.xywh)
else:
# Multiple frames
frame_idx = st.slider(
"Select Frame",
min_value=0,
max_value=num_frames - 1,
value=0,
step=1
)
st.image(st.session_state["processed_frames"][frame_idx],
caption=f"Frame {frame_idx}",
use_column_width=True)
# If you want to show bounding boxes:
if st.session_state["frame_detections"]:
with st.expander(f"Detection Results for Frame {frame_idx}"):
for box in st.session_state["frame_detections"][frame_idx]:
st.write(box.xywh)