flam123 commited on
Commit
77fb47d
·
verified ·
1 Parent(s): 682baee

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +138 -0
main.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import time
4
+ from threading import Thread
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import (
10
+ Qwen2VLForConditionalGeneration,
11
+ AutoProcessor,
12
+ TextIteratorStreamer,
13
+ )
14
+
15
+ # Constants
16
+ MAX_MAX_NEW_TOKENS = 2048
17
+ DEFAULT_MAX_NEW_TOKENS = 1024
18
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+
21
+ # Load olmOCR-7B-0225-preview
22
+ MODEL_ID = "allenai/olmOCR-7B-0225-preview"
23
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
24
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
25
+ MODEL_ID,
26
+ trust_remote_code=True,
27
+ torch_dtype=torch.float16
28
+ ).to(device).eval()
29
+
30
+ def generate_image(text: str, image: Image.Image,
31
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
32
+ temperature: float = 0.6,
33
+ top_p: float = 0.9,
34
+ top_k: int = 50,
35
+ repetition_penalty: float = 1.2):
36
+ """
37
+ Generates responses using olmOCR-7B-0225-preview for image input.
38
+ """
39
+ if image is None:
40
+ yield "Please upload an image.", "Please upload an image."
41
+ return
42
+
43
+ messages = [{
44
+ "role": "user",
45
+ "content": [
46
+ {"type": "image", "image": image},
47
+ {"type": "text", "text": text},
48
+ ]
49
+ }]
50
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
51
+ inputs = processor(
52
+ text=[prompt_full],
53
+ images=[image],
54
+ return_tensors="pt",
55
+ padding=True,
56
+ truncation=False,
57
+ max_length=MAX_INPUT_TOKEN_LENGTH
58
+ ).to(device)
59
+
60
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
61
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
62
+
63
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
64
+ thread.start()
65
+
66
+ buffer = ""
67
+ for new_text in streamer:
68
+ buffer += new_text
69
+ time.sleep(0.01)
70
+ yield buffer, buffer
71
+
72
+ def save_to_md(output_text):
73
+ file_path = f"result_{uuid.uuid4()}.md"
74
+ with open(file_path, "w") as f:
75
+ f.write(output_text)
76
+ return file_path
77
+
78
+ # Gradio UI
79
+ image_examples = [
80
+ ["Convert this page to doc [text] precisely.", "images/3.png"],
81
+ ["Convert this page to doc [text] precisely.", "images/4.png"],
82
+ ["Convert this page to doc [text] precisely.", "images/1.png"],
83
+ ["Convert chart to OTSL.", "images/2.png"]
84
+ ]
85
+
86
+ css = """
87
+ .submit-btn {
88
+ background-color: #2980b9 !important;
89
+ color: white !important;
90
+ }
91
+ .submit-btn:hover {
92
+ background-color: #3498db !important;
93
+ }
94
+ .canvas-output {
95
+ border: 2px solid #4682B4;
96
+ border-radius: 10px;
97
+ padding: 20px;
98
+ }
99
+ """
100
+
101
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
102
+ gr.Markdown("# **Doc OCR - olmOCR-7B-0225-preview**")
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
107
+ image_upload = gr.Image(type="pil", label="Upload Image")
108
+ image_submit = gr.Button("Submit", elem_classes="submit-btn")
109
+ gr.Examples(
110
+ examples=image_examples,
111
+ inputs=[image_query, image_upload]
112
+ )
113
+
114
+ with gr.Accordion("Advanced options", open=False):
115
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
116
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
117
+ top_p = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
118
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
119
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
120
+
121
+ with gr.Column():
122
+ with gr.Column(elem_classes="canvas-output"):
123
+ gr.Markdown("## Output")
124
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
125
+ with gr.Accordion("Result.md", open=False):
126
+ markdown_output = gr.Markdown(label="(Result.md)")
127
+
128
+ gr.Markdown("**Model: olmOCR-7B-0225-preview**")
129
+ gr.Markdown("> [`olmOCR-7B`](https://huggingface.co/allenai/olmOCR-7B-0225-preview) is optimized for high-fidelity document OCR and LaTeX-aware image-to-text tasks.")
130
+
131
+ image_submit.click(
132
+ fn=generate_image,
133
+ inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
134
+ outputs=[output, markdown_output]
135
+ )
136
+
137
+ if __name__ == "__main__":
138
+ demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)