ccr-colorado / app.py
tstone87's picture
Update app.py
6cd7819 verified
raw
history blame
6.11 kB
import os
import tempfile
import cv2
import streamlit as st
import PIL
import requests
from ultralytics import YOLO
import time
import numpy as np
# Page config first
st.set_page_config(
page_title="WildfireWatch: AI Detection",
page_icon="🔥",
layout="wide",
initial_sidebar_state="expanded"
)
# Model path
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
# Session state initialization
for key in ["processed_frames", "slider_value", "processed_video", "start_time"]:
if key not in st.session_state:
st.session_state[key] = [] if key == "processed_frames" else 0 if key == "slider_value" else None
# Sidebar
with st.sidebar:
st.header("Upload & Settings")
source_file = st.file_uploader("Upload image/video", type=["jpg", "jpeg", "png", "bmp", "webp", "mp4"])
confidence = float(st.slider("Confidence Threshold", 25, 100, 40)) / 100
fps_options = {
"Original FPS": None,
"3 FPS": 3,
"1 FPS": 1,
"1 frame/4s": 0.25,
"1 frame/10s": 0.1,
"1 frame/15s": 0.0667,
"1 frame/30s": 0.0333
}
video_option = st.selectbox("Output Frame Rate", list(fps_options.keys()))
process_button = st.button("Detect Wildfire")
progress_bar = st.progress(0)
progress_text = st.empty()
download_slot = st.empty()
# Main page
st.title("WildfireWatch: AI-Powered Detection")
col1, col2 = st.columns(2)
with col1:
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True)
with col2:
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True)
st.markdown("""
Early wildfire detection using YOLOv8 AI vision model. See examples below or upload your own content!
""")
# Example videos
st.header("Example Results")
for example in [("T1.mp4", "T2.mpg"), ("LA1.mp4", "LA2.mp4")]:
col1, col2 = st.columns(2)
orig_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{example[0]}"
proc_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{example[1]}"
orig_data = requests.get(orig_url).content
proc_data = requests.get(proc_url).content
with col1:
st.video(orig_data)
with col2:
st.video(proc_data)
st.header("Your Results")
result_cols = st.columns(2)
viewer_slot = st.empty()
# Load model
try:
model = YOLO(model_path)
except Exception as ex:
st.error(f"Model loading failed: {str(ex)}")
model = None
# Processing
if process_button and source_file and model:
st.session_state.processed_frames = []
if source_file.type.split('/')[0] == 'image':
image = PIL.Image.open(source_file)
res = model.predict(image, conf=confidence)
result = res[0].plot()[:, :, ::-1]
with result_cols[0]:
st.image(image, caption="Original", use_column_width=True)
with result_cols[1]:
st.image(result, caption="Detected", use_column_width=True)
else:
# Video processing
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(source_file.read())
vidcap = cv2.VideoCapture(tmp.name)
orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_fps = fps_options[video_option] if fps_options[video_option] else orig_fps
sample_interval = max(1, int(orig_fps / output_fps)) if output_fps else 1
st.session_state.start_time = time.time()
frame_count = 0
processed_count = 0
success, frame = vidcap.read()
while success:
if frame_count % sample_interval == 0:
res = model.predict(frame, conf=confidence)
processed_frame = res[0].plot()[:, :, ::-1]
st.session_state.processed_frames.append(processed_frame)
processed_count += 1
elapsed = time.time() - st.session_state.start_time
progress = frame_count / total_frames
if elapsed > 0 and processed_count > 0:
time_per_frame = elapsed / processed_count
frames_left = (total_frames - frame_count) / sample_interval
eta = frames_left * time_per_frame
eta_str = f"{int(eta // 60)}m {int(eta % 60)}s"
else:
eta_str = "Calculating..."
progress_bar.progress(min(progress, 1.0))
progress_text.text(f"Progress: {progress:.1%} | ETA: {eta_str}")
frame_count += 1
success, frame = vidcap.read()
vidcap.release()
os.unlink(tmp.name)
if st.session_state.processed_frames:
out_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), output_fps or orig_fps, (width, height))
for frame in st.session_state.processed_frames:
writer.write(frame)
writer.release()
with open(out_path, 'rb') as f:
st.session_state.processed_video = f.read()
os.unlink(out_path)
progress_bar.progress(1.0)
progress_text.text("Processing complete!")
with result_cols[0]:
st.video(source_file)
with result_cols[1]:
st.video(st.session_state.processed_video)
download_slot.download_button(
label="Download Processed Video",
data=st.session_state.processed_video,
file_name="processed_wildfire.mp4",
mime="video/mp4"
)
if not source_file:
st.info("Please upload a file to begin.")