Spaces:
Sleeping
Sleeping
File size: 3,660 Bytes
8093104 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import json
import re
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
# Set Hugging Face Token
hf_token = os.getenv("HF_TOKEN")
# Initialize Aya Vision Model
model_id = "CohereForAI/aya-vision-8b"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.float16
)
# Initialize Pix2Struct OCR Model
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
# Load prompt
def load_prompt():
with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
return f.read()
# Try extracting JSON from model output
def try_extract_json(text):
if not text or not text.strip():
return None
try:
return json.loads(text)
except json.JSONDecodeError:
# Try extracting JSON substring by brace balancing
start = text.find('{')
if start == -1:
return None
brace_count = 0
json_candidate = ''
for i in range(start, len(text)):
char = text[i]
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
json_candidate += char
if brace_count == 0:
break
try:
return json.loads(json_candidate)
except json.JSONDecodeError:
return None
# Extract OCR text using Pix2Struct
def extract_all_text_pix2struct(image: Image.Image):
inputs = ocr_processor(images=image, return_tensors="pt")
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
return output_text.strip()
# Assign event/gateway names from OCR text
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
if not ocr_text or not json_data:
return json_data
lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
def assign_best_guess(obj):
if not obj.get("name") or obj["name"].strip() == "":
obj["name"] = "(label unknown)"
for evt in json_data.get("events", []):
assign_best_guess(evt)
for gw in json_data.get("gateways", []):
assign_best_guess(gw)
return json_data
# Run Aya model on image
def run_model(image: Image.Image):
prompt = load_prompt()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
inputs = processor.apply_chat_template(
messages,
padding=True,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
gen_tokens = model.generate(
**inputs,
max_new_tokens=5000,
do_sample=True,
temperature=0.3,
)
output_text = processor.tokenizer.decode(
gen_tokens[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
parsed_json = try_extract_json(output_text)
# Apply OCR post-processing
ocr_text = extract_all_text_pix2struct(image)
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
# Return both parsed and raw
return {
"json": parsed_json,
"raw": output_text
}
|