Update app.py
Browse files
app.py
CHANGED
@@ -1,58 +1,33 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import
|
3 |
-
import torch
|
4 |
from PIL import Image
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
def create_overlay(image, attention_map, alpha=0.5):
|
8 |
-
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
|
9 |
-
heatmap = np.uint8(255 * attention_map)
|
10 |
-
heatmap = Image.fromarray(heatmap).resize(image.size)
|
11 |
-
heatmap = np.array(heatmap)
|
12 |
-
heatmap = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
|
13 |
-
|
14 |
-
image_array = np.array(image)
|
15 |
-
overlay = Image.fromarray(np.uint8(image_array * (1 - alpha) + heatmap * alpha))
|
16 |
-
return overlay
|
17 |
|
18 |
@st.cache_resource
|
19 |
def load_model():
|
20 |
-
|
21 |
-
model = AutoModelForImageClassification.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
|
22 |
-
return processor, model
|
23 |
|
24 |
def main():
|
25 |
st.title("Détection de Pneumothorax")
|
26 |
|
27 |
-
|
28 |
|
29 |
uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
|
30 |
|
31 |
if uploaded_file:
|
32 |
image = Image.open(uploaded_file).convert('RGB')
|
33 |
resized_image = image.resize((224, 224))
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
if st.button("Analyser"):
|
37 |
with st.spinner("Analyse en cours..."):
|
38 |
-
|
39 |
-
outputs = model(**inputs)
|
40 |
-
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
41 |
-
|
42 |
-
# Obtenir attentions des dernières couches
|
43 |
-
attention = outputs.hidden_states[-1].mean(1)[0].detach().numpy()
|
44 |
-
attention_map = attention.reshape(14, 14) # ViT patch size
|
45 |
-
|
46 |
-
# Créer overlay
|
47 |
-
overlay = create_overlay(resized_image, attention_map)
|
48 |
-
|
49 |
-
col1, col2 = st.columns(2)
|
50 |
-
with col1:
|
51 |
-
st.write("Résultat:", model.config.id2label[outputs.logits.argmax(-1).item()])
|
52 |
-
st.write(f"Confiance: {probs.max().item():.2%}")
|
53 |
-
|
54 |
with col2:
|
55 |
-
st.
|
|
|
|
|
56 |
|
57 |
if __name__ == "__main__":
|
58 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline
|
|
|
3 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
@st.cache_resource
|
6 |
def load_model():
|
7 |
+
return pipeline("image-classification", model="mrm8488/vit-base-patch16-224_finetuned-pneumothorax", top_k=2)
|
|
|
|
|
8 |
|
9 |
def main():
|
10 |
st.title("Détection de Pneumothorax")
|
11 |
|
12 |
+
model = load_model()
|
13 |
|
14 |
uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
|
15 |
|
16 |
if uploaded_file:
|
17 |
image = Image.open(uploaded_file).convert('RGB')
|
18 |
resized_image = image.resize((224, 224))
|
19 |
+
|
20 |
+
col1, col2 = st.columns(2)
|
21 |
+
with col1:
|
22 |
+
st.image(resized_image, width=300, caption="Image originale")
|
23 |
|
24 |
if st.button("Analyser"):
|
25 |
with st.spinner("Analyse en cours..."):
|
26 |
+
results = model(resized_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
with col2:
|
28 |
+
st.write("Résultats:")
|
29 |
+
for result in results:
|
30 |
+
st.write(f"{result['label']}: {result['score']:.2%}")
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
main()
|