File size: 2,083 Bytes
3a8de33
d1bb7e2
c6111b8
3a8de33
e6b9318
c6111b8
e6b9318
 
 
c6111b8
d1bb7e2
30abd6a
d1bb7e2
e6b9318
 
c6111b8
e6b9318
2653a83
d1bb7e2
 
 
e6b9318
 
d1bb7e2
e6b9318
2653a83
 
e6b9318
 
2653a83
e6b9318
 
 
 
d1bb7e2
2653a83
fa36a00
e6b9318
 
 
3a8de33
e6b9318
 
c6111b8
9164d6d
c6111b8
9164d6d
c6111b8
9164d6d
 
 
c6111b8
 
2653a83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr 
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import numpy as np
import torch

# Load the primary model (DeepDiveDev/transformodocs-ocr)
processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")

# Load the fallback model (microsoft/trocr-base-handwritten)
processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

# Function to extract text using both models
def extract_text(image):
    try:
        # Convert NumPy array to PIL Image if needed
        if isinstance(image, np.ndarray):
            if len(image.shape) == 2:  # Grayscale (H, W), convert to RGB
                image = np.stack([image] * 3, axis=-1)
            image = Image.fromarray(image)
        else:
            image = Image.open(image).convert("RGB")  # Ensure RGB mode

        # Maintain aspect ratio while resizing
        image.thumbnail((640, 640))  

        # Process with the primary model
        pixel_values = processor1(images=image, return_tensors="pt").pixel_values.to(torch.float32)
        generated_ids = model1.generate(pixel_values)
        extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # If output seems incorrect, use the fallback model
        if len(extracted_text.strip()) < 2:  
            inputs = processor2(images=image, return_tensors="pt").pixel_values.to(torch.float32)
            generated_ids = model2.generate(inputs)
            extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return extracted_text

    except Exception as e:
        return f"Error: {str(e)}"

# Gradio Interface
iface = gr.Interface(
    fn=extract_text,
    inputs="image",
    outputs="text",
    title="TransformoDocs - AI OCR",
    description="Upload a handwritten document and get the extracted text.",
)

iface.launch()