ARCQUB commited on
Commit
9ac9cd3
·
verified ·
1 Parent(s): e636262

Update models/gpt4o.py

Browse files
Files changed (1) hide show
  1. models/gpt4o.py +17 -16
models/gpt4o.py CHANGED
@@ -4,16 +4,13 @@ import os
4
  import json
5
  import base64
6
  from PIL import Image
7
- from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
- import numpy as np
9
-
10
  import openai
 
11
 
12
  model = "gpt-4o"
13
 
14
- # Load Pix2Struct model + processor (vision-language OCR)
15
- processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
16
- pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
17
 
18
 
19
  def load_prompt(prompt_file="prompts/prompt.txt"):
@@ -52,23 +49,29 @@ def encode_image_base64(image: Image.Image):
52
 
53
 
54
  def extract_all_text_pix2struct(image: Image.Image):
55
- inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
56
  predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
57
  output_text = processor.decode(predictions[0], skip_special_tokens=True)
58
  return output_text.strip()
59
 
60
 
61
- # Optional: assign best-matching label from full extracted text using proximity (simplified version)
62
  def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
63
  if not ocr_text:
64
  return json_data
65
 
66
- # You could use NLP matching or regex in complex cases
67
- words = ocr_text.split()
68
-
69
  def guess_name_fallback(obj):
70
  if not obj.get("name") or obj["name"].strip() == "":
71
- obj["name"] = "(label unknown)" # fallback if matching logic isn't yet implemented
72
 
73
  for evt in json_data.get("events", []):
74
  guess_name_fallback(evt)
@@ -83,9 +86,7 @@ def run_model(image: Image.Image, api_key: str = None):
83
  prompt_text = load_prompt()
84
  encoded_image = encode_image_base64(image)
85
 
86
- # Pull from environment if not passed explicitly
87
  api_key = api_key or os.getenv("OPENAI_API_KEY")
88
-
89
  if not api_key:
90
  return {"json": None, "raw": "⚠️ API key is missing. Please set it as a secret in your Space or upload it as a file."}
91
 
@@ -107,8 +108,8 @@ def run_model(image: Image.Image, api_key: str = None):
107
  output_text = response.choices[0].message.content.strip()
108
  parsed_json = try_extract_json(output_text)
109
 
110
- # Vision-language OCR assist step (Pix2Struct)
111
  full_ocr_text = extract_all_text_pix2struct(image)
112
  parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)
113
 
114
- return {"json": parsed_json, "raw": output_text}
 
4
  import json
5
  import base64
6
  from PIL import Image
 
 
 
7
  import openai
8
+ import torch
9
 
10
  model = "gpt-4o"
11
 
12
+ pix2struct_model = None
13
+ processor = None
 
14
 
15
 
16
  def load_prompt(prompt_file="prompts/prompt.txt"):
 
49
 
50
 
51
  def extract_all_text_pix2struct(image: Image.Image):
52
+ global pix2struct_model, processor
53
+
54
+ # Lazy-load the Pix2Struct model
55
+ if processor is None or pix2struct_model is None:
56
+ from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
57
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
58
+ pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(
59
+ "google/pix2struct-textcaps-base"
60
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
61
+
62
+ inputs = processor(images=image, return_tensors="pt").to(pix2struct_model.device)
63
  predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
64
  output_text = processor.decode(predictions[0], skip_special_tokens=True)
65
  return output_text.strip()
66
 
67
 
 
68
  def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
69
  if not ocr_text:
70
  return json_data
71
 
 
 
 
72
  def guess_name_fallback(obj):
73
  if not obj.get("name") or obj["name"].strip() == "":
74
+ obj["name"] = "(label unknown)"
75
 
76
  for evt in json_data.get("events", []):
77
  guess_name_fallback(evt)
 
86
  prompt_text = load_prompt()
87
  encoded_image = encode_image_base64(image)
88
 
 
89
  api_key = api_key or os.getenv("OPENAI_API_KEY")
 
90
  if not api_key:
91
  return {"json": None, "raw": "⚠️ API key is missing. Please set it as a secret in your Space or upload it as a file."}
92
 
 
108
  output_text = response.choices[0].message.content.strip()
109
  parsed_json = try_extract_json(output_text)
110
 
111
+ # Use Pix2Struct OCR enrichment
112
  full_ocr_text = extract_all_text_pix2struct(image)
113
  parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)
114
 
115
+ return {"json": parsed_json, "raw": output_text}