Obai33 commited on
Commit
ff2d389
·
verified ·
1 Parent(s): d199483

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/17w1I1LKrJAebkjqIeNAKHQDirlY8Xxsw
8
+ """
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torchvision
13
+ import matplotlib.pyplot as plt
14
+ import zipfile
15
+ import os
16
+ import gradio as gr
17
+ from PIL import Image
18
+
19
+ CHARS = "~=" + " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.'-!?:;\""
20
+ BLANK = 0
21
+ PAD = 1
22
+ CHARS_DICT = {c: i for i, c in enumerate(CHARS)}
23
+ TEXTLEN = 30
24
+
25
+ tokens_list = list(CHARS_DICT.keys())
26
+ silence_token = '|'
27
+
28
+ if silence_token not in tokens_list:
29
+ tokens_list.append(silence_token)
30
+
31
+
32
+ def fit_picture(img):
33
+ target_height = 32
34
+ target_width = 400
35
+
36
+ # Calculate resize dimensions
37
+ aspect_ratio = img.width / img.height
38
+ if aspect_ratio > (target_width / target_height):
39
+ resize_width = target_width
40
+ resize_height = int(target_width / aspect_ratio)
41
+ else:
42
+ resize_height = target_height
43
+ resize_width = int(target_height * aspect_ratio)
44
+
45
+ # Resize transformation
46
+ resize_transform = transforms.Resize((resize_height, resize_width))
47
+
48
+ # Pad transformation
49
+ padding_height = (target_height - resize_height) if target_height > resize_height else 0
50
+ padding_width = (target_width - resize_width) if target_width > resize_width else 0
51
+ pad_transform = transforms.Pad((0, 0, padding_width, padding_height), fill=0, padding_mode='constant')
52
+
53
+ transform = torchvision.transforms.Compose([
54
+ torchvision.transforms.Grayscale(num_output_channels = 1),
55
+ torchvision.transforms.ToTensor(),
56
+ torchvision.transforms.Normalize(0.5,0.5),
57
+ resize_transform,
58
+ pad_transform
59
+ ])
60
+
61
+ fin_img = transform(img)
62
+ return fin_img
63
+
64
+ def load_model(filename):
65
+ data = torch.load(filename)
66
+ recognizer.load_state_dict(data["recognizer"])
67
+ optimizer.load_state_dict(data["optimizer"])
68
+
69
+ def ctc_decode_sequence(seq):
70
+ """Removes blanks and repetitions from the sequence."""
71
+ ret = []
72
+ prev = BLANK
73
+ for x in seq:
74
+ if prev != BLANK and prev != x:
75
+ ret.append(prev)
76
+ prev = x
77
+ if seq[-1] == 66:
78
+ ret.append(66)
79
+ return ret
80
+
81
+ def ctc_decode(codes):
82
+ """Decode a batch of sequences."""
83
+ ret = []
84
+ for cs in codes.T:
85
+ ret.append(ctc_decode_sequence(cs))
86
+ return ret
87
+
88
+
89
+ def decode_text(codes):
90
+ chars = [CHARS[c] for c in codes]
91
+ return ''.join(chars)
92
+
93
+ class Residual(torch.nn.Module):
94
+ def __init__(self, in_channels, out_channels, stride, pdrop = 0.2):
95
+ super().__init__()
96
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, stride, 1)
97
+ self.bn1 = torch.nn.BatchNorm2d(out_channels)
98
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1)
99
+ self.bn2 = torch.nn.BatchNorm2d(out_channels)
100
+ if in_channels != out_channels or stride != 1:
101
+ self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, stride, 0)
102
+ else:
103
+ self.skip = torch.nn.Identity()
104
+ self.dropout = torch.nn.Dropout2d(pdrop)
105
+
106
+ def forward(self, x):
107
+ y = torch.nn.functional.relu(self.bn1(self.conv1(x)))
108
+ y = torch.nn.functional.relu(self.bn2(self.conv2(y)) + self.skip(x))
109
+ y = self.dropout(y)
110
+ return y
111
+
112
+ class TextRecognizer(torch.nn.Module):
113
+ def __init__(self, labels):
114
+ super().__init__()
115
+ self.feature_extractor = torch.nn.Sequential(
116
+ Residual(1, 32, 1),
117
+ Residual(32, 32, 2),
118
+ Residual(32, 32, 1),
119
+ Residual(32, 64, 2),
120
+ Residual(64, 64, 1),
121
+ Residual(64, 128, (2,1)),
122
+ Residual(128, 128, 1),
123
+ Residual(128, 128, (2,1)),
124
+ Residual(128, 128, (2,1)),
125
+ )
126
+ self.recurrent = torch.nn.LSTM(128, 128, 1 ,bidirectional = True)
127
+ self.output = torch.nn.Linear(256, labels)
128
+
129
+ def forward(self, x):
130
+ x = self.feature_extractor(x)
131
+ x = x.squeeze(2)
132
+ x = x.permute(2,0,1)
133
+ x,_ = self.recurrent(x)
134
+ x = self.output(x)
135
+ return x
136
+
137
+ recognizer = TextRecognizer(len(CHARS))
138
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
139
+ print("Device:", DEVICE)
140
+ LR = 1e-3
141
+
142
+ recognizer.to(DEVICE)
143
+ optimizer = torch.optim.Adam(recognizer.parameters(), lr=LR)
144
+
145
+ load_model('model.pt')
146
+ recognizer.eval()
147
+
148
+ def ctc_read(image):
149
+ imagefin = fit_picture(image)
150
+ image_tensor = imagefin.unsqueeze(0).to(DEVICE)
151
+ print(image_tensor.size())
152
+
153
+ with torch.no_grad():
154
+ scores = recognizer(image_tensor)
155
+
156
+ predictions = scores.argmax(2).cpu().numpy()
157
+
158
+ decoded_sequences = ctc_decode(predictions)
159
+
160
+ # Convert decoded sequences to text
161
+ for i in decoded_sequences:
162
+ decoded_text = decode_text(i)
163
+
164
+ return decoded_text
165
+
166
+
167
+ # Gradio Interface
168
+ iface = gr.Interface(
169
+ fn=ctc_read,
170
+ inputs=gr.Image(type="pil"), # PIL Image input
171
+ outputs="text", # Text output
172
+ title="Handwritten Text Recognition",
173
+ description="Upload an image, and the custome AI will extract the text."
174
+ )
175
+
176
+ iface.launch(share=True)