import os
import tempfile
import base64
import time
import cv2
import streamlit as st
import requests
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
import imageio
import numpy as np
# Page config must be first
st.set_page_config(
page_title="Wildfire Detection Demo",
page_icon="🔥",
layout="wide",
initial_sidebar_state="expanded"
)
# Helper function to display videos
def show_video(video_bytes: bytes, title: str, loop=True):
if not video_bytes:
st.warning(f"No {title} video available.")
return
video_base64 = base64.b64encode(video_bytes).decode()
loop_attr = "loop" if loop else ""
video_html = f"""
{title}
"""
st.markdown(video_html, unsafe_allow_html=True)
# Initialize session state
for key in ["processed_video", "processing_complete", "start_time", "progress"]:
if key not in st.session_state:
st.session_state[key] = None if key in ["processed_video", "start_time"] else False if key == "processing_complete" else 0
# Load model
@st.cache_resource
def load_model():
repo_id = "tstone87/ccr-colorado"
filename = "best.pt"
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
return YOLO(model_path)
except Exception as e:
st.error(f"Failed to load model: {str(e)}")
return None
model = load_model()
# Sidebar
with st.sidebar:
st.header("Process Your Own Video")
uploaded_file = st.file_uploader("Upload a video", type=["mp4"])
confidence = st.slider("Detection Confidence", 0.25, 1.0, 0.4)
fps_options = {
"Original FPS": None,
"3 FPS": 3,
"1 FPS": 1,
"1 frame/4s": 0.25,
"1 frame/10s": 0.1,
"1 frame/15s": 0.0667,
"1 frame/30s": 0.0333
}
selected_fps = st.selectbox("Output FPS", list(fps_options.keys()), index=0)
process_button = st.button("Process Video")
progress_bar = st.progress(0)
progress_text = st.empty()
download_slot = st.empty()
# Main content
st.title("Wildfire Detection Demo")
st.markdown("Watch our example videos below or upload your own in the sidebar!")
# Example videos
example_videos = {
"T Example": ("T1.mp4", "T2.mpg"),
"LA Example": ("LA1.mp4", "LA2.mp4")
}
for example_name in example_videos:
col1, col2 = st.columns(2)
orig_file, proc_file = example_videos[example_name]
try:
orig_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{orig_file}"
proc_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{proc_file}"
orig_data = requests.get(orig_url).content
proc_data = requests.get(proc_url).content
with col1:
show_video(orig_data, f"{example_name} - Original", loop=True)
with col2:
show_video(proc_data, f"{example_name} - Processed", loop=True)
except Exception as e:
st.error(f"Failed to load {example_name}: {str(e)}")
# Video processing
def process_video(video_file, target_fps, confidence):
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(video_file.read())
tmp_path = tmp.name
try:
reader = imageio.get_reader(tmp_path)
meta = reader.get_meta_data()
original_fps = meta['fps']
width, height = meta['size']
total_frames = meta['nframes'] if meta['nframes'] != float('inf') else 1000 # Fallback for unknown length
output_fps = fps_options[target_fps] if fps_options[target_fps] else original_fps
frame_interval = max(1, int(original_fps / output_fps)) if output_fps else 1
out_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), output_fps or original_fps, (width, height))
st.session_state.start_time = time.time()
processed_count = 0
for i, frame in enumerate(reader):
if i % frame_interval == 0:
frame_rgb = np.array(frame)
results = model.predict(frame_rgb, conf=confidence)
processed_frame = results[0].plot()[:, :, ::-1]
writer.write(processed_frame)
processed_count += 1
elapsed = time.time() - st.session_state.start_time
progress = (i + 1) / total_frames
st.session_state.progress = min(progress, 1.0)
if elapsed > 0:
frames_left = total_frames - i - 1
time_per_frame = elapsed / processed_count
eta = frames_left * time_per_frame / frame_interval
eta_str = f"{int(eta // 60)}m {int(eta % 60)}s"
else:
eta_str = "Calculating..."
progress_bar.progress(st.session_state.progress)
progress_text.text(f"Progress: {st.session_state.progress:.1%} | ETA: {eta_str}")
writer.release()
reader.close()
with open(out_path, 'rb') as f:
return f.read()
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
if os.path.exists(out_path):
os.unlink(out_path)
# Process uploaded video
if process_button and uploaded_file and model:
with st.spinner("Processing video..."):
st.session_state.processed_video = process_video(uploaded_file, selected_fps, confidence)
st.session_state.processing_complete = True
progress_bar.progress(1.0)
progress_text.text("Processing complete!")
# Show processed video and download button
if st.session_state.processing_complete and st.session_state.processed_video:
st.subheader("Your Processed Video")
show_video(st.session_state.processed_video, "Processed Result", loop=False)
download_slot.download_button(
label="Download Processed Video",
data=st.session_state.processed_video,
file_name="processed_wildfire.mp4",
mime="video/mp4"
)
if not model:
st.error("Model loading failed. Please check the repository and model file availability.")