Spaces:
Runtime error
Runtime error
from PIL import Image, ImageDraw | |
from ultralytics import YOLO | |
import streamlit as st | |
import tempfile | |
import cv2 | |
import numpy as np | |
import base64 | |
# Initialize YOLO model | |
model = YOLO("best.pt") | |
# Function to perform object detection on an image | |
def detect_objects_image(image): | |
results = model(image) | |
result = results[0] | |
output = [] | |
num_potholes_detected = 0 | |
num_cracks_detected = 0 | |
num_alligator_cracks_detected = 0 | |
for box in result.boxes: | |
x1, y1, x2, y2 = [round(x) for x in box.xyxy[0].tolist()] | |
class_id = box.cls[0].item() | |
prob = round(box.conf[0].item(), 2) | |
class_name = result.names[class_id] | |
output.append([x1, y1, x2, y2, class_name, prob]) | |
# Count detections by class | |
if class_name == "pothole": | |
num_potholes_detected += 1 | |
elif class_name == "crack": | |
num_cracks_detected += 1 | |
elif class_name == "alligator-crack": | |
num_alligator_cracks_detected += 1 | |
return output, num_potholes_detected, num_cracks_detected, num_alligator_cracks_detected | |
# Function to process and annotate a video | |
def process_video(video_path, output_path, frame_interval): | |
cap = cv2.VideoCapture(video_path) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
frame_interval_count = int(fps * frame_interval) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
frame_count = 0 | |
detections_summary = { | |
'potholes': 0, | |
'cracks': 0, | |
'alligator_cracks': 0 | |
} | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count % frame_interval_count == 0: | |
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
detections, num_potholes, num_cracks, num_alligator_cracks = detect_objects_image(image) | |
detections_summary['potholes'] += num_potholes | |
detections_summary['cracks'] += num_cracks | |
detections_summary['alligator_cracks'] += num_alligator_cracks | |
draw = ImageDraw.Draw(image) | |
for detection in detections: | |
x1, y1, x2, y2, class_name, prob = detection | |
draw.rectangle([x1, y1, x2, y2], outline="red", width=3) | |
text = f"{class_name} {prob:.2f}" | |
draw.text((x1, y1), text, fill="red") | |
annotated_frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
else: | |
annotated_frame = frame | |
out.write(annotated_frame) | |
frame_count += 1 | |
cap.release() | |
out.release() | |
return detections_summary | |
# Function to generate a download link for a file | |
def get_download_link(file_path, text, file_type): | |
with open(file_path, 'rb') as f: | |
file_bytes = f.read() | |
file_b64 = base64.b64encode(file_bytes).decode() | |
download_link = f'<a href="data:{file_type};base64,{file_b64}" download="{text}">{text}</a>' | |
return download_link | |
# Streamlit app | |
def main(): | |
st.title("Road Condition Inspection") | |
st.subheader("Upload an image or video to detect objects") | |
# File uploader for image and video | |
uploaded_file = st.file_uploader("Choose a file...", type=["jpg", "jpeg", "png", "mp4"]) | |
if uploaded_file is not None: | |
file_type = uploaded_file.type | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix="." + uploaded_file.name.split('.')[-1]) | |
temp_file.write(uploaded_file.read()) | |
temp_file.close() | |
if file_type.startswith("image"): | |
image = Image.open(temp_file.name) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
if st.button('Detect Objects (Image)'): | |
detections, num_potholes, num_cracks, num_alligator_cracks = detect_objects_image(image) | |
draw = ImageDraw.Draw(image) | |
for detection in detections: | |
x1, y1, x2, y2, class_name, prob = detection | |
draw.rectangle([x1, y1, x2, y2], outline="red", width=3) | |
text = f"{class_name} {prob:.2f}" | |
draw.text((x1, y1), text, fill="red") | |
st.image(image, caption='Annotated Image', use_column_width=True) | |
st.subheader("Detection Summary") | |
if num_potholes > 0: | |
st.write(f"Potholes Detected: {num_potholes}") | |
if num_cracks > 0: | |
st.write(f"Cracks Detected: {num_cracks}") | |
if num_alligator_cracks > 0: | |
st.write(f"Alligator Cracks Detected: {num_alligator_cracks}") | |
annotated_image_path = temp_file.name.replace(".", "_annotated.") | |
image.save(annotated_image_path) | |
st.markdown(get_download_link(annotated_image_path, "Download Annotated Image", "image/png"), unsafe_allow_html=True) | |
elif file_type.startswith("video"): | |
video_bytes = open(temp_file.name, 'rb').read() | |
st.video(video_bytes) | |
if st.button('Detect Objects (Video)'): | |
annotated_video_path = temp_file.name.replace(".", "_annotated") + ".mp4" | |
detections_summary = process_video(temp_file.name, annotated_video_path, frame_interval=1) | |
st.subheader("Annotated Video Download") | |
st.markdown(get_download_link(annotated_video_path, "Download Annotated Video", "video/mp4"), unsafe_allow_html=True) | |
st.subheader("Detection Summary") | |
if detections_summary['potholes'] > 0: | |
st.write(f"Total Potholes Detected: {detections_summary['potholes']}") | |
if detections_summary['cracks'] > 0: | |
st.write(f"Total Cracks Detected: {detections_summary['cracks']}") | |
if detections_summary['alligator_cracks'] > 0: | |
st.write(f"Total Alligator Cracks Detected: {detections_summary['alligator_cracks']}") | |
if __name__ == '__main__': | |
main() | |