atlury commited on
Commit
34e259e
·
verified ·
1 Parent(s): d2cecea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -35
app.py CHANGED
@@ -1,47 +1,56 @@
1
  import gradio as gr
2
  from ultralytics import YOLO
3
- import spaces
 
 
 
4
  import torch
5
-
6
- # Load pre-trained YOLOv8 model
7
- model = YOLO("yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt")
8
-
9
- # Get class names from model
10
- class_names = model.names
11
-
12
- @spaces.GPU(duration=60)
 
 
 
 
 
 
 
 
13
  def process_image(image):
14
- try:
15
- # Process the image
16
- results = model(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
17
- result = results[0] # Get the first result
18
 
19
- # Extract annotated image and labels with class names
20
- annotated_image = result.plot()
 
21
 
22
- # Use cls attribute for labels and get class name from model, DO NOT use .item() on box.conf
23
- detected_areas_labels = "\n".join([
24
- f"{class_names[int(box.cls.item())].upper()}: {box.conf:.2f}" for box in result.boxes
25
- ])
26
 
27
- return annotated_image, detected_areas_labels
28
- except Exception as e:
29
- return None, f"Error processing image: {e}
30
 
31
-
32
- # Create the Gradio Interface
33
- with gr.Blocks() as demo:
34
- gr.Markdown("# Document Segmentation Demo (ZeroGPU)")
35
- # Input Components
36
- input_image = gr.Image(type="pil", label="Upload Image")
37
-
38
- # Output Components
39
  output_image = gr.Image(type="pil", label="Annotated Image")
40
  output_text = gr.Textbox(label="Detected Areas and Labels")
41
 
42
- # Button to trigger inference
43
- btn = gr.Button("Run Document Segmentation")
44
- btn.click(fn=process_image, inputs=input_image, outputs=[output_image, output_text])
 
 
 
 
45
 
46
- # Launch the demo
47
- demo.queue(max_size=1).launch() # Queue to handle concurrent requests
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
+ import cv2
4
+ import numpy as np
5
+ import os
6
+ import requests
7
  import torch
8
+ import spaces # Import spaces to use ZeroGPU functionality
9
+
10
+ # Ensure the model file is in the correct location
11
+ model_path = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
12
+ if not os.path.exists(model_path):
13
+ # Download the model file if it doesn't exist
14
+ model_url = "https://huggingface.co/DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet/resolve/main/yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
15
+ response = requests.get(model_url)
16
+ with open(model_path, "wb") as f:
17
+ f.write(response.content)
18
+
19
+ # Load the document segmentation model
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+ docseg_model = YOLO(model_path).to(device)
22
+
23
+ @spaces.GPU
24
  def process_image(image):
25
+ # Convert image to the format YOLO model expects
26
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
27
+ results = docseg_model(image)
 
28
 
29
+ # Extract annotated image from results
30
+ annotated_img = results[0].plot()
31
+ annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
32
 
33
+ # Prepare detected areas and labels as text output
34
+ detected_areas_labels = "\n".join(
35
+ [f"{box.label}: {box.conf:.2f}" for box in results[0].boxes]
36
+ )
37
 
38
+ return annotated_img, detected_areas_labels
 
 
39
 
40
+ # Define the Gradio interface
41
+ with gr.Blocks() as interface:
42
+ gr.Markdown("### Document Segmentation using YOLOv8")
43
+ input_image = gr.Image(type="pil", label="Input Image")
 
 
 
 
44
  output_image = gr.Image(type="pil", label="Annotated Image")
45
  output_text = gr.Textbox(label="Detected Areas and Labels")
46
 
47
+ gr.Button("Run").click(
48
+ fn=process_image,
49
+ inputs=input_image,
50
+ outputs=[output_image, output_text]
51
+ )
52
+
53
+ interface.launch()
54
 
55
+ if __name__ == "__main__":
56
+ interface.launch()