Spaces:
Running
Running
# Import required libraries | |
import PIL | |
import cv2 | |
import streamlit as st | |
from ultralytics import YOLO | |
import tempfile | |
import time | |
import os | |
# Replace the relative path to your weight file | |
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt' # Your correct model | |
# Setting page layout | |
st.set_page_config( | |
page_title="WildfireWatch", | |
page_icon="🔥", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Creating sidebar | |
with st.sidebar: | |
st.header("IMAGE/VIDEO UPLOAD") | |
source_file = st.file_uploader( | |
"Choose an image or video...", type=("jpg", "jpeg", "png", 'bmp', 'webp', 'mp4')) | |
confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100 | |
sampling_options = { | |
"Every Frame": 0, | |
"1 FPS": 1, | |
"2 FPS": 2, | |
"5 FPS": 5, | |
"1 frame / 5s": 5, | |
"1 frame / 10s": 10, | |
"1 frame / 15s": 15 | |
} | |
sampling_rate = st.selectbox("Analysis Rate", list(sampling_options.keys()), index=1) | |
# Creating main page heading | |
st.title("WildfireWatch: Detecting Wildfire using AI") | |
# Adding informative pictures and description about the motivation for the app | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image("https://huggingface.co/spaces/ankitkupadhyay/fire_and_smoke/resolve/main/Fire_1.jpeg", use_column_width=True) | |
with col2: | |
st.image("https://huggingface.co/spaces/ankitkupadhyay/fire_and_smoke/resolve/main/Fire_2.jpeg", use_column_width=True) | |
st.markdown(""" | |
Wildfires are a major environmental issue, causing substantial losses to ecosystems, human livelihoods, and potentially leading to loss of life. Early detection of wildfires can prevent these losses. Our application, WildfireWatch, uses state-of-the-art YOLOv8 model for real-time wildfire and smoke detection in images and videos. | |
""") | |
st.markdown("---") | |
st.header("Let's Detect Wildfire") | |
# Creating two columns on the main page | |
col1, col2 = st.columns(2) | |
# Adding image to the first column if image is uploaded | |
with col1: | |
if source_file: | |
if source_file.type.split('/')[0] == 'image': | |
uploaded_image = PIL.Image.open(source_file) | |
st.image(source_file, caption="Uploaded Image", use_column_width=True) | |
else: | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(source_file.read()) | |
vidcap = cv2.VideoCapture(tfile.name) | |
try: | |
model = YOLO(model_path) | |
except Exception as ex: | |
st.error(f"Unable to load model. Check the specified path: {model_path}") | |
st.error(ex) | |
st.stop() | |
if st.sidebar.button('Let\'s Detect Wildfire'): | |
if not source_file: | |
st.error("Please upload a file first!") | |
elif source_file.type.split('/')[0] == 'image': | |
res = model.predict(uploaded_image, conf=confidence) | |
boxes = res[0].boxes | |
res_plotted = res[0].plot()[:, :, ::-1] | |
with col2: | |
st.image(res_plotted, caption='Detected Image', use_column_width=True) | |
try: | |
with st.expander("Detection Results"): | |
for box in boxes: | |
st.write(box.xywh) | |
except Exception as ex: | |
st.write("No image is uploaded yet!") | |
else: | |
# Frame sampling setup | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = int(vidcap.get(cv2.CAP_PROP_FPS)) or 30 | |
frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
target_rate = sampling_options[sampling_rate] | |
frame_skip = 1 if target_rate == 0 else max(1, int(fps / target_rate) if target_rate <= 5 else int(fps * target_rate)) | |
# Output video setup | |
output_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='_detected.mp4') | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
output_fps = 1 # Fixed for short compilation | |
out = cv2.VideoWriter(output_tfile.name, fourcc, output_fps, (frame_width, frame_height)) | |
success, image = vidcap.read() | |
frame_count = 0 | |
processed_count = 0 | |
while success: | |
if frame_count % frame_skip == 0: | |
res = model.predict(image, conf=confidence) | |
boxes = res[0].boxes | |
res_plotted = res[0].plot()[:, :, ::-1] | |
with col2: | |
st.image(res_plotted, caption=f'Detected Frame {frame_count}', use_column_width=True) | |
try: | |
with st.expander("Detection Results"): | |
for box in boxes: | |
st.write(box.xywh) | |
except Exception as ex: | |
st.write("No detection results available.") | |
out.write(res_plotted[:, :, ::-1]) # Write only analyzed frame | |
processed_count += 1 | |
if total_frames > 0: | |
progress = (frame_count + 1) / total_frames * 100 | |
st.write(f"Progress: {progress:.1f}% (Analyzed {processed_count} frames)") | |
success, image = vidcap.read() | |
frame_count += 1 | |
time.sleep(0.05) | |
vidcap.release() | |
out.release() | |
os.unlink(tfile.name) | |
with col2: | |
with open(output_tfile.name, 'rb') as f: | |
st.download_button( | |
label="Download Analyzed Video", | |
data=f, | |
file_name="analyzed_video.mp4", | |
mime="video/mp4" | |
) | |
st.write(f"Video processing complete. Analyzed {processed_count} frames.") |