yassonee commited on
Commit
edcb3d3
·
verified ·
1 Parent(s): 9f1f60f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -118
app.py CHANGED
@@ -3,31 +3,33 @@ from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
  import numpy as np
5
  import colorsys
6
- from streamlit.web.server.server import Server
7
- import streamlit.components.v1 as components
8
 
9
- # Add WebSocket headers configuration
10
- if Server.get_current():
11
- Server.get_current()._websocket_headers = {
12
- "Access-Control-Allow-Origin": "*",
13
- "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
14
- "Access-Control-Allow-Headers": "Content-Type"
15
- }
16
 
 
17
  st.set_page_config(
18
  page_title="Fraktur Detektion",
19
  layout="wide",
20
  initial_sidebar_state="collapsed"
21
  )
22
 
23
- # Add Edge WebSocket compatibility
24
- components.html("""
25
- <script>
26
- if (window.WebSocket && navigator.userAgent.indexOf("Edge") > -1) {
27
- window.WebSocket = window.WebSocket || window.MozWebSocket;
28
- }
29
- </script>
30
- """, height=0)
 
 
 
 
 
 
31
 
32
  st.markdown("""
33
  <style>
@@ -115,10 +117,28 @@ st.markdown("""
115
  [data-testid="stExpander"], .element-container:has(>.stAlert) {
116
  display: none !important;
117
  }
 
 
 
 
 
118
  </style>
 
 
 
 
 
 
 
 
 
 
 
 
119
  """, unsafe_allow_html=True)
120
 
121
- @st.cache_resource
 
122
  def load_models():
123
  return {
124
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
@@ -145,21 +165,17 @@ def create_heatmap_overlay(image, box, score):
145
  x1, y1 = box['xmin'], box['ymin']
146
  x2, y2 = box['xmax'], box['ymax']
147
 
148
- # Couleur basée sur le score
149
  if score > 0.8:
150
- fill_color = (255, 0, 0, 100) # Rouge
151
  border_color = (255, 0, 0, 255)
152
  elif score > 0.6:
153
- fill_color = (255, 165, 0, 100) # Orange
154
  border_color = (255, 165, 0, 255)
155
  else:
156
- fill_color = (255, 255, 0, 100) # Jaune
157
  border_color = (255, 255, 0, 255)
158
 
159
- # Rectangle semi-transparent
160
  draw.rectangle([x1, y1, x2, y2], fill=fill_color)
161
-
162
- # Bordure
163
  draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
164
 
165
  return overlay
@@ -171,20 +187,16 @@ def draw_boxes(image, predictions):
171
  box = pred['box']
172
  score = pred['score']
173
 
174
- # Création de l'overlay
175
  overlay = create_heatmap_overlay(image, box, score)
176
  result_image = Image.alpha_composite(result_image, overlay)
177
 
178
- # Ajout du texte
179
  draw = ImageDraw.Draw(result_image)
180
  temp = 36.5 + (score * 2.5)
181
  label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
182
 
183
- # Fond noir pour le texte
184
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
185
  draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
186
 
187
- # Texte en blanc
188
  draw.text(
189
  (box['xmin'], box['ymin']-20),
190
  label,
@@ -194,100 +206,99 @@ def draw_boxes(image, predictions):
194
  return result_image
195
 
196
  def main():
197
- models = load_models()
198
-
199
- with st.container():
200
- st.write("### 📤 Röntgenbild hochladen")
201
- uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
202
 
203
- col1, col2 = st.columns([2, 1])
204
- with col1:
205
- conf_threshold = st.slider(
206
- "Konfidenzschwelle",
207
- min_value=0.0, max_value=1.0,
208
- value=0.60, step=0.05,
209
- label_visibility="visible"
210
- )
211
- with col2:
212
- analyze_button = st.button("Analysieren")
213
-
214
- if uploaded_file and analyze_button:
215
- with st.spinner("Bild wird analysiert..."):
216
- image = Image.open(uploaded_file)
217
- results_container = st.container()
218
 
219
- predictions_watcher = models["KnochenWächter"](image)
220
- predictions_master = models["RöntgenMeister"](image)
221
- predictions_locator = models["KnochenAuge"](image)
222
-
223
- has_fracture = False
224
- max_fracture_score = 0
225
- filtered_locations = [p for p in predictions_locator
226
- if p['score'] >= conf_threshold]
227
-
228
- for pred in predictions_watcher:
229
- if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
230
- has_fracture = True
231
- max_fracture_score = max(max_fracture_score, pred['score'])
232
-
233
- with results_container:
234
- st.write("### 🔍 Analyse Ergebnisse")
235
- col1, col2 = st.columns(2)
236
 
237
- with col1:
238
- st.write("#### 🤖 KI-Diagnose")
239
-
240
- st.markdown("#### 🛡️ KnochenWächter")
241
- # Afficher tous les résultats de KnochenWächter
242
- for pred in predictions_watcher:
243
- confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
244
- label_lower = pred['label'].lower()
245
- # Mettre à jour max_fracture_score seulement pour les fractures
246
- if pred['score'] >= conf_threshold and 'fracture' in label_lower:
247
- has_fracture = True
248
- max_fracture_score = max(max_fracture_score, pred['score'])
249
- # Afficher tous les résultats
250
- st.markdown(f"""
251
- <div class="result-box" style="color: #1a1a1a;">
252
- <span style="color: {confidence_color}; font-weight: 500;">
253
- {pred['score']:.1%}
254
- </span> - {translate_label(pred['label'])}
255
- </div>
256
- """, unsafe_allow_html=True)
257
-
258
- st.markdown("#### 🎓 RöntgenMeister")
259
- # Afficher tous les résultats de RöntgenMeister
260
- for pred in predictions_master:
261
- confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
262
- st.markdown(f"""
263
- <div class="result-box" style="color: #1a1a1a;">
264
- <span style="color: {confidence_color}; font-weight: 500;">
265
- {pred['score']:.1%}
266
- </span> - {translate_label(pred['label'])}
267
- </div>
268
- """, unsafe_allow_html=True)
269
-
270
- if max_fracture_score > 0:
271
- st.write("#### 📊 Wahrscheinlichkeit")
272
- no_fracture_prob = 1 - max_fracture_score
273
- st.markdown(f"""
274
- <div class="result-box" style="color: #1a1a1a;">
275
- Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
276
- Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
277
- </div>
278
- """, unsafe_allow_html=True)
279
 
280
- with col2:
281
- predictions = models["KnochenAuge"](image)
282
- filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- if filtered_preds:
285
- st.write("#### 🎯 Fraktur Lokalisation")
286
- result_image = draw_boxes(image, filtered_preds)
287
- st.image(result_image, use_container_width=True)
288
- else:
289
- st.write("#### 🖼️ Röntgenbild")
290
- st.image(image, use_container_width=True)
 
 
 
 
 
 
291
 
292
  if __name__ == "__main__":
293
  main()
 
3
  from PIL import Image, ImageDraw
4
  import numpy as np
5
  import colorsys
6
+ import os
 
7
 
8
+ # Configuration des en-têtes pour WebSocket
9
+ os.environ['STREAMLIT_SERVER_WEBSOCKET_HEADERS'] = '{"Access-Control-Allow-Origin": "*"}'
10
+ os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'true'
 
 
 
 
11
 
12
+ # Configuration de la page
13
  st.set_page_config(
14
  page_title="Fraktur Detektion",
15
  layout="wide",
16
  initial_sidebar_state="collapsed"
17
  )
18
 
19
+ # Script pour gérer les WebSockets sur Edge
20
+ st.markdown("""
21
+ <script>
22
+ if (window.WebSocket && navigator.userAgent.indexOf("Edge") > -1) {
23
+ const originalWebSocket = window.WebSocket;
24
+ window.WebSocket = function(url, protocols) {
25
+ if (url.includes('_stcore/stream')) {
26
+ url = url.replace('wss://', 'ws://');
27
+ }
28
+ return new originalWebSocket(url, protocols);
29
+ };
30
+ }
31
+ </script>
32
+ """, unsafe_allow_html=True)
33
 
34
  st.markdown("""
35
  <style>
 
117
  [data-testid="stExpander"], .element-container:has(>.stAlert) {
118
  display: none !important;
119
  }
120
+
121
+ /* Fix for WebSocket connection issues */
122
+ iframe {
123
+ visibility: hidden;
124
+ }
125
  </style>
126
+
127
+ <script>
128
+ // Fix for WebSocket connection
129
+ window.addEventListener('load', function() {
130
+ setTimeout(function() {
131
+ const frames = document.getElementsByTagName('iframe');
132
+ for (let frame of frames) {
133
+ frame.style.visibility = 'visible';
134
+ }
135
+ }, 1000);
136
+ });
137
+ </script>
138
  """, unsafe_allow_html=True)
139
 
140
+ # Configuration du cache pour les modèles
141
+ @st.cache_resource(show_spinner=False)
142
  def load_models():
143
  return {
144
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
 
165
  x1, y1 = box['xmin'], box['ymin']
166
  x2, y2 = box['xmax'], box['ymax']
167
 
 
168
  if score > 0.8:
169
+ fill_color = (255, 0, 0, 100)
170
  border_color = (255, 0, 0, 255)
171
  elif score > 0.6:
172
+ fill_color = (255, 165, 0, 100)
173
  border_color = (255, 165, 0, 255)
174
  else:
175
+ fill_color = (255, 255, 0, 100)
176
  border_color = (255, 255, 0, 255)
177
 
 
178
  draw.rectangle([x1, y1, x2, y2], fill=fill_color)
 
 
179
  draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
180
 
181
  return overlay
 
187
  box = pred['box']
188
  score = pred['score']
189
 
 
190
  overlay = create_heatmap_overlay(image, box, score)
191
  result_image = Image.alpha_composite(result_image, overlay)
192
 
 
193
  draw = ImageDraw.Draw(result_image)
194
  temp = 36.5 + (score * 2.5)
195
  label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
196
 
 
197
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
198
  draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
199
 
 
200
  draw.text(
201
  (box['xmin'], box['ymin']-20),
202
  label,
 
206
  return result_image
207
 
208
  def main():
209
+ try:
210
+ models = load_models()
 
 
 
211
 
212
+ with st.container():
213
+ st.write("### 📤 Röntgenbild hochladen")
214
+ uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ col1, col2 = st.columns([2, 1])
217
+ with col1:
218
+ conf_threshold = st.slider(
219
+ "Konfidenzschwelle",
220
+ min_value=0.0, max_value=1.0,
221
+ value=0.60, step=0.05,
222
+ label_visibility="visible"
223
+ )
224
+ with col2:
225
+ analyze_button = st.button("Analysieren")
226
+
227
+ if uploaded_file and analyze_button:
228
+ with st.spinner("Bild wird analysiert..."):
229
+ image = Image.open(uploaded_file)
230
+ results_container = st.container()
 
 
231
 
232
+ predictions_watcher = models["KnochenWächter"](image)
233
+ predictions_master = models["RöntgenMeister"](image)
234
+ predictions_locator = models["KnochenAuge"](image)
235
+
236
+ has_fracture = False
237
+ max_fracture_score = 0
238
+ filtered_locations = [p for p in predictions_locator
239
+ if p['score'] >= conf_threshold]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ for pred in predictions_watcher:
242
+ if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
243
+ has_fracture = True
244
+ max_fracture_score = max(max_fracture_score, pred['score'])
245
+
246
+ with results_container:
247
+ st.write("### 🔍 Analyse Ergebnisse")
248
+ col1, col2 = st.columns(2)
249
+
250
+ with col1:
251
+ st.write("#### 🤖 KI-Diagnose")
252
+
253
+ st.markdown("#### 🛡️ KnochenWächter")
254
+ for pred in predictions_watcher:
255
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
256
+ label_lower = pred['label'].lower()
257
+ if pred['score'] >= conf_threshold and 'fracture' in label_lower:
258
+ has_fracture = True
259
+ max_fracture_score = max(max_fracture_score, pred['score'])
260
+ st.markdown(f"""
261
+ <div class="result-box" style="color: #1a1a1a;">
262
+ <span style="color: {confidence_color}; font-weight: 500;">
263
+ {pred['score']:.1%}
264
+ </span> - {translate_label(pred['label'])}
265
+ </div>
266
+ """, unsafe_allow_html=True)
267
+
268
+ st.markdown("#### 🎓 RöntgenMeister")
269
+ for pred in predictions_master:
270
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
271
+ st.markdown(f"""
272
+ <div class="result-box" style="color: #1a1a1a;">
273
+ <span style="color: {confidence_color}; font-weight: 500;">
274
+ {pred['score']:.1%}
275
+ </span> - {translate_label(pred['label'])}
276
+ </div>
277
+ """, unsafe_allow_html=True)
278
+
279
+ if max_fracture_score > 0:
280
+ st.write("#### 📊 Wahrscheinlichkeit")
281
+ no_fracture_prob = 1 - max_fracture_score
282
+ st.markdown(f"""
283
+ <div class="result-box" style="color: #1a1a1a;">
284
+ Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
285
+ Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
286
+ </div>
287
+ """, unsafe_allow_html=True)
288
 
289
+ with col2:
290
+ predictions = models["KnochenAuge"](image)
291
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
292
+
293
+ if filtered_preds:
294
+ st.write("#### 🎯 Fraktur Lokalisation")
295
+ result_image = draw_boxes(image, filtered_preds)
296
+ st.image(result_image, use_container_width=True)
297
+ else:
298
+ st.write("#### 🖼️ Röntgenbild")
299
+ st.image(image, use_container_width=True)
300
+ except Exception as e:
301
+ st.error(f"Ein Fehler ist aufgetreten: {str(e)}")
302
 
303
  if __name__ == "__main__":
304
  main()