|
import streamlit as st |
|
from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
|
|
def create_overlay(image, attention_map, alpha=0.5): |
|
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) |
|
heatmap = np.uint8(255 * attention_map) |
|
heatmap = Image.fromarray(heatmap).resize(image.size) |
|
heatmap = np.array(heatmap) |
|
heatmap = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1) |
|
|
|
image_array = np.array(image) |
|
overlay = Image.fromarray(np.uint8(image_array * (1 - alpha) + heatmap * alpha)) |
|
return overlay |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
processor = AutoImageProcessor.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax") |
|
model = AutoModelForImageClassification.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax") |
|
return processor, model |
|
|
|
def main(): |
|
st.title("Détection de Pneumothorax") |
|
|
|
processor, model = load_model() |
|
|
|
uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file: |
|
image = Image.open(uploaded_file).convert('RGB') |
|
resized_image = image.resize((224, 224)) |
|
st.image(resized_image, width=400) |
|
|
|
if st.button("Analyser"): |
|
with st.spinner("Analyse en cours..."): |
|
inputs = processor(images=resized_image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
attention = outputs.hidden_states[-1].mean(1)[0].detach().numpy() |
|
attention_map = attention.reshape(14, 14) |
|
|
|
|
|
overlay = create_overlay(resized_image, attention_map) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.write("Résultat:", model.config.id2label[outputs.logits.argmax(-1).item()]) |
|
st.write(f"Confiance: {probs.max().item():.2%}") |
|
|
|
with col2: |
|
st.image(overlay, caption="Zones suspectes", width=400) |
|
|
|
if __name__ == "__main__": |
|
main() |