imflash217 commited on
Commit
a6ce242
·
1 Parent(s): 4a89cb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -66
app.py CHANGED
@@ -1,86 +1,40 @@
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
3
  from PIL import Image
4
- from transformers import TrOCRProcessor
5
- from transformers import VisionEncoderDecoderModel
6
- import cv2
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- import warnings
10
-
11
- warnings.filterwarnings("ignore")
12
 
13
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
14
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
15
 
16
- def extract_text(image):
17
- # calling the processor is equivalent to calling the feature extractor
18
- pixel_values = processor(image, return_tensors="pt").pixel_values
19
- generated_ids = model.generate(pixel_values)
20
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
21
- return generated_text
22
-
23
- def hand_written(image_raw):
24
- image_raw = np.array(image_raw)
25
- image = cv2.cvtColor(image_raw,cv2.COLOR_BGR2GRAY)
26
- image = cv2.GaussianBlur(image,(5,5),0)
27
- image = cv2.threshold(image,200,255,cv2.THRESH_BINARY_INV)[1]
28
- kernal = cv2.getStructuringElement(cv2.MORPH_RECT,(10,1))
29
- image = cv2.dilate(image,kernal,iterations=5)
30
- contours,hier = cv2.findContours(image,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
31
- all_box = []
32
- for i in contours:
33
- bbox = cv2.boundingRect(i)
34
- all_box.append(bbox)
35
-
36
- # Calculate maximum rectangle height
37
- c = np.array(all_box)
38
- max_height = np.max(c[::, 3])
39
 
40
- # Sort the contours by y-value
41
- by_y = sorted(all_box, key=lambda x: x[1]) # y values
 
42
 
43
- line_y = by_y[0][1] # first y
44
- line = 1
45
- by_line = []
46
 
47
- # Assign a line number to each contour
48
- for x, y, w, h in by_y:
49
- if y > line_y + max_height:
50
- line_y = y
51
- line += 1
52
- by_line.append((line, x, y, w, h))
53
 
54
- # This will now sort automatically by line then by x
55
- contours_sorted = [(x, y, w, h) for line, x, y, w, h in sorted(by_line)]
56
 
57
- text = ""
58
-
59
- for line in contours_sorted:
60
- x,y,w,h = line
61
- cropped_image = image_raw[y:y+h,x:x+w]
62
- try:
63
- extracted = extract_text(cropped_image)
64
- if not extracted == "0 0" and not extracted == "0 1":
65
- text = "\n".join([text,extracted])
66
- except:
67
- print("skiping")
68
- pass
69
- return text
70
-
71
- ## gradio app
72
-
73
- title = "TrOCR + EN_ICR demo"
74
- description = "TrOCR Handwritten Recognizer"
75
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models</a> | <a href='https://github.com/microsoft/unilm/tree/master/trocr'>Github Repo</a></p>"
76
- examples =[["img_hw_0.png"]]
77
 
78
- iface = gr.Interface(fn=hand_written,
79
  inputs=gr.inputs.Image(type="pil"),
80
  outputs=gr.outputs.Textbox(),
81
  title=title,
82
  description=description,
83
  article=article,
84
  examples=examples)
85
-
86
- iface.launch(debug=True,share=True)
 
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ import requests
4
  from PIL import Image
 
 
 
 
 
 
 
 
5
 
6
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
7
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
8
 
9
+ # load image examples
10
+ urls = ['https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg', 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSoolxi9yWGAT5SLZShv8vVd0bz47UWRzQC19fDTeE8GmGv_Rn-PCF1pP1rrUx8kOjA4gg&usqp=CAU',
11
+ 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRNYtTuSBpZPV_nkBYPMFwVVD9asZOPgHww4epu9EqWgDmXW--sE2o8og40ZfDGo87j5w&usqp=CAU']
12
+ for idx, url in enumerate(urls):
13
+ image = Image.open(requests.get(url, stream=True).raw)
14
+ image.save(f"image_{idx}.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def process_image(image):
17
+ # prepare image
18
+ pixel_values = processor(image, return_tensors="pt").pixel_values
19
 
20
+ # generate (no beam search)
21
+ generated_ids = model.generate(pixel_values)
 
22
 
23
+ # decode
24
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
25
 
26
+ return generated_text
 
27
 
28
+ title = "TrOCR + EN_ICR"
29
+ description = "Demo for handwritten TrOCR"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models</a> | <a href='https://github.com/microsoft/unilm/tree/master/trocr'>Github Repo</a></p>"
31
+ examples =[["img_hw_0.png"], ["img_hw_1.png"]]
32
 
33
+ iface = gr.Interface(fn=process_image,
34
  inputs=gr.inputs.Image(type="pil"),
35
  outputs=gr.outputs.Textbox(),
36
  title=title,
37
  description=description,
38
  article=article,
39
  examples=examples)
40
+ iface.launch(debug=True)