nutellalpb commited on
Commit
a6bd9de
·
verified ·
1 Parent(s): c8409d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -197
app.py CHANGED
@@ -1,201 +1,48 @@
1
- import gradio as gr
2
- import spaces
3
- from transformers import AutoModel, AutoTokenizer
4
  from PIL import Image
5
- import numpy as np
6
- import os
7
- import base64
8
- import io
9
- import uuid
10
- import tempfile
11
- import time
12
- import shutil
13
- from pathlib import Path
14
-
15
- tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
16
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True)
17
- model = model.eval().cuda()
18
-
19
- UPLOAD_FOLDER = "./uploads"
20
- RESULTS_FOLDER = "./results"
21
-
22
- for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
23
- if not os.path.exists(folder):
24
- os.makedirs(folder)
25
-
26
- def image_to_base64(image):
27
- buffered = io.BytesIO()
28
- image.save(buffered, format="PNG")
29
- return base64.b64encode(buffered.getvalue()).decode()
30
-
31
- @spaces.GPU
32
- def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
33
- unique_id = str(uuid.uuid4())
34
- image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
35
- result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
36
-
37
- shutil.copy(image, image_path)
38
-
39
- try:
40
- if got_mode == "plain texts OCR":
41
- res = model.chat(tokenizer, image_path, ocr_type='ocr')
42
- return res, None
43
- elif got_mode == "format texts OCR":
44
- res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
45
- elif got_mode == "plain multi-crop OCR":
46
- res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
47
- return res, None
48
- elif got_mode == "format multi-crop OCR":
49
- res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
50
- elif got_mode == "plain fine-grained OCR":
51
- res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
52
- return res, None
53
- elif got_mode == "format fine-grained OCR":
54
- res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
55
-
56
- # res_markdown = f"$$ {res} $$"
57
- res_markdown = res
58
-
59
- if "format" in got_mode and os.path.exists(result_path):
60
- with open(result_path, 'r') as f:
61
- html_content = f.read()
62
- encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
63
- iframe_src = f"data:text/html;base64,{encoded_html}"
64
- iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
65
- download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>'
66
- return res_markdown, f"{download_link}<br>{iframe}"
67
- else:
68
- return res_markdown, None
69
- except Exception as e:
70
- return f"Error: {str(e)}", None
71
- finally:
72
- if os.path.exists(image_path):
73
- os.remove(image_path)
74
-
75
- def task_update(task):
76
- if "fine-grained" in task:
77
- return [
78
- gr.update(visible=True),
79
- gr.update(visible=False),
80
- gr.update(visible=False),
81
- ]
82
- else:
83
- return [
84
- gr.update(visible=False),
85
- gr.update(visible=False),
86
- gr.update(visible=False),
87
- ]
88
-
89
- def fine_grained_update(task):
90
- if task == "box":
91
- return [
92
- gr.update(visible=False, value = ""),
93
- gr.update(visible=True),
94
- ]
95
- elif task == 'color':
96
- return [
97
- gr.update(visible=True),
98
- gr.update(visible=False, value = ""),
99
- ]
100
-
101
- def cleanup_old_files():
102
- current_time = time.time()
103
- for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
104
- for file_path in Path(folder).glob('*'):
105
- if current_time - file_path.stat().st_mtime > 3600: # 1 hour
106
- file_path.unlink()
107
-
108
- title_html = """
109
- <h2> <span class="gradient-text" id="text">General OCR Theory</span><span class="plain-text">: Towards OCR-2.0 via a Unified End-to-end Model</span></h2>
110
- <a href="https://huggingface.co/ucaslcl/GOT-OCR2_0">[😊 Hugging Face]</a>
111
- <a href="https://arxiv.org/abs/2409.01704">[📜 Paper]</a>
112
- <a href="https://github.com/Ucas-HaoranWei/GOT-OCR2.0/">[🌟 GitHub]</a>
113
- """
114
-
115
- with gr.Blocks() as demo:
116
- gr.HTML(title_html)
117
- gr.Markdown("""
118
- "🔥🔥🔥This is the official online demo of GOT-OCR-2.0 model!!!"
119
 
120
- ### Demo Guidelines
121
- You need to upload your image below and choose one mode of GOT, then click "Submit" to run GOT model. More characters will result in longer wait times.
122
- - **plain texts OCR & format texts OCR**: The two modes are for the image-level OCR.
123
- - **plain multi-crop OCR & format multi-crop OCR**: For images with more complex content, you can achieve higher-quality results with these modes.
124
- - **plain fine-grained OCR & format fine-grained OCR**: In these modes, you can specify fine-grained regions on the input image for more flexible OCR. Fine-grained regions can be coordinates of the box, red color, blue color, or green color.
125
- """)
126
 
127
- with gr.Row():
128
- with gr.Column():
129
- image_input = gr.Image(type="filepath", label="upload your image")
130
- task_dropdown = gr.Dropdown(
131
- choices=[
132
- "plain texts OCR",
133
- "format texts OCR",
134
- "plain multi-crop OCR",
135
- "format multi-crop OCR",
136
- "plain fine-grained OCR",
137
- "format fine-grained OCR",
138
- ],
139
- label="Choose one mode of GOT",
140
- value="plain texts OCR"
141
- )
142
- fine_grained_dropdown = gr.Dropdown(
143
- choices=["box", "color"],
144
- label="fine-grained type",
145
- visible=False
146
- )
147
- color_dropdown = gr.Dropdown(
148
- choices=["red", "green", "blue"],
149
- label="color list",
150
- visible=False
151
- )
152
- box_input = gr.Textbox(
153
- label="input box: [x1,y1,x2,y2]",
154
- placeholder="e.g., [0,0,100,100]",
155
- visible=False
156
- )
157
- submit_button = gr.Button("Submit")
158
-
159
- with gr.Column():
160
- ocr_result = gr.Textbox(label="GOT output")
161
-
162
- with gr.Column():
163
- gr.Markdown("**If you choose the mode with format, the mathpix result will be automatically rendered as follows:**")
164
- html_result = gr.HTML(label="rendered html", show_label=True)
165
-
166
- gr.Examples(
167
- examples=[
168
- ["assets/coco.jpg", "plain texts OCR", "", "", ""],
169
- ["assets/en_30.png", "plain texts OCR", "", "", ""],
170
- ["assets/eq.jpg", "format texts OCR", "", "", ""],
171
- ["assets/table.jpg", "format texts OCR", "", "", ""],
172
- ["assets/giga.jpg", "format multi-crop OCR", "", "", ""],
173
- ["assets/aff2.png", "plain fine-grained OCR", "box", "", "[409,763,756,891]"],
174
- ["assets/color.png", "plain fine-grained OCR", "color", "red", ""],
175
- ],
176
- inputs=[image_input, task_dropdown, fine_grained_dropdown, color_dropdown, box_input],
177
- outputs=[ocr_result, html_result],
178
- fn=run_GOT,
179
- label="examples",
180
- )
181
-
182
- task_dropdown.change(
183
- task_update,
184
- inputs=[task_dropdown],
185
- outputs=[fine_grained_dropdown, color_dropdown, box_input]
186
- )
187
- fine_grained_dropdown.change(
188
- fine_grained_update,
189
- inputs=[fine_grained_dropdown],
190
- outputs=[color_dropdown, box_input]
191
- )
192
 
193
- submit_button.click(
194
- run_GOT,
195
- inputs=[image_input, task_dropdown, fine_grained_dropdown, color_dropdown, box_input],
196
- outputs=[ocr_result, html_result]
197
- )
198
-
199
- if __name__ == "__main__":
200
- cleanup_old_files()
201
- demo.launch()
 
 
1
+ from transformers import AutoTokenizer, AutoModel
 
 
2
  from PIL import Image
3
+ import torch
4
+
5
+ # Charger le modèle GOT-OCR2_0 pour la reconnaissance des plaques d'immatriculation
6
+ ocr_tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
7
+ ocr_model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True)
8
+ ocr_model.eval().cuda()
9
+
10
+ # Charger le modèle de suivi et reconnaissance de couleur des véhicules
11
+ vehicle_model = torch.hub.load('ultralytics/yolov5', 'custom', path='sujithvamshi/Real-Time-Vehicle-Tracking-And-Colour-Recognition/best.pt')
12
+
13
+ # Fonction pour extraire la plaque d'immatriculation avec OCR
14
+ def get_license_plate(image):
15
+ # Utiliser le modèle GOT-OCR pour extraire le texte (plaque d'immatriculation)
16
+ image_tensor = ocr_tokenizer(image, return_tensors="pt").input_ids
17
+ with torch.no_grad():
18
+ output = ocr_model(image_tensor)
19
+ plate_text = ocr_tokenizer.decode(output.logits[0], skip_special_tokens=True)
20
+ return plate_text
21
+
22
+ # Fonction pour extraire la couleur du véhicule
23
+ def get_vehicle_color(image):
24
+ # Utiliser le modèle Real-Time-Vehicle-Tracking-And-Colour-Recognition pour obtenir la couleur
25
+ results = vehicle_model(image)
26
+ color_info = results.pandas().xyxy[0].color # Hypothèse: le modèle retourne une info de couleur
27
+ return color_info
28
+
29
+ # Fusionner les deux résultats
30
+ def process_image(image_path):
31
+ image = Image.open(image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # 1. Extraire la plaque d'immatriculation
34
+ license_plate = get_license_plate(image)
 
 
 
 
35
 
36
+ # 2. Extraire la couleur du véhicule
37
+ vehicle_color = get_vehicle_color(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # 3. Retourner la fusion des résultats
40
+ return {
41
+ "license_plate": license_plate,
42
+ "vehicle_color": vehicle_color
43
+ }
44
+
45
+ # Exemple d'utilisation
46
+ image_path = "path_to_your_image.jpg"
47
+ result = process_image(image_path)
48
+ print(f"Plaque d'immatriculation: {result['license_plate']}, Couleur du véhicule: {result['vehicle_color']}")