Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -46,6 +46,8 @@ class_colors = {
|
|
46 |
processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True)
|
47 |
model_ocr = AutoModel.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True).to('cuda')
|
48 |
|
|
|
|
|
49 |
|
50 |
# YOLO inference function
|
51 |
def run_yolo(image):
|
@@ -90,46 +92,143 @@ def process_image(uploaded_file):
|
|
90 |
|
91 |
# Process and save uploaded videos
|
92 |
@st.cache_data
|
|
|
93 |
def process_video_and_save(uploaded_file):
|
94 |
-
|
95 |
-
|
96 |
-
temp_file_path = temp_file.name
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
frames = []
|
101 |
-
current_frame = 0
|
102 |
-
start_time = time.time()
|
103 |
|
104 |
-
|
105 |
-
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
if not ret:
|
110 |
-
break
|
111 |
-
results = run_yolo(frame)
|
112 |
-
processed_frame = process_results(results, frame)
|
113 |
-
processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB)
|
114 |
-
frames.append(processed_frame_rgb)
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
-
|
127 |
-
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
128 |
-
out.write(frame_bgr)
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
out.release()
|
131 |
-
|
132 |
-
|
|
|
133 |
|
134 |
|
135 |
|
@@ -206,6 +305,193 @@ def send_email(license_text, violation_image_path, violation_type):
|
|
206 |
server.sendmail(FROM_EMAIL, TO_EMAIL, msg.as_string())
|
207 |
print("Email with attachment sent successfully!")
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
# Streamlit app main function
|
211 |
def main():
|
|
|
46 |
processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True)
|
47 |
model_ocr = AutoModel.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True).to('cuda')
|
48 |
|
49 |
+
# Define lane area coordinates (example coordinates)
|
50 |
+
red_lane = np.array([[2, 1583], [1, 1131], [1828, 1141], [1912, 1580]], np.int32)
|
51 |
|
52 |
# YOLO inference function
|
53 |
def run_yolo(image):
|
|
|
92 |
|
93 |
# Process and save uploaded videos
|
94 |
@st.cache_data
|
95 |
+
# Define the function to process the video
|
96 |
def process_video_and_save(uploaded_file):
|
97 |
+
# Path for Arabic font
|
98 |
+
font_path = "/kaggle/input/fontss/alfont_com_arial-1.ttf"
|
|
|
99 |
|
100 |
+
# Paths for saving violation images
|
101 |
+
violation_image_path = '/kaggle/working/violation.jpg'
|
|
|
|
|
|
|
102 |
|
103 |
+
# Track emails already sent to avoid duplicate emails
|
104 |
+
sent_emails = {}
|
105 |
|
106 |
+
# Dictionary to track violations per license plate
|
107 |
+
violations_dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
# Video path (input)
|
110 |
+
video_path = "/kaggle/working/uploaded_video.mp4" # Save the uploaded video file to this path
|
111 |
+
with open(video_path, "wb") as f:
|
112 |
+
f.write(uploaded_file.getbuffer())
|
113 |
|
114 |
+
cap = cv2.VideoCapture(video_path)
|
115 |
+
|
116 |
+
# Check if the video file opened successfully
|
117 |
+
if not cap.isOpened():
|
118 |
+
print("Error opening video file")
|
119 |
+
return None
|
120 |
+
|
121 |
+
# Define codec and output video settings
|
122 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
123 |
+
output_video_path = '/kaggle/working/output_violation.mp4'
|
124 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
125 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # Frame width
|
126 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Frame height
|
127 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
128 |
|
129 |
+
margin_y = 50
|
|
|
|
|
130 |
|
131 |
+
# Process the video frame by frame
|
132 |
+
while cap.isOpened():
|
133 |
+
ret, frame = cap.read()
|
134 |
+
if not ret:
|
135 |
+
break # End of video
|
136 |
+
|
137 |
+
# Draw the red lane rectangle on each frame
|
138 |
+
cv2.polylines(frame, [red_lane], isClosed=True, color=(0, 0, 255), thickness=3) # Red lane
|
139 |
+
|
140 |
+
# Perform detection using YOLO on the current frame
|
141 |
+
results = model.track(frame)
|
142 |
+
|
143 |
+
# Process each detection in the results
|
144 |
+
for box in results[0].boxes:
|
145 |
+
x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy()) # Bounding box coordinates
|
146 |
+
label = model.names[int(box.cls)] # Class name (MotorbikeDelivery, Helmet, etc.)
|
147 |
+
color = (255, 0, 0) # Use a fixed color for bounding boxes
|
148 |
+
confidence = box.conf[0].item()
|
149 |
+
|
150 |
+
# Initialize flags and variables for the violations
|
151 |
+
helmet_violation = False
|
152 |
+
lane_violation = False
|
153 |
+
violation_type = []
|
154 |
+
|
155 |
+
# Draw bounding box around detected object
|
156 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3) # 3 is the thickness of the rectangle
|
157 |
+
|
158 |
+
# Add label to the box (e.g., 'MotorbikeDelivery')
|
159 |
+
cv2.putText(frame, f'{label}: {confidence:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
160 |
+
|
161 |
+
# Detect MotorbikeDelivery
|
162 |
+
if label == 'MotorbikeDelivery' and confidence >= 0.4:
|
163 |
+
motorbike_crop = frame[max(0, y1 - margin_y):y2, x1:x2]
|
164 |
+
delivery_center = ((x1 + x2) // 2, (y2))
|
165 |
+
in_red_lane = cv2.pointPolygonTest(red_lane, delivery_center, False)
|
166 |
+
if in_red_lane >= 0:
|
167 |
+
lane_violation = True
|
168 |
+
violation_type.append("In Red Lane")
|
169 |
+
|
170 |
+
# Perform detection within the cropped motorbike region
|
171 |
+
sub_results = model(motorbike_crop)
|
172 |
+
|
173 |
+
for result in sub_results[0].boxes:
|
174 |
+
sub_x1, sub_y1, sub_x2, sub_y2 = map(int, result.xyxy[0].cpu().numpy()) # Bounding box coordinates
|
175 |
+
sub_label = model.names[int(result.cls)]
|
176 |
+
sub_color = (255, 0, 0) # Red color for the bounding box of sub-objects
|
177 |
+
|
178 |
+
# Draw bounding box around sub-detected objects (No_Helmet, License_plate, etc.)
|
179 |
+
cv2.rectangle(motorbike_crop, (sub_x1, sub_y1), (sub_x2, sub_y2), sub_color, 2)
|
180 |
+
cv2.putText(motorbike_crop, sub_label, (sub_x1, sub_y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, sub_color, 2)
|
181 |
+
|
182 |
+
if sub_label == 'No_Helmet':
|
183 |
+
helmet_violation = True
|
184 |
+
violation_type.append("No Helmet")
|
185 |
+
continue
|
186 |
+
if sub_label == 'License_plate':
|
187 |
+
license_crop = motorbike_crop[sub_y1:sub_y2, sub_x1:sub_x2]
|
188 |
+
|
189 |
+
# Apply OCR if a violation is detected
|
190 |
+
if helmet_violation or lane_violation:
|
191 |
+
# Perform OCR on the license plate
|
192 |
+
cv2.imwrite(violation_image_path, frame)
|
193 |
+
license_plate_pil = Image.fromarray(cv2.cvtColor(license_crop, cv2.COLOR_BGR2RGB))
|
194 |
+
temp_image_path = '/kaggle/working/license_plate.png'
|
195 |
+
license_plate_pil.save(temp_image_path)
|
196 |
+
license_plate_text = model_ocr.chat(processor, temp_image_path, ocr_type='ocr')
|
197 |
+
filtered_text = filter_license_plate_text(license_plate_text)
|
198 |
+
# Check if the license plate is already detected and saved
|
199 |
+
if filtered_text:
|
200 |
+
# Get the email from the database
|
201 |
+
email = get_vehicle_information(filtered_text, lane_violation, helmet_violation, violation_image_path, "Riyadh")
|
202 |
+
|
203 |
+
# Add the license plate and its violations to the violations dictionary
|
204 |
+
if filtered_text not in violations_dict:
|
205 |
+
violations_dict[filtered_text] = violation_type #{"1234AB":[no_Helmet,In_red_Lane]}
|
206 |
+
send_email(filtered_text, violation_image_path, ', '.join(violation_type), email)
|
207 |
+
else:
|
208 |
+
# Update the violations for the license plate if new ones are found
|
209 |
+
current_violations = set(violations_dict[filtered_text]) # no helmet
|
210 |
+
new_violations = set(violation_type) # red lane, no helmet
|
211 |
+
updated_violations = list(current_violations | new_violations) # red_lane, no helmet
|
212 |
+
|
213 |
+
# If new violations are found, update and send email
|
214 |
+
if updated_violations != violations_dict[filtered_text]:
|
215 |
+
violations_dict[filtered_text] = updated_violations
|
216 |
+
send_email(filtered_text, violation_image_path, ', '.join(updated_violations), email)
|
217 |
+
|
218 |
+
# Draw OCR text (English and Arabic) on the original frame
|
219 |
+
arabic_text = convert_to_arabic(filtered_text)
|
220 |
+
frame = draw_text_pil(frame, filtered_text, (x1, y2 + 30), font_path, font_size=30, color=(255, 255, 255))
|
221 |
+
frame = draw_text_pil(frame, arabic_text, (x1, y2 + 60), font_path, font_size=30, color=(0, 255, 0))
|
222 |
+
|
223 |
+
# Write the processed frame to the output video
|
224 |
+
out.write(frame)
|
225 |
+
|
226 |
+
# Release resources when done
|
227 |
+
cap.release()
|
228 |
out.release()
|
229 |
+
|
230 |
+
return output_video_path # Return the path of the processed video
|
231 |
+
|
232 |
|
233 |
|
234 |
|
|
|
305 |
server.sendmail(FROM_EMAIL, TO_EMAIL, msg.as_string())
|
306 |
print("Email with attachment sent successfully!")
|
307 |
|
308 |
+
def draw_text_pil(img, text, position, font_path, font_size, color):
|
309 |
+
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
310 |
+
|
311 |
+
draw = ImageDraw.Draw(img_pil)
|
312 |
+
|
313 |
+
try:
|
314 |
+
font = ImageFont.truetype(font_path, size=font_size)
|
315 |
+
except IOError:
|
316 |
+
print(f"Font file not found at {font_path}. Using default font.")
|
317 |
+
font = ImageFont.load_default()
|
318 |
+
|
319 |
+
draw.text(position, text, font=font, fill=color)
|
320 |
+
|
321 |
+
img_np = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
322 |
+
return img_np
|
323 |
+
|
324 |
+
import sqlite3
|
325 |
+
from datetime import datetime
|
326 |
+
from sqlalchemy import create_engine
|
327 |
+
from sqlalchemy.orm import sessionmaker
|
328 |
+
from sqlalchemy.ext.declarative import declarative_base
|
329 |
+
from sqlalchemy import Column, String, Integer, Boolean
|
330 |
+
|
331 |
+
def get_vehicle_information(detected_license_plate, lane_violation, no_helmet, image_link, city):
|
332 |
+
# Get current date and time
|
333 |
+
current_datetime = datetime.now()
|
334 |
+
current_date = current_datetime.strftime('%Y-%m-%d')
|
335 |
+
current_time = current_datetime.strftime('%H:%M:%S')
|
336 |
+
|
337 |
+
# Connect to SQLite database (motorbike_detections.db)
|
338 |
+
try:
|
339 |
+
motorbike_conn = sqlite3.connect('motorbike_detections.db')
|
340 |
+
motorbike_cursor = motorbike_conn.cursor()
|
341 |
+
|
342 |
+
# Check if the detection already exists
|
343 |
+
check_query = '''
|
344 |
+
SELECT DetectionID FROM MotorbikeDetections
|
345 |
+
WHERE LicensePlate = ? AND Date = ? AND Time = ? AND ImageLink = ?
|
346 |
+
'''
|
347 |
+
motorbike_cursor.execute(check_query, (detected_license_plate, current_date, current_time, image_link))
|
348 |
+
existing_detection = motorbike_cursor.fetchone()
|
349 |
+
|
350 |
+
if existing_detection:
|
351 |
+
print(f"Detection for license plate {detected_license_plate} at {current_date} {current_time} already exists.")
|
352 |
+
else:
|
353 |
+
# Insert the new detection record
|
354 |
+
insert_query = '''
|
355 |
+
INSERT INTO MotorbikeDetections (Date, Time, City, LicensePlate, LaneViolation, NoHelmet, ImageLink)
|
356 |
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
357 |
+
'''
|
358 |
+
motorbike_cursor.execute(insert_query, (
|
359 |
+
current_date,
|
360 |
+
current_time,
|
361 |
+
city,
|
362 |
+
detected_license_plate,
|
363 |
+
int(lane_violation), # Convert boolean to integer (1 or 0)
|
364 |
+
int(no_helmet), # Convert boolean to integer (1 or 0)
|
365 |
+
image_link
|
366 |
+
))
|
367 |
+
|
368 |
+
# Commit the transaction
|
369 |
+
motorbike_conn.commit()
|
370 |
+
print(f"Detection data for license plate {detected_license_plate} inserted successfully.")
|
371 |
+
|
372 |
+
except sqlite3.IntegrityError as e:
|
373 |
+
print(f"Integrity Error: {e}. This detection may already exist.")
|
374 |
+
except sqlite3.Error as e:
|
375 |
+
print(f"An error occurred while inserting detection data: {e}")
|
376 |
+
motorbike_conn.rollback()
|
377 |
+
finally:
|
378 |
+
# Close the motorbike detections database connection
|
379 |
+
motorbike_conn.close()
|
380 |
+
|
381 |
+
# Retrieve email from 'vehicle_information.db'
|
382 |
+
try:
|
383 |
+
# Create an engine and session for SQLAlchemy
|
384 |
+
engine = create_engine('sqlite:///vehicle_information.db')
|
385 |
+
Session = sessionmaker(bind=engine)
|
386 |
+
session = Session()
|
387 |
+
|
388 |
+
# Query the VehicleInformation table for the detected license plate
|
389 |
+
vehicle_info = session.query(VehicleInformation).filter_by(license_plate=detected_license_plate).first()
|
390 |
+
|
391 |
+
if vehicle_info:
|
392 |
+
print(f"Email found for license plate {detected_license_plate}: {vehicle_info.email}")
|
393 |
+
return vehicle_info.email
|
394 |
+
else:
|
395 |
+
print(f"No vehicle information found for license plate {detected_license_plate}.")
|
396 |
+
return None
|
397 |
+
|
398 |
+
except Exception as e:
|
399 |
+
print(f"An error occurred while retrieving vehicle information: {e}")
|
400 |
+
return None
|
401 |
+
|
402 |
+
finally:
|
403 |
+
# Close the SQLAlchemy session
|
404 |
+
session.close()
|
405 |
+
|
406 |
+
import sqlite3
|
407 |
+
from datetime import datetime
|
408 |
+
|
409 |
+
def setup_motorbike_detections_db():
|
410 |
+
# Connect to SQLite database (or create it if it doesn't exist)
|
411 |
+
conn = sqlite3.connect('motorbike_detections.db')
|
412 |
+
cursor = conn.cursor()
|
413 |
+
|
414 |
+
# Drop the old table if it exists, useful for restructuring
|
415 |
+
cursor.execute('DROP TABLE IF EXISTS MotorbikeDetections')
|
416 |
+
|
417 |
+
# Create the new table with a unique constraint on LicensePlate, Date, Time, and ImageLink
|
418 |
+
cursor.execute('''
|
419 |
+
CREATE TABLE IF NOT EXISTS MotorbikeDetections (
|
420 |
+
DetectionID INTEGER PRIMARY KEY AUTOINCREMENT,
|
421 |
+
Date DATE NOT NULL,
|
422 |
+
Time TIME NOT NULL,
|
423 |
+
City VARCHAR(100),
|
424 |
+
LicensePlate VARCHAR(100),
|
425 |
+
LaneViolation BOOLEAN NOT NULL,
|
426 |
+
NoHelmet BOOLEAN NOT NULL,
|
427 |
+
ImageLink VARCHAR(255),
|
428 |
+
UNIQUE(LicensePlate, Date, Time, ImageLink)
|
429 |
+
)
|
430 |
+
''')
|
431 |
+
|
432 |
+
# Commit changes
|
433 |
+
conn.commit()
|
434 |
+
|
435 |
+
# Close the connection
|
436 |
+
conn.close()
|
437 |
+
|
438 |
+
|
439 |
+
from sqlalchemy import create_engine, Column, String, Integer
|
440 |
+
from sqlalchemy.ext.declarative import declarative_base
|
441 |
+
from sqlalchemy.orm import sessionmaker
|
442 |
+
|
443 |
+
# Define the base for model creation
|
444 |
+
Base = declarative_base()
|
445 |
+
|
446 |
+
# Define the VehicleInformation table using SQLAlchemy ORM
|
447 |
+
class VehicleInformation(Base):
|
448 |
+
__tablename__ = 'VehicleInformation'
|
449 |
+
|
450 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
451 |
+
license_plate = Column(String(100), unique=True, nullable=False)
|
452 |
+
email = Column(String(255), nullable=False)
|
453 |
+
phone_number = Column(String(15), nullable=False)
|
454 |
+
driver_id = Column(String(50), nullable=False) # Government ID or identity number
|
455 |
+
|
456 |
+
def setup_database():
|
457 |
+
# Create an SQLite database (this could be any database like PostgreSQL, MySQL, etc.)
|
458 |
+
engine = create_engine('sqlite:///vehicle_information.db')
|
459 |
+
|
460 |
+
# Drop existing tables and recreate them (for ensuring clean database on rerun)
|
461 |
+
Base.metadata.drop_all(engine)
|
462 |
+
Base.metadata.create_all(engine)
|
463 |
+
|
464 |
+
# Create a session to interact with the database
|
465 |
+
Session = sessionmaker(bind=engine)
|
466 |
+
session = Session()
|
467 |
+
|
468 |
+
# Insert some dummy data into the VehicleInformation table
|
469 |
+
vehicle_data_1 = VehicleInformation(
|
470 |
+
license_plate='1234 AB',
|
471 |
+
email='[email protected]',
|
472 |
+
phone_number='0559947203',
|
473 |
+
driver_id='ID1110000000'
|
474 |
+
)
|
475 |
+
|
476 |
+
vehicle_data_2 = VehicleInformation(
|
477 |
+
license_plate='3321 AR',
|
478 |
+
email='[email protected]',
|
479 |
+
phone_number='0539003545',
|
480 |
+
driver_id='ID2220000000'
|
481 |
+
)
|
482 |
+
|
483 |
+
# Add records to the session
|
484 |
+
session.add(vehicle_data_1)
|
485 |
+
session.add(vehicle_data_2)
|
486 |
+
|
487 |
+
# Commit the records to the database
|
488 |
+
session.commit()
|
489 |
+
|
490 |
+
# Query the table to confirm data
|
491 |
+
vehicles = session.query(VehicleInformation).all()
|
492 |
+
|
493 |
+
# Close the session
|
494 |
+
session.close()
|
495 |
|
496 |
# Streamlit app main function
|
497 |
def main():
|