TheKnight115 commited on
Commit
26f78a4
·
verified ·
1 Parent(s): 4de3665

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -31
app.py CHANGED
@@ -46,6 +46,8 @@ class_colors = {
46
  processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True)
47
  model_ocr = AutoModel.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True).to('cuda')
48
 
 
 
49
 
50
  # YOLO inference function
51
  def run_yolo(image):
@@ -90,46 +92,143 @@ def process_image(uploaded_file):
90
 
91
  # Process and save uploaded videos
92
  @st.cache_data
 
93
  def process_video_and_save(uploaded_file):
94
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
95
- temp_file.write(uploaded_file.read())
96
- temp_file_path = temp_file.name
97
 
98
- video = cv2.VideoCapture(temp_file_path)
99
- total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
100
- frames = []
101
- current_frame = 0
102
- start_time = time.time()
103
 
104
- progress_bar = st.progress(0)
105
- progress_text = st.empty()
106
 
107
- while True:
108
- ret, frame = video.read()
109
- if not ret:
110
- break
111
- results = run_yolo(frame)
112
- processed_frame = process_results(results, frame)
113
- processed_frame_rgb = cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB)
114
- frames.append(processed_frame_rgb)
115
 
116
- current_frame += 1
117
- progress_percentage = int((current_frame / total_frames) * 100)
118
- progress_bar.progress(progress_percentage)
119
- progress_text.text(f"Processing frame {current_frame}/{total_frames} ({progress_percentage}%)")
120
 
121
- video.release()
122
- output_path = 'processed_video.mp4'
123
- height, width, _ = frames[0].shape
124
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))
 
 
 
 
 
 
 
 
 
 
125
 
126
- for frame in frames:
127
- frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
128
- out.write(frame_bgr)
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  out.release()
131
-
132
- return output_path
 
133
 
134
 
135
 
@@ -206,6 +305,193 @@ def send_email(license_text, violation_image_path, violation_type):
206
  server.sendmail(FROM_EMAIL, TO_EMAIL, msg.as_string())
207
  print("Email with attachment sent successfully!")
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  # Streamlit app main function
211
  def main():
 
46
  processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True)
47
  model_ocr = AutoModel.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True).to('cuda')
48
 
49
+ # Define lane area coordinates (example coordinates)
50
+ red_lane = np.array([[2, 1583], [1, 1131], [1828, 1141], [1912, 1580]], np.int32)
51
 
52
  # YOLO inference function
53
  def run_yolo(image):
 
92
 
93
  # Process and save uploaded videos
94
  @st.cache_data
95
+ # Define the function to process the video
96
  def process_video_and_save(uploaded_file):
97
+ # Path for Arabic font
98
+ font_path = "/kaggle/input/fontss/alfont_com_arial-1.ttf"
 
99
 
100
+ # Paths for saving violation images
101
+ violation_image_path = '/kaggle/working/violation.jpg'
 
 
 
102
 
103
+ # Track emails already sent to avoid duplicate emails
104
+ sent_emails = {}
105
 
106
+ # Dictionary to track violations per license plate
107
+ violations_dict = {}
 
 
 
 
 
 
108
 
109
+ # Video path (input)
110
+ video_path = "/kaggle/working/uploaded_video.mp4" # Save the uploaded video file to this path
111
+ with open(video_path, "wb") as f:
112
+ f.write(uploaded_file.getbuffer())
113
 
114
+ cap = cv2.VideoCapture(video_path)
115
+
116
+ # Check if the video file opened successfully
117
+ if not cap.isOpened():
118
+ print("Error opening video file")
119
+ return None
120
+
121
+ # Define codec and output video settings
122
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
123
+ output_video_path = '/kaggle/working/output_violation.mp4'
124
+ fps = cap.get(cv2.CAP_PROP_FPS)
125
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # Frame width
126
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Frame height
127
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
128
 
129
+ margin_y = 50
 
 
130
 
131
+ # Process the video frame by frame
132
+ while cap.isOpened():
133
+ ret, frame = cap.read()
134
+ if not ret:
135
+ break # End of video
136
+
137
+ # Draw the red lane rectangle on each frame
138
+ cv2.polylines(frame, [red_lane], isClosed=True, color=(0, 0, 255), thickness=3) # Red lane
139
+
140
+ # Perform detection using YOLO on the current frame
141
+ results = model.track(frame)
142
+
143
+ # Process each detection in the results
144
+ for box in results[0].boxes:
145
+ x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy()) # Bounding box coordinates
146
+ label = model.names[int(box.cls)] # Class name (MotorbikeDelivery, Helmet, etc.)
147
+ color = (255, 0, 0) # Use a fixed color for bounding boxes
148
+ confidence = box.conf[0].item()
149
+
150
+ # Initialize flags and variables for the violations
151
+ helmet_violation = False
152
+ lane_violation = False
153
+ violation_type = []
154
+
155
+ # Draw bounding box around detected object
156
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3) # 3 is the thickness of the rectangle
157
+
158
+ # Add label to the box (e.g., 'MotorbikeDelivery')
159
+ cv2.putText(frame, f'{label}: {confidence:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
160
+
161
+ # Detect MotorbikeDelivery
162
+ if label == 'MotorbikeDelivery' and confidence >= 0.4:
163
+ motorbike_crop = frame[max(0, y1 - margin_y):y2, x1:x2]
164
+ delivery_center = ((x1 + x2) // 2, (y2))
165
+ in_red_lane = cv2.pointPolygonTest(red_lane, delivery_center, False)
166
+ if in_red_lane >= 0:
167
+ lane_violation = True
168
+ violation_type.append("In Red Lane")
169
+
170
+ # Perform detection within the cropped motorbike region
171
+ sub_results = model(motorbike_crop)
172
+
173
+ for result in sub_results[0].boxes:
174
+ sub_x1, sub_y1, sub_x2, sub_y2 = map(int, result.xyxy[0].cpu().numpy()) # Bounding box coordinates
175
+ sub_label = model.names[int(result.cls)]
176
+ sub_color = (255, 0, 0) # Red color for the bounding box of sub-objects
177
+
178
+ # Draw bounding box around sub-detected objects (No_Helmet, License_plate, etc.)
179
+ cv2.rectangle(motorbike_crop, (sub_x1, sub_y1), (sub_x2, sub_y2), sub_color, 2)
180
+ cv2.putText(motorbike_crop, sub_label, (sub_x1, sub_y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, sub_color, 2)
181
+
182
+ if sub_label == 'No_Helmet':
183
+ helmet_violation = True
184
+ violation_type.append("No Helmet")
185
+ continue
186
+ if sub_label == 'License_plate':
187
+ license_crop = motorbike_crop[sub_y1:sub_y2, sub_x1:sub_x2]
188
+
189
+ # Apply OCR if a violation is detected
190
+ if helmet_violation or lane_violation:
191
+ # Perform OCR on the license plate
192
+ cv2.imwrite(violation_image_path, frame)
193
+ license_plate_pil = Image.fromarray(cv2.cvtColor(license_crop, cv2.COLOR_BGR2RGB))
194
+ temp_image_path = '/kaggle/working/license_plate.png'
195
+ license_plate_pil.save(temp_image_path)
196
+ license_plate_text = model_ocr.chat(processor, temp_image_path, ocr_type='ocr')
197
+ filtered_text = filter_license_plate_text(license_plate_text)
198
+ # Check if the license plate is already detected and saved
199
+ if filtered_text:
200
+ # Get the email from the database
201
+ email = get_vehicle_information(filtered_text, lane_violation, helmet_violation, violation_image_path, "Riyadh")
202
+
203
+ # Add the license plate and its violations to the violations dictionary
204
+ if filtered_text not in violations_dict:
205
+ violations_dict[filtered_text] = violation_type #{"1234AB":[no_Helmet,In_red_Lane]}
206
+ send_email(filtered_text, violation_image_path, ', '.join(violation_type), email)
207
+ else:
208
+ # Update the violations for the license plate if new ones are found
209
+ current_violations = set(violations_dict[filtered_text]) # no helmet
210
+ new_violations = set(violation_type) # red lane, no helmet
211
+ updated_violations = list(current_violations | new_violations) # red_lane, no helmet
212
+
213
+ # If new violations are found, update and send email
214
+ if updated_violations != violations_dict[filtered_text]:
215
+ violations_dict[filtered_text] = updated_violations
216
+ send_email(filtered_text, violation_image_path, ', '.join(updated_violations), email)
217
+
218
+ # Draw OCR text (English and Arabic) on the original frame
219
+ arabic_text = convert_to_arabic(filtered_text)
220
+ frame = draw_text_pil(frame, filtered_text, (x1, y2 + 30), font_path, font_size=30, color=(255, 255, 255))
221
+ frame = draw_text_pil(frame, arabic_text, (x1, y2 + 60), font_path, font_size=30, color=(0, 255, 0))
222
+
223
+ # Write the processed frame to the output video
224
+ out.write(frame)
225
+
226
+ # Release resources when done
227
+ cap.release()
228
  out.release()
229
+
230
+ return output_video_path # Return the path of the processed video
231
+
232
 
233
 
234
 
 
305
  server.sendmail(FROM_EMAIL, TO_EMAIL, msg.as_string())
306
  print("Email with attachment sent successfully!")
307
 
308
+ def draw_text_pil(img, text, position, font_path, font_size, color):
309
+ img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
310
+
311
+ draw = ImageDraw.Draw(img_pil)
312
+
313
+ try:
314
+ font = ImageFont.truetype(font_path, size=font_size)
315
+ except IOError:
316
+ print(f"Font file not found at {font_path}. Using default font.")
317
+ font = ImageFont.load_default()
318
+
319
+ draw.text(position, text, font=font, fill=color)
320
+
321
+ img_np = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
322
+ return img_np
323
+
324
+ import sqlite3
325
+ from datetime import datetime
326
+ from sqlalchemy import create_engine
327
+ from sqlalchemy.orm import sessionmaker
328
+ from sqlalchemy.ext.declarative import declarative_base
329
+ from sqlalchemy import Column, String, Integer, Boolean
330
+
331
+ def get_vehicle_information(detected_license_plate, lane_violation, no_helmet, image_link, city):
332
+ # Get current date and time
333
+ current_datetime = datetime.now()
334
+ current_date = current_datetime.strftime('%Y-%m-%d')
335
+ current_time = current_datetime.strftime('%H:%M:%S')
336
+
337
+ # Connect to SQLite database (motorbike_detections.db)
338
+ try:
339
+ motorbike_conn = sqlite3.connect('motorbike_detections.db')
340
+ motorbike_cursor = motorbike_conn.cursor()
341
+
342
+ # Check if the detection already exists
343
+ check_query = '''
344
+ SELECT DetectionID FROM MotorbikeDetections
345
+ WHERE LicensePlate = ? AND Date = ? AND Time = ? AND ImageLink = ?
346
+ '''
347
+ motorbike_cursor.execute(check_query, (detected_license_plate, current_date, current_time, image_link))
348
+ existing_detection = motorbike_cursor.fetchone()
349
+
350
+ if existing_detection:
351
+ print(f"Detection for license plate {detected_license_plate} at {current_date} {current_time} already exists.")
352
+ else:
353
+ # Insert the new detection record
354
+ insert_query = '''
355
+ INSERT INTO MotorbikeDetections (Date, Time, City, LicensePlate, LaneViolation, NoHelmet, ImageLink)
356
+ VALUES (?, ?, ?, ?, ?, ?, ?)
357
+ '''
358
+ motorbike_cursor.execute(insert_query, (
359
+ current_date,
360
+ current_time,
361
+ city,
362
+ detected_license_plate,
363
+ int(lane_violation), # Convert boolean to integer (1 or 0)
364
+ int(no_helmet), # Convert boolean to integer (1 or 0)
365
+ image_link
366
+ ))
367
+
368
+ # Commit the transaction
369
+ motorbike_conn.commit()
370
+ print(f"Detection data for license plate {detected_license_plate} inserted successfully.")
371
+
372
+ except sqlite3.IntegrityError as e:
373
+ print(f"Integrity Error: {e}. This detection may already exist.")
374
+ except sqlite3.Error as e:
375
+ print(f"An error occurred while inserting detection data: {e}")
376
+ motorbike_conn.rollback()
377
+ finally:
378
+ # Close the motorbike detections database connection
379
+ motorbike_conn.close()
380
+
381
+ # Retrieve email from 'vehicle_information.db'
382
+ try:
383
+ # Create an engine and session for SQLAlchemy
384
+ engine = create_engine('sqlite:///vehicle_information.db')
385
+ Session = sessionmaker(bind=engine)
386
+ session = Session()
387
+
388
+ # Query the VehicleInformation table for the detected license plate
389
+ vehicle_info = session.query(VehicleInformation).filter_by(license_plate=detected_license_plate).first()
390
+
391
+ if vehicle_info:
392
+ print(f"Email found for license plate {detected_license_plate}: {vehicle_info.email}")
393
+ return vehicle_info.email
394
+ else:
395
+ print(f"No vehicle information found for license plate {detected_license_plate}.")
396
+ return None
397
+
398
+ except Exception as e:
399
+ print(f"An error occurred while retrieving vehicle information: {e}")
400
+ return None
401
+
402
+ finally:
403
+ # Close the SQLAlchemy session
404
+ session.close()
405
+
406
+ import sqlite3
407
+ from datetime import datetime
408
+
409
+ def setup_motorbike_detections_db():
410
+ # Connect to SQLite database (or create it if it doesn't exist)
411
+ conn = sqlite3.connect('motorbike_detections.db')
412
+ cursor = conn.cursor()
413
+
414
+ # Drop the old table if it exists, useful for restructuring
415
+ cursor.execute('DROP TABLE IF EXISTS MotorbikeDetections')
416
+
417
+ # Create the new table with a unique constraint on LicensePlate, Date, Time, and ImageLink
418
+ cursor.execute('''
419
+ CREATE TABLE IF NOT EXISTS MotorbikeDetections (
420
+ DetectionID INTEGER PRIMARY KEY AUTOINCREMENT,
421
+ Date DATE NOT NULL,
422
+ Time TIME NOT NULL,
423
+ City VARCHAR(100),
424
+ LicensePlate VARCHAR(100),
425
+ LaneViolation BOOLEAN NOT NULL,
426
+ NoHelmet BOOLEAN NOT NULL,
427
+ ImageLink VARCHAR(255),
428
+ UNIQUE(LicensePlate, Date, Time, ImageLink)
429
+ )
430
+ ''')
431
+
432
+ # Commit changes
433
+ conn.commit()
434
+
435
+ # Close the connection
436
+ conn.close()
437
+
438
+
439
+ from sqlalchemy import create_engine, Column, String, Integer
440
+ from sqlalchemy.ext.declarative import declarative_base
441
+ from sqlalchemy.orm import sessionmaker
442
+
443
+ # Define the base for model creation
444
+ Base = declarative_base()
445
+
446
+ # Define the VehicleInformation table using SQLAlchemy ORM
447
+ class VehicleInformation(Base):
448
+ __tablename__ = 'VehicleInformation'
449
+
450
+ id = Column(Integer, primary_key=True, autoincrement=True)
451
+ license_plate = Column(String(100), unique=True, nullable=False)
452
+ email = Column(String(255), nullable=False)
453
+ phone_number = Column(String(15), nullable=False)
454
+ driver_id = Column(String(50), nullable=False) # Government ID or identity number
455
+
456
+ def setup_database():
457
+ # Create an SQLite database (this could be any database like PostgreSQL, MySQL, etc.)
458
+ engine = create_engine('sqlite:///vehicle_information.db')
459
+
460
+ # Drop existing tables and recreate them (for ensuring clean database on rerun)
461
+ Base.metadata.drop_all(engine)
462
+ Base.metadata.create_all(engine)
463
+
464
+ # Create a session to interact with the database
465
+ Session = sessionmaker(bind=engine)
466
+ session = Session()
467
+
468
+ # Insert some dummy data into the VehicleInformation table
469
+ vehicle_data_1 = VehicleInformation(
470
+ license_plate='1234 AB',
471
+ email='[email protected]',
472
+ phone_number='0559947203',
473
+ driver_id='ID1110000000'
474
+ )
475
+
476
+ vehicle_data_2 = VehicleInformation(
477
+ license_plate='3321 AR',
478
+ email='[email protected]',
479
+ phone_number='0539003545',
480
+ driver_id='ID2220000000'
481
+ )
482
+
483
+ # Add records to the session
484
+ session.add(vehicle_data_1)
485
+ session.add(vehicle_data_2)
486
+
487
+ # Commit the records to the database
488
+ session.commit()
489
+
490
+ # Query the table to confirm data
491
+ vehicles = session.query(VehicleInformation).all()
492
+
493
+ # Close the session
494
+ session.close()
495
 
496
  # Streamlit app main function
497
  def main():