radpid / app.py
yassonee's picture
Update app.py
7b58439 verified
raw
history blame
2.41 kB
import streamlit as st
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
import numpy as np
def create_overlay(image, attention_map, alpha=0.5):
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
heatmap = np.uint8(255 * attention_map)
heatmap = Image.fromarray(heatmap).resize(image.size)
heatmap = np.array(heatmap)
heatmap = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
image_array = np.array(image)
overlay = Image.fromarray(np.uint8(image_array * (1 - alpha) + heatmap * alpha))
return overlay
@st.cache_resource
def load_model():
processor = AutoImageProcessor.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
model = AutoModelForImageClassification.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
return processor, model
def main():
st.title("Détection de Pneumothorax")
processor, model = load_model()
uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert('RGB')
resized_image = image.resize((224, 224))
st.image(resized_image, width=400)
if st.button("Analyser"):
with st.spinner("Analyse en cours..."):
inputs = processor(images=resized_image, return_tensors="pt")
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Obtenir attentions des dernières couches
attention = outputs.hidden_states[-1].mean(1)[0].detach().numpy()
attention_map = attention.reshape(14, 14) # ViT patch size
# Créer overlay
overlay = create_overlay(resized_image, attention_map)
col1, col2 = st.columns(2)
with col1:
st.write("Résultat:", model.config.id2label[outputs.logits.argmax(-1).item()])
st.write(f"Confiance: {probs.max().item():.2%}")
with col2:
st.image(overlay, caption="Zones suspectes", width=400)
if __name__ == "__main__":
main()