OmidSakaki commited on
Commit
aa1c443
·
verified ·
1 Parent(s): a47807e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py CHANGED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import ImageOps, Image
3
+ import numpy as np
4
+ from tensorflow.keras.models import load_model
5
+ import tensorflow as tf
6
+
7
+ #load weights
8
+ best_model_plate = YOLO('best.pt')
9
+ best_model_digits = YOLO('best2.pt')
10
+
11
+ def Detect_Plate(img):
12
+ '''
13
+ Predict on a car image and return result folder and predicted image path.
14
+ '''
15
+ result1 = best_model_plate.predict([img], save=True, iou=0.7)
16
+
17
+ # Location of saved predicted images
18
+ result_path1 = result1[0].save_dir
19
+ pred1_path = os.path.join(result_path1, os.listdir(result_path1)[-1])
20
+ return result1, pred1_path
21
+
22
+ def Detect_Digits(img2) :
23
+ '''
24
+ Predict on a croped plate and return result folder and predicted image.
25
+ '''
26
+ result2 = best_model_digits.predict([img2], save=True, iou=0.7, show_conf=False)
27
+
28
+ # Location of saved predicted images
29
+ result_path = result2[0].save_dir
30
+ pred2_path = os.path.join(result_path, os.listdir(result_path)[0])
31
+ return result2, pred2_path
32
+
33
+ def Sort_Digits(result2) :
34
+ '''
35
+ Sort detected labels base on their X-cordinate,
36
+ --- Sort from lef to right ---
37
+ '''
38
+ list_of_items = []
39
+ for i in range(len(result2[0])) :
40
+ # Class labels decoded
41
+ class_label = classes_dict[str(int(result2[0][i].boxes.cls))]
42
+ # X cordinate to find the situation on X-axis
43
+ cord_x = int(result2[0][i].boxes.xyxy[0][0])
44
+ list_of_items.append((cord_x, class_label))
45
+
46
+ list_of_items.sort()
47
+ list_of_items
48
+
49
+ digits = []
50
+ for digit in list_of_items :
51
+ digits.append(digit[1])
52
+ if len(digits) == 7 : digits.append('0')
53
+ return digits
54
+
55
+
56
+ raw_plate = ('raw_plate.png')
57
+
58
+ def FINAL(img_path) :
59
+ '''
60
+ A pipeline for all parts of phase 3.
61
+ start with a car image.
62
+ result is digits and char on car plate.
63
+ '''
64
+ # Read car image ( STEP-1 )
65
+ img = cv2.imread(img_path)
66
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
67
+
68
+ # First prediction -> Detect car-plate ( STEP-2 )
69
+ result1, _ = Detect_Plate(img)
70
+
71
+ # Plate Cordinates ( STEP-3 )
72
+ pts = result1[0].boxes.xyxy.tolist()[0]
73
+
74
+ # Crop plate
75
+ img2 = img[round(pts[1]):round(pts[3]), round(pts[0]):round(pts[2])]
76
+
77
+ # Resize plate to feed to second model ( STEP-4 )
78
+ img2 = cv2.resize(img2, dsize=(120, 70))
79
+
80
+ # Second prediction -> Detect digits in plate
81
+ result2, _ = Detect_Digits(img2)
82
+
83
+ # Sort detected digits ( STEP-5 )
84
+ digits = Sort_Digits(result2)
85
+
86
+ Plot_Result(img, raw_plate, digits)
87
+
88
+ # set title
89
+ st.title('Persian Plates Digits Detection')
90
+
91
+ # set header
92
+ st.header('Please upload image')
93
+
94
+ # upload file
95
+ img = st.file_uploader('', type=['jpeg', 'jpg', 'png'])
96
+
97
+ # display image
98
+ if img is not None:
99
+ image = Image.open(img).convert('RGB')
100
+ st.image(img, use_column_width=True)
101
+
102
+
103
+ # write classification
104
+ st.write("## {}".format(FINAL(img)))