|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image |
|
import torch |
|
|
|
st.set_page_config(page_title="Aide au diagnostic radiologique", layout="wide") |
|
|
|
def load_models(): |
|
models = { |
|
'Fracture': "kathleen/vit-base-fracture-detection", |
|
'Pneumothorax': "nickmuchi/pneumothorax-detection-vit", |
|
'Pneumonie': "nickmuchi/chest-xray-pneumonia-detection" |
|
} |
|
|
|
loaded_models = {} |
|
for name, model_id in models.items(): |
|
loaded_models[name] = pipeline("image-classification", model=model_id) |
|
return loaded_models |
|
|
|
@st.cache_resource |
|
def get_models(): |
|
return load_models() |
|
|
|
def main(): |
|
st.title("Assistant de diagnostic radiologique") |
|
|
|
models = get_models() |
|
|
|
uploaded_file = st.file_uploader("Télécharger une image radiologique", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Image téléchargée", use_column_width=True) |
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1: |
|
if st.button("Détecter Fracture"): |
|
with st.spinner("Analyse en cours..."): |
|
result = models['Fracture'](image) |
|
st.write(f"Résultat: {result[0]['label']}") |
|
st.write(f"Confiance: {result[0]['score']:.2%}") |
|
|
|
with col2: |
|
if st.button("Détecter Pneumothorax"): |
|
with st.spinner("Analyse en cours..."): |
|
result = models['Pneumothorax'](image) |
|
st.write(f"Résultat: {result[0]['label']}") |
|
st.write(f"Confiance: {result[0]['score']:.2%}") |
|
|
|
with col3: |
|
if st.button("Détecter Pneumonie"): |
|
with st.spinner("Analyse en cours..."): |
|
result = models['Pneumonie'](image) |
|
st.write(f"Résultat: {result[0]['label']}") |
|
st.write(f"Confiance: {result[0]['score']:.2%}") |
|
|
|
if __name__ == "__main__": |
|
main() |