import os import streamlit as st from PIL import Image import pandas as pd from datetime import datetime from transformers import pipeline import requests from geopy.geocoders import Nominatim import folium from streamlit_folium import st_folium import cv2 import numpy as np from huggingface_hub import snapshot_download st.set_page_config(page_title="Skin Cancer Dashboard", layout="wide") # --- Configuration --- # Ensure you have set your Hugging Face token as an environment variable: #export HF_TOKEN="YOUR_TOKEN_HERE" MODEL_NAME = "Anwarkh1/Skin_Cancer-Image_Classification" LLM_NAME = "google/flan-t5-xl" HF_TOKEN = os.environ.get("HF_TOKEN") DATA_DIR = "data/harvard_dataset" # Path where you download and unpack the Harvard Dataverse dataset DIARY_CSV = "diary.csv" # Initialize session state defaults if 'initialized' not in st.session_state: st.session_state['label'] = None st.session_state['score'] = None st.session_state['mole_id'] = '' st.session_state['geo_location'] = '' st.session_state['chat_history'] = [] st.session_state['initialized'] = True # Initialize geolocator for free geocoding geolocator = Nominatim(user_agent="skin-dashboard", timeout = 10) @st.cache_resource def load_image_model(token: str): # 1) load the feature extractor from the Hub as usual extractor = AutoFeatureExtractor.from_pretrained( MODEL_NAME, use_auth_token=token ) # 2) manually create a ConvNextConfig with the right num_labels / id2label config = ConvNextConfig( num_labels=2, id2label={0: "benign", 1: "malignant"}, label2id={"benign": 0, "malignant": 1} ) # 3) load the weights with that config override model = AutoModelForImageClassification.from_pretrained( MODEL_NAME, config=config, use_auth_token=token ) # 4) build your pipeline return pipeline( "image-classification", model=model, feature_extractor=extractor, device=0 # or -1 for CPU ) @st.cache_resource def load_llm(token: str): return pipeline( "text2text-generation", model=LLM_NAME, device_map="auto", # or device=0 for single GPU / -1 for CPU max_length=10000, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, ) classifier = load_image_model(HF_TOKEN) if HF_TOKEN else None explainer = load_llm(HF_TOKEN) if HF_TOKEN else None # --- Diary Init ---- if not os.path.exists(DIARY_CSV): pd.DataFrame( columns=["timestamp", "image_path", "mole_id", "geo_location", "label", "score", "body_location", "prior_consultation", "pain", "itch"] ).to_csv(DIARY_CSV, index=False) # --- Save entry helper def save_entry(img_path: str, mole_id: str, geo_location: str, label: str, score: float, body_location: str, prior_consult: str, pain: str, itch: str): df = pd.read_csv(DIARY_CSV) entry = { "timestamp": datetime.now().isoformat(), "image_path": img_path, "mole_id": mole_id, "geo_location": geo_location, "label": label, "score": float(score), "body_location": body_location, "prior_consultation": prior_consult, "pain": pain, "itch": itch } df.loc[len(df)] = entry df.to_csv(DIARY_CSV, index=False) # --- Preprocessing Functions --- def remove_hair(img: np.ndarray) -> np.ndarray: gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (17, 17)) blackhat = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, kernel) _, mask = cv2.threshold(blackhat, 10, 255, cv2.THRESH_BINARY) return cv2.inpaint(img, mask, 1, cv2.INPAINT_TELEA) def preprocess(img: Image.Image, size: int = 224) -> Image.Image: arr = np.array(img) bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) bgr = remove_hair(bgr) bgr = cv2.bilateralFilter(bgr, d=9, sigmaColor=75, sigmaSpace=75) lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) cl = clahe.apply(l) merged = cv2.merge((cl, a, b)) bgr = cv2.cvtColor(merged, cv2.COLOR_LAB2BGR) h, w = bgr.shape[:2] scale = size / max(h, w) nh, nw = int(h*scale), int(w*scale) resized = cv2.resize(bgr, (nw, nh), interpolation=cv2.INTER_AREA) canvas = np.full((size, size, 3), 128, dtype=np.uint8) top, left = (size-nh)//2, (size-nw)//2 canvas[top:top+nh, left:left+nw] = resized rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB) return Image.fromarray(rgb) # -----Streamlit layout ---- st.title("🩺 Skin Cancer Recognition Dashboard") menu = ["Scan Mole","Chat","Diary", "Dataset Explorer"] choice = st.sidebar.selectbox("Navigation", menu) # --- Initialize Scan a Mole --- if choice == "Scan Mole": st.header("🔍 Scan a Mole") if not classifier: st.error("Missing HF_TOKEN.") st.stop() upload = st.file_uploader("Upload a skin image", type=["jpg","jpeg","png"]) if not upload: st.info("Please upload an image to begin.") st.stop() raw = Image.open(upload).convert("RGB") st.image(raw, caption="Original", use_container_width=True) proc = preprocess(raw) st.image(proc, caption="Preprocessed", use_container_width=True) mole = st.text_input("Mole ID") city = st.text_input("Geographic location") body = st.selectbox("Body location", ["Face","Scalp","Neck","Chest","Back","Arm","Hand","Leg","Foot","Other"]) prior = st.radio("Prior consult?", ["Yes","No"], horizontal=True) pain = st.radio("Pain?", ["Yes","No"], horizontal=True) itch = st.radio("Itch?", ["Yes","No"], horizontal=True) if st.button("Classify"): if not mole or not city: st.error("Enter ID and location.") else: with st.spinner("Analyzing..."): out = classifier(proc) lbl, scr = out[0]["label"], out[0]["score"] save_dir = os.path.join("scans", f"{mole}_{datetime.now().timestamp()}.png") os.makedirs(os.path.dirname(save_dir), exist_ok=True) raw.save(save_dir) save_entry(save_dir, mole, city, lbl, scr, body, prior, pain, itch) st.session_state.update({ 'label': lbl, 'score': scr, 'mole_id': mole, 'geo_location': city }) if st.session_state['label']: st.success(f"Prediction: {st.session_state['label']} (score {st.session_state['score']:.2f})") if explainer: with st.spinner("Explaining..."): text = explainer(f"Explain {st.session_state['label']} and recommendation.")[0]['generated_text'] st.markdown("### Explanation"); st.write(text) loc = geolocator.geocode(st.session_state['geo_location']) if loc: m = folium.Map([loc.latitude, loc.longitude], zoom_start=12) folium.Marker([loc.latitude, loc.longitude], "You").add_to(m) resp = requests.post( "https://overpass-api.de/api/interpreter", data={"data": f"[out:json];node(around:5000,{loc.latitude},{loc.longitude})[~\"^(amenity|healthcare)$\"~\"clinic|doctors\"];out;"} ) for el in resp.json().get('elements', []): tags = el.get('tags', {}); lat = el.get('lat') or el['center']['lat']; lon = el.get('lon') or el['center']['lon'] folium.Marker([lat, lon], tags.get('name','Clinic')).add_to(m) st.markdown("### Nearby Clinics"); st_folium(m, width=700) # --- Chat Tab --- elif choice == "Chat": st.header("💬 Follow-Up Chat") if not st.session_state['label']: st.info("Please perform a scan first in the 'Scan Mole' tab.") else: lbl = st.session_state['label'] scr = st.session_state['score'] mid = st.session_state['mole_id'] gloc = st.session_state['geo_location'] st.markdown(f"**Context:** prediction for **{mid}** at **{gloc}** is **{lbl}** (confidence {scr:.2f}).") # New user message comes first for immediate loop user_q = st.chat_input("Ask a follow-up question:", key="chat_input") if user_q and explainer: st.session_state['chat_history'].append({'role':'user','content':user_q}) system_p = "You are a dermatology assistant. Provide concise medical advice without clarifying questions." tpl = ( f"{system_p}\nContext: prediction is {lbl} with confidence {scr:.2f}.\n" f"User: {user_q}\nAssistant:" ) with st.spinner("Generating response..."): reply = explainer(tpl)[0]['generated_text'] st.session_state['chat_history'].append({'role':'assistant','content':reply}) # Display the updated chat history for msg in st.session_state['chat_history']: prefix = 'You' if msg['role']=='user' else 'AI' st.markdown(f"**{prefix}:** {msg['content']}") # --- Diary Page --- elif choice == "Diary": st.header("📖 Skin Cancer Diary") df = pd.read_csv(DIARY_CSV) df['timestamp'] = pd.to_datetime(df['timestamp']) if df.empty: st.info("No diary entries yet.") else: mole_ids = sorted(df['mole_id'].unique()) sel = st.selectbox("Select Mole to View", ['All'] + mole_ids, key="diary_sel") if sel == 'All': # Display moles in columns (max 3 per row) chunks = [mole_ids[i:i+3] for i in range(0, len(mole_ids), 3)] for group in chunks: cols = st.columns(len(group)) for col, mid in zip(cols, group): with col: st.subheader(mid) entries = df[df['mole_id'] == mid].sort_values('timestamp') # Show image timeline for _, row in entries.iterrows(): if os.path.exists(row['image_path']): st.image( row['image_path'], width=150, caption=f"{row['timestamp'].strftime('%Y-%m-%d')} — {row['score']:.2f}" ) st.write(f"Total scans: {len(entries)}") else: # Detailed view for a single mole entries = df[df['mole_id'] == sel].sort_values('timestamp') if entries.empty: st.warning(f"No entries for {sel}.") else: # Score over time st.line_chart(entries.set_index('timestamp')['score']) st.markdown("#### Image Timeline") for _, row in entries.iterrows(): if os.path.exists(row['image_path']): st.image( row['image_path'], width=200, caption=( f"{row['timestamp'].strftime('%Y-%m-%d %H:%M')} — " f"Score: {row['score']:.2f}" ) ) st.markdown("#### Details") st.dataframe( entries[ ['timestamp','geo_location','label','score', 'body_location','prior_consultation','pain','itch'] ] .rename(columns={ 'timestamp':'Time','geo_location':'Location', 'label':'Diagnosis','score':'Confidence', 'body_location':'Body Part','prior_consultation':'Prior Consult', 'pain':'Pain','itch':'Itch' }) .sort_values('Time', ascending=False) ) else: st.header("📂 Dataset Explorer") st.write("Preview images from the Harvard Skin Cancer Dataset") # pick up to 15 image files image_files = [ f for f in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, f)) and f.lower().endswith((".jpg", ".jpeg", ".png")) ][:15] for i in range(0, len(image_files), 3): cols = st.columns(3) for col, fn in zip(cols, image_files[i : i + 3]): path = os.path.join(DATA_DIR, fn) img = Image.open(path) col.image(img, use_container_width=True) col.caption(fn) st.sidebar.markdown("---") st.sidebar.write("Dataset powered by Harvard Dataverse [DBW86T]") st.sidebar.write(f"Model: {MODEL_NAME}") st.sidebar.write(f"LLM: {LLM_NAME}") if __name__ == '__main__': st.write()