Spaces:
Running
Running
File size: 5,483 Bytes
93d9ee0 dd06629 4eb7c20 dd06629 4eb7c20 dd06629 4eb7c20 2ef412f 93d0893 2ef412f 93d0893 a044f10 93d0893 2ef412f 4eb7c20 2ef412f dd06629 2ef412f 4eb7c20 2ef412f 4eb7c20 2ef412f 4eb7c20 2ef412f 4eb7c20 8fcf400 2ef412f 4eb7c20 2ef412f 93d0893 4eb7c20 2ef412f 4eb7c20 93d0893 4eb7c20 62e4e64 4eb7c20 93d0893 2ef412f 4eb7c20 62e4e64 4eb7c20 2ef412f 4eb7c20 2ef412f 4eb7c20 2ef412f 4eb7c20 2ef412f 4eb7c20 2ef412f 4eb7c20 2ef412f 92af12f 2ef412f 4eb7c20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import cv2
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
from doctr.file_utils import is_tf_available
from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page
if is_tf_available():
import tensorflow as tf
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
if any(tf.config.experimental.list_physical_devices("gpu")):
forward_device = tf.device("/gpu:0")
else:
forward_device = tf.device("/cpu:0")
else:
import torch
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main(det_archs, reco_archs):
"""Build a streamlit layout"""
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("docTR: Document Text Recognition")
# For newline
st.write("\n")
# Instructions
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
# Set the columns
cols = st.columns((1, 1, 1, 1))
cols[0].subheader("Input page")
cols[1].subheader("Segmentation heatmap")
cols[2].subheader("OCR output")
cols[3].subheader("Page reconstitution")
# Sidebar
# File selection
st.sidebar.title("Document selection")
# Choose your own image
uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
if uploaded_file is not None:
if uploaded_file.name.endswith(".pdf"):
doc = DocumentFile.from_pdf(uploaded_file.read())
else:
doc = DocumentFile.from_images(uploaded_file.read())
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
page = doc[page_idx]
cols[0].image(page)
# Model selection
st.sidebar.title("Model selection")
st.sidebar.markdown("**Backend**: " + ("TensorFlow" if is_tf_available() else "PyTorch"))
det_arch = st.sidebar.selectbox("Text detection model", det_archs)
reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
# For newline
st.sidebar.write("\n")
# Only straight pages or possible rotation
st.sidebar.title("Parameters")
assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
# Disable page orientation detection
disable_page_orientation = st.sidebar.checkbox("Disable page orientation detection", value=False)
# Disable crop orientation detection
disable_crop_orientation = st.sidebar.checkbox("Disable crop orientation detection", value=False)
# Straighten pages
straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
# Export as straight boxes
export_straight_boxes = st.sidebar.checkbox("Export as straight boxes", value=False)
st.sidebar.write("\n")
# Binarization threshold
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
st.sidebar.write("\n")
# Box threshold
box_thresh = st.sidebar.slider("Box threshold", min_value=0.1, max_value=0.9, value=0.1, step=0.1)
st.sidebar.write("\n")
if st.sidebar.button("Analyze page"):
if uploaded_file is None:
st.sidebar.write("Please upload a document")
else:
with st.spinner("Loading model..."):
predictor = load_predictor(
det_arch=det_arch,
reco_arch=reco_arch,
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
export_as_straight_boxes=export_straight_boxes,
disable_page_orientation=disable_page_orientation,
disable_crop_orientation=disable_crop_orientation,
bin_thresh=bin_thresh,
box_thresh=box_thresh,
device=forward_device,
)
with st.spinner("Analyzing..."):
# Forward the image to the model
seg_map = forward_image(predictor, page, forward_device)
seg_map = np.squeeze(seg_map)
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
# Plot the raw heatmap
fig, ax = plt.subplots()
ax.imshow(seg_map)
ax.axis("off")
cols[1].pyplot(fig)
# Plot OCR output
out = predictor([page])
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
cols[2].pyplot(fig)
# Page reconsitution under input page
page_export = out.pages[0].export()
if assume_straight_pages or (not assume_straight_pages and straighten_pages):
img = out.pages[0].synthesize()
cols[3].image(img, clamp=True)
# Display JSON
st.markdown("\nHere are your analysis results in JSON format:")
st.json(page_export, expanded=False)
if __name__ == "__main__":
main(DET_ARCHS, RECO_ARCHS)
|