bfd-rg / app.py
3dredstone's picture
Update app.py
ef18197 verified
raw
history blame
15.3 kB
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import HTMLResponse, StreamingResponse
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
import io
import uvicorn
import base64
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as ReportLabImage
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.enums import TA_CENTER
from reportlab.lib.units import inch
app = FastAPI()
# Chargement des modèles
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
models = load_models()
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Knochenbruch",
"normal": "Normal",
"abnormal": "Auffällig",
"F1": "Knochenbruch",
"NF": "Kein Knochenbruch"
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
if score > 0.8:
fill_color = (255, 0, 0, 100)
border_color = (255, 0, 0, 255)
elif score > 0.6:
fill_color = (255, 165, 0, 100)
border_color = (255, 165, 0, 255)
else:
fill_color = (255, 255, 0, 100)
border_color = (255, 255, 0, 255)
draw.rectangle([x1, y1, x2, y2], fill=fill_color)
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
for pred in predictions:
box = pred['box']
score = pred['score']
overlay = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, overlay)
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
label = f"{translate_label(pred['label'])} ({score:.1%}{temp:.1f}°C)"
# Calculate text bounding box more accurately
# Temporarily create a dummy draw object to get text size if draw.textbbox is not accurate enough or available for current Pillow version
try:
text_bbox = draw.textbbox((box['xmin'], box['ymin'] - 20), label)
except AttributeError: # Fallback for older Pillow versions
# Estimate text size if textbbox is not available
font_size = 10 # This might need to be adjusted based on actual font used
text_width = len(label) * font_size * 0.6 # rough estimation
text_height = font_size * 1.2 # rough estimation
text_bbox = (box['xmin'], box['ymin'] - text_height, box['xmin'] + text_width, box['ymin'])
draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
draw.text(
(box['xmin'], box['ymin']-20),
label,
fill=(255, 255, 255, 255)
)
return result_image
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
COMMON_STYLES = """
body {
font-family: system-ui, -apple-system, sans-serif;
background: #f0f2f5;
margin: 0;
padding: 20px;
color: #1a1a1a;
}
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: transparent;
}
::-webkit-scrollbar-thumb {
background-color: rgba(156, 163, 175, 0.5);
border-radius: 4px;
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.button {
background: #2d2d2d;
color: white;
border: none;
padding: 12px 30px;
border-radius: 8px;
cursor: pointer;
font-size: 1.1em;
transition: all 0.3s ease;
position: relative;
}
.button:hover {
background: #404040;
}
@keyframes progress {
0% { width: 0; }
100% { width: 100%; }
}
.button-progress {
position: absolute;
bottom: 0;
left: 0;
height: 4px;
background: rgba(255, 255, 255, 0.5);
width: 0;
}
.button:active .button-progress {
animation: progress 2s linear forwards;
}
img {
max-width: 100%;
height: auto;
border-radius: 8px;
}
@keyframes blink {
0% { opacity: 1; }
50% { opacity: 0; }
100% { opacity: 1; }
}
#loading {
display: none;
color: white;
margin-top: 10px;
animation: blink 1s infinite;
text-align: center;
}
"""
@app.get("/", response_class=HTMLResponse)
async def main():
content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Fraktur Detektion</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.upload-section {{
background: #2d2d2d;
padding: 40px;
border-radius: 12px;
margin: 20px 0;
text-align: center;
border: 2px dashed #404040;
transition: all 0.3s ease;
color: white;
}}
.upload-section:hover {{
border-color: #555;
}}
input[type="file"] {{
font-size: 1.1em;
margin: 20px 0;
color: white;
}}
input[type="file"]::file-selector-button {{
font-size: 1em;
padding: 10px 20px;
border-radius: 8px;
border: 1px solid #404040;
background: #2d2d2d;
color: white;
transition: all 0.3s ease;
cursor: pointer;
}}
input[type="file"]::file-selector-button:hover {{
background: #404040;
}}
.confidence-slider {{
width: 100%;
max-width: 300px;
margin: 20px auto;
}}
input[type="range"] {{
width: 100%;
height: 8px;
border-radius: 4px;
background: #404040;
outline: none;
transition: all 0.3s ease;
-webkit-appearance: none;
}}
input[type="range"]::-webkit-slider-thumb {{
-webkit-appearance: none;
width: 20px;
height: 20px;
border-radius: 50%;
background: white;
cursor: pointer;
border: none;
}}
.input-field {{
margin-bottom: 20px;
}}
.input-field label {{
display: block;
margin-bottom: 5px;
font-size: 1.1em;
}}
.input-field input[type="text"] {{
width: calc(100% - 20px);
padding: 10px;
border-radius: 5px;
border: 1px solid #ccc;
background: #fff;
color: #1a1a1a;
font-size: 1em;
}}
</style>
</head>
<body>
<div class="container">
<div class="upload-section">
<form action="/analyze" method="post" enctype="multipart/form-data" onsubmit="document.getElementById('loading').style.display = 'block';">
<div class="input-field">
<label for="patient_name">Patientenname:</label>
<input type="text" id="patient_name" name="patient_name" required>
</div>
<div>
<input type="file" name="file" accept="image/*" required>
</div>
<div class="confidence-slider">
<label for="threshold">Konfidenzschwelle: <span id="thresholdValue">0.60</span></label>
<input type="range" id="threshold" name="threshold"
min="0" max="1" step="0.05" value="0.60"
oninput="document.getElementById('thresholdValue').textContent = parseFloat(this.value).toFixed(2)">
</div>
<button type="submit" class="button">
Analysieren & PDF Erstellen
<div class="button-progress"></div>
</button>
<div id="loading">Loading...</div>
</form>
</div>
</div>
</body>
</html>
"""
return content
@app.post("/analyze", response_class=StreamingResponse)
async def analyze_file(patient_name: str = Form(...), file: UploadFile = File(...), threshold: float = Form(0.6)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB") # Ensure RGB for PDF
predictions_watcher = models["KnochenWächter"](image)
predictions_master = models["RöntgenMeister"](image)
predictions_locator = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions_locator if p['score'] >= threshold]
if filtered_preds:
result_image = draw_boxes(image, filtered_preds)
else:
result_image = image
# Generate PDF
buffer = io.BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=letter)
styles = getSampleStyleSheet()
centered_style = ParagraphStyle(
name='Centered',
parent=styles['Normal'],
alignment=TA_CENTER,
fontSize=12,
leading=14
)
heading_style = ParagraphStyle(
name='Heading',
parent=styles['h1'],
alignment=TA_CENTER,
fontSize=24,
spaceAfter=20
)
subheading_style = ParagraphStyle(
name='SubHeading',
parent=styles['h2'],
alignment=TA_CENTER,
fontSize=16,
spaceAfter=10
)
report_text_style = ParagraphStyle(
name='ReportText',
parent=styles['Normal'],
alignment=TA_CENTER,
fontSize=12,
spaceAfter=5
)
story = []
story.append(Paragraph("<b>Fraktur Detektionsbericht</b>", heading_style))
story.append(Spacer(1, 0.2 * inch))
story.append(Paragraph(f"<b>Patientenname:</b> {patient_name}", subheading_style))
story.append(Spacer(1, 0.4 * inch))
# KnochenWächter results
story.append(Paragraph("<b>KnochenWächter Ergebnisse:</b>", subheading_style))
for pred in predictions_watcher:
story.append(Paragraph(
f"{translate_label(pred['label'])}: {pred['score']:.1%}",
report_text_style
))
story.append(Spacer(1, 0.2 * inch))
# RöntgenMeister results
story.append(Paragraph("<b>RöntgenMeister Ergebnisse:</b>", subheading_style))
for pred in predictions_master:
story.append(Paragraph(
f"{translate_label(pred['label'])}: {pred['score']:.1%}",
report_text_style
))
story.append(Spacer(1, 0.4 * inch))
# Analyzed Image
story.append(Paragraph("<b>Röntgenbild Analyse:</b>", subheading_style))
# Save the result image temporarily to a buffer to be added to PDF
img_buffer = io.BytesIO()
result_image.save(img_buffer, format="PNG")
img_buffer.seek(0)
img_rl = ReportLabImage(img_buffer)
# Scale image to fit within page width while maintaining aspect ratio
img_width, img_height = img_rl.drawWidth, img_rl.drawHeight
aspect_ratio = img_height / img_width
max_width = 5 * inch # Adjust as needed for page layout
if img_width > max_width:
img_rl.drawWidth = max_width
img_rl.drawHeight = max_width * aspect_ratio
# Center the image
img_rl.hAlign = 'CENTER'
story.append(img_rl)
story.append(Spacer(1, 0.4 * inch))
# Final report text based on object detection
if filtered_preds:
story.append(Paragraph(
"<b>Die Analyse des Röntgenbildes zeigt eine mögliche Frakturlokalisation.</b>",
report_text_style
))
for pred in filtered_preds:
score = pred['score']
temp = 36.5 + (score * 2.5)
story.append(Paragraph(
f"Detektion: {translate_label(pred['label'])} mit {score:.1%} Konfidenz ({temp:.1f}°C)",
report_text_style
))
else:
story.append(Paragraph(
"<b>Basierend auf der Objektlokalisierungsanalyse wurde keine Fraktur mit ausreichender Konfidenz detektiert.</b>",
report_text_style
))
story.append(Spacer(1, 0.2 * inch))
story.append(Paragraph("Dies ist ein automatisch generierter Bericht und sollte von einem Arzt überprüft werden.", centered_style))
doc.build(story)
buffer.seek(0)
return StreamingResponse(buffer, media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename=Fraktur_Bericht_{patient_name.replace(' ', '_')}.pdf"})
except Exception as e:
return HTMLResponse(f"""
<!DOCTYPE html>
<html>
<head>
<title>Fehler</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.error-box {{
background: #fee2e2;
border: 1px solid #ef4444;
padding: 20px;
border-radius: 8px;
margin: 20px 0;
}}
</style>
</head>
<body>
<div class="container">
<div class="error-box">
<h3>Fehler</h3>
<p>{str(e)}</p>
</div>
<a href="/" class="button back-button">
← Zurück
<div class="button-progress"></div>
</a>
</div>
</body>
</html>
""")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)