snap-assist / app.py
ilovetensor's picture
Update app.py
d445667 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from PIL import Image
import os
import re
import torch
# Load the GOT model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='auto', use_safetensors=True)
model = model.eval().to(device)
def extract_text(image):
if image is None:
return "No image uploaded", ""
image_path = 'temp_image.png'
image.save(image_path)
try:
res = model.chat(tokenizer, image_path, ocr_type='ocr')
return res, res
except Exception as e:
return f"Error: {str(e)}", ""
finally:
if os.path.exists(image_path):
os.remove(image_path)
def keyword_search(extracted_text, keyword):
if not extracted_text:
return "No text extracted yet."
if not keyword:
return extracted_text
# Escape HTML special characters
extracted_text = extracted_text.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;')
# Use regular expressions to find matches, ignoring case
pattern = re.compile(re.escape(keyword), re.IGNORECASE)
highlighted_text = pattern.sub(lambda x: f"<mark>{x.group()}</mark>", extracted_text)
return highlighted_text
with gr.Blocks() as demo:
gr.Markdown("# OCR and Document Search Web Application")
extracted_text_state = gr.State()
with gr.Column():
image_input = gr.Image(type="pil", label="Upload an image")
extract_button = gr.Button("Extract Text")
extracted_text_output = gr.Textbox(label="Extracted Text", lines=10)
keyword_input = gr.Textbox(label="Enter keyword to search")
search_button = gr.Button("Search")
search_results_output = gr.HTML(label="Search Results")
extract_button.click(fn=extract_text, inputs=image_input, outputs=[extracted_text_output, extracted_text_state])
search_button.click(fn=keyword_search, inputs=[extracted_text_state, keyword_input], outputs=search_results_output)
if __name__ == "__main__":
demo.launch()