allenocr / app.py
flam123's picture
Update app.py
5fbe48d verified
raw
history blame
4.91 kB
'''import os
import uuid
import time
from threading import Thread
import gradio as gr
import torch
from PIL import Image
from transformers import (
Qwen2VLForConditionalGeneration,
AutoProcessor,
TextIteratorStreamer,
)
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load olmOCR-7B-0225-preview
MODEL_ID = "allenai/olmOCR-7B-0225-preview"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
def generate_image(text: str, image: Image.Image,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
"""
Generates responses using olmOCR-7B-0225-preview for image input.
"""
if image is None:
yield "Please upload an image.", "Please upload an image."
return
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text},
]
}]
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt_full],
images=[image],
return_tensors="pt",
padding=True,
truncation=False,
max_length=MAX_INPUT_TOKEN_LENGTH
).to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer, buffer
def save_to_md(output_text):
file_path = f"result_{uuid.uuid4()}.md"
with open(file_path, "w") as f:
f.write(output_text)
return file_path
# Gradio UI
image_examples = [
["Convert this page to doc [text] precisely.", "images/3.png"],
["Convert this page to doc [text] precisely.", "images/4.png"],
["Convert this page to doc [text] precisely.", "images/1.png"],
["Convert chart to OTSL.", "images/2.png"]
]
css = """
.submit-btn {
background-color: #2980b9 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
.canvas-output {
border: 2px solid #4682B4;
border-radius: 10px;
padding: 20px;
}
"""
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
gr.Markdown("# **Doc OCR - olmOCR-7B-0225-preview**")
with gr.Row():
with gr.Column():
image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
image_upload = gr.Image(type="pil", label="Upload Image")
image_submit = gr.Button("Submit", elem_classes="submit-btn")
gr.Examples(
examples=image_examples,
inputs=[image_query, image_upload]
)
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
with gr.Column():
with gr.Column(elem_classes="canvas-output"):
gr.Markdown("## Output")
output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
with gr.Accordion("Result.md", open=False):
markdown_output = gr.Markdown(label="(Result.md)")
gr.Markdown("**Model: olmOCR-7B-0225-preview**")
gr.Markdown("> [`olmOCR-7B`](https://huggingface.co/allenai/olmOCR-7B-0225-preview) is optimized for high-fidelity document OCR and LaTeX-aware image-to-text tasks.")
image_submit.click(
fn=generate_image,
inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[output, markdown_output]
)
if __name__ == "__main__":
demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)'''