3dredstone commited on
Commit
ef18197
·
verified ·
1 Parent(s): babd90f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +452 -45
app.py CHANGED
@@ -1,50 +1,457 @@
1
- import gradio as gr
2
- from transformers import AutoImageProcessor, AutoModelForImageClassification
3
- from PIL import Image
4
- import torch
 
 
 
 
5
 
6
- # Load model and processor from the Hugging Face Hub
7
- model_name = "prithivMLmods/Bone-Fracture-Detection"
8
- model = AutoModelForImageClassification.from_pretrained(model_name)
9
- processor = AutoImageProcessor.from_pretrained(model_name)
 
10
 
11
- def detect_fracture(image):
12
- """
13
- Takes a NumPy image array, processes it, and returns the model's prediction.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- # Convert NumPy array to a PIL Image
16
- image = Image.fromarray(image).convert("RGB")
17
-
18
- # Process the image and prepare it as input for the model
19
- inputs = processor(images=image, return_tensors="pt")
20
-
21
- # Perform inference without calculating gradients
22
- with torch.no_grad():
23
- outputs = model(**inputs)
24
- logits = outputs.logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Apply softmax to get probabilities and convert to a list
27
- probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
28
-
29
- # Create a dictionary of labels and their corresponding probabilities
30
- # This now correctly uses the labels from the model's configuration
31
- prediction = {model.config.id2label[i]: round(probs[i], 3) for i in range(len(probs))}
32
-
33
- return prediction
34
-
35
- # Create the Gradio Interface
36
- iface = gr.Interface(
37
- fn=detect_fracture,
38
- inputs=gr.Image(type="numpy", label="Upload Bone X-ray"),
39
- outputs=gr.Label(num_top_classes=2, label="Detection Result"),
40
- title="🔬 Bone Fracture Detection",
41
- description="Upload a bone X-ray image to detect if there is a fracture. The model will return the probability for 'Fractured' and 'Not Fractured'.",
42
- examples=[
43
- ["fractured_example.png"],
44
- ["not_fractured_example.png"]
45
- ] # Note: You would need to have these image files in the same directory for the examples to work.
46
- )
47
-
48
- # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  if __name__ == "__main__":
50
- iface.launch()
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.responses import HTMLResponse, StreamingResponse
3
+ from transformers import pipeline
4
+ from PIL import Image, ImageDraw
5
+ import numpy as np
6
+ import io
7
+ import uvicorn
8
+ import base64
9
 
10
+ from reportlab.lib.pagesizes import letter
11
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as ReportLabImage
12
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
13
+ from reportlab.lib.enums import TA_CENTER
14
+ from reportlab.lib.units import inch
15
 
16
+ app = FastAPI()
17
+
18
+ # Chargement des modèles
19
+ def load_models():
20
+ return {
21
+ "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
22
+ "KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
23
+ "RöntgenMeister": pipeline("image-classification",
24
+ model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
25
+ }
26
+
27
+ models = load_models()
28
+
29
+ def translate_label(label):
30
+ translations = {
31
+ "fracture": "Knochenbruch",
32
+ "no fracture": "Kein Knochenbruch",
33
+ "normal": "Normal",
34
+ "abnormal": "Auffällig",
35
+ "F1": "Knochenbruch",
36
+ "NF": "Kein Knochenbruch"
37
+ }
38
+ return translations.get(label.lower(), label)
39
+
40
+ def create_heatmap_overlay(image, box, score):
41
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
42
+ draw = ImageDraw.Draw(overlay)
43
+
44
+ x1, y1 = box['xmin'], box['ymin']
45
+ x2, y2 = box['xmax'], box['ymax']
46
+
47
+ if score > 0.8:
48
+ fill_color = (255, 0, 0, 100)
49
+ border_color = (255, 0, 0, 255)
50
+ elif score > 0.6:
51
+ fill_color = (255, 165, 0, 100)
52
+ border_color = (255, 165, 0, 255)
53
+ else:
54
+ fill_color = (255, 255, 0, 100)
55
+ border_color = (255, 255, 0, 255)
56
+
57
+ draw.rectangle([x1, y1, x2, y2], fill=fill_color)
58
+ draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
59
+
60
+ return overlay
61
+
62
+ def draw_boxes(image, predictions):
63
+ result_image = image.copy().convert('RGBA')
64
+
65
+ for pred in predictions:
66
+ box = pred['box']
67
+ score = pred['score']
68
+
69
+ overlay = create_heatmap_overlay(image, box, score)
70
+ result_image = Image.alpha_composite(result_image, overlay)
71
+
72
+ draw = ImageDraw.Draw(result_image)
73
+ temp = 36.5 + (score * 2.5)
74
+ label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
75
+
76
+ # Calculate text bounding box more accurately
77
+ # Temporarily create a dummy draw object to get text size if draw.textbbox is not accurate enough or available for current Pillow version
78
+ try:
79
+ text_bbox = draw.textbbox((box['xmin'], box['ymin'] - 20), label)
80
+ except AttributeError: # Fallback for older Pillow versions
81
+ # Estimate text size if textbbox is not available
82
+ font_size = 10 # This might need to be adjusted based on actual font used
83
+ text_width = len(label) * font_size * 0.6 # rough estimation
84
+ text_height = font_size * 1.2 # rough estimation
85
+ text_bbox = (box['xmin'], box['ymin'] - text_height, box['xmin'] + text_width, box['ymin'])
86
+
87
+
88
+ draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
89
+
90
+ draw.text(
91
+ (box['xmin'], box['ymin']-20),
92
+ label,
93
+ fill=(255, 255, 255, 255)
94
+ )
95
+
96
+ return result_image
97
+
98
+ def image_to_base64(image):
99
+ buffered = io.BytesIO()
100
+ image.save(buffered, format="PNG")
101
+ img_str = base64.b64encode(buffered.getvalue()).decode()
102
+ return f"data:image/png;base64,{img_str}"
103
+
104
+ COMMON_STYLES = """
105
+ body {
106
+ font-family: system-ui, -apple-system, sans-serif;
107
+ background: #f0f2f5;
108
+ margin: 0;
109
+ padding: 20px;
110
+ color: #1a1a1a;
111
+ }
112
+ ::-webkit-scrollbar {
113
+ width: 8px;
114
+ height: 8px;
115
+ }
116
+
117
+ ::-webkit-scrollbar-track {
118
+ background: transparent;
119
+ }
120
+
121
+ ::-webkit-scrollbar-thumb {
122
+ background-color: rgba(156, 163, 175, 0.5);
123
+ border-radius: 4px;
124
+ }
125
+
126
+ .container {
127
+ max-width: 1200px;
128
+ margin: 0 auto;
129
+ background: white;
130
+ padding: 20px;
131
+ border-radius: 10px;
132
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
133
+ }
134
+ .button {
135
+ background: #2d2d2d;
136
+ color: white;
137
+ border: none;
138
+ padding: 12px 30px;
139
+ border-radius: 8px;
140
+ cursor: pointer;
141
+ font-size: 1.1em;
142
+ transition: all 0.3s ease;
143
+ position: relative;
144
+ }
145
+ .button:hover {
146
+ background: #404040;
147
+ }
148
+ @keyframes progress {
149
+ 0% { width: 0; }
150
+ 100% { width: 100%; }
151
+ }
152
+ .button-progress {
153
+ position: absolute;
154
+ bottom: 0;
155
+ left: 0;
156
+ height: 4px;
157
+ background: rgba(255, 255, 255, 0.5);
158
+ width: 0;
159
+ }
160
+ .button:active .button-progress {
161
+ animation: progress 2s linear forwards;
162
+ }
163
+ img {
164
+ max-width: 100%;
165
+ height: auto;
166
+ border-radius: 8px;
167
+ }
168
+ @keyframes blink {
169
+ 0% { opacity: 1; }
170
+ 50% { opacity: 0; }
171
+ 100% { opacity: 1; }
172
+ }
173
+ #loading {
174
+ display: none;
175
+ color: white;
176
+ margin-top: 10px;
177
+ animation: blink 1s infinite;
178
+ text-align: center;
179
+ }
180
+ """
181
+
182
+ @app.get("/", response_class=HTMLResponse)
183
+ async def main():
184
+ content = f"""
185
+ <!DOCTYPE html>
186
+ <html>
187
+ <head>
188
+ <title>Fraktur Detektion</title>
189
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
190
+ <style>
191
+ {COMMON_STYLES}
192
+
193
+ .upload-section {{
194
+ background: #2d2d2d;
195
+ padding: 40px;
196
+ border-radius: 12px;
197
+ margin: 20px 0;
198
+ text-align: center;
199
+ border: 2px dashed #404040;
200
+ transition: all 0.3s ease;
201
+ color: white;
202
+ }}
203
+ .upload-section:hover {{
204
+ border-color: #555;
205
+ }}
206
+ input[type="file"] {{
207
+ font-size: 1.1em;
208
+ margin: 20px 0;
209
+ color: white;
210
+ }}
211
+ input[type="file"]::file-selector-button {{
212
+ font-size: 1em;
213
+ padding: 10px 20px;
214
+ border-radius: 8px;
215
+ border: 1px solid #404040;
216
+ background: #2d2d2d;
217
+ color: white;
218
+ transition: all 0.3s ease;
219
+ cursor: pointer;
220
+ }}
221
+ input[type="file"]::file-selector-button:hover {{
222
+ background: #404040;
223
+ }}
224
+ .confidence-slider {{
225
+ width: 100%;
226
+ max-width: 300px;
227
+ margin: 20px auto;
228
+ }}
229
+ input[type="range"] {{
230
+ width: 100%;
231
+ height: 8px;
232
+ border-radius: 4px;
233
+ background: #404040;
234
+ outline: none;
235
+ transition: all 0.3s ease;
236
+ -webkit-appearance: none;
237
+ }}
238
+ input[type="range"]::-webkit-slider-thumb {{
239
+ -webkit-appearance: none;
240
+ width: 20px;
241
+ height: 20px;
242
+ border-radius: 50%;
243
+ background: white;
244
+ cursor: pointer;
245
+ border: none;
246
+ }}
247
+ .input-field {{
248
+ margin-bottom: 20px;
249
+ }}
250
+ .input-field label {{
251
+ display: block;
252
+ margin-bottom: 5px;
253
+ font-size: 1.1em;
254
+ }}
255
+ .input-field input[type="text"] {{
256
+ width: calc(100% - 20px);
257
+ padding: 10px;
258
+ border-radius: 5px;
259
+ border: 1px solid #ccc;
260
+ background: #fff;
261
+ color: #1a1a1a;
262
+ font-size: 1em;
263
+ }}
264
+ </style>
265
+ </head>
266
+ <body>
267
+ <div class="container">
268
+ <div class="upload-section">
269
+ <form action="/analyze" method="post" enctype="multipart/form-data" onsubmit="document.getElementById('loading').style.display = 'block';">
270
+ <div class="input-field">
271
+ <label for="patient_name">Patientenname:</label>
272
+ <input type="text" id="patient_name" name="patient_name" required>
273
+ </div>
274
+ <div>
275
+ <input type="file" name="file" accept="image/*" required>
276
+ </div>
277
+ <div class="confidence-slider">
278
+ <label for="threshold">Konfidenzschwelle: <span id="thresholdValue">0.60</span></label>
279
+ <input type="range" id="threshold" name="threshold"
280
+ min="0" max="1" step="0.05" value="0.60"
281
+ oninput="document.getElementById('thresholdValue').textContent = parseFloat(this.value).toFixed(2)">
282
+ </div>
283
+ <button type="submit" class="button">
284
+ Analysieren & PDF Erstellen
285
+ <div class="button-progress"></div>
286
+ </button>
287
+ <div id="loading">Loading...</div>
288
+ </form>
289
+ </div>
290
+ </div>
291
+ </body>
292
+ </html>
293
  """
294
+ return content
295
+
296
+ @app.post("/analyze", response_class=StreamingResponse)
297
+ async def analyze_file(patient_name: str = Form(...), file: UploadFile = File(...), threshold: float = Form(0.6)):
298
+ try:
299
+ contents = await file.read()
300
+ image = Image.open(io.BytesIO(contents)).convert("RGB") # Ensure RGB for PDF
301
+
302
+ predictions_watcher = models["KnochenWächter"](image)
303
+ predictions_master = models["RöntgenMeister"](image)
304
+ predictions_locator = models["KnochenAuge"](image)
305
+
306
+ filtered_preds = [p for p in predictions_locator if p['score'] >= threshold]
307
+ if filtered_preds:
308
+ result_image = draw_boxes(image, filtered_preds)
309
+ else:
310
+ result_image = image
311
+
312
+ # Generate PDF
313
+ buffer = io.BytesIO()
314
+ doc = SimpleDocTemplate(buffer, pagesize=letter)
315
+ styles = getSampleStyleSheet()
316
+ centered_style = ParagraphStyle(
317
+ name='Centered',
318
+ parent=styles['Normal'],
319
+ alignment=TA_CENTER,
320
+ fontSize=12,
321
+ leading=14
322
+ )
323
+ heading_style = ParagraphStyle(
324
+ name='Heading',
325
+ parent=styles['h1'],
326
+ alignment=TA_CENTER,
327
+ fontSize=24,
328
+ spaceAfter=20
329
+ )
330
+ subheading_style = ParagraphStyle(
331
+ name='SubHeading',
332
+ parent=styles['h2'],
333
+ alignment=TA_CENTER,
334
+ fontSize=16,
335
+ spaceAfter=10
336
+ )
337
+ report_text_style = ParagraphStyle(
338
+ name='ReportText',
339
+ parent=styles['Normal'],
340
+ alignment=TA_CENTER,
341
+ fontSize=12,
342
+ spaceAfter=5
343
+ )
344
+
345
+ story = []
346
+
347
+ story.append(Paragraph("<b>Fraktur Detektionsbericht</b>", heading_style))
348
+ story.append(Spacer(1, 0.2 * inch))
349
+ story.append(Paragraph(f"<b>Patientenname:</b> {patient_name}", subheading_style))
350
+ story.append(Spacer(1, 0.4 * inch))
351
+
352
+ # KnochenWächter results
353
+ story.append(Paragraph("<b>KnochenWächter Ergebnisse:</b>", subheading_style))
354
+ for pred in predictions_watcher:
355
+ story.append(Paragraph(
356
+ f"{translate_label(pred['label'])}: {pred['score']:.1%}",
357
+ report_text_style
358
+ ))
359
+ story.append(Spacer(1, 0.2 * inch))
360
+
361
+ # RöntgenMeister results
362
+ story.append(Paragraph("<b>RöntgenMeister Ergebnisse:</b>", subheading_style))
363
+ for pred in predictions_master:
364
+ story.append(Paragraph(
365
+ f"{translate_label(pred['label'])}: {pred['score']:.1%}",
366
+ report_text_style
367
+ ))
368
+ story.append(Spacer(1, 0.4 * inch))
369
+
370
+ # Analyzed Image
371
+ story.append(Paragraph("<b>Röntgenbild Analyse:</b>", subheading_style))
372
+
373
+ # Save the result image temporarily to a buffer to be added to PDF
374
+ img_buffer = io.BytesIO()
375
+ result_image.save(img_buffer, format="PNG")
376
+ img_buffer.seek(0)
377
+ img_rl = ReportLabImage(img_buffer)
378
+
379
+ # Scale image to fit within page width while maintaining aspect ratio
380
+ img_width, img_height = img_rl.drawWidth, img_rl.drawHeight
381
+ aspect_ratio = img_height / img_width
382
+ max_width = 5 * inch # Adjust as needed for page layout
383
+ if img_width > max_width:
384
+ img_rl.drawWidth = max_width
385
+ img_rl.drawHeight = max_width * aspect_ratio
386
 
387
+ # Center the image
388
+ img_rl.hAlign = 'CENTER'
389
+
390
+ story.append(img_rl)
391
+ story.append(Spacer(1, 0.4 * inch))
392
+
393
+
394
+ # Final report text based on object detection
395
+ if filtered_preds:
396
+ story.append(Paragraph(
397
+ "<b>Die Analyse des Röntgenbildes zeigt eine mögliche Frakturlokalisation.</b>",
398
+ report_text_style
399
+ ))
400
+ for pred in filtered_preds:
401
+ score = pred['score']
402
+ temp = 36.5 + (score * 2.5)
403
+ story.append(Paragraph(
404
+ f"Detektion: {translate_label(pred['label'])} mit {score:.1%} Konfidenz ({temp:.1f}°C)",
405
+ report_text_style
406
+ ))
407
+ else:
408
+ story.append(Paragraph(
409
+ "<b>Basierend auf der Objektlokalisierungsanalyse wurde keine Fraktur mit ausreichender Konfidenz detektiert.</b>",
410
+ report_text_style
411
+ ))
412
+ story.append(Spacer(1, 0.2 * inch))
413
+ story.append(Paragraph("Dies ist ein automatisch generierter Bericht und sollte von einem Arzt überprüft werden.", centered_style))
414
+
415
+
416
+ doc.build(story)
417
+ buffer.seek(0)
418
+
419
+ return StreamingResponse(buffer, media_type="application/pdf",
420
+ headers={"Content-Disposition": f"attachment; filename=Fraktur_Bericht_{patient_name.replace(' ', '_')}.pdf"})
421
+
422
+ except Exception as e:
423
+ return HTMLResponse(f"""
424
+ <!DOCTYPE html>
425
+ <html>
426
+ <head>
427
+ <title>Fehler</title>
428
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
429
+ <style>
430
+ {COMMON_STYLES}
431
+ .error-box {{
432
+ background: #fee2e2;
433
+ border: 1px solid #ef4444;
434
+ padding: 20px;
435
+ border-radius: 8px;
436
+ margin: 20px 0;
437
+ }}
438
+ </style>
439
+ </head>
440
+ <body>
441
+ <div class="container">
442
+ <div class="error-box">
443
+ <h3>Fehler</h3>
444
+ <p>{str(e)}</p>
445
+ </div>
446
+ <a href="/" class="button back-button">
447
+ ← Zurück
448
+ <div class="button-progress"></div>
449
+ </a>
450
+ </div>
451
+ </body>
452
+ </html>
453
+ """)
454
+
455
  if __name__ == "__main__":
456
+ uvicorn.run(app, host="0.0.0.0", port=7860)
457
+