Shak33l-UiRev commited on
Commit
956f2af
·
verified ·
1 Parent(s): 932131a

donut str error & omniparser path error

Browse files
Files changed (1) hide show
  1. app.py +65 -65
app.py CHANGED
@@ -45,17 +45,17 @@ def load_model(model_name):
45
 
46
  elif model_name == "OmniParser":
47
  # Load YOLO model for icon detection
48
- yolo_model = YOLO("microsoft/OmniParser")
49
 
50
  # Load Florence-2 processor and model for captioning
51
  processor = AutoProcessor.from_pretrained(
52
- "microsoft/Florence-2-base",
53
  trust_remote_code=True
54
  )
55
 
56
  # Load the captioning model
57
  caption_model = AutoModelForCausalLM.from_pretrained(
58
- "microsoft/OmniParser",
59
  trust_remote_code=True
60
  )
61
 
@@ -75,16 +75,7 @@ def load_model(model_name):
75
  @spaces.GPU
76
  @torch.inference_mode()
77
  def analyze_document(image, model_name, models_dict):
78
- """Analyze document using selected model
79
-
80
- Args:
81
- image (PIL.Image): Input image to analyze
82
- model_name (str): Name of the model to use ("Donut", "LayoutLMv3", or "OmniParser")
83
- models_dict (dict): Dictionary containing loaded model components
84
-
85
- Returns:
86
- dict: Analysis results including detected elements, text, and/or coordinates
87
- """
88
  try:
89
  if models_dict is None:
90
  return {"error": "Model failed to load", "type": "model_error"}
@@ -98,77 +89,82 @@ def analyze_document(image, model_name, models_dict):
98
  temp_path = "temp_image.png"
99
  image.save(temp_path)
100
 
101
- # Run YOLO detection
102
- yolo_results = models_dict['yolo'](
103
- temp_path,
104
- conf=box_threshold,
105
- iou=iou_threshold
106
- )
107
-
108
- # Process detections and generate captions
109
- results = []
110
- for det in yolo_results[0].boxes.data:
111
- x1, y1, x2, y2, conf, cls = det
112
-
113
- # Get region of interest
114
- roi = image.crop((int(x1), int(y1), int(x2), int(y2)))
115
-
116
- # Generate caption using the model
117
- inputs = models_dict['processor'](
118
- images=roi,
119
- return_tensors="pt"
120
- )
121
-
122
- outputs = models_dict['model'].generate(
123
- **inputs,
124
- max_length=50,
125
- num_beams=4,
126
- temperature=0.7
127
  )
128
 
129
- caption = models_dict['processor'].decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- results.append({
132
- "bbox": [float(x) for x in [x1, y1, x2, y2]],
133
- "confidence": float(conf),
134
- "class": int(cls),
135
- "caption": caption
136
- })
137
-
138
- # Clean up temporary file
139
- if os.path.exists(temp_path):
140
- os.remove(temp_path)
141
-
142
- return {
143
- "detected_elements": len(results),
144
- "elements": results
145
- }
146
 
147
  elif model_name == "Donut":
 
 
 
148
  # Process image with Donut
149
- pixel_values = models_dict['processor'](image, return_tensors="pt").pixel_values
150
 
151
  task_prompt = "<s_cord>analyze the document and extract information</s_cord>"
152
- decoder_input_ids = models_dict['processor'].tokenizer(
153
  task_prompt,
154
  add_special_tokens=False,
155
  return_tensors="pt"
156
  ).input_ids
157
 
158
- outputs = models_dict['model'].generate(
159
  pixel_values,
160
  decoder_input_ids=decoder_input_ids,
161
  max_length=512,
162
  early_stopping=True,
163
- pad_token_id=models_dict['processor'].tokenizer.pad_token_id,
164
- eos_token_id=models_dict['processor'].tokenizer.eos_token_id,
165
  use_cache=True,
166
  num_beams=4,
167
- bad_words_ids=[[models_dict['processor'].tokenizer.unk_token_id]],
168
  return_dict_in_generate=True
169
  )
170
 
171
- sequence = models_dict['processor'].batch_decode(outputs.sequences)[0]
172
  sequence = sequence.replace(task_prompt, "").replace("</s_cord>", "").strip()
173
 
174
  try:
@@ -179,19 +175,22 @@ def analyze_document(image, model_name, models_dict):
179
  return result
180
 
181
  elif model_name == "LayoutLMv3":
 
 
 
182
  # Process image with LayoutLMv3
183
- encoded_inputs = models_dict['processor'](
184
  image,
185
  return_tensors="pt",
186
  add_special_tokens=True,
187
  return_offsets_mapping=True
188
  )
189
 
190
- outputs = models_dict['model'](**encoded_inputs)
191
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
192
 
193
  # Convert predictions to labels
194
- words = models_dict['processor'].tokenizer.convert_ids_to_tokens(
195
  encoded_inputs.input_ids.squeeze().tolist()
196
  )
197
 
@@ -215,6 +214,7 @@ def analyze_document(image, model_name, models_dict):
215
  except Exception as e:
216
  import traceback
217
  error_details = traceback.format_exc()
 
218
  return {
219
  "error": str(e),
220
  "type": "processing_error",
 
45
 
46
  elif model_name == "OmniParser":
47
  # Load YOLO model for icon detection
48
+ yolo_model = YOLO("microsoft/OmniParser-icon-detection")
49
 
50
  # Load Florence-2 processor and model for captioning
51
  processor = AutoProcessor.from_pretrained(
52
+ "microsoft/OmniParser-caption",
53
  trust_remote_code=True
54
  )
55
 
56
  # Load the captioning model
57
  caption_model = AutoModelForCausalLM.from_pretrained(
58
+ "microsoft/OmniParser-caption",
59
  trust_remote_code=True
60
  )
61
 
 
75
  @spaces.GPU
76
  @torch.inference_mode()
77
  def analyze_document(image, model_name, models_dict):
78
+ """Analyze document using selected model"""
 
 
 
 
 
 
 
 
 
79
  try:
80
  if models_dict is None:
81
  return {"error": "Model failed to load", "type": "model_error"}
 
89
  temp_path = "temp_image.png"
90
  image.save(temp_path)
91
 
92
+ try:
93
+ # Run YOLO detection
94
+ yolo_results = models_dict['yolo'](
95
+ temp_path,
96
+ conf=box_threshold,
97
+ iou=iou_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
99
 
100
+ # Process detections and generate captions
101
+ results = []
102
+ for det in yolo_results[0].boxes.data:
103
+ x1, y1, x2, y2, conf, cls = det
104
+
105
+ # Get region of interest
106
+ roi = image.crop((int(x1), int(y1), int(x2), int(y2)))
107
+
108
+ # Generate caption using the model
109
+ inputs = models_dict['processor'](
110
+ images=roi,
111
+ return_tensors="pt"
112
+ )
113
+
114
+ outputs = models_dict['model'].generate(
115
+ **inputs,
116
+ max_length=50,
117
+ num_beams=4,
118
+ temperature=0.7
119
+ )
120
+
121
+ caption = models_dict['processor'].decode(outputs[0], skip_special_tokens=True)
122
+
123
+ results.append({
124
+ "bbox": [float(x) for x in [x1, y1, x2, y2]],
125
+ "confidence": float(conf),
126
+ "class": int(cls),
127
+ "caption": caption
128
+ })
129
+
130
+ return {
131
+ "detected_elements": len(results),
132
+ "elements": results
133
+ }
134
 
135
+ finally:
136
+ # Clean up temporary file
137
+ if os.path.exists(temp_path):
138
+ os.remove(temp_path)
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  elif model_name == "Donut":
141
+ model = models_dict['model']
142
+ processor = models_dict['processor']
143
+
144
  # Process image with Donut
145
+ pixel_values = processor(image, return_tensors="pt").pixel_values
146
 
147
  task_prompt = "<s_cord>analyze the document and extract information</s_cord>"
148
+ decoder_input_ids = processor.tokenizer(
149
  task_prompt,
150
  add_special_tokens=False,
151
  return_tensors="pt"
152
  ).input_ids
153
 
154
+ outputs = model.generate(
155
  pixel_values,
156
  decoder_input_ids=decoder_input_ids,
157
  max_length=512,
158
  early_stopping=True,
159
+ pad_token_id=processor.tokenizer.pad_token_id,
160
+ eos_token_id=processor.tokenizer.eos_token_id,
161
  use_cache=True,
162
  num_beams=4,
163
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
164
  return_dict_in_generate=True
165
  )
166
 
167
+ sequence = processor.batch_decode(outputs.sequences)[0]
168
  sequence = sequence.replace(task_prompt, "").replace("</s_cord>", "").strip()
169
 
170
  try:
 
175
  return result
176
 
177
  elif model_name == "LayoutLMv3":
178
+ model = models_dict['model']
179
+ processor = models_dict['processor']
180
+
181
  # Process image with LayoutLMv3
182
+ encoded_inputs = processor(
183
  image,
184
  return_tensors="pt",
185
  add_special_tokens=True,
186
  return_offsets_mapping=True
187
  )
188
 
189
+ outputs = model(**encoded_inputs)
190
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
191
 
192
  # Convert predictions to labels
193
+ words = processor.tokenizer.convert_ids_to_tokens(
194
  encoded_inputs.input_ids.squeeze().tolist()
195
  )
196
 
 
214
  except Exception as e:
215
  import traceback
216
  error_details = traceback.format_exc()
217
+ logger.error(f"Analysis error: {str(e)}\n{error_details}")
218
  return {
219
  "error": str(e),
220
  "type": "processing_error",