File size: 2,406 Bytes
a326b94 7b58439 6b115d2 7b58439 005d8cf 7b58439 005d8cf 07f59cf 7b58439 005d8cf 7b58439 005d8cf 7b58439 005d8cf c1026ff 005d8cf 7b58439 c1026ff 005d8cf 07f59cf 7b58439 005d8cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import streamlit as st
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
import numpy as np
def create_overlay(image, attention_map, alpha=0.5):
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
heatmap = np.uint8(255 * attention_map)
heatmap = Image.fromarray(heatmap).resize(image.size)
heatmap = np.array(heatmap)
heatmap = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
image_array = np.array(image)
overlay = Image.fromarray(np.uint8(image_array * (1 - alpha) + heatmap * alpha))
return overlay
@st.cache_resource
def load_model():
processor = AutoImageProcessor.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
model = AutoModelForImageClassification.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
return processor, model
def main():
st.title("Détection de Pneumothorax")
processor, model = load_model()
uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert('RGB')
resized_image = image.resize((224, 224))
st.image(resized_image, width=400)
if st.button("Analyser"):
with st.spinner("Analyse en cours..."):
inputs = processor(images=resized_image, return_tensors="pt")
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Obtenir attentions des dernières couches
attention = outputs.hidden_states[-1].mean(1)[0].detach().numpy()
attention_map = attention.reshape(14, 14) # ViT patch size
# Créer overlay
overlay = create_overlay(resized_image, attention_map)
col1, col2 = st.columns(2)
with col1:
st.write("Résultat:", model.config.id2label[outputs.logits.argmax(-1).item()])
st.write(f"Confiance: {probs.max().item():.2%}")
with col2:
st.image(overlay, caption="Zones suspectes", width=400)
if __name__ == "__main__":
main() |