omvishesh commited on
Commit
a809459
·
verified ·
1 Parent(s): c7da912

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -1,23 +1,30 @@
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  import os
4
- import re # Import regular expressions module
 
 
 
5
 
6
  # Load the OCR model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
 
 
8
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0',
9
  trust_remote_code=True,
10
  low_cpu_mem_usage=True,
11
- pad_token_id=tokenizer.eos_token_id).eval() # Removed device_map='cuda' and .cuda()
12
 
13
- # Define the function to process images and extract text
14
  def extract_text_from_image(image):
15
  # Save the uploaded image temporarily
16
  image_path = "temp_image.jpg"
17
  image.save(image_path)
18
 
19
- # Call the model to perform OCR
20
- extracted_text = model.chat(tokenizer, image_path, ocr_type='ocr')
 
21
 
22
  # Remove the temporary image file
23
  os.remove(image_path)
 
1
+ import torch
2
  import gradio as gr
3
  from transformers import AutoModel, AutoTokenizer
4
  import os
5
+ import re # For keyword searching and highlighting
6
+
7
+ # Ensure that the code runs on the CPU
8
+ device = torch.device('cpu')
9
 
10
  # Load the OCR model and tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
12
+
13
+ # Load the model onto the CPU
14
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0',
15
  trust_remote_code=True,
16
  low_cpu_mem_usage=True,
17
+ pad_token_id=tokenizer.eos_token_id).eval().to(device)
18
 
19
+ # Function to extract text from an image
20
  def extract_text_from_image(image):
21
  # Save the uploaded image temporarily
22
  image_path = "temp_image.jpg"
23
  image.save(image_path)
24
 
25
+ # Perform OCR using the model, ensuring it runs on CPU
26
+ with torch.no_grad():
27
+ extracted_text = model.chat(tokenizer, image_path, ocr_type='ocr')
28
 
29
  # Remove the temporary image file
30
  os.remove(image_path)