Spaces:
Running
Running
import os | |
import tempfile | |
import cv2 | |
import streamlit as st | |
import PIL | |
import requests | |
from ultralytics import YOLO | |
import time | |
import numpy as np | |
import imageio_ffmpeg as ffmpeg | |
import base64 | |
# Page config first | |
st.set_page_config( | |
page_title="Fire Watch: Fire and Smoke Detection with an AI Vision Model", | |
page_icon="🔥", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Model path | |
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt' | |
# Session state initialization | |
for key in ["processed_frames", "slider_value", "processed_video", "start_time"]: | |
if key not in st.session_state: | |
st.session_state[key] = [] if key == "processed_frames" else 0 if key == "slider_value" else None | |
# Sidebar | |
with st.sidebar: | |
st.header("Upload & Settings") | |
source_file = st.file_uploader("Upload image/video", type=["jpg", "jpeg", "png", "bmp", "webp", "mp4"]) | |
confidence = float(st.slider("Confidence Threshold", 10, 100, 20)) / 100 | |
fps_options = { | |
"Original FPS": None, | |
"3 FPS": 3, | |
"1 FPS": 1, | |
"1 frame/4s": 0.25, | |
"1 frame/10s": 0.1, | |
"1 frame/15s": 0.0667, | |
"1 frame/30s": 0.0333 | |
} | |
video_option = st.selectbox("Output Frame Rate", list(fps_options.keys())) | |
process_button = st.button("Detect fire") | |
progress_bar = st.progress(0) | |
progress_text = st.empty() | |
download_slot = st.empty() | |
# Main page | |
st.title("Fire Watch: AI-Powered Fire and Smoke Detection") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True) | |
with col2: | |
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True) | |
st.markdown(""" | |
Early wildfire detection using YOLOv8 AI vision model. See examples below or upload your own content! | |
""") | |
# Function to create synchronized video pair HTML | |
def create_synced_video_pair(orig_url, proc_url, pair_id): | |
orig_bytes = requests.get(orig_url).content | |
proc_bytes = requests.get(proc_url).content | |
orig_b64 = base64.b64encode(orig_bytes).decode('utf-8') | |
proc_b64 = base64.b64encode(proc_bytes).decode('utf-8') | |
html = f""" | |
<div style="display: flex; justify-content: space-between;"> | |
<div style="width: 48%;"> | |
<h4>Original</h4> | |
<video id="orig_{pair_id}" width="100%" controls> | |
<source src="data:video/mp4;base64,{orig_b64}" type="video/mp4"> | |
</video> | |
</div> | |
<div style="width: 48%;"> | |
<h4>Processed</h4> | |
<video id="proc_{pair_id}" width="100%" controls> | |
<source src="data:video/mp4;base64,{proc_b64}" type="video/mp4"> | |
</video> | |
</div> | |
</div> | |
<script> | |
const origVideo_{pair_id} = document.getElementById('orig_{pair_id}'); | |
const procVideo_{pair_id} = document.getElementById('proc_{pair_id}'); | |
origVideo_{pair_id}.addEventListener('play', function() {{ | |
procVideo_{pair_id}.currentTime = origVideo_{pair_id}.currentTime; | |
procVideo_{pair_id}.play(); | |
}}); | |
procVideo_{pair_id}.addEventListener('play', function() {{ | |
origVideo_{pair_id}.currentTime = procVideo_{pair_id}.currentTime; | |
origVideo_{pair_id}.play(); | |
}}); | |
origVideo_{pair_id}.addEventListener('pause', function() {{ | |
procVideo_{pair_id}.pause(); | |
}}); | |
procVideo_{pair_id}.addEventListener('pause', function() {{ | |
origVideo_{pair_id}.pause(); | |
}}); | |
origVideo_{pair_id}.addEventListener('seeked', function() {{ | |
procVideo_{pair_id}.currentTime = origVideo_{pair_id}.currentTime; | |
}}); | |
procVideo_{pair_id}.addEventListener('seeked', function() {{ | |
origVideo_{pair_id}.currentTime = procVideo_{pair_id}.currentTime; | |
}}); | |
</script> | |
""" | |
return html | |
st.header("Your Results") | |
result_cols = st.columns(2) | |
viewer_slot = st.empty() | |
# Example videos with synchronization | |
st.header("Example Results") | |
examples = [ | |
("T Example", "T1.mp4", "T2.mp4"), | |
("LA Example", "LA1.mp4", "LA2.mp4") | |
] | |
for title, orig_file, proc_file in examples: | |
st.subheader(title) | |
orig_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{orig_file}" | |
proc_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{proc_file}" | |
pair_id = title.replace(" ", "").lower() | |
video_html = create_synced_video_pair(orig_url, proc_url, pair_id) | |
st.markdown(video_html, unsafe_allow_html=True) | |
# Load model | |
try: | |
model = YOLO(model_path) | |
except Exception as ex: | |
st.error(f"Model loading failed: {str(ex)}") | |
model = None | |
# Processing | |
if process_button and source_file and model: | |
st.session_state.processed_frames = [] | |
if source_file.type.split('/')[0] == 'image': | |
image = PIL.Image.open(source_file) | |
res = model.predict(image, conf=confidence) | |
result = res[0].plot()[:, :, ::-1] | |
with result_cols[0]: | |
st.image(image, caption="Original", use_column_width=True) | |
with result_cols[1]: | |
st.image(result, caption="Detected", use_column_width=True) | |
else: | |
# Video processing | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
tmp.write(source_file.read()) | |
vidcap = cv2.VideoCapture(tmp.name) | |
orig_fps = vidcap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
output_fps = fps_options[video_option] if fps_options[video_option] else orig_fps | |
sample_interval = max(1, int(orig_fps / output_fps)) if output_fps else 1 | |
st.session_state.start_time = time.time() | |
frame_count = 0 | |
processed_count = 0 | |
success, frame = vidcap.read() | |
while success: | |
if frame_count % sample_interval == 0: | |
res = model.predict(frame, conf=confidence) | |
processed_frame = res[0].plot()[:, :, ::-1] | |
if not processed_frame.flags['C_CONTIGUOUS']: | |
processed_frame = np.ascontiguousarray(processed_frame) | |
st.session_state.processed_frames.append(processed_frame) | |
processed_count += 1 | |
elapsed = time.time() - st.session_state.start_time | |
progress = frame_count / total_frames | |
if elapsed > 0 and processed_count > 0: | |
time_per_frame = elapsed / processed_count | |
frames_left = (total_frames - frame_count) / sample_interval | |
eta = frames_left * time_per_frame | |
eta_str = f"{int(eta // 60)}m {int(eta % 60)}s" | |
else: | |
eta_str = "Calculating..." | |
progress_bar.progress(min(progress, 1.0)) | |
progress_text.text(f"Progress: {progress:.1%} | ETA: {eta_str}") | |
frame_count += 1 | |
success, frame = vidcap.read() | |
vidcap.release() | |
os.unlink(tmp.name) | |
if st.session_state.processed_frames: | |
out_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name | |
writer = ffmpeg.write_frames( | |
out_path, | |
(width, height), | |
fps=output_fps or orig_fps, | |
codec='libx264', | |
pix_fmt_in='bgr24', | |
pix_fmt_out='yuv420p' | |
) | |
writer.send(None) # Initialize writer | |
for frame in st.session_state.processed_frames: | |
writer.send(frame) | |
writer.close() | |
with open(out_path, 'rb') as f: | |
st.session_state.processed_video = f.read() | |
os.unlink(out_path) | |
progress_bar.progress(1.0) | |
progress_text.text("Processing complete!") | |
with result_cols[0]: | |
st.video(source_file) | |
with result_cols[1]: | |
st.video(st.session_state.processed_video) | |
download_slot.download_button( | |
label="Download Processed Video", | |
data=st.session_state.processed_video, | |
file_name="processed_wildfire.mp4", | |
mime="video/mp4" | |
) | |
if not source_file: | |
st.info("Please upload a file to begin.") |