File size: 2,406 Bytes
a326b94
7b58439
6b115d2
7b58439
 
005d8cf
7b58439
 
 
 
 
 
 
 
 
 
005d8cf
 
07f59cf
7b58439
 
 
005d8cf
 
7b58439
005d8cf
7b58439
005d8cf
c1026ff
005d8cf
 
7b58439
 
c1026ff
005d8cf
07f59cf
7b58439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
005d8cf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
import numpy as np

def create_overlay(image, attention_map, alpha=0.5):
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    heatmap = np.uint8(255 * attention_map)
    heatmap = Image.fromarray(heatmap).resize(image.size)
    heatmap = np.array(heatmap)
    heatmap = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
    
    image_array = np.array(image)
    overlay = Image.fromarray(np.uint8(image_array * (1 - alpha) + heatmap * alpha))
    return overlay

@st.cache_resource
def load_model():
    processor = AutoImageProcessor.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
    model = AutoModelForImageClassification.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
    return processor, model

def main():
    st.title("Détection de Pneumothorax")
    
    processor, model = load_model()
    
    uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
    
    if uploaded_file:
        image = Image.open(uploaded_file).convert('RGB')
        resized_image = image.resize((224, 224))
        st.image(resized_image, width=400)
        
        if st.button("Analyser"):
            with st.spinner("Analyse en cours..."):
                inputs = processor(images=resized_image, return_tensors="pt")
                outputs = model(**inputs)
                probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
                
                # Obtenir attentions des dernières couches
                attention = outputs.hidden_states[-1].mean(1)[0].detach().numpy()
                attention_map = attention.reshape(14, 14)  # ViT patch size
                
                # Créer overlay
                overlay = create_overlay(resized_image, attention_map)
                
                col1, col2 = st.columns(2)
                with col1:
                    st.write("Résultat:", model.config.id2label[outputs.logits.argmax(-1).item()])
                    st.write(f"Confiance: {probs.max().item():.2%}")
                
                with col2:
                    st.image(overlay, caption="Zones suspectes", width=400)

if __name__ == "__main__":
    main()