Spaces:
Sleeping
Sleeping
File size: 5,028 Bytes
ff2d389 afa022d ff2d389 afa022d ff2d389 cb49109 ff2d389 cb49109 ff2d389 6441848 ff2d389 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import zipfile
import os
import gradio as gr
from PIL import Image
CHARS = "~=" + " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.'-!?:;\""
BLANK = 0
PAD = 1
CHARS_DICT = {c: i for i, c in enumerate(CHARS)}
TEXTLEN = 30
tokens_list = list(CHARS_DICT.keys())
silence_token = '|'
if silence_token not in tokens_list:
tokens_list.append(silence_token)
def fit_picture(img):
target_height = 32
target_width = 400
# Calculate resize dimensions
aspect_ratio = img.width / img.height
if aspect_ratio > (target_width / target_height):
resize_width = target_width
resize_height = int(target_width / aspect_ratio)
else:
resize_height = target_height
resize_width = int(target_height * aspect_ratio)
# Resize transformation
resize_transform = torchvision.transforms.Resize((resize_height, resize_width))
# Pad transformation
padding_height = (target_height - resize_height) if target_height > resize_height else 0
padding_width = (target_width - resize_width) if target_width > resize_width else 0
pad_transform = torchvision.transforms.Pad((0, 0, padding_width, padding_height), fill=0, padding_mode='constant')
transformss = torchvision.transforms.Compose([
torchvision.transforms.Grayscale(num_output_channels = 1),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(0.5,0.5),
resize_transform,
pad_transform
])
fin_img = transformss(img)
return fin_img
def load_model(filename):
data = torch.load(filename, map_location=torch.device('cpu'), weights_only=True)
recognizer.load_state_dict(data["recognizer"])
optimizer.load_state_dict(data["optimizer"])
def ctc_decode_sequence(seq):
"""Removes blanks and repetitions from the sequence."""
ret = []
prev = BLANK
for x in seq:
if prev != BLANK and prev != x:
ret.append(prev)
prev = x
if seq[-1] == 66:
ret.append(66)
return ret
def ctc_decode(codes):
"""Decode a batch of sequences."""
ret = []
for cs in codes.T:
ret.append(ctc_decode_sequence(cs))
return ret
def decode_text(codes):
chars = [CHARS[c] for c in codes]
return ''.join(chars)
class Residual(torch.nn.Module):
def __init__(self, in_channels, out_channels, stride, pdrop = 0.2):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, stride, 1)
self.bn1 = torch.nn.BatchNorm2d(out_channels)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1)
self.bn2 = torch.nn.BatchNorm2d(out_channels)
if in_channels != out_channels or stride != 1:
self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, stride, 0)
else:
self.skip = torch.nn.Identity()
self.dropout = torch.nn.Dropout2d(pdrop)
def forward(self, x):
y = torch.nn.functional.relu(self.bn1(self.conv1(x)))
y = torch.nn.functional.relu(self.bn2(self.conv2(y)) + self.skip(x))
y = self.dropout(y)
return y
class TextRecognizer(torch.nn.Module):
def __init__(self, labels):
super().__init__()
self.feature_extractor = torch.nn.Sequential(
Residual(1, 32, 1),
Residual(32, 32, 2),
Residual(32, 32, 1),
Residual(32, 64, 2),
Residual(64, 64, 1),
Residual(64, 128, (2,1)),
Residual(128, 128, 1),
Residual(128, 128, (2,1)),
Residual(128, 128, (2,1)),
)
self.recurrent = torch.nn.LSTM(128, 128, 1 ,bidirectional = True)
self.output = torch.nn.Linear(256, labels)
def forward(self, x):
x = self.feature_extractor(x)
x = x.squeeze(2)
x = x.permute(2,0,1)
x,_ = self.recurrent(x)
x = self.output(x)
return x
recognizer = TextRecognizer(len(CHARS))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
LR = 1e-3
recognizer.to(DEVICE)
optimizer = torch.optim.Adam(recognizer.parameters(), lr=LR)
load_model('model.pt')
recognizer.eval()
def ctc_read(image):
imagefin = fit_picture(image)
image_tensor = imagefin.unsqueeze(0).to(DEVICE)
print(image_tensor.size())
with torch.no_grad():
scores = recognizer(image_tensor)
predictions = scores.argmax(2).cpu().numpy()
decoded_sequences = ctc_decode(predictions)
# Convert decoded sequences to text
for i in decoded_sequences:
decoded_text = decode_text(i)
return decoded_text
# Gradio Interface
iface = gr.Interface(
fn=ctc_read,
inputs=gr.Image(type="pil"), # PIL Image input
outputs="text", # Text output
title="Handwritten Text Recognition",
description="Upload an image, and the custome AI will extract the text."
)
iface.launch(share=True) |