omvishesh commited on
Commit
c86b97e
·
verified ·
1 Parent(s): a809459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -12
app.py CHANGED
@@ -1,30 +1,25 @@
1
- import torch
2
  import gradio as gr
3
  from transformers import AutoModel, AutoTokenizer
4
  import os
5
- import re # For keyword searching and highlighting
6
-
7
- # Ensure that the code runs on the CPU
8
- device = torch.device('cpu')
9
 
10
  # Load the OCR model and tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
12
-
13
- # Load the model onto the CPU
14
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0',
15
  trust_remote_code=True,
16
  low_cpu_mem_usage=True,
17
- pad_token_id=tokenizer.eos_token_id).eval().to(device)
 
 
18
 
19
- # Function to extract text from an image
20
  def extract_text_from_image(image):
21
  # Save the uploaded image temporarily
22
  image_path = "temp_image.jpg"
23
  image.save(image_path)
24
 
25
- # Perform OCR using the model, ensuring it runs on CPU
26
- with torch.no_grad():
27
- extracted_text = model.chat(tokenizer, image_path, ocr_type='ocr')
28
 
29
  # Remove the temporary image file
30
  os.remove(image_path)
 
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  import os
4
+ import re # Import regular expressions module
 
 
 
5
 
6
  # Load the OCR model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
 
 
8
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0',
9
  trust_remote_code=True,
10
  low_cpu_mem_usage=True,
11
+ device_map='cuda',
12
+ use_safetensors=True,
13
+ pad_token_id=tokenizer.eos_token_id).eval().cuda()
14
 
15
+ # Define the function to process images and extract text
16
  def extract_text_from_image(image):
17
  # Save the uploaded image temporarily
18
  image_path = "temp_image.jpg"
19
  image.save(image_path)
20
 
21
+ # Call the model to perform OCR
22
+ extracted_text = model.chat(tokenizer, image_path, ocr_type='ocr')
 
23
 
24
  # Remove the temporary image file
25
  os.remove(image_path)