Spaces:
Running
Running
Abdur Rahman
commited on
Commit
·
390ca68
1
Parent(s):
4347569
Deploy to HuggingFace spaces
Browse files- README.md +5 -1
- UrduGlyphs.txt +180 -0
- app.py +65 -0
- model.py +57 -0
- modules/.DS_Store +0 -0
- modules/cnn/.DS_Store +0 -0
- modules/cnn/__pycache__/unet.cpython-310.pyc +0 -0
- modules/cnn/unet.py +116 -0
- modules/dropout_layer.py +16 -0
- modules/feature_extraction.py +11 -0
- modules/prediction.py +81 -0
- modules/sequence_modeling.py +36 -0
- read.py +39 -0
- requirements.txt +11 -0
- utils.py +75 -0
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: UrduOCR UTRNet
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
|
@@ -8,6 +8,10 @@ sdk_version: 4.16.0
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: cc-by-nc-sa-4.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: UrduOCR UTRNet
|
| 3 |
+
emoji: 📖
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: cc-by-nc-sa-4.0
|
| 11 |
+
references:
|
| 12 |
+
- https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 13 |
+
- https://abdur75648.github.io/UTRNet/
|
| 14 |
+
- https://arxiv.org/abs/2306.15782
|
| 15 |
---
|
| 16 |
|
| 17 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
UrduGlyphs.txt
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ا
|
| 2 |
+
آ
|
| 3 |
+
ب
|
| 4 |
+
پ
|
| 5 |
+
ت
|
| 6 |
+
ٹ
|
| 7 |
+
ث
|
| 8 |
+
ج
|
| 9 |
+
چ
|
| 10 |
+
ح
|
| 11 |
+
خ
|
| 12 |
+
د
|
| 13 |
+
ڈ
|
| 14 |
+
ذ
|
| 15 |
+
ر
|
| 16 |
+
ڑ
|
| 17 |
+
ز
|
| 18 |
+
ژ
|
| 19 |
+
س
|
| 20 |
+
ش
|
| 21 |
+
ص
|
| 22 |
+
ض
|
| 23 |
+
ط
|
| 24 |
+
ظ
|
| 25 |
+
ع
|
| 26 |
+
غ
|
| 27 |
+
ف
|
| 28 |
+
ق
|
| 29 |
+
ک
|
| 30 |
+
ك
|
| 31 |
+
گ
|
| 32 |
+
ل
|
| 33 |
+
م
|
| 34 |
+
ن
|
| 35 |
+
ں
|
| 36 |
+
و
|
| 37 |
+
ہ
|
| 38 |
+
ھ
|
| 39 |
+
ء
|
| 40 |
+
ی
|
| 41 |
+
ے
|
| 42 |
+
ئ
|
| 43 |
+
ۓ
|
| 44 |
+
ي
|
| 45 |
+
ې
|
| 46 |
+
ٿ
|
| 47 |
+
ڐ
|
| 48 |
+
ڙ
|
| 49 |
+
أ
|
| 50 |
+
ؤ
|
| 51 |
+
ۀ
|
| 52 |
+
ۃ
|
| 53 |
+
ة
|
| 54 |
+
ه
|
| 55 |
+
ۂ
|
| 56 |
+
ﮥ
|
| 57 |
+
ٴ
|
| 58 |
+
َ
|
| 59 |
+
ً
|
| 60 |
+
ُ
|
| 61 |
+
ِ
|
| 62 |
+
ٍ
|
| 63 |
+
ْ
|
| 64 |
+
ٗ
|
| 65 |
+
ٓ
|
| 66 |
+
ٰ
|
| 67 |
+
ٖ
|
| 68 |
+
٘
|
| 69 |
+
ٔ
|
| 70 |
+
ّ
|
| 71 |
+
ؔ
|
| 72 |
+
۰
|
| 73 |
+
۱
|
| 74 |
+
۲
|
| 75 |
+
۳
|
| 76 |
+
۴
|
| 77 |
+
۵
|
| 78 |
+
۶
|
| 79 |
+
۷
|
| 80 |
+
۸
|
| 81 |
+
۹
|
| 82 |
+
٪
|
| 83 |
+
%
|
| 84 |
+
+
|
| 85 |
+
=
|
| 86 |
+
٤
|
| 87 |
+
٫
|
| 88 |
+
,
|
| 89 |
+
-
|
| 90 |
+
_
|
| 91 |
+
٥
|
| 92 |
+
٬
|
| 93 |
+
"
|
| 94 |
+
'
|
| 95 |
+
/
|
| 96 |
+
\
|
| 97 |
+
>
|
| 98 |
+
<
|
| 99 |
+
؍
|
| 100 |
+
،
|
| 101 |
+
؛
|
| 102 |
+
:
|
| 103 |
+
؟
|
| 104 |
+
{
|
| 105 |
+
}
|
| 106 |
+
[
|
| 107 |
+
]
|
| 108 |
+
(
|
| 109 |
+
)
|
| 110 |
+
.
|
| 111 |
+
‘
|
| 112 |
+
’
|
| 113 |
+
٠
|
| 114 |
+
۔
|
| 115 |
+
“
|
| 116 |
+
”
|
| 117 |
+
!
|
| 118 |
+
*
|
| 119 |
+
A
|
| 120 |
+
B
|
| 121 |
+
C
|
| 122 |
+
D
|
| 123 |
+
E
|
| 124 |
+
F
|
| 125 |
+
G
|
| 126 |
+
H
|
| 127 |
+
I
|
| 128 |
+
J
|
| 129 |
+
K
|
| 130 |
+
L
|
| 131 |
+
M
|
| 132 |
+
N
|
| 133 |
+
O
|
| 134 |
+
P
|
| 135 |
+
Q
|
| 136 |
+
R
|
| 137 |
+
S
|
| 138 |
+
T
|
| 139 |
+
U
|
| 140 |
+
V
|
| 141 |
+
W
|
| 142 |
+
X
|
| 143 |
+
Y
|
| 144 |
+
Z
|
| 145 |
+
a
|
| 146 |
+
b
|
| 147 |
+
c
|
| 148 |
+
d
|
| 149 |
+
e
|
| 150 |
+
f
|
| 151 |
+
g
|
| 152 |
+
h
|
| 153 |
+
i
|
| 154 |
+
j
|
| 155 |
+
k
|
| 156 |
+
l
|
| 157 |
+
m
|
| 158 |
+
n
|
| 159 |
+
o
|
| 160 |
+
p
|
| 161 |
+
q
|
| 162 |
+
r
|
| 163 |
+
s
|
| 164 |
+
t
|
| 165 |
+
u
|
| 166 |
+
v
|
| 167 |
+
w
|
| 168 |
+
x
|
| 169 |
+
y
|
| 170 |
+
z
|
| 171 |
+
0
|
| 172 |
+
1
|
| 173 |
+
2
|
| 174 |
+
3
|
| 175 |
+
4
|
| 176 |
+
5
|
| 177 |
+
6
|
| 178 |
+
7
|
| 179 |
+
8
|
| 180 |
+
9
|
app.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from read import text_recognizer
|
| 4 |
+
from model import Model
|
| 5 |
+
from utils import CTCLabelConverter
|
| 6 |
+
from kraken import binarization
|
| 7 |
+
from kraken import pageseg as detection_model
|
| 8 |
+
from PIL import ImageDraw
|
| 9 |
+
|
| 10 |
+
""" vocab / character number configuration """
|
| 11 |
+
file = open("UrduGlyphs.txt","r",encoding="utf-8")
|
| 12 |
+
content = file.readlines()
|
| 13 |
+
content = ''.join([str(elem).strip('\n') for elem in content])
|
| 14 |
+
content = content+" "
|
| 15 |
+
""" model configuration """
|
| 16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 17 |
+
converter = CTCLabelConverter(content)
|
| 18 |
+
recognition_model = Model(num_class=len(converter.character), device=device)
|
| 19 |
+
modrecognition_modelel = recognition_model.to(device)
|
| 20 |
+
recognition_model.load_state_dict(torch.load("best_norm_ED.pth", map_location=device))
|
| 21 |
+
recognition_model.eval()
|
| 22 |
+
|
| 23 |
+
examples = ["1.jpg","2.jpg","3.jpg"]
|
| 24 |
+
|
| 25 |
+
input = gr.Image(type="pil",image_mode="RGB", label="Input Image")
|
| 26 |
+
|
| 27 |
+
def predict(input):
|
| 28 |
+
"Line Detection"
|
| 29 |
+
bw_input = binarization.nlbin(input)
|
| 30 |
+
bounding_boxes = detection_model.segment(bw_input)['boxes']
|
| 31 |
+
bounding_boxes.sort(key=lambda x: x[1])
|
| 32 |
+
|
| 33 |
+
"Draw the bounding boxes"
|
| 34 |
+
draw = ImageDraw.Draw(input)
|
| 35 |
+
for box in bounding_boxes:
|
| 36 |
+
draw.rectangle(box, outline='red', width=3)
|
| 37 |
+
|
| 38 |
+
"Crop the detected lines"
|
| 39 |
+
cropped_images = []
|
| 40 |
+
for box in bounding_boxes:
|
| 41 |
+
cropped_images.append(input.crop(box))
|
| 42 |
+
len(cropped_images)
|
| 43 |
+
|
| 44 |
+
"Recognize the text"
|
| 45 |
+
texts = []
|
| 46 |
+
for img in cropped_images:
|
| 47 |
+
texts.append(text_recognizer(img, recognition_model, converter, device))
|
| 48 |
+
|
| 49 |
+
"Join the text"
|
| 50 |
+
text = "\n".join(texts)
|
| 51 |
+
|
| 52 |
+
"Return the image with bounding boxes and the text"
|
| 53 |
+
return input,text
|
| 54 |
+
|
| 55 |
+
output_image = gr.Image(type="pil",image_mode="RGB",label="Detected Lines")
|
| 56 |
+
output_text = gr.Textbox(label="Recognized Text",interactive=True,show_copy_button=True)
|
| 57 |
+
|
| 58 |
+
iface = gr.Interface(predict,
|
| 59 |
+
inputs=input,
|
| 60 |
+
outputs=[output_image,output_text],
|
| 61 |
+
title="End-to-End Urdu OCR",
|
| 62 |
+
description="Demo Web App For UTRNet (https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition)",
|
| 63 |
+
examples=examples,
|
| 64 |
+
allow_flagging="never")
|
| 65 |
+
iface.launch()
|
model.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from modules.dropout_layer import dropout_layer
|
| 4 |
+
from modules.sequence_modeling import BidirectionalLSTM
|
| 5 |
+
from modules.feature_extraction import UNet_FeatureExtractor
|
| 6 |
+
|
| 7 |
+
class Model(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, num_class=181, device='cpu'):
|
| 10 |
+
super(Model, self).__init__()
|
| 11 |
+
self.device = device
|
| 12 |
+
|
| 13 |
+
""" FeatureExtraction """
|
| 14 |
+
self.FeatureExtraction = UNet_FeatureExtractor(1, 512)
|
| 15 |
+
self.FeatureExtraction_output = 512
|
| 16 |
+
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Temporal Dropout
|
| 20 |
+
"""
|
| 21 |
+
self.dropout1 = dropout_layer(self.device)
|
| 22 |
+
self.dropout2 = dropout_layer(self.device)
|
| 23 |
+
self.dropout3 = dropout_layer(self.device)
|
| 24 |
+
self.dropout4 = dropout_layer(self.device)
|
| 25 |
+
self.dropout5 = dropout_layer(self.device)
|
| 26 |
+
|
| 27 |
+
""" Sequence modeling"""
|
| 28 |
+
self.SequenceModeling = nn.Sequential(
|
| 29 |
+
BidirectionalLSTM(self.FeatureExtraction_output, 256, 256),
|
| 30 |
+
BidirectionalLSTM(256, 256, 256))
|
| 31 |
+
self.SequenceModeling_output = 256
|
| 32 |
+
|
| 33 |
+
""" Prediction """
|
| 34 |
+
self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
|
| 35 |
+
|
| 36 |
+
def forward(self, input, text=None, is_train=True):
|
| 37 |
+
""" Feature extraction stage """
|
| 38 |
+
visual_feature = self.FeatureExtraction(input)
|
| 39 |
+
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))
|
| 40 |
+
visual_feature = visual_feature.squeeze(3)
|
| 41 |
+
|
| 42 |
+
""" Temporal Dropout + Sequence modeling stage """
|
| 43 |
+
visual_feature_after_dropout1 = self.dropout1(visual_feature)
|
| 44 |
+
visual_feature_after_dropout2 = self.dropout2(visual_feature)
|
| 45 |
+
visual_feature_after_dropout3 = self.dropout3(visual_feature)
|
| 46 |
+
visual_feature_after_dropout4 = self.dropout4(visual_feature)
|
| 47 |
+
visual_feature_after_dropout5 = self.dropout5(visual_feature)
|
| 48 |
+
contextual_feature1 = self.SequenceModeling(visual_feature_after_dropout1)
|
| 49 |
+
contextual_feature2 = self.SequenceModeling(visual_feature_after_dropout2)
|
| 50 |
+
contextual_feature3 = self.SequenceModeling(visual_feature_after_dropout3)
|
| 51 |
+
contextual_feature4 = self.SequenceModeling(visual_feature_after_dropout4)
|
| 52 |
+
contextual_feature5 = self.SequenceModeling(visual_feature_after_dropout5)
|
| 53 |
+
contextual_feature = ( (contextual_feature1).add ((contextual_feature2).add(((contextual_feature3).add(((contextual_feature4).add(contextual_feature5)))))) ) * (1/5)
|
| 54 |
+
|
| 55 |
+
""" Prediction stage """
|
| 56 |
+
prediction = self.Prediction(contextual_feature.contiguous())
|
| 57 |
+
return prediction
|
modules/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
modules/cnn/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
modules/cnn/__pycache__/unet.cpython-310.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
modules/cnn/unet.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
# Code For UNet Feature Extractor - Source - https://github.com/milesial/Pytorch-UNet
|
| 8 |
+
class DoubleConv(nn.Module):
|
| 9 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
if not mid_channels:
|
| 14 |
+
mid_channels = out_channels
|
| 15 |
+
self.double_conv = nn.Sequential(
|
| 16 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
| 17 |
+
nn.BatchNorm2d(mid_channels),
|
| 18 |
+
nn.ReLU(inplace=True),
|
| 19 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 20 |
+
nn.BatchNorm2d(out_channels),
|
| 21 |
+
nn.ReLU(inplace=True)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return self.double_conv(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Down(nn.Module):
|
| 29 |
+
"""Downscaling with maxpool then double conv"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_channels, out_channels):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.maxpool_conv = nn.Sequential(
|
| 34 |
+
nn.MaxPool2d(2),
|
| 35 |
+
DoubleConv(in_channels, out_channels)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
return self.maxpool_conv(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Up(nn.Module):
|
| 43 |
+
"""Upscaling then double conv"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, in_channels, out_channels):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
| 48 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
| 49 |
+
|
| 50 |
+
def forward(self, x1, x2):
|
| 51 |
+
x1 = self.up(x1)
|
| 52 |
+
# input is CHW
|
| 53 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 54 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 55 |
+
|
| 56 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
| 57 |
+
diffY // 2, diffY - diffY // 2])
|
| 58 |
+
x = torch.cat([x2, x1], dim=1)
|
| 59 |
+
return self.conv(x)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class OutConv(nn.Module):
|
| 63 |
+
def __init__(self, in_channels, out_channels):
|
| 64 |
+
super(OutConv, self).__init__()
|
| 65 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
return self.conv(x)
|
| 69 |
+
|
| 70 |
+
class UNet(nn.Module):
|
| 71 |
+
def __init__(self, n_channels=1, n_classes=512):
|
| 72 |
+
super(UNet, self).__init__()
|
| 73 |
+
self.n_channels = n_channels
|
| 74 |
+
self.n_classes = n_classes
|
| 75 |
+
|
| 76 |
+
self.inc = DoubleConv(n_channels, 32)
|
| 77 |
+
self.down1 = Down(32, 64)
|
| 78 |
+
self.down2 = Down(64, 128)
|
| 79 |
+
self.down3 = Down(128, 256)
|
| 80 |
+
self.down4 = Down(256, 512)
|
| 81 |
+
self.up1 = Up(512, 256)
|
| 82 |
+
self.up2 = Up(256, 128)
|
| 83 |
+
self.up3 = Up(128, 64)
|
| 84 |
+
self.up4 = Up(64, 32)
|
| 85 |
+
self.outc = OutConv(32, n_classes)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
# print(x.shape) # torch.Size([1, 1, 32, 400])
|
| 89 |
+
x1 = self.inc(x)
|
| 90 |
+
# print(x1.shape) # torch.Size([1, 32, 32, 400])
|
| 91 |
+
x2 = self.down1(x1)
|
| 92 |
+
# print(x2.shape) # torch.Size([1, 64, 16, 200])
|
| 93 |
+
x3 = self.down2(x2)
|
| 94 |
+
# print(x3.shape) # torch.Size([1, 128, 8, 100])
|
| 95 |
+
x4 = self.down3(x3)
|
| 96 |
+
# print(x4.shape) # torch.Size([1, 256, 4, 50])
|
| 97 |
+
x5 = self.down4(x4)
|
| 98 |
+
# print(x5.shape) # torch.Size([1, 512, 2, 25])
|
| 99 |
+
|
| 100 |
+
# print("Upscaling...")
|
| 101 |
+
x = self.up1(x5, x4)
|
| 102 |
+
# print(x.shape) # torch.Size([1, 256, 4, 50])
|
| 103 |
+
x = self.up2(x, x3)
|
| 104 |
+
# print(x.shape) # torch.Size([1, 128, 8, 100])
|
| 105 |
+
x = self.up3(x, x2)
|
| 106 |
+
# print(x.shape) # torch.Size([1, 64, 16, 200])
|
| 107 |
+
x = self.up4(x, x1)
|
| 108 |
+
# print(x.shape) # torch.Size([1, 32, 32, 400])
|
| 109 |
+
logits = self.outc(x)
|
| 110 |
+
# print(logits.shape) # torch.Size([1, 512, 32, 400])
|
| 111 |
+
return logits
|
| 112 |
+
|
| 113 |
+
# x = torch.randn(1, 1, 32, 400)
|
| 114 |
+
# net = UNet()
|
| 115 |
+
# out = net(x)
|
| 116 |
+
# print(out.shape)
|
modules/dropout_layer.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class dropout_layer(nn.Module):
|
| 7 |
+
def __init__(self,device):
|
| 8 |
+
super(dropout_layer, self).__init__()
|
| 9 |
+
self.device = device
|
| 10 |
+
def forward(self, input):
|
| 11 |
+
nums = (np.random.rand(input.shape[1]) > 0.2).astype (int)
|
| 12 |
+
dummy_array_output = torch.from_numpy(nums).to(self.device)
|
| 13 |
+
dummy_array_output_t = torch.reshape(dummy_array_output, (input.shape[1], 1)).to(self.device) #Transpose
|
| 14 |
+
dummy_array_output_f = dummy_array_output_t.repeat(input.shape[0], 1,input.shape[2]).to(self.device) #Same size as input
|
| 15 |
+
output = input*dummy_array_output_f #element-wise multiplication
|
| 16 |
+
return output
|
modules/feature_extraction.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from modules.cnn.unet import UNet
|
| 4 |
+
|
| 5 |
+
class UNet_FeatureExtractor(nn.Module):
|
| 6 |
+
def __init__(self, input_channel=1, output_channel=512):
|
| 7 |
+
super(UNet_FeatureExtractor, self).__init__()
|
| 8 |
+
self.ConvNet = UNet(input_channel, output_channel)
|
| 9 |
+
|
| 10 |
+
def forward(self, input):
|
| 11 |
+
return self.ConvNet(input)
|
modules/prediction.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class Attention(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, input_size, hidden_size, num_classes, device):
|
| 9 |
+
super(Attention, self).__init__()
|
| 10 |
+
self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
|
| 11 |
+
self.hidden_size = hidden_size
|
| 12 |
+
self.num_classes = num_classes
|
| 13 |
+
self.generator = nn.Linear(hidden_size, num_classes)
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
def _char_to_onehot(self, input_char, onehot_dim=38):
|
| 17 |
+
input_char = input_char.unsqueeze(1)
|
| 18 |
+
batch_size = input_char.size(0)
|
| 19 |
+
one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device)
|
| 20 |
+
one_hot = one_hot.scatter_(1, input_char, 1)
|
| 21 |
+
return one_hot
|
| 22 |
+
|
| 23 |
+
def forward(self, batch_H, text, is_train=True, batch_max_length=25):
|
| 24 |
+
"""
|
| 25 |
+
input:
|
| 26 |
+
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
|
| 27 |
+
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
|
| 28 |
+
output: probability distribution at each step [batch_size x num_steps x num_classes]
|
| 29 |
+
"""
|
| 30 |
+
batch_size = batch_H.size(0)
|
| 31 |
+
num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
|
| 32 |
+
|
| 33 |
+
output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(self.device)
|
| 34 |
+
hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(self.device),
|
| 35 |
+
torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(self.device))
|
| 36 |
+
|
| 37 |
+
if is_train:
|
| 38 |
+
for i in range(num_steps):
|
| 39 |
+
# one-hot vectors for a i-th char. in a batch
|
| 40 |
+
char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes)
|
| 41 |
+
# hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
|
| 42 |
+
hidden, _ = self.attention_cell(hidden, batch_H, char_onehots)
|
| 43 |
+
output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
|
| 44 |
+
probs = self.generator(output_hiddens)
|
| 45 |
+
|
| 46 |
+
else:
|
| 47 |
+
targets = torch.LongTensor(batch_size).fill_(0).to(self.device) # [GO] token
|
| 48 |
+
probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(self.device)
|
| 49 |
+
|
| 50 |
+
for i in range(num_steps):
|
| 51 |
+
char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes)
|
| 52 |
+
hidden, _ = self.attention_cell(hidden, batch_H, char_onehots)
|
| 53 |
+
probs_step = self.generator(hidden[0])
|
| 54 |
+
probs[:, i, :] = probs_step
|
| 55 |
+
_, next_input = probs_step.max(1)
|
| 56 |
+
targets = next_input
|
| 57 |
+
|
| 58 |
+
return probs # batch_size x num_steps x num_classes
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class AttentionCell(nn.Module):
|
| 62 |
+
|
| 63 |
+
def __init__(self, input_size, hidden_size, num_embeddings):
|
| 64 |
+
super(AttentionCell, self).__init__()
|
| 65 |
+
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
|
| 66 |
+
self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
|
| 67 |
+
self.score = nn.Linear(hidden_size, 1, bias=False)
|
| 68 |
+
self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
|
| 69 |
+
self.hidden_size = hidden_size
|
| 70 |
+
|
| 71 |
+
def forward(self, prev_hidden, batch_H, char_onehots):
|
| 72 |
+
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
|
| 73 |
+
batch_H_proj = self.i2h(batch_H)
|
| 74 |
+
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
|
| 75 |
+
e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
|
| 76 |
+
|
| 77 |
+
alpha = F.softmax(e, dim=1)
|
| 78 |
+
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
| 79 |
+
concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
|
| 80 |
+
cur_hidden = self.rnn(concat_context, prev_hidden)
|
| 81 |
+
return cur_hidden, alpha
|
modules/sequence_modeling.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class BidirectionalLSTM(nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 7 |
+
super(BidirectionalLSTM, self).__init__()
|
| 8 |
+
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
|
| 9 |
+
self.linear = nn.Linear(hidden_size * 2, output_size)
|
| 10 |
+
|
| 11 |
+
def forward(self, input):
|
| 12 |
+
"""
|
| 13 |
+
input : visual feature [batch_size x T x input_size]
|
| 14 |
+
output : contextual feature [batch_size x T x output_size]
|
| 15 |
+
"""
|
| 16 |
+
self.rnn.flatten_parameters()
|
| 17 |
+
recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
|
| 18 |
+
output = self.linear(recurrent) # batch_size x T x output_size
|
| 19 |
+
return output
|
| 20 |
+
|
| 21 |
+
class LSTM(nn.Module):
|
| 22 |
+
|
| 23 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 24 |
+
super(LSTM, self).__init__()
|
| 25 |
+
self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
|
| 26 |
+
self.linear = nn.Linear(hidden_size, output_size)
|
| 27 |
+
|
| 28 |
+
def forward(self, input):
|
| 29 |
+
"""
|
| 30 |
+
input : visual feature [batch_size x T x input_size]
|
| 31 |
+
output : contextual feature [batch_size x T x output_size]
|
| 32 |
+
"""
|
| 33 |
+
self.rnn.flatten_parameters()
|
| 34 |
+
recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x hidden_size
|
| 35 |
+
output = self.linear(recurrent) # batch_size x T x output_size
|
| 36 |
+
return output
|
read.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
from utils import NormalizePAD
|
| 7 |
+
|
| 8 |
+
import warnings
|
| 9 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 10 |
+
|
| 11 |
+
def text_recognizer(img_cropped, model, converter, device):
|
| 12 |
+
""" Image processing """
|
| 13 |
+
img = img_cropped.convert('L')
|
| 14 |
+
img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
|
| 15 |
+
w, h = img.size
|
| 16 |
+
ratio = w / float(h)
|
| 17 |
+
if math.ceil(32 * ratio) > 400:
|
| 18 |
+
resized_w = 400
|
| 19 |
+
else:
|
| 20 |
+
resized_w = math.ceil(32 * ratio)
|
| 21 |
+
img = img.resize((resized_w, 32), Image.Resampling.BICUBIC)
|
| 22 |
+
transform = NormalizePAD((1, 32, 400))
|
| 23 |
+
img = transform(img)
|
| 24 |
+
img = img.unsqueeze(0)
|
| 25 |
+
batch_size = 1
|
| 26 |
+
img = img.to(device)
|
| 27 |
+
|
| 28 |
+
""" Prediction """
|
| 29 |
+
preds = model(img)
|
| 30 |
+
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
| 31 |
+
_, preds_index = preds.max(2)
|
| 32 |
+
preds_str = converter.decode(preds_index.data, preds_size.data)[0]
|
| 33 |
+
return preds_str
|
| 34 |
+
|
| 35 |
+
# if __name__ == '__main__':
|
| 36 |
+
# image_path = "test.jpg"
|
| 37 |
+
# img_cropped = Image.open(image_path)
|
| 38 |
+
# preds_str = text_recognizer(img_cropped)
|
| 39 |
+
# print(preds_str)
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.9.1
|
| 2 |
+
torchvision>=0.10.1
|
| 3 |
+
PyArabic==0.6.15
|
| 4 |
+
arabic-reshaper==3.0.0
|
| 5 |
+
numpy==1.21.6
|
| 6 |
+
Pillow==9.4.0
|
| 7 |
+
tqdm==4.65.0
|
| 8 |
+
opencv-python==4.5.1.48
|
| 9 |
+
opencv-contrib-python==4.5.1.48
|
| 10 |
+
kraken==4.3.13
|
| 11 |
+
gradio==4.15.0
|
utils.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
|
| 6 |
+
import warnings
|
| 7 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 8 |
+
|
| 9 |
+
class NormalizePAD(object):
|
| 10 |
+
|
| 11 |
+
def __init__(self, max_size, PAD_type='right'):
|
| 12 |
+
self.toTensor = T.ToTensor()
|
| 13 |
+
self.max_size = max_size
|
| 14 |
+
self.max_width_half = math.floor(max_size[2] / 2)
|
| 15 |
+
self.PAD_type = PAD_type
|
| 16 |
+
|
| 17 |
+
def __call__(self, img):
|
| 18 |
+
img = self.toTensor(img)
|
| 19 |
+
img.sub_(0.5).div_(0.5)
|
| 20 |
+
c, h, w = img.size()
|
| 21 |
+
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
|
| 22 |
+
Pad_img[:, :, :w] = img # right pad
|
| 23 |
+
if self.max_size[2] != w: # add border Pad
|
| 24 |
+
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
|
| 25 |
+
|
| 26 |
+
return Pad_img
|
| 27 |
+
|
| 28 |
+
class CTCLabelConverter(object):
|
| 29 |
+
""" Convert between text-label and text-index """
|
| 30 |
+
|
| 31 |
+
def __init__(self, character):
|
| 32 |
+
# character (str): set of the possible characters.
|
| 33 |
+
dict_character = list(character)
|
| 34 |
+
|
| 35 |
+
self.dict = {}
|
| 36 |
+
for i, char in enumerate(dict_character):
|
| 37 |
+
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
|
| 38 |
+
self.dict[char] = i + 1
|
| 39 |
+
|
| 40 |
+
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
|
| 41 |
+
|
| 42 |
+
def encode(self, text, batch_max_length=25):
|
| 43 |
+
"""convert text-label into text-index.
|
| 44 |
+
input:
|
| 45 |
+
text: text labels of each image. [batch_size]
|
| 46 |
+
batch_max_length: max length of text label in the batch. 25 by default
|
| 47 |
+
|
| 48 |
+
output:
|
| 49 |
+
text: text index for CTCLoss. [batch_size, batch_max_length]
|
| 50 |
+
length: length of each text. [batch_size]
|
| 51 |
+
"""
|
| 52 |
+
length = [len(s) for s in text]
|
| 53 |
+
|
| 54 |
+
# The index used for padding (=0) would not affect the CTC loss calculation.
|
| 55 |
+
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
|
| 56 |
+
for i, t in enumerate(text):
|
| 57 |
+
text = list(t)
|
| 58 |
+
text = [self.dict[char] for char in text]
|
| 59 |
+
batch_text[i][:len(text)] = torch.LongTensor(text)
|
| 60 |
+
return (batch_text, torch.IntTensor(length))
|
| 61 |
+
|
| 62 |
+
def decode(self, text_index, length):
|
| 63 |
+
""" convert text-index into text-label. """
|
| 64 |
+
texts = []
|
| 65 |
+
for index, l in enumerate(length):
|
| 66 |
+
t = text_index[index, :]
|
| 67 |
+
|
| 68 |
+
char_list = []
|
| 69 |
+
for i in range(l):
|
| 70 |
+
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
| 71 |
+
char_list.append(self.character[t[i]])
|
| 72 |
+
text = ''.join(char_list)
|
| 73 |
+
|
| 74 |
+
texts.append(text)
|
| 75 |
+
return texts
|