BhumikaMak commited on
Commit
9d244f3
·
verified ·
1 Parent(s): d00769c

update: netron embedding

Browse files
Files changed (1) hide show
  1. app.py +24 -39
app.py CHANGED
@@ -1,48 +1,38 @@
1
  import gradio as gr
2
- import netron
3
  import os
4
- import threading
5
  import time
6
- from PIL import Image
7
  import cv2
8
  import numpy as np
9
- import torch
10
- import base64
11
  from yolov5 import xai_yolov5
12
  from yolov8 import xai_yolov8s
13
 
14
- # Sample images directory
15
  sample_images = {
16
  "Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"),
17
  "Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"),
18
  }
 
 
19
 
20
- # Preloaded model file path (update this path as needed)
21
- preloaded_model_file = os.path.join(os.getcwd(), "weight_files/yolov5.onnx") # Example path
22
-
23
  def load_sample_image(sample_name):
24
- """Load a sample image based on user selection."""
25
  image_path = sample_images.get(sample_name)
26
  if image_path and os.path.exists(image_path):
27
  return Image.open(image_path)
28
  return None
29
 
 
30
  def process_image(sample_choice, uploaded_image, yolo_versions):
31
- """Process the image using selected YOLO models."""
32
  if uploaded_image is not None:
33
- image = uploaded_image # Use the uploaded image
34
  else:
35
- image = load_sample_image(sample_choice) # Use selected sample image
36
 
37
  image = np.array(image)
38
  image = cv2.resize(image, (640, 640))
39
  result_images = []
40
 
41
- # Encode image to base64
42
- _, buffer = cv2.imencode('.jpg', image)
43
- image_base64 = base64.b64encode(buffer).decode('utf-8')
44
-
45
- # Process image with each selected YOLO version
46
  for yolo_version in yolo_versions:
47
  if yolo_version == "yolov5":
48
  result_images.append(xai_yolov5(image))
@@ -53,19 +43,14 @@ def process_image(sample_choice, uploaded_image, yolo_versions):
53
 
54
  return result_images
55
 
56
- def serve_netron(model_file):
57
- """Start the Netron server in a separate thread."""
58
- threading.Thread(target=netron.start, args=(model_file,), daemon=True).start()
59
- time.sleep(1) # Give some time for the server to start
60
- return "http://localhost:8080" # Default Netron URL
61
-
62
  def view_model():
63
- """Handle model visualization using preloaded model file."""
64
- if not os.path.exists(preloaded_model_file):
65
- return "Model file not found."
66
-
67
- netron_url = serve_netron(preloaded_model_file)
68
- return f'<iframe src="{netron_url}" width="100%" height="600px"></iframe>'
69
 
70
  # Custom CSS for styling (optional)
71
  custom_css = """
@@ -78,13 +63,13 @@ custom_css = """
78
  }
79
  """
80
 
 
81
  with gr.Blocks(css=custom_css) as interface:
82
- gr.Markdown("# XAI: Visualize Object Detection of Your Models")
83
-
84
  default_sample = "Sample 1"
85
 
86
  with gr.Row():
87
- # Left side: Sample selection and upload image
88
  with gr.Column():
89
  sample_selection = gr.Radio(
90
  choices=list(sample_images.keys()),
@@ -112,7 +97,7 @@ with gr.Blocks(css=custom_css) as interface:
112
  label="Selected Sample Image",
113
  )
114
 
115
- # Below the sample image, display results and architecture side by side
116
  with gr.Row():
117
  result_gallery = gr.Gallery(
118
  label="Results",
@@ -120,9 +105,9 @@ with gr.Blocks(css=custom_css) as interface:
120
  rows=1,
121
  height=500,
122
  )
 
123
 
124
- netron_display = gr.HTML(label="Netron Visualization")
125
-
126
  sample_selection.change(
127
  fn=load_sample_image,
128
  inputs=sample_selection,
@@ -135,9 +120,9 @@ with gr.Blocks(css=custom_css) as interface:
135
  outputs=[result_gallery],
136
  )
137
 
138
- # Update Netron display when the interface loads
139
- netron_display.value = view_model() # Directly set the value
140
 
141
- # Launching Gradio app and handling Netron visualization separately.
142
  if __name__ == "__main__":
143
  interface.launch(share=True)
 
1
  import gradio as gr
 
2
  import os
 
3
  import time
 
4
  import cv2
5
  import numpy as np
6
+ from PIL import Image
 
7
  from yolov5 import xai_yolov5
8
  from yolov8 import xai_yolov8s
9
 
10
+ # Paths
11
  sample_images = {
12
  "Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"),
13
  "Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"),
14
  }
15
+ preloaded_model_file = os.path.join(os.getcwd(), "weight_files/yolov5.onnx") # Update as needed
16
+ netron_html_file = os.path.join(os.getcwd(), "model_visualization.html") # Netron exported file
17
 
18
+ # Load sample images
 
 
19
  def load_sample_image(sample_name):
 
20
  image_path = sample_images.get(sample_name)
21
  if image_path and os.path.exists(image_path):
22
  return Image.open(image_path)
23
  return None
24
 
25
+ # Process image with YOLO models
26
  def process_image(sample_choice, uploaded_image, yolo_versions):
 
27
  if uploaded_image is not None:
28
+ image = uploaded_image
29
  else:
30
+ image = load_sample_image(sample_choice)
31
 
32
  image = np.array(image)
33
  image = cv2.resize(image, (640, 640))
34
  result_images = []
35
 
 
 
 
 
 
36
  for yolo_version in yolo_versions:
37
  if yolo_version == "yolov5":
38
  result_images.append(xai_yolov5(image))
 
43
 
44
  return result_images
45
 
46
+ # Embed Netron visualization (Static HTML Export)
 
 
 
 
 
47
  def view_model():
48
+ if os.path.exists(netron_html_file):
49
+ with open(netron_html_file, "r") as f:
50
+ html_content = f.read()
51
+ return html_content
52
+ else:
53
+ return "<p style='color: red;'>Netron export not found. Generate HTML export and place it in the correct path.</p>"
54
 
55
  # Custom CSS for styling (optional)
56
  custom_css = """
 
63
  }
64
  """
65
 
66
+ # Gradio Interface
67
  with gr.Blocks(css=custom_css) as interface:
68
+ gr.Markdown("# XAI: Visualize Object Detection with YOLO Models")
69
+
70
  default_sample = "Sample 1"
71
 
72
  with gr.Row():
 
73
  with gr.Column():
74
  sample_selection = gr.Radio(
75
  choices=list(sample_images.keys()),
 
97
  label="Selected Sample Image",
98
  )
99
 
100
+ # Results and Netron Visualization
101
  with gr.Row():
102
  result_gallery = gr.Gallery(
103
  label="Results",
 
105
  rows=1,
106
  height=500,
107
  )
108
+ netron_display = gr.HTML(label="Model Architecture")
109
 
110
+ # Callbacks
 
111
  sample_selection.change(
112
  fn=load_sample_image,
113
  inputs=sample_selection,
 
120
  outputs=[result_gallery],
121
  )
122
 
123
+ # Load Netron HTML into the iframe
124
+ netron_display.value = view_model()
125
 
126
+ # Launch Gradio
127
  if __name__ == "__main__":
128
  interface.launch(share=True)