smoldocling-preview / backends /smoldocling.py
taprosoft
fix: limit token count
9c2f030
raw
history blame
2.33 kB
# Prerequisites:
# pip install torch
# pip install docling_core
# pip install transformers
import torch
from docling_core.types.doc import DoclingDocument
from docling_core.types.doc.document import DocTagsDocument
from img2table.document import PDF
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_PAGES = 1
# Initialize processor and model
processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
model = AutoModelForVision2Seq.from_pretrained(
"ds4sd/SmolDocling-256M-preview",
torch_dtype=torch.bfloat16,
_attn_implementation="eager",
).to(DEVICE)
# Create input messages
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Convert this page to docling."},
],
},
]
def convert_smoldocling(path: str, file_name: str):
doc = PDF(path)
output_md = ""
for image in doc.images[:MAX_PAGES]:
# convert ndarray to Image
image = Image.fromarray(image)
# resize image to maximum width of 1200
max_width = 1200
if image.width > max_width:
image = image.resize(
(max_width, int(max_width * image.height / image.width))
)
# Prepare inputs
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(DEVICE)
# Generate outputs
generated_ids = model.generate(**inputs, max_new_tokens=4096)
prompt_length = inputs.input_ids.shape[1]
trimmed_generated_ids = generated_ids[:, prompt_length:]
doctags = processor.batch_decode(
trimmed_generated_ids,
skip_special_tokens=False,
)[0].lstrip()
# Populate document
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
# create a docling document
doc = DoclingDocument(name="Document")
doc.load_from_doctags(doctags_doc)
# export as any format
# HTML
# doc.save_as_html(output_file)
# MD
output_md += doc.export_to_markdown() + "\n\n"
return output_md, []