Shak33l-UiRev commited on
Commit
dea33ff
·
verified ·
1 Parent(s): 63da31e

updated device management

Browse files
Files changed (1) hide show
  1. app.py +104 -43
app.py CHANGED
@@ -17,7 +17,14 @@ from datetime import datetime
17
 
18
  @st.cache_resource
19
  def load_model(model_name):
20
- """Load the selected model and processor"""
 
 
 
 
 
 
 
21
  try:
22
  if model_name == "Donut":
23
  processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
@@ -27,63 +34,98 @@ def load_model(model_name):
27
  model.config.pad_token_id = processor.tokenizer.pad_token_id
28
  model.config.vocab_size = len(processor.tokenizer)
29
 
 
 
30
  elif model_name == "LayoutLMv3":
31
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
32
  model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
33
 
 
 
34
  elif model_name == "OmniParser":
35
  # Load YOLO model for icon detection
36
- yolo_model = YOLO("microsoft/OmniParser-icon-detection")
37
- # Load BLIP-2 model for captioning
38
- processor = AutoProcessor.from_pretrained("microsoft/OmniParser-caption")
39
- model = AutoModelForCausalLM.from_pretrained(
40
- "microsoft/OmniParser-caption",
 
 
 
 
 
 
41
  trust_remote_code=True
42
  )
43
 
44
  return {
45
  'yolo': yolo_model,
46
  'processor': processor,
47
- 'model': model
48
  }
49
 
50
- return model, processor
51
-
 
52
  except Exception as e:
53
  st.error(f"Error loading model {model_name}: {str(e)}")
54
- return None, None
55
 
56
- def analyze_document(image, model_name, model, processor):
57
- """Analyze document using selected model"""
 
 
 
 
 
 
 
 
 
 
 
58
  try:
 
 
 
59
  if model_name == "OmniParser":
60
- # Save image temporarily
 
 
 
 
61
  temp_path = "temp_image.png"
62
  image.save(temp_path)
63
 
64
- # Configure box detection parameters
65
- box_threshold = 0.05
66
- iou_threshold = 0.1
67
-
68
  # Run YOLO detection
69
- yolo_results = model['yolo'](
70
  temp_path,
71
  conf=box_threshold,
72
  iou=iou_threshold
73
  )
74
 
75
- # Process detections
76
  results = []
77
  for det in yolo_results[0].boxes.data:
78
  x1, y1, x2, y2, conf, cls = det
79
 
80
  # Get region of interest
81
- roi = image.crop((x1, y1, x2, y2))
82
 
83
  # Generate caption using the model
84
- inputs = processor(images=roi, return_tensors="pt")
85
- outputs = model['model'].generate(**inputs, max_length=50)
86
- caption = processor.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
87
 
88
  results.append({
89
  "bbox": [float(x) for x in [x1, y1, x2, y2]],
@@ -92,31 +134,40 @@ def analyze_document(image, model_name, model, processor):
92
  "caption": caption
93
  })
94
 
 
 
 
 
95
  return {
96
  "detected_elements": len(results),
97
  "elements": results
98
  }
99
 
100
  elif model_name == "Donut":
101
- # Previous Donut code remains the same
102
- pixel_values = processor(image, return_tensors="pt").pixel_values
 
103
  task_prompt = "<s_cord>analyze the document and extract information</s_cord>"
104
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
 
 
 
 
105
 
106
- outputs = model.generate(
107
  pixel_values,
108
  decoder_input_ids=decoder_input_ids,
109
  max_length=512,
110
  early_stopping=True,
111
- pad_token_id=processor.tokenizer.pad_token_id,
112
- eos_token_id=processor.tokenizer.eos_token_id,
113
  use_cache=True,
114
  num_beams=4,
115
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
116
  return_dict_in_generate=True
117
  )
118
 
119
- sequence = processor.batch_decode(outputs.sequences)[0]
120
  sequence = sequence.replace(task_prompt, "").replace("</s_cord>", "").strip()
121
 
122
  try:
@@ -124,19 +175,22 @@ def analyze_document(image, model_name, model, processor):
124
  except json.JSONDecodeError:
125
  result = {"raw_text": sequence}
126
 
 
 
127
  elif model_name == "LayoutLMv3":
128
- # Previous LayoutLMv3 code remains the same
129
- encoded_inputs = processor(
130
  image,
131
  return_tensors="pt",
132
  add_special_tokens=True,
133
  return_offsets_mapping=True
134
  )
135
 
136
- outputs = model(**encoded_inputs)
137
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
138
 
139
- words = processor.tokenizer.convert_ids_to_tokens(
 
140
  encoded_inputs.input_ids.squeeze().tolist()
141
  )
142
 
@@ -152,11 +206,19 @@ def analyze_document(image, model_name, model, processor):
152
  "confidence_scores": outputs.logits.softmax(-1).max(-1).values.squeeze().tolist()
153
  }
154
 
155
- return result
 
 
 
156
 
157
  except Exception as e:
158
- st.error(f"Error analyzing document: {str(e)}")
159
- return {"error": str(e), "type": "analysis_error"}
 
 
 
 
 
160
 
161
  # Set page config with improved layout
162
  st.set_page_config(
@@ -372,6 +434,7 @@ st.markdown("""
372
  """)
373
 
374
  # Add performance metrics section
 
375
  if st.checkbox("Show Performance Metrics"):
376
  st.markdown("""
377
  ### Model Performance Metrics
@@ -379,8 +442,7 @@ if st.checkbox("Show Performance Metrics"):
379
  |-------|---------------------|--------------|-----------|
380
  | Donut | 2-3 seconds | 6-8GB | 85-90% |
381
  | LayoutLMv3 | 3-4 seconds | 12-15GB | 88-93% |
382
- | BROS | 1-2 seconds | 4-6GB | 82-87% |
383
- | LLaVA-1.5 | 4-5 seconds | 25-40GB | 90-95% |
384
 
385
  *Accuracy varies based on document type and quality
386
  """)
@@ -389,7 +451,7 @@ if st.checkbox("Show Performance Metrics"):
389
  st.markdown("---")
390
  st.markdown("""
391
  v1.1 - Created with Streamlit
392
- \nFor issues or feedback, please visit our [GitHub repository](https://github.com/yourusername/doc-analysis)
393
  """)
394
 
395
  # Add model selection guidance
@@ -398,6 +460,5 @@ if st.checkbox("Show Model Selection Guide"):
398
  ### How to Choose the Right Model
399
  1. **Donut**: Choose for structured documents with clear layouts
400
  2. **LayoutLMv3**: Best for documents with complex layouts and relationships
401
- 3. **BROS**: Ideal for quick analysis and simple documents
402
- 4. **LLaVA-1.5**: Perfect for complex documents requiring deep understanding
403
  """)
 
17
 
18
  @st.cache_resource
19
  def load_model(model_name):
20
+ """Load the selected model and processor
21
+
22
+ Args:
23
+ model_name (str): Name of the model to load ("Donut", "LayoutLMv3", or "OmniParser")
24
+
25
+ Returns:
26
+ dict: Dictionary containing model components
27
+ """
28
  try:
29
  if model_name == "Donut":
30
  processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
 
34
  model.config.pad_token_id = processor.tokenizer.pad_token_id
35
  model.config.vocab_size = len(processor.tokenizer)
36
 
37
+ return {'model': model, 'processor': processor}
38
+
39
  elif model_name == "LayoutLMv3":
40
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
41
  model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
42
 
43
+ return {'model': model, 'processor': processor}
44
+
45
  elif model_name == "OmniParser":
46
  # Load YOLO model for icon detection
47
+ yolo_model = YOLO("microsoft/OmniParser")
48
+
49
+ # Load Florence-2 processor and model for captioning
50
+ processor = AutoProcessor.from_pretrained(
51
+ "microsoft/Florence-2-base",
52
+ trust_remote_code=True
53
+ )
54
+
55
+ # Load the captioning model
56
+ caption_model = AutoModelForCausalLM.from_pretrained(
57
+ "microsoft/OmniParser",
58
  trust_remote_code=True
59
  )
60
 
61
  return {
62
  'yolo': yolo_model,
63
  'processor': processor,
64
+ 'model': caption_model
65
  }
66
 
67
+ else:
68
+ raise ValueError(f"Unknown model name: {model_name}")
69
+
70
  except Exception as e:
71
  st.error(f"Error loading model {model_name}: {str(e)}")
72
+ return None
73
 
74
+ @spaces.GPU
75
+ @torch.inference_mode()
76
+ def analyze_document(image, model_name, models_dict):
77
+ """Analyze document using selected model
78
+
79
+ Args:
80
+ image (PIL.Image): Input image to analyze
81
+ model_name (str): Name of the model to use ("Donut", "LayoutLMv3", or "OmniParser")
82
+ models_dict (dict): Dictionary containing loaded model components
83
+
84
+ Returns:
85
+ dict: Analysis results including detected elements, text, and/or coordinates
86
+ """
87
  try:
88
+ if models_dict is None:
89
+ return {"error": "Model failed to load", "type": "model_error"}
90
+
91
  if model_name == "OmniParser":
92
+ # Configure detection parameters
93
+ box_threshold = 0.05 # Confidence threshold for detection
94
+ iou_threshold = 0.1 # IoU threshold for NMS
95
+
96
+ # Save image temporarily for YOLO processing
97
  temp_path = "temp_image.png"
98
  image.save(temp_path)
99
 
 
 
 
 
100
  # Run YOLO detection
101
+ yolo_results = models_dict['yolo'](
102
  temp_path,
103
  conf=box_threshold,
104
  iou=iou_threshold
105
  )
106
 
107
+ # Process detections and generate captions
108
  results = []
109
  for det in yolo_results[0].boxes.data:
110
  x1, y1, x2, y2, conf, cls = det
111
 
112
  # Get region of interest
113
+ roi = image.crop((int(x1), int(y1), int(x2), int(y2)))
114
 
115
  # Generate caption using the model
116
+ inputs = models_dict['processor'](
117
+ images=roi,
118
+ return_tensors="pt"
119
+ )
120
+
121
+ outputs = models_dict['model'].generate(
122
+ **inputs,
123
+ max_length=50,
124
+ num_beams=4,
125
+ temperature=0.7
126
+ )
127
+
128
+ caption = models_dict['processor'].decode(outputs[0], skip_special_tokens=True)
129
 
130
  results.append({
131
  "bbox": [float(x) for x in [x1, y1, x2, y2]],
 
134
  "caption": caption
135
  })
136
 
137
+ # Clean up temporary file
138
+ if os.path.exists(temp_path):
139
+ os.remove(temp_path)
140
+
141
  return {
142
  "detected_elements": len(results),
143
  "elements": results
144
  }
145
 
146
  elif model_name == "Donut":
147
+ # Process image with Donut
148
+ pixel_values = models_dict['processor'](image, return_tensors="pt").pixel_values
149
+
150
  task_prompt = "<s_cord>analyze the document and extract information</s_cord>"
151
+ decoder_input_ids = models_dict['processor'].tokenizer(
152
+ task_prompt,
153
+ add_special_tokens=False,
154
+ return_tensors="pt"
155
+ ).input_ids
156
 
157
+ outputs = models_dict['model'].generate(
158
  pixel_values,
159
  decoder_input_ids=decoder_input_ids,
160
  max_length=512,
161
  early_stopping=True,
162
+ pad_token_id=models_dict['processor'].tokenizer.pad_token_id,
163
+ eos_token_id=models_dict['processor'].tokenizer.eos_token_id,
164
  use_cache=True,
165
  num_beams=4,
166
+ bad_words_ids=[[models_dict['processor'].tokenizer.unk_token_id]],
167
  return_dict_in_generate=True
168
  )
169
 
170
+ sequence = models_dict['processor'].batch_decode(outputs.sequences)[0]
171
  sequence = sequence.replace(task_prompt, "").replace("</s_cord>", "").strip()
172
 
173
  try:
 
175
  except json.JSONDecodeError:
176
  result = {"raw_text": sequence}
177
 
178
+ return result
179
+
180
  elif model_name == "LayoutLMv3":
181
+ # Process image with LayoutLMv3
182
+ encoded_inputs = models_dict['processor'](
183
  image,
184
  return_tensors="pt",
185
  add_special_tokens=True,
186
  return_offsets_mapping=True
187
  )
188
 
189
+ outputs = models_dict['model'](**encoded_inputs)
190
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
191
 
192
+ # Convert predictions to labels
193
+ words = models_dict['processor'].tokenizer.convert_ids_to_tokens(
194
  encoded_inputs.input_ids.squeeze().tolist()
195
  )
196
 
 
206
  "confidence_scores": outputs.logits.softmax(-1).max(-1).values.squeeze().tolist()
207
  }
208
 
209
+ return result
210
+
211
+ else:
212
+ return {"error": f"Unknown model: {model_name}", "type": "model_error"}
213
 
214
  except Exception as e:
215
+ import traceback
216
+ error_details = traceback.format_exc()
217
+ return {
218
+ "error": str(e),
219
+ "type": "processing_error",
220
+ "details": error_details
221
+ }
222
 
223
  # Set page config with improved layout
224
  st.set_page_config(
 
434
  """)
435
 
436
  # Add performance metrics section
437
+
438
  if st.checkbox("Show Performance Metrics"):
439
  st.markdown("""
440
  ### Model Performance Metrics
 
442
  |-------|---------------------|--------------|-----------|
443
  | Donut | 2-3 seconds | 6-8GB | 85-90% |
444
  | LayoutLMv3 | 3-4 seconds | 12-15GB | 88-93% |
445
+ | OmniParser | 2-3 seconds | 8-10GB | 85-90% |
 
446
 
447
  *Accuracy varies based on document type and quality
448
  """)
 
451
  st.markdown("---")
452
  st.markdown("""
453
  v1.1 - Created with Streamlit
454
+ \nPowered by Hugging Face Spaces 🤗
455
  """)
456
 
457
  # Add model selection guidance
 
460
  ### How to Choose the Right Model
461
  1. **Donut**: Choose for structured documents with clear layouts
462
  2. **LayoutLMv3**: Best for documents with complex layouts and relationships
463
+ 3. **OmniParser**: Best for UI elements and screen parsing
 
464
  """)