Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import cv2 | |
import streamlit as st | |
import PIL | |
from ultralytics import YOLO | |
# Ensure your model path points directly to the .pt file (not an HTML page) | |
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt' | |
st.set_page_config( | |
page_title="Fire Watch using AI vision models", | |
page_icon="🔥", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# --- SIDEBAR --- | |
with st.sidebar: | |
st.header("IMAGE/VIDEO UPLOAD") | |
source_file = st.file_uploader("Choose an image or video...", | |
type=("jpg", "jpeg", "png", "bmp", "webp", "mp4")) | |
confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100 | |
video_option = st.selectbox( | |
"Select Video Shortening Option", | |
["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"] | |
) | |
progress_text = st.empty() | |
progress_bar = st.progress(0) | |
# --- MAIN PAGE TITLE AND IMAGES --- | |
st.title("WildfireWatch: Detecting Wildfire using AI") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True) | |
with col2: | |
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True) | |
st.markdown(""" | |
Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas. | |
Early detection is critical. WildfireWatch leverages YOLOv8 for real-time fire and smoke detection | |
in images and videos. | |
""") | |
st.markdown("---") | |
st.header("Fire Detection:") | |
# --- DISPLAY UPLOADED FILE --- | |
col1, col2 = st.columns(2) | |
if source_file: | |
file_type = source_file.type.split('/')[0] | |
if file_type == 'image': | |
uploaded_image = PIL.Image.open(source_file) | |
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) | |
else: | |
# Temporarily store the uploaded video | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(source_file.read()) | |
vidcap = cv2.VideoCapture(tfile.name) | |
else: | |
st.info("Please upload an image or video file to begin.") | |
# --- LOAD YOLO MODEL --- | |
try: | |
model = YOLO(model_path) | |
except Exception as ex: | |
st.error(f"Unable to load model. Check the specified path: {model_path}") | |
st.error(ex) | |
# --- SESSION STATE FOR PROCESSED FRAMES --- | |
if "processed_frames" not in st.session_state: | |
st.session_state["processed_frames"] = [] | |
# We'll keep the detection results for each frame (if you want them) | |
if "frame_detections" not in st.session_state: | |
st.session_state["frame_detections"] = [] | |
# --- WHEN USER CLICKS DETECT --- | |
if st.sidebar.button("Let's Detect Wildfire"): | |
if not source_file: | |
st.warning("No file uploaded!") | |
elif file_type == 'image': | |
# IMAGE DETECTION | |
res = model.predict(uploaded_image, conf=confidence) | |
boxes = res[0].boxes | |
res_plotted = res[0].plot()[:, :, ::-1] | |
with col2: | |
st.image(res_plotted, caption='Detected Image', use_column_width=True) | |
with st.expander("Detection Results"): | |
for box in boxes: | |
st.write(box.xywh) | |
else: | |
# VIDEO DETECTION | |
# Clear previous frames from session_state | |
st.session_state["processed_frames"] = [] | |
st.session_state["frame_detections"] = [] | |
processed_frames = st.session_state["processed_frames"] | |
frame_detections = st.session_state["frame_detections"] | |
frame_count = 0 | |
orig_fps = vidcap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Determine sampling interval | |
if video_option == "Original FPS": | |
sample_interval = 1 | |
output_fps = orig_fps | |
elif video_option == "1 fps": | |
sample_interval = int(orig_fps) if orig_fps > 0 else 1 | |
output_fps = 1 | |
elif video_option == "1 frame per 5 seconds": | |
sample_interval = int(orig_fps * 5) if orig_fps > 0 else 5 | |
output_fps = 1 | |
elif video_option == "1 frame per 10 seconds": | |
sample_interval = int(orig_fps * 10) if orig_fps > 0 else 10 | |
output_fps = 1 | |
elif video_option == "1 frame per 15 seconds": | |
sample_interval = int(orig_fps * 15) if orig_fps > 0 else 15 | |
output_fps = 1 | |
else: | |
sample_interval = 1 | |
output_fps = orig_fps | |
success, image = vidcap.read() | |
while success: | |
if frame_count % sample_interval == 0: | |
# Run detection | |
res = model.predict(image, conf=confidence) | |
res_plotted = res[0].plot()[:, :, ::-1] | |
processed_frames.append(res_plotted) | |
# If you want to store bounding boxes for each frame: | |
frame_detections.append(res[0].boxes) | |
# Update progress | |
if total_frames > 0: | |
progress_pct = int((frame_count / total_frames) * 100) | |
progress_text.text(f"Processing frame {frame_count} / {total_frames} ({progress_pct}%)") | |
progress_bar.progress(min(100, progress_pct)) | |
else: | |
progress_text.text(f"Processing frame {frame_count}") | |
frame_count += 1 | |
success, image = vidcap.read() | |
# Processing complete | |
progress_text.text("Video processing complete!") | |
progress_bar.progress(100) | |
# Create shortened video from processed frames | |
if processed_frames: | |
temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(temp_video_file.name, fourcc, output_fps, (width, height)) | |
for frame in processed_frames: | |
out.write(frame) | |
out.release() | |
st.success("Shortened video created successfully!") | |
with open(temp_video_file.name, 'rb') as video_file: | |
st.download_button( | |
label="Download Shortened Video", | |
data=video_file.read(), | |
file_name="shortened_video.mp4", | |
mime="video/mp4" | |
) | |
else: | |
st.error("No frames were processed from the video.") | |
# --- DISPLAY THE PROCESSED FRAMES AFTER DETECTION --- | |
if st.session_state["processed_frames"]: | |
st.markdown("### Browse Detected Frames") | |
num_frames = len(st.session_state["processed_frames"]) | |
if num_frames == 1: | |
# Only one frame was processed | |
st.image(st.session_state["processed_frames"][0], caption="Frame 0", use_column_width=True) | |
# If you want to show bounding boxes: | |
if st.session_state["frame_detections"]: | |
with st.expander("Detection Results for Frame 0"): | |
for box in st.session_state["frame_detections"][0]: | |
st.write(box.xywh) | |
else: | |
# Multiple frames | |
frame_idx = st.slider( | |
"Select Frame", | |
min_value=0, | |
max_value=num_frames - 1, | |
value=0, | |
step=1 | |
) | |
st.image(st.session_state["processed_frames"][frame_idx], | |
caption=f"Frame {frame_idx}", | |
use_column_width=True) | |
# If you want to show bounding boxes: | |
if st.session_state["frame_detections"]: | |
with st.expander(f"Detection Results for Frame {frame_idx}"): | |
for box in st.session_state["frame_detections"][frame_idx]: | |
st.write(box.xywh) | |