Spaces:
Sleeping
Sleeping
import streamlit as st | |
import cv2 | |
import PIL.Image | |
from ultralytics import YOLO | |
import tempfile | |
import time | |
import requests | |
import numpy as np | |
import streamlink | |
# Page Config | |
st.set_page_config(page_title="AI Fire Watch", page_icon="🌍", layout="wide") | |
# Lighter Background CSS with Darker Text | |
st.markdown( | |
""" | |
<style> | |
.stApp { | |
background-color: #f5f5f5; | |
color: #1a1a1a; /* Dark text for general content */ | |
} | |
h1 { | |
color: #1a1a1a; /* Darker title text */ | |
} | |
.stTabs > div > button { | |
background-color: #e0e0e0; | |
color: #1a1a1a; /* Darker tab text */ | |
font-weight: bold; | |
} | |
.stTabs > div > button:hover { | |
background-color: #d0d0d0; | |
color: #1a1a1a; | |
} | |
.stButton > button { | |
background-color: #e0e0e0; | |
color: #1a1a1a; /* Darker button text */ | |
font-weight: bold; | |
} | |
.stButton > button:hover { | |
background-color: #d0d0d0; | |
color: #1a1a1a; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# Load Model | |
model_path = 'https://huggingface.co/spaces/ankitkupadhyay/fire_and_smoke/resolve/main/best.pt' | |
try: | |
model = YOLO(model_path) | |
except Exception as ex: | |
st.error(f"Model loading failed: {ex}") | |
st.stop() | |
# Initialize Session State | |
if 'monitoring' not in st.session_state: | |
st.session_state.monitoring = False | |
if 'current_webcam_url' not in st.session_state: | |
st.session_state.current_webcam_url = None | |
# Header | |
st.title("AI Fire Watch") | |
st.markdown("Monitor fire and smoke in real-time with AI precision.") | |
# Tabs | |
tabs = st.tabs(["Upload", "Webcam", "YouTube"]) | |
# Tab 1: Upload | |
with tabs[0]: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.markdown("**Add Your File**") | |
st.write("Upload an image or video to scan for fire or smoke.") | |
uploaded_file = st.file_uploader("", type=["jpg", "jpeg", "png", "mp4"], label_visibility="collapsed") | |
confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="upload_conf") | |
with col2: | |
if uploaded_file: | |
file_type = uploaded_file.type.split('/')[0] | |
if file_type == 'image': | |
image = PIL.Image.open(uploaded_file) | |
results = model.predict(image, conf=confidence) | |
detected_image = results[0].plot()[:, :, ::-1] | |
st.image(detected_image, use_column_width=True) | |
st.write(f"Objects detected: {len(results[0].boxes)}") | |
elif file_type == 'video': | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(uploaded_file.read()) | |
cap = cv2.VideoCapture(tfile.name) | |
frame_placeholder = st.empty() | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
results = model.predict(frame, conf=confidence) | |
detected_frame = results[0].plot()[:, :, ::-1] | |
frame_placeholder.image(detected_frame, use_column_width=True) | |
time.sleep(0.05) | |
cap.release() | |
# Tab 2: Webcam | |
with tabs[1]: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.markdown("**Webcam Feed**") | |
st.write("Provide a webcam URL to check snapshots for hazards every 30 seconds.") | |
webcam_url = st.text_input("Webcam URL", "http://<your_webcam_ip>/current.jpg", label_visibility="collapsed") | |
confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="webcam_conf") | |
start = st.button("Begin Monitoring", key="webcam_start") | |
stop = st.button("Stop Monitoring", key="webcam_stop") | |
# Handle monitoring state | |
if start: | |
st.session_state.monitoring = True | |
st.session_state.current_webcam_url = webcam_url | |
if stop or (st.session_state.monitoring and webcam_url != st.session_state.current_webcam_url): | |
st.session_state.monitoring = False | |
st.session_state.current_webcam_url = None | |
with col2: | |
if st.session_state.monitoring and st.session_state.current_webcam_url: | |
image_placeholder = st.empty() | |
timer_placeholder = st.empty() | |
refresh_interval = 30 # Refresh every 30 seconds | |
while True: | |
start_time = time.time() | |
try: | |
response = requests.get(st.session_state.current_webcam_url, timeout=5) | |
if response.status_code != 200: | |
st.error(f"Fetch failed: HTTP {response.status_code}") | |
break | |
image_array = np.asarray(bytearray(response.content), dtype=np.uint8) | |
frame = cv2.imdecode(image_array, cv2.IMREAD_COLOR) | |
if frame is None: | |
st.error("Image decoding failed.") | |
break | |
results = model.predict(frame, conf=confidence) | |
detected_frame = results[0].plot()[:, :, ::-1] | |
image_placeholder.image(detected_frame, use_column_width=True) | |
elapsed = time.time() - start_time | |
remaining = max(0, refresh_interval - elapsed) | |
timer_placeholder.write(f"Next scan: {int(remaining)}s") | |
while remaining > 0: | |
time.sleep(1) | |
elapsed = time.time() - start_time | |
remaining = max(0, refresh_interval - elapsed) | |
timer_placeholder.write(f"Next scan: {int(remaining)}s") | |
st.experimental_rerun() | |
except Exception as e: | |
st.error(f"Error: {e}") | |
st.session_state.monitoring = False | |
break | |
# Tab 3: YouTube | |
with tabs[2]: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.markdown("**YouTube Live**") | |
st.write("Enter a live YouTube URL to auto-analyze the stream.") | |
youtube_url = st.text_input("YouTube URL", "https://www.youtube.com/watch?v=<id>", label_visibility="collapsed") | |
confidence = st.slider("Detection Threshold", 0.25, 1.0, 0.4, key="yt_conf") | |
with col2: | |
if youtube_url and youtube_url != "https://www.youtube.com/watch?v=<id>": | |
st.write("Analyzing live stream...") | |
try: | |
streams = streamlink.streams(youtube_url) | |
if not streams: | |
st.error("No streams found. Check the URL.") | |
else: | |
stream_url = streams["best"].to_url() | |
cap = cv2.VideoCapture(stream_url) | |
if not cap.isOpened(): | |
st.error("Unable to open stream.") | |
else: | |
frame_placeholder = st.empty() | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
st.error("Stream interrupted.") | |
break | |
results = model.predict(frame, conf=confidence) | |
detected_frame = results[0].plot()[:, :, ::-1] | |
frame_placeholder.image(detected_frame, use_column_width=True) | |
st.write(f"Objects detected: {len(results[0].boxes)}") | |
time.sleep(1) # Check every second | |
cap.release() | |
except Exception as e: | |
st.error(f"Error: {e}") |