ilovetensor commited on
Commit
13ebdd6
·
verified ·
1 Parent(s): 9038ad8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -74
app.py CHANGED
@@ -1,90 +1,59 @@
1
  import gradio as gr
2
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
 
 
 
3
  import torch
4
 
5
- # Check if CUDA is available and set device
6
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
- # Load the Qwen2-VL model and processor
9
-
10
- model = Qwen2VLForConditionalGeneration.from_pretrained(
11
- "Qwen/Qwen2-VL-2B-Instruct",
12
- trust_remote_code=True,
13
- torch_dtype=torch.bfloat16
14
- ).to(device).eval()
15
-
16
- processor = AutoProcessor.from_pretrained(
17
- "Qwen/Qwen2-VL-2B-Instruct",
18
- trust_remote_code=True
19
- )
20
 
21
  def extract_text(image):
22
- # Prompt for OCR extraction
23
- prompt = "Please extract all the text from the image, including any text in Hindi and English."
24
- # Prepare inputs
25
- inputs = processor(images=[image], text=prompt, return_tensors="pt").to(device)
26
- # Generate outputs
27
- outputs = model.generate(**inputs, max_new_tokens=500)
28
- # Decode the generated text
29
- generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
30
- return generated_text
31
-
32
- def search_text(extracted_text, keyword):
33
- import re
34
- # Compile regex pattern for case-insensitive search
 
 
 
 
 
 
 
 
35
  pattern = re.compile(re.escape(keyword), re.IGNORECASE)
36
- matches = pattern.finditer(extracted_text)
37
- # Highlight matching keywords
38
- highlighted_text = extracted_text
39
- offset = 0
40
- for match in matches:
41
- start, end = match.start() + offset, match.end() + offset
42
- # Insert HTML tags for highlighting
43
- highlighted_text = highlighted_text[:start] + "<mark>" + highlighted_text[start:end] + "</mark>" + highlighted_text[end:]
44
- offset += len("<mark></mark>")
45
  return highlighted_text
46
 
47
  with gr.Blocks() as demo:
48
- gr.Markdown("# OCR and Keyword Search Web Application Prototype")
49
- with gr.Row():
50
- image_input = gr.Image(type='pil', label="Upload an image containing text in Hindi and English")
 
 
 
51
  extract_button = gr.Button("Extract Text")
 
52
  extracted_text_output = gr.Textbox(label="Extracted Text", lines=10)
53
- with gr.Row():
54
- keyword_input = gr.Textbox(label="Enter keyword to search within the extracted text")
55
- search_button = gr.Button("Search")
56
  search_results_output = gr.HTML(label="Search Results")
57
-
58
- # State to store the extracted text
59
- extracted_text_state = gr.State()
60
-
61
- # Function to extract text and display
62
- def extract_and_display(image):
63
- extracted_text = extract_text(image)
64
- extracted_text_state.value = extracted_text
65
- return extracted_text
66
-
67
- # Function to search within the extracted text
68
- def search_and_display(keyword):
69
- extracted_text = extracted_text_state.value
70
- if not extracted_text:
71
- return "No extracted text available. Please upload an image and extract text first."
72
- highlighted_text = search_text(extracted_text, keyword)
73
- return highlighted_text
74
-
75
- # Set up button click events
76
- extract_button.click(
77
- fn=extract_and_display,
78
- inputs=image_input,
79
- outputs=extracted_text_output
80
- )
81
-
82
- search_button.click(
83
- fn=search_and_display,
84
- inputs=keyword_input,
85
- outputs=search_results_output
86
- )
87
 
88
- # Launch the Gradio app
 
89
 
90
- demo.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from PIL import Image
4
+ import os
5
+ import re
6
  import torch
7
 
8
+ # Load the GOT model
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
 
11
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
12
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='auto', use_safetensors=True)
13
+ model = model.eval().to(device)
 
 
 
 
 
 
 
 
 
14
 
15
  def extract_text(image):
16
+ # Save the image to a temporary file
17
+ image_path = 'temp_image.png'
18
+ image.save(image_path)
19
+ # Use the GOT model to extract text
20
+ try:
21
+ res = model.chat(tokenizer, image_path, ocr_type='ocr')
22
+ return res, res # Return the extracted text and also set it in the state variable
23
+ except Exception as e:
24
+ return f"Error: {str(e)}", ""
25
+ finally:
26
+ if os.path.exists(image_path):
27
+ os.remove(image_path)
28
+
29
+ def keyword_search(extracted_text, keyword):
30
+ if not extracted_text:
31
+ return "No text extracted yet."
32
+ if not keyword:
33
+ return extracted_text
34
+ # Escape HTML special characters
35
+ extracted_text = extracted_text.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;')
36
+ # Use regular expressions to find matches, ignoring case
37
  pattern = re.compile(re.escape(keyword), re.IGNORECASE)
38
+ highlighted_text = pattern.sub(lambda x: f"<mark>{x.group()}</mark>", extracted_text)
 
 
 
 
 
 
 
 
39
  return highlighted_text
40
 
41
  with gr.Blocks() as demo:
42
+ gr.Markdown("# OCR and Document Search Web Application")
43
+
44
+ extracted_text_state = gr.State()
45
+
46
+ with gr.Column():
47
+ image_input = gr.Image(type="pil", label="Upload an image")
48
  extract_button = gr.Button("Extract Text")
49
+
50
  extracted_text_output = gr.Textbox(label="Extracted Text", lines=10)
51
+ keyword_input = gr.Textbox(label="Enter keyword to search")
52
+ search_button = gr.Button("Search")
 
53
  search_results_output = gr.HTML(label="Search Results")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ extract_button.click(fn=extract_text, inputs=image_input, outputs=[extracted_text_output, extracted_text_state])
56
+ search_button.click(fn=keyword_search, inputs=[extracted_text_state, keyword_input], outputs=search_results_output)
57
 
58
+ if __name__ == "__main__":
59
+ demo.launch()