Spaces:
Running
Running
import os | |
import tempfile | |
import base64 | |
import time | |
import cv2 | |
import streamlit as st | |
import requests | |
from ultralytics import YOLO | |
from huggingface_hub import hf_hub_download | |
import imageio | |
import numpy as np | |
# Page config must be first | |
st.set_page_config( | |
page_title="Wildfire Detection Demo", | |
page_icon="🔥", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Helper function to display videos | |
def show_video(video_bytes: bytes, title: str, loop=True): | |
if not video_bytes: | |
st.warning(f"No {title} video available.") | |
return | |
video_base64 = base64.b64encode(video_bytes).decode() | |
loop_attr = "loop" if loop else "" | |
video_html = f""" | |
<h4>{title}</h4> | |
<video width="100%" controls autoplay muted {loop_attr}> | |
<source src="data:video/mp4;base64,{video_base64}" type="video/mp4"> | |
Your browser does not support the video tag. | |
</video> | |
""" | |
st.markdown(video_html, unsafe_allow_html=True) | |
# Initialize session state | |
for key in ["processed_video", "processing_complete", "start_time", "progress"]: | |
if key not in st.session_state: | |
st.session_state[key] = None if key in ["processed_video", "start_time"] else False if key == "processing_complete" else 0 | |
# Load model | |
def load_model(): | |
repo_id = "tstone87/ccr-colorado" | |
filename = "best.pt" | |
try: | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") | |
return YOLO(model_path) | |
except Exception as e: | |
st.error(f"Failed to load model: {str(e)}") | |
return None | |
model = load_model() | |
# Sidebar | |
with st.sidebar: | |
st.header("Process Your Own Video") | |
uploaded_file = st.file_uploader("Upload a video", type=["mp4"]) | |
confidence = st.slider("Detection Confidence", 0.25, 1.0, 0.4) | |
fps_options = { | |
"Original FPS": None, | |
"3 FPS": 3, | |
"1 FPS": 1, | |
"1 frame/4s": 0.25, | |
"1 frame/10s": 0.1, | |
"1 frame/15s": 0.0667, | |
"1 frame/30s": 0.0333 | |
} | |
selected_fps = st.selectbox("Output FPS", list(fps_options.keys()), index=0) | |
process_button = st.button("Process Video") | |
progress_bar = st.progress(0) | |
progress_text = st.empty() | |
download_slot = st.empty() | |
# Main content | |
st.title("Wildfire Detection Demo") | |
st.markdown("Watch our example videos below or upload your own in the sidebar!") | |
# Example videos | |
example_videos = { | |
"T Example": ("T1.mp4", "T2.mpg"), | |
"LA Example": ("LA1.mp4", "LA2.mp4") | |
} | |
for example_name in example_videos: | |
col1, col2 = st.columns(2) | |
orig_file, proc_file = example_videos[example_name] | |
try: | |
orig_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{orig_file}" | |
proc_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{proc_file}" | |
orig_data = requests.get(orig_url).content | |
proc_data = requests.get(proc_url).content | |
with col1: | |
show_video(orig_data, f"{example_name} - Original", loop=True) | |
with col2: | |
show_video(proc_data, f"{example_name} - Processed", loop=True) | |
except Exception as e: | |
st.error(f"Failed to load {example_name}: {str(e)}") | |
# Video processing | |
def process_video(video_file, target_fps, confidence): | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
tmp.write(video_file.read()) | |
tmp_path = tmp.name | |
try: | |
reader = imageio.get_reader(tmp_path) | |
meta = reader.get_meta_data() | |
original_fps = meta['fps'] | |
width, height = meta['size'] | |
total_frames = meta['nframes'] if meta['nframes'] != float('inf') else 1000 # Fallback for unknown length | |
output_fps = fps_options[target_fps] if fps_options[target_fps] else original_fps | |
frame_interval = max(1, int(original_fps / output_fps)) if output_fps else 1 | |
out_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name | |
writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), output_fps or original_fps, (width, height)) | |
st.session_state.start_time = time.time() | |
processed_count = 0 | |
for i, frame in enumerate(reader): | |
if i % frame_interval == 0: | |
frame_rgb = np.array(frame) | |
results = model.predict(frame_rgb, conf=confidence) | |
processed_frame = results[0].plot()[:, :, ::-1] | |
writer.write(processed_frame) | |
processed_count += 1 | |
elapsed = time.time() - st.session_state.start_time | |
progress = (i + 1) / total_frames | |
st.session_state.progress = min(progress, 1.0) | |
if elapsed > 0: | |
frames_left = total_frames - i - 1 | |
time_per_frame = elapsed / processed_count | |
eta = frames_left * time_per_frame / frame_interval | |
eta_str = f"{int(eta // 60)}m {int(eta % 60)}s" | |
else: | |
eta_str = "Calculating..." | |
progress_bar.progress(st.session_state.progress) | |
progress_text.text(f"Progress: {st.session_state.progress:.1%} | ETA: {eta_str}") | |
writer.release() | |
reader.close() | |
with open(out_path, 'rb') as f: | |
return f.read() | |
finally: | |
if os.path.exists(tmp_path): | |
os.unlink(tmp_path) | |
if os.path.exists(out_path): | |
os.unlink(out_path) | |
# Process uploaded video | |
if process_button and uploaded_file and model: | |
with st.spinner("Processing video..."): | |
st.session_state.processed_video = process_video(uploaded_file, selected_fps, confidence) | |
st.session_state.processing_complete = True | |
progress_bar.progress(1.0) | |
progress_text.text("Processing complete!") | |
# Show processed video and download button | |
if st.session_state.processing_complete and st.session_state.processed_video: | |
st.subheader("Your Processed Video") | |
show_video(st.session_state.processed_video, "Processed Result", loop=False) | |
download_slot.download_button( | |
label="Download Processed Video", | |
data=st.session_state.processed_video, | |
file_name="processed_wildfire.mp4", | |
mime="video/mp4" | |
) | |
if not model: | |
st.error("Model loading failed. Please check the repository and model file availability.") |