Shak33l-UiRev commited on
Commit
63da31e
·
verified ·
1 Parent(s): 95816fe

correct paths

Browse files

Using the correct model paths for OmniParser:

Icon detection: "microsoft/OmniParser-icon-detection"
Caption generation: "microsoft/OmniParser-caption"


Added better error handling and debug information:

Timestamps for debug messages
Color-coded messages by level
More detailed error information

Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -33,14 +33,14 @@ def load_model(model_name):
33
 
34
  elif model_name == "OmniParser":
35
  # Load YOLO model for icon detection
36
- yolo_model = YOLO('microsoft/OmniParser', task='detect')
37
- # Load Florence-2 model for captioning
38
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
39
  model = AutoModelForCausalLM.from_pretrained(
40
- "microsoft/OmniParser",
41
- torch_dtype=torch.float16,
42
  trust_remote_code=True
43
  )
 
44
  return {
45
  'yolo': yolo_model,
46
  'processor': processor,
@@ -48,6 +48,7 @@ def load_model(model_name):
48
  }
49
 
50
  return model, processor
 
51
  except Exception as e:
52
  st.error(f"Error loading model {model_name}: {str(e)}")
53
  return None, None
@@ -61,15 +62,14 @@ def analyze_document(image, model_name, model, processor):
61
  image.save(temp_path)
62
 
63
  # Configure box detection parameters
64
- box_threshold = 0.05 # Can be made configurable
65
- iou_threshold = 0.1 # Can be made configurable
66
 
67
  # Run YOLO detection
68
  yolo_results = model['yolo'](
69
  temp_path,
70
  conf=box_threshold,
71
- iou=iou_threshold,
72
- device='cpu' if not torch.cuda.is_available() else 'cuda'
73
  )
74
 
75
  # Process detections
@@ -80,7 +80,7 @@ def analyze_document(image, model_name, model, processor):
80
  # Get region of interest
81
  roi = image.crop((x1, y1, x2, y2))
82
 
83
- # Generate caption using Florence-2
84
  inputs = processor(images=roi, return_tensors="pt")
85
  outputs = model['model'].generate(**inputs, max_length=50)
86
  caption = processor.decode(outputs[0], skip_special_tokens=True)
@@ -97,8 +97,8 @@ def analyze_document(image, model_name, model, processor):
97
  "elements": results
98
  }
99
 
100
- # [Previous model handling remains the same...]
101
  elif model_name == "Donut":
 
102
  pixel_values = processor(image, return_tensors="pt").pixel_values
103
  task_prompt = "<s_cord>analyze the document and extract information</s_cord>"
104
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
@@ -125,6 +125,7 @@ def analyze_document(image, model_name, model, processor):
125
  result = {"raw_text": sequence}
126
 
127
  elif model_name == "LayoutLMv3":
 
128
  encoded_inputs = processor(
129
  image,
130
  return_tensors="pt",
@@ -154,9 +155,8 @@ def analyze_document(image, model_name, model, processor):
154
  return result
155
 
156
  except Exception as e:
157
- error_msg = str(e)
158
- st.error(f"Error analyzing document: {error_msg}")
159
- return {"error": error_msg, "type": "analysis_error"}
160
 
161
  # Set page config with improved layout
162
  st.set_page_config(
 
33
 
34
  elif model_name == "OmniParser":
35
  # Load YOLO model for icon detection
36
+ yolo_model = YOLO("microsoft/OmniParser-icon-detection")
37
+ # Load BLIP-2 model for captioning
38
+ processor = AutoProcessor.from_pretrained("microsoft/OmniParser-caption")
39
  model = AutoModelForCausalLM.from_pretrained(
40
+ "microsoft/OmniParser-caption",
 
41
  trust_remote_code=True
42
  )
43
+
44
  return {
45
  'yolo': yolo_model,
46
  'processor': processor,
 
48
  }
49
 
50
  return model, processor
51
+
52
  except Exception as e:
53
  st.error(f"Error loading model {model_name}: {str(e)}")
54
  return None, None
 
62
  image.save(temp_path)
63
 
64
  # Configure box detection parameters
65
+ box_threshold = 0.05
66
+ iou_threshold = 0.1
67
 
68
  # Run YOLO detection
69
  yolo_results = model['yolo'](
70
  temp_path,
71
  conf=box_threshold,
72
+ iou=iou_threshold
 
73
  )
74
 
75
  # Process detections
 
80
  # Get region of interest
81
  roi = image.crop((x1, y1, x2, y2))
82
 
83
+ # Generate caption using the model
84
  inputs = processor(images=roi, return_tensors="pt")
85
  outputs = model['model'].generate(**inputs, max_length=50)
86
  caption = processor.decode(outputs[0], skip_special_tokens=True)
 
97
  "elements": results
98
  }
99
 
 
100
  elif model_name == "Donut":
101
+ # Previous Donut code remains the same
102
  pixel_values = processor(image, return_tensors="pt").pixel_values
103
  task_prompt = "<s_cord>analyze the document and extract information</s_cord>"
104
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
 
125
  result = {"raw_text": sequence}
126
 
127
  elif model_name == "LayoutLMv3":
128
+ # Previous LayoutLMv3 code remains the same
129
  encoded_inputs = processor(
130
  image,
131
  return_tensors="pt",
 
155
  return result
156
 
157
  except Exception as e:
158
+ st.error(f"Error analyzing document: {str(e)}")
159
+ return {"error": str(e), "type": "analysis_error"}
 
160
 
161
  # Set page config with improved layout
162
  st.set_page_config(