ibaiGorordo commited on
Commit
fa14840
·
1 Parent(s): 7bcacfa

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from lstr import LSTR
6
+ model_path = "models/model_float32.onnx"
7
+
8
+ title = "Lane Shape Prediction with Transformers (LSTR)"
9
+ description = "Demo for performing lane detection using the LSTR model. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
10
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2011.04233'>End-to-end Lane Shape Prediction with Transformers</a> | <a href='https://github.com/liuruijin17/LSTR'>Original Model</a></p>"
11
+
12
+ # Initialize lane detection model
13
+ lane_detector = LSTR(model_path)
14
+
15
+ def inference(image):
16
+ image = np.array(image, dtype=np.uint8)
17
+ input_img = image.copy()
18
+ detected_lanes, lane_ids = lane_detector.detect_lanes(input_img)
19
+ output_img = lane_detector.draw_lanes(image)
20
+
21
+ # output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
22
+ im_pil = Image.fromarray(output_img)
23
+ return im_pil
24
+
25
+ gr.Interface(
26
+ inference,
27
+ [gr.inputs.Image(type="pil", label="Input")],
28
+ gr.outputs.Image(type="pil", label="Output"),
29
+ title=title,
30
+ description=description,
31
+ article=article,
32
+ examples=[
33
+ ["dog_road.jpg"],
34
+ ["swiss_road.jpeg"]
35
+ ]).launch()
dog_road.jpg ADDED
lstr/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from lstr.lstr import LSTR
lstr/lstr.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import cv2
3
+ import time
4
+ import numpy as np
5
+ import onnxruntime
6
+ print(onnxruntime.get_device())
7
+
8
+ lane_colors = [(249,65,68),(243,114,44),(248,150,30),(249,132,74),(249,199,79),(144,190,109),(77, 144, 142),(39, 125, 161)]
9
+ log_space = np.logspace(0,2, 50, base=1/10, endpoint=True)
10
+
11
+ class LSTR():
12
+
13
+ def __init__(self, model_path):
14
+
15
+ # Initialize model
16
+ self.model = self.initialize_model(model_path)
17
+
18
+ def __call__(self, image):
19
+
20
+ return self.detect_lanes(image)
21
+
22
+ def initialize_model(self, model_path):
23
+
24
+ opts = onnxruntime.SessionOptions()
25
+ opts.intra_op_num_threads = 16
26
+ self.session = onnxruntime.InferenceSession(model_path,sess_options=opts)
27
+
28
+ # Get model info
29
+ self.getModel_input_details()
30
+ self.getModel_output_details()
31
+
32
+ def detect_lanes(self, image):
33
+
34
+ input_tensor, mask_tensor = self.prepare_inputs(image)
35
+
36
+ outputs = self.inference(input_tensor, mask_tensor)
37
+
38
+ detected_lanes, good_lanes = self.process_output(outputs)
39
+
40
+ return detected_lanes, good_lanes
41
+
42
+ def prepare_inputs(self, img):
43
+
44
+ self.img_height, self.img_width, self.img_channels = img.shape
45
+
46
+ # Transform the image for inference
47
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
48
+ img = cv2.resize(img,(self.input_width, self.input_height))
49
+
50
+ # Scale input pixel values to -1 to 1
51
+ mean=[0.485, 0.456, 0.406]
52
+ std=[0.229, 0.224, 0.225]
53
+
54
+ img = ((img/ 255.0 - mean) / std)
55
+ # img = img/ 255.0
56
+
57
+ img = img.transpose(2, 0, 1)
58
+ input_tensor = img[np.newaxis,:,:,:].astype(np.float32)
59
+
60
+ mask_tensor = np.zeros((1, 1, self.input_height, self.input_width), dtype=np.float32)
61
+
62
+ return input_tensor, mask_tensor
63
+
64
+ def inference(self, input_tensor, mask_tensor):
65
+ start = time.time()
66
+ outputs = self.session.run(self.output_names, {self.rgb_input_name: input_tensor,
67
+ self.mask_input_name: mask_tensor})
68
+ # print(time.time() - start)
69
+ return outputs
70
+
71
+ @staticmethod
72
+ def softmax(x):
73
+ """Compute softmax values for each sets of scores in x."""
74
+ e_x = np.exp(x - np.max(x))
75
+ return e_x / e_x.sum(axis=-1).T
76
+
77
+ def process_output(self, outputs):
78
+
79
+ pred_logits = outputs[0]
80
+ pred_curves = outputs[1]
81
+
82
+ # Filter good lanes based on the probability
83
+ prob = self.softmax(pred_logits)
84
+ good_detections = np.where(np.argmax(prob,axis=-1)==1)
85
+
86
+ pred_logits = pred_logits[good_detections]
87
+ pred_curves = pred_curves[good_detections]
88
+
89
+ lanes = []
90
+ for lane_data in pred_curves:
91
+ bounds = lane_data[:2]
92
+ k_2, f_2, m_2, n_1, b_2, b_3 = lane_data[2:]
93
+
94
+ # Calculate the points for the lane
95
+ y_norm = bounds[0]+log_space*(bounds[1]-bounds[0])
96
+ x_norm = (k_2 / (y_norm - f_2) ** 2 + m_2 / (y_norm - f_2) + n_1 + b_2 * y_norm - b_3)
97
+ lane_points = np.vstack((x_norm*self.img_width, y_norm*self.img_height)).astype(int)
98
+
99
+ lanes.append(lane_points)
100
+
101
+ self.lanes = lanes
102
+ self.good_lanes = good_detections[1]
103
+
104
+ return lanes, self.good_lanes
105
+
106
+ def getModel_input_details(self):
107
+
108
+ model_inputs = self.session.get_inputs()
109
+ self.rgb_input_name = self.session.get_inputs()[0].name
110
+ self.mask_input_name = self.session.get_inputs()[1].name
111
+
112
+ self.input_shape = self.session.get_inputs()[0].shape
113
+ self.input_height = self.input_shape[2]
114
+ self.input_width = self.input_shape[3]
115
+
116
+ def getModel_output_details(self):
117
+
118
+ model_outputs = self.session.get_outputs()
119
+ self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
120
+ # print(self.output_names)
121
+
122
+ def draw_lanes(self,input_img):
123
+
124
+ # Write the detected line points in the image
125
+ visualization_img = input_img.copy()
126
+
127
+ # Draw a mask for the current lane
128
+ right_lane = np.where(self.good_lanes==0)[0]
129
+ left_lane = np.where(self.good_lanes==5)[0]
130
+
131
+ if(len(left_lane) and len(right_lane)):
132
+
133
+ lane_segment_img = visualization_img.copy()
134
+
135
+ points = np.vstack((self.lanes[left_lane[0]].T,
136
+ np.flipud(self.lanes[right_lane[0]].T)))
137
+ cv2.fillConvexPoly(lane_segment_img, points, color =(0,191,255))
138
+ visualization_img = cv2.addWeighted(visualization_img, 0.7, lane_segment_img, 0.3, 0)
139
+
140
+ for lane_num,lane_points in zip(self.good_lanes, self.lanes):
141
+ for lane_point in lane_points.T:
142
+ cv2.circle(visualization_img, (lane_point[0],lane_point[1]), 3, lane_colors[lane_num], -1)
143
+
144
+ return visualization_img
145
+
146
+ if __name__ == '__main__':
147
+ model_path='../models/model_float32.onnx'
148
+ lane_detector = LSTR(model_path)
149
+
150
+ img = cv2.imread("../dog_road.jpg")
151
+ detected_lanes, lane_ids = lane_detector(img)
152
+ print(lane_ids)
153
+
154
+ lane_img = lane_detector.draw_lanes(img)
155
+ cv2.namedWindow("Detected lanes", cv2.WINDOW_NORMAL)
156
+ cv2.imshow("Detected lanes",lane_img)
157
+ cv2.imwrite("out.jpg", lane_img)
158
+ cv2.waitKey(0)
159
+
models/model_float32.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b189aa31c389ae974469a563f493a4985f5354f5455dc363a340ed6f6cfd39dc
3
+ size 3074878
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ opencv-python-headless
2
+ onnx
3
+ onnxruntime
4
+ onnxruntime-gpu
swiss_road.jpeg ADDED