onursavas qoobeeshy commited on
Commit
5f92ac2
·
0 Parent(s):

Duplicate from qoobeeshy/yolo-document-layout-analysis

Browse files

Co-authored-by: saad sakib noor <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/0050a8ee-382b-447e-9c5b-8506d9507bef.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/019384d0-88c2-46ba-8f1b-bf7432f50ea3.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ yoloenv/
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Document Layout Analysis
3
+ emoji: 🐠
4
+ colorFrom: gray
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: qoobeeshy/yolo-document-layout-analysis
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import torch
4
+ import os
5
+ from tqdm import tqdm
6
+ # import wandb
7
+ from ultralytics import YOLO
8
+ import cv2
9
+ import numpy as np
10
+ import pandas as pd
11
+ from skimage.transform import resize
12
+ from skimage import img_as_bool
13
+ from skimage.morphology import convex_hull_image
14
+ import json
15
+
16
+ # wandb.init(mode='disabled')
17
+
18
+ def tableConvexHull(img, masks):
19
+ mask=np.zeros(masks[0].shape,dtype="bool")
20
+ for msk in masks:
21
+ temp=msk.cpu().detach().numpy();
22
+ chull = convex_hull_image(temp);
23
+ mask=np.bitwise_or(mask,chull)
24
+ return mask
25
+
26
+ def cls_exists(clss, cls):
27
+ indices = torch.where(clss==cls)
28
+ return len(indices[0])>0
29
+
30
+ def empty_mask(img):
31
+ mask = np.zeros(img.shape[:2], dtype="uint8")
32
+ return np.array(mask, dtype=bool)
33
+
34
+ def extract_img_mask(img_model, img, config):
35
+ res_dict = {
36
+ 'status' : 1
37
+ }
38
+ res = get_predictions(img_model, img, config)
39
+
40
+ if res['status']==-1:
41
+ res_dict['status'] = -1
42
+
43
+ elif res['status']==0:
44
+ res_dict['mask']=empty_mask(img)
45
+
46
+ else:
47
+ masks = res['masks']
48
+ boxes = res['boxes']
49
+ clss = boxes[:, 5]
50
+ mask = extract_mask(img, masks, boxes, clss, 0)
51
+ res_dict['mask'] = mask
52
+ return res_dict
53
+
54
+ def get_predictions(model, img2, config):
55
+ res_dict = {
56
+ 'status': 1
57
+ }
58
+ try:
59
+ for result in model.predict(source=img2, verbose=False, retina_masks=config['rm'],\
60
+ imgsz=config['sz'], conf=config['conf'], stream=True,\
61
+ classes=config['classes']):
62
+ try:
63
+ res_dict['masks'] = result.masks.data
64
+ res_dict['boxes'] = result.boxes.data
65
+ del result
66
+ return res_dict
67
+ except Exception as e:
68
+ res_dict['status'] = 0
69
+ return res_dict
70
+ except:
71
+ res_dict['status'] = -1
72
+ return res_dict
73
+
74
+ def extract_mask(img, masks, boxes, clss, cls):
75
+ if not cls_exists(clss, cls):
76
+ return empty_mask(img)
77
+ indices = torch.where(clss==cls)
78
+ c_masks = masks[indices]
79
+ mask_arr = torch.any(c_masks, dim=0).bool()
80
+ mask_arr = mask_arr.cpu().detach().numpy()
81
+ mask = mask_arr
82
+ return mask
83
+
84
+
85
+ def get_masks(img, model, img_model, flags, configs):
86
+ response = {
87
+ 'status': 1
88
+ }
89
+ ans_masks = []
90
+ img2 = img
91
+
92
+
93
+ # ***** Getting paragraph and text masks
94
+ res = get_predictions(model, img2, configs['paratext'])
95
+ if res['status']==-1:
96
+ response['status'] = -1
97
+ return response
98
+ elif res['status']==0:
99
+ for i in range(2): ans_masks.append(empty_mask(img))
100
+ else:
101
+ masks, boxes = res['masks'], res['boxes']
102
+ clss = boxes[:, 5]
103
+ for cls in range(2):
104
+ mask = extract_mask(img, masks, boxes, clss, cls)
105
+ ans_masks.append(mask)
106
+
107
+
108
+ # ***** Getting image and table masks
109
+ res2 = get_predictions(model, img2, configs['imgtab'])
110
+ if res2['status']==-1:
111
+ response['status'] = -1
112
+ return response
113
+ elif res2['status']==0:
114
+ for i in range(2): ans_masks.append(empty_mask(img))
115
+ else:
116
+ masks, boxes = res2['masks'], res2['boxes']
117
+ clss = boxes[:, 5]
118
+
119
+ if cls_exists(clss, 2):
120
+ img_res = extract_img_mask(img_model, img, configs['image'])
121
+ if img_res['status'] == 1:
122
+ img_mask = img_res['mask']
123
+ else:
124
+ response['status'] = -1
125
+ return response
126
+
127
+ else:
128
+ img_mask = empty_mask(img)
129
+ ans_masks.append(img_mask)
130
+
131
+ if cls_exists(clss, 3):
132
+ indices = torch.where(clss==3)
133
+ tbl_mask = tableConvexHull(img, masks[indices])
134
+ else:
135
+ tbl_mask = empty_mask(img)
136
+ ans_masks.append(tbl_mask)
137
+
138
+ if not configs['paratext']['rm']:
139
+ h, w, c = img.shape
140
+ for i in range(4):
141
+ ans_masks[i] = img_as_bool(resize(ans_masks[i], (h, w)))
142
+
143
+
144
+ response['masks'] = ans_masks
145
+ return response
146
+
147
+ def overlay(image, mask, color, alpha, resize=None):
148
+ """Combines image and its segmentation mask into a single image.
149
+ https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay
150
+
151
+ Params:
152
+ image: Training image. np.ndarray,
153
+ mask: Segmentation mask. np.ndarray,
154
+ color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0)
155
+ alpha: Segmentation mask's transparency. float = 0.5,
156
+ resize: If provided, both image and its mask are resized before blending them together.
157
+ tuple[int, int] = (1024, 1024))
158
+
159
+ Returns:
160
+ image_combined: The combined image. np.ndarray
161
+
162
+ """
163
+ color = color[::-1]
164
+ colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
165
+ colored_mask = np.moveaxis(colored_mask, 0, -1)
166
+ masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
167
+ image_overlay = masked.filled()
168
+
169
+ if resize is not None:
170
+ image = cv2.resize(image.transpose(1, 2, 0), resize)
171
+ image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)
172
+
173
+ image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
174
+
175
+ return image_combined
176
+
177
+
178
+
179
+ model_path = 'models'
180
+ general_model_name = 'e50_aug.pt'
181
+ image_model_name = 'e100_img.pt'
182
+
183
+ general_model = YOLO(os.path.join(model_path, general_model_name))
184
+ image_model = YOLO(os.path.join(model_path, image_model_name))
185
+
186
+ image_path = 'examples'
187
+ sample_name = ['0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png',
188
+ '0050a8ee-382b-447e-9c5b-8506d9507bef.png', '0064d3e2-3ba2-4332-a28f-3a165f2b84b1.png']
189
+
190
+ sample_path = [os.path.join(image_path, sample) for sample in sample_name]
191
+
192
+ flags = {
193
+ 'hist': False,
194
+ 'bz': False
195
+ }
196
+
197
+
198
+ configs = {}
199
+ configs['paratext'] = {
200
+ 'sz' : 640,
201
+ 'conf': 0.25,
202
+ 'rm': True,
203
+ 'classes': [0, 1]
204
+ }
205
+ configs['imgtab'] = {
206
+ 'sz' : 640,
207
+ 'conf': 0.35,
208
+ 'rm': True,
209
+ 'classes': [2, 3]
210
+ }
211
+ configs['image'] = {
212
+ 'sz' : 640,
213
+ 'conf': 0.35,
214
+ 'rm': True,
215
+ 'classes': [0]
216
+ }
217
+
218
+ def evaluate(img_path, model=general_model, img_model=image_model,\
219
+ configs=configs, flags=flags):
220
+ # print('starting')
221
+ img = cv2.imread(img_path)
222
+ res = get_masks(img, general_model, image_model, flags, configs)
223
+ if res['status']==-1:
224
+ for idx in configs.keys():
225
+ configs[idx]['rm'] = False
226
+ return evaluate(img, model, img_model, flags, configs)
227
+ else:
228
+ masks = res['masks']
229
+
230
+ color_map = {
231
+ 0 : (255, 0, 0),
232
+ 1 : (0, 255, 0),
233
+ 2 : (0, 0, 255),
234
+ 3 : (255, 255, 0),
235
+ }
236
+ for i, mask in enumerate(masks):
237
+ img = overlay(image=img, mask=mask, color=color_map[i], alpha=0.4)
238
+ # print('finishing')
239
+ return img
240
+
241
+ # output = evaluate(img_path=sample_path, model=general_model, img_model=image_model,\
242
+ # configs=configs, flags=flags)
243
+
244
+
245
+ inputs_image = [
246
+ gr.components.Image(type="filepath", label="Input Image"),
247
+ ]
248
+ outputs_image = [
249
+ gr.components.Image(type="numpy", label="Output Image"),
250
+ ]
251
+ interface_image = gr.Interface(
252
+ fn=evaluate,
253
+ inputs=inputs_image,
254
+ outputs=outputs_image,
255
+ title="Document Layout Segmentor",
256
+ examples=sample_path,
257
+ cache_examples=True,
258
+ ).launch()
examples/0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png ADDED
examples/0050a8ee-382b-447e-9c5b-8506d9507bef.png ADDED

Git LFS Details

  • SHA256: 4576ea936110706d799c7f0c58792a2fc3e87e29fadfe9f109978a50c0bc3b9e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.85 MB
examples/0064d3e2-3ba2-4332-a28f-3a165f2b84b1.png ADDED
examples/00fb93d5-7c67-4851-ad08-f23ed2159467.png ADDED
examples/019384d0-88c2-46ba-8f1b-bf7432f50ea3.png ADDED

Git LFS Details

  • SHA256: d1607b7e810d11609ced709caf86e5d4ea633e86958ff3356bff5732642d4e19
  • Pointer size: 133 Bytes
  • Size of remote file: 11.5 MB
models/e100_img.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7424265a528fd1a2f741bb48a3586e69496de55f14e4a4c5ba867e83c2d159f8
3
+ size 54786656
models/e50_aug.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12dba7a7156750342fb35ef2305a0bffa31615258aced63811e9220990f1f0a3
3
+ size 54792992
requirements.txt ADDED
Binary file (2.41 kB). View file