ccr-colorado / app.py
tstone87's picture
Update app.py
8a3e216 verified
raw
history blame
5.69 kB
# Import required libraries
import PIL
import cv2
import streamlit as st
from ultralytics import YOLO
import tempfile
import time
import os
# Replace the relative path to your weight file
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt' # Your correct model
# Setting page layout
st.set_page_config(
page_title="WildfireWatch",
page_icon="🔥",
layout="wide",
initial_sidebar_state="expanded"
)
# Creating sidebar
with st.sidebar:
st.header("IMAGE/VIDEO UPLOAD")
source_file = st.file_uploader(
"Choose an image or video...", type=("jpg", "jpeg", "png", 'bmp', 'webp', 'mp4'))
confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100
sampling_options = {
"Every Frame": 0,
"1 FPS": 1,
"2 FPS": 2,
"5 FPS": 5,
"1 frame / 5s": 5,
"1 frame / 10s": 10,
"1 frame / 15s": 15
}
sampling_rate = st.selectbox("Analysis Rate", list(sampling_options.keys()), index=1)
# Creating main page heading
st.title("WildfireWatch: Detecting Wildfire using AI")
# Adding informative pictures and description about the motivation for the app
col1, col2 = st.columns(2)
with col1:
st.image("https://huggingface.co/spaces/ankitkupadhyay/fire_and_smoke/resolve/main/Fire_1.jpeg", use_column_width=True)
with col2:
st.image("https://huggingface.co/spaces/ankitkupadhyay/fire_and_smoke/resolve/main/Fire_2.jpeg", use_column_width=True)
st.markdown("""
Wildfires are a major environmental issue, causing substantial losses to ecosystems, human livelihoods, and potentially leading to loss of life. Early detection of wildfires can prevent these losses. Our application, WildfireWatch, uses state-of-the-art YOLOv8 model for real-time wildfire and smoke detection in images and videos.
""")
st.markdown("---")
st.header("Let's Detect Wildfire")
# Creating two columns on the main page
col1, col2 = st.columns(2)
# Adding image to the first column if image is uploaded
with col1:
if source_file:
if source_file.type.split('/')[0] == 'image':
uploaded_image = PIL.Image.open(source_file)
st.image(source_file, caption="Uploaded Image", use_column_width=True)
else:
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(source_file.read())
vidcap = cv2.VideoCapture(tfile.name)
try:
model = YOLO(model_path)
except Exception as ex:
st.error(f"Unable to load model. Check the specified path: {model_path}")
st.error(ex)
st.stop()
if st.sidebar.button('Let\'s Detect Wildfire'):
if not source_file:
st.error("Please upload a file first!")
elif source_file.type.split('/')[0] == 'image':
res = model.predict(uploaded_image, conf=confidence)
boxes = res[0].boxes
res_plotted = res[0].plot()[:, :, ::-1]
with col2:
st.image(res_plotted, caption='Detected Image', use_column_width=True)
try:
with st.expander("Detection Results"):
for box in boxes:
st.write(box.xywh)
except Exception as ex:
st.write("No image is uploaded yet!")
else:
# Frame sampling setup
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(vidcap.get(cv2.CAP_PROP_FPS)) or 30
frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
target_rate = sampling_options[sampling_rate]
frame_skip = 1 if target_rate == 0 else max(1, int(fps / target_rate) if target_rate <= 5 else int(fps * target_rate))
# Output video setup
output_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='_detected.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
output_fps = 1 # Fixed for short compilation
out = cv2.VideoWriter(output_tfile.name, fourcc, output_fps, (frame_width, frame_height))
success, image = vidcap.read()
frame_count = 0
processed_count = 0
while success:
if frame_count % frame_skip == 0:
res = model.predict(image, conf=confidence)
boxes = res[0].boxes
res_plotted = res[0].plot()[:, :, ::-1]
with col2:
st.image(res_plotted, caption=f'Detected Frame {frame_count}', use_column_width=True)
try:
with st.expander("Detection Results"):
for box in boxes:
st.write(box.xywh)
except Exception as ex:
st.write("No detection results available.")
out.write(res_plotted[:, :, ::-1]) # Write only analyzed frame
processed_count += 1
if total_frames > 0:
progress = (frame_count + 1) / total_frames * 100
st.write(f"Progress: {progress:.1f}% (Analyzed {processed_count} frames)")
success, image = vidcap.read()
frame_count += 1
time.sleep(0.05)
vidcap.release()
out.release()
os.unlink(tfile.name)
with col2:
with open(output_tfile.name, 'rb') as f:
st.download_button(
label="Download Analyzed Video",
data=f,
file_name="analyzed_video.mp4",
mime="video/mp4"
)
st.write(f"Video processing complete. Analyzed {processed_count} frames.")