Tonic commited on
Commit
3b378ec
Β·
unverified Β·
1 Parent(s): 3f89105

add sliders for variables , add cpu support

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -20,10 +20,12 @@ import re
20
 
21
  model_name = 'ucaslcl/GOT-OCR2_0'
22
 
 
 
23
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
24
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
25
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
26
- model = model.eval().cuda()
27
  model.config.pad_token_id = tokenizer.eos_token_id
28
 
29
  UPLOAD_FOLDER = "./uploads"
@@ -40,7 +42,7 @@ def image_to_base64(image):
40
 
41
 
42
  @spaces.GPU()
43
- def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None):
44
  if image is None:
45
  return "Error: No image provided", None, None
46
 
@@ -49,7 +51,7 @@ def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None):
49
  result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
50
 
51
  try:
52
- if isinstance(image, dict): # If image is from ImageEditor
53
  composite_image = image.get("composite")
54
  if composite_image is not None:
55
  if isinstance(composite_image, np.ndarray):
@@ -68,19 +70,19 @@ def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None):
68
  return "Error: Unsupported image format", None, None
69
 
70
  if task == "Plain Text OCR":
71
- res = model.chat(tokenizer, image_path, ocr_type='ocr')
72
  return res, None, unique_id
73
  else:
74
  if task == "Format Text OCR":
75
- res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
76
  elif task == "Fine-grained OCR (Box)":
77
- res = model.chat(tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path)
78
  elif task == "Fine-grained OCR (Color)":
79
- res = model.chat(tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path)
80
  elif task == "Multi-crop OCR":
81
- res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
82
  elif task == "Render Formatted OCR":
83
- res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
84
 
85
  if os.path.exists(result_path):
86
  with open(result_path, 'r') as f:
@@ -249,6 +251,10 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
249
  label="OCR Color",
250
  visible=False
251
  )
 
 
 
 
252
  submit_button = gr.Button("Process")
253
  editor_submit_button = gr.Button("Process Edited Image", visible=False)
254
 
 
20
 
21
  model_name = 'ucaslcl/GOT-OCR2_0'
22
 
23
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
+
25
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
26
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
27
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map=device, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
28
+ model = model.eval().to(device)
29
  model.config.pad_token_id = tokenizer.eos_token_id
30
 
31
  UPLOAD_FOLDER = "./uploads"
 
42
 
43
 
44
  @spaces.GPU()
45
+ def process_image(image, task, max_new_tokens, no_repeat_ngram_size, ocr_type=None, ocr_box=None, ocr_color=None):
46
  if image is None:
47
  return "Error: No image provided", None, None
48
 
 
51
  result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
52
 
53
  try:
54
+ if isinstance(image, dict):
55
  composite_image = image.get("composite")
56
  if composite_image is not None:
57
  if isinstance(composite_image, np.ndarray):
 
70
  return "Error: Unsupported image format", None, None
71
 
72
  if task == "Plain Text OCR":
73
+ res = model.chat(tokenizer, image_path, ocr_type='ocr', max_new_tokens=max_new_tokens, no_repeat_ngram_size=no_repeat_ngram_size)
74
  return res, None, unique_id
75
  else:
76
  if task == "Format Text OCR":
77
+ res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path, max_new_tokens=max_new_tokens, no_repeat_ngram_size=no_repeat_ngram_size)
78
  elif task == "Fine-grained OCR (Box)":
79
+ res = model.chat(tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path, max_new_tokens=max_new_tokens, no_repeat_ngram_size=no_repeat_ngram_size)
80
  elif task == "Fine-grained OCR (Color)":
81
+ res = model.chat(tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path, max_new_tokens=max_new_tokens, no_repeat_ngram_size=no_repeat_ngram_size)
82
  elif task == "Multi-crop OCR":
83
+ res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path, max_new_tokens=max_new_tokens, no_repeat_ngram_size=no_repeat_ngram_size)
84
  elif task == "Render Formatted OCR":
85
+ res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path, max_new_tokens=max_new_tokens, no_repeat_ngram_size=no_repeat_ngram_size)
86
 
87
  if os.path.exists(result_path):
88
  with open(result_path, 'r') as f:
 
251
  label="OCR Color",
252
  visible=False
253
  )
254
+ with gr.Row():
255
+ max_new_tokens_slider = gr.Slider(50, 500, step=10, value=150, label="Max New Tokens")
256
+ no_repeat_ngram_size_slider = gr.Slider(1, 10, step=1, value=2, label="No Repeat N-gram Size")
257
+
258
  submit_button = gr.Button("Process")
259
  editor_submit_button = gr.Button("Process Edited Image", visible=False)
260