BhumikaMak commited on
Commit
4d51c47
·
verified ·
1 Parent(s): 9f0de77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -34
app.py CHANGED
@@ -1,13 +1,17 @@
 
1
  import threading
2
  import gradio as gr
 
3
  from PIL import Image
4
- import numpy as np
5
  import cv2
 
 
 
6
 
7
  # Sample images directory
8
  sample_images = {
9
- "Sample 1": "data/xai/sample1.jpeg",
10
- "Sample 2": "data/xai/sample2.jpg",
11
  }
12
 
13
  def load_sample_image(sample_name):
@@ -17,38 +21,51 @@ def load_sample_image(sample_name):
17
  return Image.open(image_path)
18
  return None
19
 
20
- def process_image(sample_choice, uploaded_image, yolo_versions):
21
  """Process the image using selected YOLO models."""
 
22
  if uploaded_image is not None:
23
  image = uploaded_image
24
  else:
25
  image = load_sample_image(sample_choice)
26
 
27
- # Resize image for YOLO model input
28
  image = np.array(image)
29
  image = cv2.resize(image, (640, 640))
30
  result_images = []
31
-
32
  # Apply selected models
33
- for model in yolo_versions:
34
- if model == "yolov5":
35
- result_images.append("YOLOv5 result") # Placeholder for YOLOv5
36
- elif model == "yolov8":
37
- result_images.append("YOLOv8 result") # Placeholder for YOLOv8
 
 
38
  return result_images
39
 
40
- def run_processing(sample_choice, uploaded_image, selected_models):
41
- results = process_image(sample_choice, uploaded_image, selected_models)
42
- return results
43
-
44
- # CSS for a clean, user-centered interface
 
 
 
 
 
 
 
 
 
 
 
45
  custom_css = """
46
  .custom-row {
47
  display: flex;
48
  justify-content: center;
49
  padding: 20px;
50
  }
51
-
52
  .custom-button {
53
  background-color: #6a1b9a;
54
  color: white;
@@ -59,58 +76,124 @@ custom_css = """
59
  cursor: pointer;
60
  margin-top: 10px;
61
  }
62
-
63
  .custom-row img {
64
  border-radius: 10px;
65
  box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
66
  }
67
-
68
  #highlighted-text {
69
  font-weight: bold;
70
  color: #1976d2;
71
  }
72
  """
73
 
74
- # Gradio interface with enhanced UX design
 
75
  with gr.Blocks(css=custom_css) as interface:
76
  gr.Markdown("""
77
- ## Welcome to NeuralVista
78
- <p id="highlighted-text">NeuralVista</p> is a powerful tool designed to help you visualize object detection models in action.
79
  """)
 
 
 
80
 
81
  with gr.Row():
 
82
  with gr.Column():
83
  sample_selection = gr.Radio(
84
  choices=list(sample_images.keys()),
85
- label="Select a Sample Image"
 
86
  )
87
 
88
  upload_image = gr.Image(
89
  label="Upload an Image",
90
- type="pil"
91
  )
92
 
93
  selected_models = gr.CheckboxGroup(
94
- choices=["yolov5", "yolov8"],
95
- label="Select Model(s)"
 
96
  )
97
 
98
- run_button = gr.Button("Run", elem_id="run_button", elem_classes="custom-button")
99
 
100
  with gr.Column():
101
- sample_display = gr.Image(label="Sample Image")
 
 
 
102
 
103
  # Results and visualization
104
  with gr.Row(elem_classes="custom-row"):
105
- result_display = gr.Gallery(
106
  label="Results",
107
- columns=2
108
- )
 
 
 
 
 
 
 
 
 
 
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  run_button.click(
111
- fn=run_processing,
112
  inputs=[sample_selection, upload_image, selected_models],
113
- outputs=[result_display]
114
  )
115
 
116
- interface.launch(share=True)
 
 
 
1
+ import netron
2
  import threading
3
  import gradio as gr
4
+ import os
5
  from PIL import Image
 
6
  import cv2
7
+ import numpy as np
8
+ from yolov5 import xai_yolov5
9
+ from yolov8 import xai_yolov8s
10
 
11
  # Sample images directory
12
  sample_images = {
13
+ "Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"),
14
+ "Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"),
15
  }
16
 
17
  def load_sample_image(sample_name):
 
21
  return Image.open(image_path)
22
  return None
23
 
24
+ def process_image(sample_choice, uploaded_image, yolo_versions, target_lyr = -5, n_components = 8):
25
  """Process the image using selected YOLO models."""
26
+ # Load sample or uploaded image
27
  if uploaded_image is not None:
28
  image = uploaded_image
29
  else:
30
  image = load_sample_image(sample_choice)
31
 
32
+ # Preprocess image
33
  image = np.array(image)
34
  image = cv2.resize(image, (640, 640))
35
  result_images = []
36
+
37
  # Apply selected models
38
+ for yolo_version in yolo_versions:
39
+ if yolo_version == "yolov5":
40
+ result_images.append(xai_yolov5(image, target_lyr = -5, n_components = 8))
41
+ elif yolo_version == "yolov8s":
42
+ result_images.append(xai_yolov8s(image))
43
+ else:
44
+ result_images.append((Image.fromarray(image), f"{yolo_version} not implemented."))
45
  return result_images
46
 
47
+ def view_model(selected_models):
48
+ """Generate Netron visualization for the selected models."""
49
+ netron_html = ""
50
+ for model in selected_models:
51
+ if model == "yolov5":
52
+ netron_html = f"""
53
+ <iframe
54
+ src="https://netron.app/?url=https://huggingface.co/FFusion/FFusionXL-BASE/blob/main/vae_encoder/model.onnx"
55
+ width="100%"
56
+ height="800"
57
+ frameborder="0">
58
+ </iframe>
59
+ """
60
+ return netron_html if netron_html else "<p>No valid models selected for visualization.</p>"
61
+
62
+ # CSS to style the Gradio components and HTML content
63
  custom_css = """
64
  .custom-row {
65
  display: flex;
66
  justify-content: center;
67
  padding: 20px;
68
  }
 
69
  .custom-button {
70
  background-color: #6a1b9a;
71
  color: white;
 
76
  cursor: pointer;
77
  margin-top: 10px;
78
  }
 
79
  .custom-row img {
80
  border-radius: 10px;
81
  box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
82
  }
 
83
  #highlighted-text {
84
  font-weight: bold;
85
  color: #1976d2;
86
  }
87
  """
88
 
89
+ # Then in the Gradio interface:
90
+
91
  with gr.Blocks(css=custom_css) as interface:
92
  gr.Markdown("""
93
+ ## Welcome to NeuralVista!
94
+ <p id="highlighted-text">NeuralVista</p> is a powerful tool designed to help you visualize models in action.
95
  """)
96
+
97
+ # Default sample
98
+ default_sample = "Sample 1"
99
 
100
  with gr.Row():
101
+ # Left side: Sample selection and image upload
102
  with gr.Column():
103
  sample_selection = gr.Radio(
104
  choices=list(sample_images.keys()),
105
+ label="Select a Sample Image",
106
+ value=default_sample,
107
  )
108
 
109
  upload_image = gr.Image(
110
  label="Upload an Image",
111
+ type="pil",
112
  )
113
 
114
  selected_models = gr.CheckboxGroup(
115
+ choices=["yolov5", "yolov8s"],
116
+ value=["yolov5"],
117
+ label="Select Model(s)",
118
  )
119
 
120
+ run_button = gr.Button("Run", elem_classes="custom-button")
121
 
122
  with gr.Column():
123
+ sample_display = gr.Image(
124
+ value=load_sample_image(default_sample),
125
+ label="Selected Sample Image",
126
+ )
127
 
128
  # Results and visualization
129
  with gr.Row(elem_classes="custom-row"):
130
+ result_gallery = gr.Gallery(
131
  label="Results",
132
+ rows=1,
133
+ height="auto", # Adjust height automatically based on content
134
+ columns=1 ,
135
+ object_fit="contain"
136
+ )
137
+ netron_display = gr.HTML(label="Netron Visualization")
138
+
139
+ # Update sample image
140
+ sample_selection.change(
141
+ fn=load_sample_image,
142
+ inputs=sample_selection,
143
+ outputs=sample_display,
144
+ )
145
 
146
+ with gr.Row(elem_classes="custom-row"):
147
+ dff_gallery = gr.Gallery(
148
+ label="Deep Feature Factorization",
149
+ rows=2, # 8 rows
150
+ columns=4, # 1 image per row
151
+ object_fit="fit",
152
+ height="auto" # Adjust as needed
153
+ )
154
+
155
+ # Multi-threaded processing
156
+ def run_both(sample_choice, uploaded_image, selected_models):
157
+ results = []
158
+ netron_html = ""
159
+
160
+ # Thread to process the image
161
+ def process_thread():
162
+ nonlocal results
163
+ target_lyr = -5
164
+ n_components = 8
165
+ results = process_image(sample_choice, uploaded_image, selected_models, target_lyr = -5, n_components = 8)
166
+
167
+ # Thread to generate Netron visualization
168
+ def netron_thread():
169
+ nonlocal netron_html
170
+ netron_html = view_model(selected_models)
171
+
172
+ # Launch threads
173
+ t1 = threading.Thread(target=process_thread)
174
+ t2 = threading.Thread(target=netron_thread)
175
+ t1.start()
176
+ t2.start()
177
+ t1.join()
178
+ t2.join()
179
+ image1, text, image2 = results[0]
180
+ if isinstance(image2, list):
181
+ # Check if image2 contains exactly 8 images
182
+ if len(image2) == 8:
183
+ print("image2 contains 8 images.")
184
+ else:
185
+ print("Warning: image2 does not contain exactly 8 images.")
186
+ else:
187
+ print("Error: image2 is not a list of images.")
188
+ return [(image1, text)], netron_html, image2
189
+
190
+ # Run button click
191
  run_button.click(
192
+ fn=run_both,
193
  inputs=[sample_selection, upload_image, selected_models],
194
+ outputs=[result_gallery, netron_display, dff_gallery],
195
  )
196
 
197
+ # Launch Gradio interface
198
+ if __name__ == "__main__":
199
+ interface.launch(share=True)