yassonee commited on
Commit
aa66005
·
verified ·
1 Parent(s): c7e5681

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -36
app.py CHANGED
@@ -1,58 +1,33 @@
1
  import streamlit as st
2
- from transformers import AutoImageProcessor, AutoModelForImageClassification
3
- import torch
4
  from PIL import Image
5
- import numpy as np
6
-
7
- def create_overlay(image, attention_map, alpha=0.5):
8
- attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
9
- heatmap = np.uint8(255 * attention_map)
10
- heatmap = Image.fromarray(heatmap).resize(image.size)
11
- heatmap = np.array(heatmap)
12
- heatmap = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
13
-
14
- image_array = np.array(image)
15
- overlay = Image.fromarray(np.uint8(image_array * (1 - alpha) + heatmap * alpha))
16
- return overlay
17
 
18
  @st.cache_resource
19
  def load_model():
20
- processor = AutoImageProcessor.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
21
- model = AutoModelForImageClassification.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
22
- return processor, model
23
 
24
  def main():
25
  st.title("Détection de Pneumothorax")
26
 
27
- processor, model = load_model()
28
 
29
  uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
30
 
31
  if uploaded_file:
32
  image = Image.open(uploaded_file).convert('RGB')
33
  resized_image = image.resize((224, 224))
34
- st.image(resized_image, width=400)
 
 
 
35
 
36
  if st.button("Analyser"):
37
  with st.spinner("Analyse en cours..."):
38
- inputs = processor(images=resized_image, return_tensors="pt")
39
- outputs = model(**inputs)
40
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
41
-
42
- # Obtenir attentions des dernières couches
43
- attention = outputs.hidden_states[-1].mean(1)[0].detach().numpy()
44
- attention_map = attention.reshape(14, 14) # ViT patch size
45
-
46
- # Créer overlay
47
- overlay = create_overlay(resized_image, attention_map)
48
-
49
- col1, col2 = st.columns(2)
50
- with col1:
51
- st.write("Résultat:", model.config.id2label[outputs.logits.argmax(-1).item()])
52
- st.write(f"Confiance: {probs.max().item():.2%}")
53
-
54
  with col2:
55
- st.image(overlay, caption="Zones suspectes", width=400)
 
 
56
 
57
  if __name__ == "__main__":
58
  main()
 
1
  import streamlit as st
2
+ from transformers import pipeline
 
3
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  @st.cache_resource
6
  def load_model():
7
+ return pipeline("image-classification", model="mrm8488/vit-base-patch16-224_finetuned-pneumothorax", top_k=2)
 
 
8
 
9
  def main():
10
  st.title("Détection de Pneumothorax")
11
 
12
+ model = load_model()
13
 
14
  uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
15
 
16
  if uploaded_file:
17
  image = Image.open(uploaded_file).convert('RGB')
18
  resized_image = image.resize((224, 224))
19
+
20
+ col1, col2 = st.columns(2)
21
+ with col1:
22
+ st.image(resized_image, width=300, caption="Image originale")
23
 
24
  if st.button("Analyser"):
25
  with st.spinner("Analyse en cours..."):
26
+ results = model(resized_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  with col2:
28
+ st.write("Résultats:")
29
+ for result in results:
30
+ st.write(f"{result['label']}: {result['score']:.2%}")
31
 
32
  if __name__ == "__main__":
33
  main()