OmidSakaki's picture
Update app.py
dcd89aa verified
raw
history blame
3.42 kB
import streamlit as st
import tensorflow as tf
from PIL import ImageOps, Image
import os # To work with operation system commands
import cv2 # To process images
import random # to generate random choices
import warnings # To avoid python warnings
import numpy as np # To work with arrays
import pandas as pd # To work with DataFrames
import seaborn as sns # To visualizations
from tqdm import tqdm # To represent progress bars
from ultralytics import YOLO # To Create Yolo model
from termcolor import colored # To colorfull outputs
import matplotlib.pyplot as plt # To visualizations
#load weights
best_model_plate = YOLO('best.pt')
best_model_digits = YOLO('best2.pt')
def Detect_Plate(img):
'''
Predict on a car image and return result folder and predicted image path.
'''
result1 = best_model_plate.predict([img], save=True, iou=0.7)
# Location of saved predicted images
result_path1 = result1[0].save_dir
pred1_path = os.path.join(result_path1, os.listdir(result_path1)[-1])
return result1, pred1_path
def Detect_Digits(img2) :
'''
Predict on a croped plate and return result folder and predicted image.
'''
result2 = best_model_digits.predict([img2], save=True, iou=0.7, show_conf=False)
# Location of saved predicted images
result_path = result2[0].save_dir
pred2_path = os.path.join(result_path, os.listdir(result_path)[0])
return result2, pred2_path
def Sort_Digits(result2) :
'''
Sort detected labels base on their X-cordinate,
--- Sort from lef to right ---
'''
list_of_items = []
for i in range(len(result2[0])) :
# Class labels decoded
class_label = classes_dict[str(int(result2[0][i].boxes.cls))]
# X cordinate to find the situation on X-axis
cord_x = int(result2[0][i].boxes.xyxy[0][0])
list_of_items.append((cord_x, class_label))
list_of_items.sort()
list_of_items
digits = []
for digit in list_of_items :
digits.append(digit[1])
if len(digits) == 7 : digits.append('0')
return digits
raw_plate = ('raw_plate.png')
def FINAL(img) :
'''
A pipeline for all parts of phase 3.
start with a car image.
result is digits and char on car plate.
'''
# Read car image ( STEP-1 )
#img = cv2.imread(img)
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.open(img).convert('RGB')
# First prediction -> Detect car-plate ( STEP-2 )
result1, _ = Detect_Plate(img)
# Plate Cordinates ( STEP-3 )
pts = result1[0].boxes.xyxy.tolist()[0]
# Crop plate
img2 = img[round(pts[1]):round(pts[3]), round(pts[0]):round(pts[2])]
# Resize plate to feed to second model ( STEP-4 )
img2 = tf.image.resize(img2, (120, 70))
# Second prediction -> Detect digits in plate
result2, _ = Detect_Digits(img2)
# Sort detected digits ( STEP-5 )
digits = Sort_Digits(result2)
Plot_Result(img, raw_plate, digits)
# set title
st.title('Persian Plates Digits Detection')
# set header
st.header('Please upload image')
# upload file
img = st.file_uploader('', type=['jpeg', 'jpg', 'png'])
# display image
if img is not None:
image = Image.open(img).convert('RGB')
st.image(image, use_column_width=True)
st.write(FINAL(img))