tstone87 commited on
Commit
6cd7819
·
verified ·
1 Parent(s): 033d048

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -124
app.py CHANGED
@@ -1,63 +1,34 @@
1
  import os
2
  import tempfile
3
- import base64
4
- import time
5
  import cv2
6
  import streamlit as st
 
7
  import requests
8
  from ultralytics import YOLO
9
- from huggingface_hub import hf_hub_download
10
- import imageio
11
  import numpy as np
12
 
13
- # Page config must be first
14
  st.set_page_config(
15
- page_title="Wildfire Detection Demo",
16
  page_icon="🔥",
17
  layout="wide",
18
  initial_sidebar_state="expanded"
19
  )
20
 
21
- # Helper function to display videos
22
- def show_video(video_bytes: bytes, title: str, loop=True):
23
- if not video_bytes:
24
- st.warning(f"No {title} video available.")
25
- return
26
- video_base64 = base64.b64encode(video_bytes).decode()
27
- loop_attr = "loop" if loop else ""
28
- video_html = f"""
29
- <h4>{title}</h4>
30
- <video width="100%" controls autoplay muted {loop_attr}>
31
- <source src="data:video/mp4;base64,{video_base64}" type="video/mp4">
32
- Your browser does not support the video tag.
33
- </video>
34
- """
35
- st.markdown(video_html, unsafe_allow_html=True)
36
 
37
- # Initialize session state
38
- for key in ["processed_video", "processing_complete", "start_time", "progress"]:
39
  if key not in st.session_state:
40
- st.session_state[key] = None if key in ["processed_video", "start_time"] else False if key == "processing_complete" else 0
41
-
42
- # Load model
43
- @st.cache_resource
44
- def load_model():
45
- repo_id = "tstone87/ccr-colorado"
46
- filename = "best.pt"
47
- try:
48
- model_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
49
- return YOLO(model_path)
50
- except Exception as e:
51
- st.error(f"Failed to load model: {str(e)}")
52
- return None
53
-
54
- model = load_model()
55
 
56
  # Sidebar
57
  with st.sidebar:
58
- st.header("Process Your Own Video")
59
- uploaded_file = st.file_uploader("Upload a video", type=["mp4"])
60
- confidence = st.slider("Detection Confidence", 0.25, 1.0, 0.4)
61
  fps_options = {
62
  "Original FPS": None,
63
  "3 FPS": 3,
@@ -67,113 +38,128 @@ with st.sidebar:
67
  "1 frame/15s": 0.0667,
68
  "1 frame/30s": 0.0333
69
  }
70
- selected_fps = st.selectbox("Output FPS", list(fps_options.keys()), index=0)
71
- process_button = st.button("Process Video")
72
  progress_bar = st.progress(0)
73
  progress_text = st.empty()
74
  download_slot = st.empty()
75
 
76
- # Main content
77
- st.title("Wildfire Detection Demo")
78
- st.markdown("Watch our example videos below or upload your own in the sidebar!")
 
 
 
 
79
 
80
- # Example videos
81
- example_videos = {
82
- "T Example": ("T1.mp4", "T2.mpg"),
83
- "LA Example": ("LA1.mp4", "LA2.mp4")
84
- }
85
 
86
- for example_name in example_videos:
 
 
87
  col1, col2 = st.columns(2)
88
- orig_file, proc_file = example_videos[example_name]
89
- try:
90
- orig_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{orig_file}"
91
- proc_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{proc_file}"
92
- orig_data = requests.get(orig_url).content
93
- proc_data = requests.get(proc_url).content
94
-
95
- with col1:
96
- show_video(orig_data, f"{example_name} - Original", loop=True)
97
- with col2:
98
- show_video(proc_data, f"{example_name} - Processed", loop=True)
99
- except Exception as e:
100
- st.error(f"Failed to load {example_name}: {str(e)}")
 
 
 
 
 
 
101
 
102
- # Video processing
103
- def process_video(video_file, target_fps, confidence):
104
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
105
- tmp.write(video_file.read())
106
- tmp_path = tmp.name
107
-
108
- try:
109
- reader = imageio.get_reader(tmp_path)
110
- meta = reader.get_meta_data()
111
- original_fps = meta['fps']
112
- width, height = meta['size']
113
- total_frames = meta['nframes'] if meta['nframes'] != float('inf') else 1000 # Fallback for unknown length
 
 
 
 
114
 
115
- output_fps = fps_options[target_fps] if fps_options[target_fps] else original_fps
116
- frame_interval = max(1, int(original_fps / output_fps)) if output_fps else 1
 
 
117
 
118
- out_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
119
- writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), output_fps or original_fps, (width, height))
120
 
121
  st.session_state.start_time = time.time()
 
122
  processed_count = 0
123
 
124
- for i, frame in enumerate(reader):
125
- if i % frame_interval == 0:
126
- frame_rgb = np.array(frame)
127
- results = model.predict(frame_rgb, conf=confidence)
128
- processed_frame = results[0].plot()[:, :, ::-1]
129
- writer.write(processed_frame)
130
 
131
  processed_count += 1
132
  elapsed = time.time() - st.session_state.start_time
133
- progress = (i + 1) / total_frames
134
- st.session_state.progress = min(progress, 1.0)
135
 
136
- if elapsed > 0:
137
- frames_left = total_frames - i - 1
138
  time_per_frame = elapsed / processed_count
139
- eta = frames_left * time_per_frame / frame_interval
 
140
  eta_str = f"{int(eta // 60)}m {int(eta % 60)}s"
141
  else:
142
  eta_str = "Calculating..."
143
 
144
- progress_bar.progress(st.session_state.progress)
145
- progress_text.text(f"Progress: {st.session_state.progress:.1%} | ETA: {eta_str}")
 
 
 
146
 
147
- writer.release()
148
- reader.close()
149
 
150
- with open(out_path, 'rb') as f:
151
- return f.read()
152
-
153
- finally:
154
- if os.path.exists(tmp_path):
155
- os.unlink(tmp_path)
156
- if os.path.exists(out_path):
 
 
157
  os.unlink(out_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Process uploaded video
160
- if process_button and uploaded_file and model:
161
- with st.spinner("Processing video..."):
162
- st.session_state.processed_video = process_video(uploaded_file, selected_fps, confidence)
163
- st.session_state.processing_complete = True
164
- progress_bar.progress(1.0)
165
- progress_text.text("Processing complete!")
166
-
167
- # Show processed video and download button
168
- if st.session_state.processing_complete and st.session_state.processed_video:
169
- st.subheader("Your Processed Video")
170
- show_video(st.session_state.processed_video, "Processed Result", loop=False)
171
- download_slot.download_button(
172
- label="Download Processed Video",
173
- data=st.session_state.processed_video,
174
- file_name="processed_wildfire.mp4",
175
- mime="video/mp4"
176
- )
177
-
178
- if not model:
179
- st.error("Model loading failed. Please check the repository and model file availability.")
 
1
  import os
2
  import tempfile
 
 
3
  import cv2
4
  import streamlit as st
5
+ import PIL
6
  import requests
7
  from ultralytics import YOLO
8
+ import time
 
9
  import numpy as np
10
 
11
+ # Page config first
12
  st.set_page_config(
13
+ page_title="WildfireWatch: AI Detection",
14
  page_icon="🔥",
15
  layout="wide",
16
  initial_sidebar_state="expanded"
17
  )
18
 
19
+ # Model path
20
+ model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt'
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Session state initialization
23
+ for key in ["processed_frames", "slider_value", "processed_video", "start_time"]:
24
  if key not in st.session_state:
25
+ st.session_state[key] = [] if key == "processed_frames" else 0 if key == "slider_value" else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Sidebar
28
  with st.sidebar:
29
+ st.header("Upload & Settings")
30
+ source_file = st.file_uploader("Upload image/video", type=["jpg", "jpeg", "png", "bmp", "webp", "mp4"])
31
+ confidence = float(st.slider("Confidence Threshold", 25, 100, 40)) / 100
32
  fps_options = {
33
  "Original FPS": None,
34
  "3 FPS": 3,
 
38
  "1 frame/15s": 0.0667,
39
  "1 frame/30s": 0.0333
40
  }
41
+ video_option = st.selectbox("Output Frame Rate", list(fps_options.keys()))
42
+ process_button = st.button("Detect Wildfire")
43
  progress_bar = st.progress(0)
44
  progress_text = st.empty()
45
  download_slot = st.empty()
46
 
47
+ # Main page
48
+ st.title("WildfireWatch: AI-Powered Detection")
49
+ col1, col2 = st.columns(2)
50
+ with col1:
51
+ st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True)
52
+ with col2:
53
+ st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True)
54
 
55
+ st.markdown("""
56
+ Early wildfire detection using YOLOv8 AI vision model. See examples below or upload your own content!
57
+ """)
 
 
58
 
59
+ # Example videos
60
+ st.header("Example Results")
61
+ for example in [("T1.mp4", "T2.mpg"), ("LA1.mp4", "LA2.mp4")]:
62
  col1, col2 = st.columns(2)
63
+ orig_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{example[0]}"
64
+ proc_url = f"https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/{example[1]}"
65
+ orig_data = requests.get(orig_url).content
66
+ proc_data = requests.get(proc_url).content
67
+ with col1:
68
+ st.video(orig_data)
69
+ with col2:
70
+ st.video(proc_data)
71
+
72
+ st.header("Your Results")
73
+ result_cols = st.columns(2)
74
+ viewer_slot = st.empty()
75
+
76
+ # Load model
77
+ try:
78
+ model = YOLO(model_path)
79
+ except Exception as ex:
80
+ st.error(f"Model loading failed: {str(ex)}")
81
+ model = None
82
 
83
+ # Processing
84
+ if process_button and source_file and model:
85
+ st.session_state.processed_frames = []
86
+ if source_file.type.split('/')[0] == 'image':
87
+ image = PIL.Image.open(source_file)
88
+ res = model.predict(image, conf=confidence)
89
+ result = res[0].plot()[:, :, ::-1]
90
+ with result_cols[0]:
91
+ st.image(image, caption="Original", use_column_width=True)
92
+ with result_cols[1]:
93
+ st.image(result, caption="Detected", use_column_width=True)
94
+ else:
95
+ # Video processing
96
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
97
+ tmp.write(source_file.read())
98
+ vidcap = cv2.VideoCapture(tmp.name)
99
 
100
+ orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
101
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
102
+ width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
103
+ height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
104
 
105
+ output_fps = fps_options[video_option] if fps_options[video_option] else orig_fps
106
+ sample_interval = max(1, int(orig_fps / output_fps)) if output_fps else 1
107
 
108
  st.session_state.start_time = time.time()
109
+ frame_count = 0
110
  processed_count = 0
111
 
112
+ success, frame = vidcap.read()
113
+ while success:
114
+ if frame_count % sample_interval == 0:
115
+ res = model.predict(frame, conf=confidence)
116
+ processed_frame = res[0].plot()[:, :, ::-1]
117
+ st.session_state.processed_frames.append(processed_frame)
118
 
119
  processed_count += 1
120
  elapsed = time.time() - st.session_state.start_time
121
+ progress = frame_count / total_frames
 
122
 
123
+ if elapsed > 0 and processed_count > 0:
 
124
  time_per_frame = elapsed / processed_count
125
+ frames_left = (total_frames - frame_count) / sample_interval
126
+ eta = frames_left * time_per_frame
127
  eta_str = f"{int(eta // 60)}m {int(eta % 60)}s"
128
  else:
129
  eta_str = "Calculating..."
130
 
131
+ progress_bar.progress(min(progress, 1.0))
132
+ progress_text.text(f"Progress: {progress:.1%} | ETA: {eta_str}")
133
+
134
+ frame_count += 1
135
+ success, frame = vidcap.read()
136
 
137
+ vidcap.release()
138
+ os.unlink(tmp.name)
139
 
140
+ if st.session_state.processed_frames:
141
+ out_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
142
+ writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), output_fps or orig_fps, (width, height))
143
+ for frame in st.session_state.processed_frames:
144
+ writer.write(frame)
145
+ writer.release()
146
+
147
+ with open(out_path, 'rb') as f:
148
+ st.session_state.processed_video = f.read()
149
  os.unlink(out_path)
150
+
151
+ progress_bar.progress(1.0)
152
+ progress_text.text("Processing complete!")
153
+ with result_cols[0]:
154
+ st.video(source_file)
155
+ with result_cols[1]:
156
+ st.video(st.session_state.processed_video)
157
+ download_slot.download_button(
158
+ label="Download Processed Video",
159
+ data=st.session_state.processed_video,
160
+ file_name="processed_wildfire.mp4",
161
+ mime="video/mp4"
162
+ )
163
 
164
+ if not source_file:
165
+ st.info("Please upload a file to begin.")