yassonee commited on
Commit
e59f527
·
verified ·
1 Parent(s): 1d4ce47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -109
app.py CHANGED
@@ -2,8 +2,6 @@ import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
  import torch
5
- from typing import List, Dict
6
- import time
7
 
8
  st.set_page_config(
9
  page_title="Fraktur Detektion",
@@ -11,90 +9,104 @@ st.set_page_config(
11
  initial_sidebar_state="collapsed"
12
  )
13
 
14
- # CSS avec animations
15
  st.markdown("""
16
  <style>
 
17
  .stApp {
18
- background-color: #f8fafc !important;
19
- padding: 0 !important;
20
  }
21
 
22
  .block-container {
23
- padding: 0.5rem !important;
24
- max-width: 100% !important;
25
- }
26
-
27
- .upload-section {
28
- display: flex;
29
- flex-direction: column;
30
- align-items: center;
31
- justify-content: center;
32
- min-height: 50vh;
33
- animation: fadeIn 0.5s ease-in;
34
  }
35
 
36
- .results-section {
37
- animation: slideUp 0.5s ease-out;
 
 
 
 
 
 
38
  }
39
 
40
- .detection-box {
41
  background: white;
42
- border-radius: 8px;
43
- padding: 1rem;
44
- box-shadow: 0 1px 3px rgba(0,0,0,0.1);
45
  margin-bottom: 1rem;
46
- transform-origin: top;
47
- animation: scaleIn 0.3s ease-out;
48
  }
49
 
50
- .result-item {
51
- padding: 0.5rem;
52
- border-radius: 4px;
53
- margin: 0.25rem 0;
54
- background: #f1f5f9;
55
- animation: fadeIn 0.3s ease-out;
56
  }
57
 
58
- .image-grid {
59
- display: grid;
60
- grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
61
- gap: 1rem;
62
- margin-top: 1rem;
63
  }
64
 
65
- .image-container {
 
66
  background: white;
67
- border-radius: 8px;
68
  padding: 0.5rem;
 
69
  box-shadow: 0 1px 3px rgba(0,0,0,0.1);
70
- animation: scaleIn 0.3s ease-out;
71
  }
72
 
73
- @keyframes fadeIn {
74
- from { opacity: 0; }
75
- to { opacity: 1; }
 
 
76
  }
77
 
78
- @keyframes slideUp {
79
- from { transform: translateY(20px); opacity: 0; }
80
- to { transform: translateY(0); opacity: 1; }
 
 
 
 
 
 
 
81
  }
82
 
83
- @keyframes scaleIn {
84
- from { transform: scale(0.95); opacity: 0; }
85
- to { transform: scale(1); opacity: 1; }
86
  }
87
 
88
- /* Compact image style */
89
- .stImage > img {
90
- max-height: 300px !important;
91
- width: auto !important;
92
- margin: 0 auto;
93
- object-fit: contain;
94
  }
95
 
96
- #MainMenu, footer, header {
97
- display: none !important;
 
 
 
 
 
 
 
 
 
 
 
 
98
  }
99
  </style>
100
  """, unsafe_allow_html=True)
@@ -117,12 +129,12 @@ def translate_label(label):
117
  }
118
  return translations.get(label.lower(), label)
119
 
120
- def draw_boxes(image: Image, predictions: List[Dict]) -> Image:
121
  draw = ImageDraw.Draw(image)
122
  for pred in predictions:
123
  box = pred['box']
124
  label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
125
- color = "#2563eb" if pred['score'] > 0.7 else "#eab308"
126
 
127
  draw.rectangle(
128
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
@@ -138,83 +150,84 @@ def draw_boxes(image: Image, predictions: List[Dict]) -> Image:
138
  def main():
139
  models = load_models()
140
 
141
- if 'analyzed_images' not in st.session_state:
142
- st.session_state.analyzed_images = []
143
-
144
- # Section upload centrée
145
- st.markdown('<div class="upload-section">', unsafe_allow_html=True)
146
  st.markdown("### 📤 Röntgenbild Upload")
147
  uploaded_files = st.file_uploader("", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
 
148
  conf_threshold = st.slider(
149
  "Konfidenzschwelle",
150
  min_value=0.0, max_value=1.0,
151
- value=0.60, step=0.05
 
152
  )
153
- analyze_button = st.button("Analysieren")
154
- st.markdown('</div>', unsafe_allow_html=True)
155
 
 
 
 
 
156
  if analyze_button and uploaded_files:
157
- st.markdown('<div class="results-section">', unsafe_allow_html=True)
158
-
159
- for uploaded_file in uploaded_files:
160
- image = Image.open(uploaded_file)
161
 
162
- # Animation de chargement
163
- with st.spinner("Analyse läuft..."):
164
- time.sleep(0.5) # Animation effect
165
 
166
- col1, col2, col3 = st.columns([1, 1, 1])
 
167
 
 
168
  with col1:
169
- st.markdown("### 📋 Bild Details")
170
  st.image(image, use_column_width=True)
171
 
 
172
  with col2:
173
- st.markdown("### 🎯 KI-Analyse")
174
 
175
- # KnochenWächter
176
- with st.container():
177
- st.markdown("#### 🛡️ KnochenWächter")
178
- predictions = models["KnochenWächter"](image)
179
- for pred in predictions:
180
- if pred['score'] >= conf_threshold:
181
- st.markdown(f"""
182
- <div class="result-item">
183
- <span style='color: {"#22c55e" if pred["score"] > 0.7 else "#eab308"}; font-weight: 500;'>
184
- {pred['score']:.1%}
185
- </span> - {translate_label(pred['label'])}
186
- </div>
187
- """, unsafe_allow_html=True)
188
 
189
- # RöntgenMeister
190
- with st.container():
191
- st.markdown("#### 🎓 RöntgenMeister")
192
- predictions = models["RöntgenMeister"](image)
193
- for pred in predictions:
194
- if pred['score'] >= conf_threshold:
195
- st.markdown(f"""
196
- <div class="result-item">
197
- <span style='color: {"#22c55e" if pred["score"] > 0.7 else "#eab308"}; font-weight: 500;'>
198
- {pred['score']:.1%}
199
- </span> - {translate_label(pred['label'])}
200
- </div>
201
- """, unsafe_allow_html=True)
202
 
203
- # Afficher la localisation uniquement si une fracture est détectée
204
  with col3:
205
- predictions_location = models["KnochenAuge"](image)
206
- fractures_detected = any(p['score'] >= conf_threshold and 'fracture' in p['label'].lower()
207
- for p in predictions_location)
 
 
208
 
209
- if fractures_detected:
210
  st.markdown("### 🔍 Fraktur Lokalisation")
211
- filtered_preds = [p for p in predictions_location if p['score'] >= conf_threshold]
212
  if filtered_preds:
213
  result_image = image.copy()
214
  result_image = draw_boxes(result_image, filtered_preds)
215
  st.image(result_image, use_column_width=True)
216
-
217
- st.markdown('</div>', unsafe_allow_html=True)
218
 
219
  if __name__ == "__main__":
220
  main()
 
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
  import torch
 
 
5
 
6
  st.set_page_config(
7
  page_title="Fraktur Detektion",
 
9
  initial_sidebar_state="collapsed"
10
  )
11
 
 
12
  st.markdown("""
13
  <style>
14
+ /* Base styles */
15
  .stApp {
16
+ background: #f0f2f5 !important;
 
17
  }
18
 
19
  .block-container {
20
+ padding: 1rem !important;
21
+ max-width: 1400px !important;
22
+ margin: 0 auto !important;
 
 
 
 
 
 
 
 
23
  }
24
 
25
+ /* Custom containers */
26
+ .center-upload {
27
+ background: white;
28
+ padding: 2rem;
29
+ border-radius: 10px;
30
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
31
+ margin-bottom: 2rem;
32
+ text-align: center;
33
  }
34
 
35
+ .analysis-container {
36
  background: white;
37
+ padding: 1.5rem;
38
+ border-radius: 10px;
39
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
40
  margin-bottom: 1rem;
41
+ animation: slideIn 0.5s ease-out;
 
42
  }
43
 
44
+ .result-box {
45
+ background: #f8f9fa;
46
+ padding: 0.75rem;
47
+ border-radius: 8px;
48
+ margin: 0.5rem 0;
49
+ border: 1px solid #e9ecef;
50
  }
51
 
52
+ /* Text styles */
53
+ h1, h2, h3, h4, p {
54
+ color: #1a1a1a !important;
55
+ margin: 0.5rem 0 !important;
 
56
  }
57
 
58
+ /* Image styles */
59
+ .stImage {
60
  background: white;
 
61
  padding: 0.5rem;
62
+ border-radius: 8px;
63
  box-shadow: 0 1px 3px rgba(0,0,0,0.1);
 
64
  }
65
 
66
+ .stImage > img {
67
+ max-height: 250px !important;
68
+ width: auto !important;
69
+ margin: 0 auto !important;
70
+ display: block !important;
71
  }
72
 
73
+ /* Animations */
74
+ @keyframes slideIn {
75
+ from {
76
+ opacity: 0;
77
+ transform: translateY(-10px);
78
+ }
79
+ to {
80
+ opacity: 1;
81
+ transform: translateY(0);
82
+ }
83
  }
84
 
85
+ /* Hide unnecessary elements */
86
+ #MainMenu, footer {
87
+ display: none !important;
88
  }
89
 
90
+ /* Custom columns spacing */
91
+ [data-testid="column"] {
92
+ padding: 0.5rem !important;
93
+ background: transparent !important;
 
 
94
  }
95
 
96
+ /* Button styling */
97
+ .stButton > button {
98
+ width: 200px;
99
+ background-color: #0066cc !important;
100
+ color: white !important;
101
+ border: none !important;
102
+ padding: 0.5rem 1rem !important;
103
+ border-radius: 5px !important;
104
+ transition: all 0.3s ease !important;
105
+ }
106
+
107
+ .stButton > button:hover {
108
+ background-color: #0052a3 !important;
109
+ transform: translateY(-1px);
110
  }
111
  </style>
112
  """, unsafe_allow_html=True)
 
129
  }
130
  return translations.get(label.lower(), label)
131
 
132
+ def draw_boxes(image, predictions):
133
  draw = ImageDraw.Draw(image)
134
  for pred in predictions:
135
  box = pred['box']
136
  label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
137
+ color = "#0066cc" if pred['score'] > 0.7 else "#ffa500"
138
 
139
  draw.rectangle(
140
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
 
150
  def main():
151
  models = load_models()
152
 
153
+ # Initial upload section
154
+ st.markdown('<div class="center-upload">', unsafe_allow_html=True)
 
 
 
155
  st.markdown("### 📤 Röntgenbild Upload")
156
  uploaded_files = st.file_uploader("", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
157
+
158
  conf_threshold = st.slider(
159
  "Konfidenzschwelle",
160
  min_value=0.0, max_value=1.0,
161
+ value=0.60, step=0.05,
162
+ key='confidence'
163
  )
 
 
164
 
165
+ analyze_button = st.button("Analysieren", key='analyze')
166
+ st.markdown('</div>', unsafe_allow_html=True)
167
+
168
+ # Analysis section
169
  if analyze_button and uploaded_files:
170
+ for idx, uploaded_file in enumerate(uploaded_files):
171
+ st.markdown(f'<div class="analysis-container">', unsafe_allow_html=True)
 
 
172
 
173
+ with st.spinner("Analysiere Bild..."):
174
+ image = Image.open(uploaded_file)
 
175
 
176
+ # Create three columns
177
+ col1, col2, col3 = st.columns(3)
178
 
179
+ # Column 1: Original Image
180
  with col1:
181
+ st.markdown("### 🖼️ Original")
182
  st.image(image, use_column_width=True)
183
 
184
+ # Column 2: AI Analysis
185
  with col2:
186
+ st.markdown("### 🤖 KI-Analyse")
187
 
188
+ # KnochenWächter results
189
+ predictions = models["KnochenWächter"](image)
190
+ st.markdown("#### 🛡️ KnochenWächter")
191
+ for pred in predictions:
192
+ if pred['score'] >= conf_threshold:
193
+ st.markdown(f"""
194
+ <div class="result-box">
195
+ <span style='color: {"#0066cc" if pred["score"] > 0.7 else "#ffa500"}; font-weight: 500;'>
196
+ {pred['score']:.1%}
197
+ </span> - {translate_label(pred['label'])}
198
+ </div>
199
+ """, unsafe_allow_html=True)
 
200
 
201
+ # RöntgenMeister results
202
+ predictions = models["RöntgenMeister"](image)
203
+ st.markdown("#### 🎓 RöntgenMeister")
204
+ for pred in predictions:
205
+ if pred['score'] >= conf_threshold:
206
+ st.markdown(f"""
207
+ <div class="result-box">
208
+ <span style='color: {"#0066cc" if pred["score"] > 0.7 else "#ffa500"}; font-weight: 500;'>
209
+ {pred['score']:.1%}
210
+ </span> - {translate_label(pred['label'])}
211
+ </div>
212
+ """, unsafe_allow_html=True)
 
213
 
214
+ # Column 3: Localization (only if fracture detected)
215
  with col3:
216
+ predictions = models["KnochenAuge"](image)
217
+ has_fracture = any(
218
+ p['score'] >= conf_threshold and 'fracture' in p['label'].lower()
219
+ for p in predictions
220
+ )
221
 
222
+ if has_fracture:
223
  st.markdown("### 🔍 Fraktur Lokalisation")
224
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
225
  if filtered_preds:
226
  result_image = image.copy()
227
  result_image = draw_boxes(result_image, filtered_preds)
228
  st.image(result_image, use_column_width=True)
229
+
230
+ st.markdown('</div>', unsafe_allow_html=True)
231
 
232
  if __name__ == "__main__":
233
  main()