prithivMLmods commited on
Commit
9180057
·
verified ·
1 Parent(s): b537cee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -5
app.py CHANGED
@@ -10,10 +10,13 @@ import re
10
  import fitz # PyMuPDF
11
  import gradio as gr
12
  import requests
 
13
  from PIL import Image, ImageDraw, ImageFont
 
 
 
14
 
15
- from model import load_model, inference_dots_ocr, inference_dolphin
16
-
17
  js_func = """
18
  function refresh() {
19
  const url = new URL(window.location);
@@ -29,7 +32,7 @@ MIN_PIXELS = 3136
29
  MAX_PIXELS = 11289600
30
  IMAGE_FACTOR = 28
31
 
32
- # Prompts
33
  prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
34
 
35
  1. Bbox format: [x1, y1, x2, y2]
@@ -45,6 +48,77 @@ prompt = """Please output the layout information from the PDF image, including e
45
  5. Final Output: Single JSON object
46
  """
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # Load models at startup
49
  models = {
50
  "dots.ocr": load_model("dots.ocr"),
@@ -217,7 +291,7 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
217
  markdown_lines.append(f"**Table:** {text}\n")
218
  elif category == 'Formula':
219
  if text.strip().startswith('$') or '\\' in text:
220
- markdown_lines.append(f"$$ \n{text}\n $$\n")
221
  else:
222
  markdown_lines.append(f"**Formula:** {text}\n")
223
  elif category == 'Caption':
@@ -463,4 +537,4 @@ def create_gradio_interface():
463
 
464
  if __name__ == "__main__":
465
  demo = create_gradio_interface()
466
- demo.queue(max_size=30).launch(share=False, debug=True, show_error=True)
 
10
  import fitz # PyMuPDF
11
  import gradio as gr
12
  import requests
13
+ import torch
14
  from PIL import Image, ImageDraw, ImageFont
15
+ from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
16
+ from huggingface_hub import snapshot_download
17
+ from qwen_vl_utils import process_vision_info
18
 
19
+ # JavaScript for theme refresh
 
20
  js_func = """
21
  function refresh() {
22
  const url = new URL(window.location);
 
32
  MAX_PIXELS = 11289600
33
  IMAGE_FACTOR = 28
34
 
35
+ # Prompt for dots.ocr
36
  prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
37
 
38
  1. Bbox format: [x1, y1, x2, y2]
 
48
  5. Final Output: Single JSON object
49
  """
50
 
51
+ # Model loading functions (from model.py)
52
+ def load_model(model_name):
53
+ if model_name == "dots.ocr":
54
+ model_id = "rednote-hilab/dots.ocr"
55
+ model_path = "./models/dots-ocr-local"
56
+ snapshot_download(
57
+ repo_id=model_id,
58
+ local_dir=model_path,
59
+ local_dir_use_symlinks=False,
60
+ )
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ model_path,
63
+ attn_implementation="flash_attention_2",
64
+ torch_dtype=torch.bfloat16,
65
+ device_map="auto",
66
+ trust_remote_code=True
67
+ )
68
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
69
+ elif model_name == "Dolphin":
70
+ model_id = "ByteDance/Dolphin"
71
+ processor = AutoProcessor.from_pretrained(model_id)
72
+ model = VisionEncoderDecoderModel.from_pretrained(model_id)
73
+ model.eval()
74
+ device = "cuda" if torch.cuda.is_available() else "cpu"
75
+ model.to(device)
76
+ model = model.half() # Use half precision
77
+ else:
78
+ raise ValueError(f"Unknown model: {model_name}")
79
+ return model, processor
80
+
81
+ # Inference functions (from model.py)
82
+ def inference_dots_ocr(model, processor, image, prompt, max_new_tokens):
83
+ messages = [
84
+ {
85
+ "role": "user",
86
+ "content": [
87
+ {"type": "image", "image": image},
88
+ {"type": "text", "text": prompt}
89
+ ]
90
+ }
91
+ ]
92
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
+ image_inputs, video_inputs = process_vision_info(messages)
94
+ inputs = processor(
95
+ text=[text],
96
+ images=image_inputs,
97
+ videos=video_inputs,
98
+ padding=True,
99
+ return_tensors="pt",
100
+ )
101
+ inputs = inputs.to(model.device)
102
+ with torch.no_grad():
103
+ generated_ids = model.generate(
104
+ **inputs,
105
+ max_new_tokens=max_new_tokens,
106
+ do_sample=False # Removed temperature=0.1 to fix the warning
107
+ )
108
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
109
+ output_text = processor.batch_decode(
110
+ generated_ids_trimmed,
111
+ skip_special_tokens=True,
112
+ clean_up_tokenization_spaces=False
113
+ )
114
+ return output_text[0] if output_text else ""
115
+
116
+ def inference_dolphin(model, processor, image):
117
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(model.device).half()
118
+ generated_ids = model.generate(pixel_values)
119
+ generated_text = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
120
+ return generated_text
121
+
122
  # Load models at startup
123
  models = {
124
  "dots.ocr": load_model("dots.ocr"),
 
291
  markdown_lines.append(f"**Table:** {text}\n")
292
  elif category == 'Formula':
293
  if text.strip().startswith('$') or '\\' in text:
294
+ markdown_lines.append(f"$$ \n{text}\n $$\n")
295
  else:
296
  markdown_lines.append(f"**Formula:** {text}\n")
297
  elif category == 'Caption':
 
537
 
538
  if __name__ == "__main__":
539
  demo = create_gradio_interface()
540
+ demo.queue(max_size=30).launch(share=True, debug=True, show_error=True)