import streamlit as st import cv2 import numpy as np import os import json import joblib import pickle from typing import Dict, Any from sentence_transformers import SentenceTransformer, CrossEncoder from langdetect import detect from tensorflow.keras.models import Model, load_model from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess from vit_keras.layers import ClassToken, AddPositionEmbs, TransformerBlock # ================== CACHING ================== @st.cache_resource def load_all_models(): cnn_model = load_model("Main_py/banana_cnn_model.keras", compile=False) vit_model = load_model( "Main_py/banana_vit_model.keras", compile=False, custom_objects={ 'ClassToken': ClassToken, 'AddPositionEmbs': AddPositionEmbs, 'TransformerBlock': TransformerBlock } ) cnn_feat_ext = Model(inputs=cnn_model.input, outputs=cnn_model.get_layer(index=-4).output) vit_feat_ext = Model(inputs=vit_model.input, outputs=vit_model.get_layer(index=-4).output) return cnn_model, vit_model, cnn_feat_ext, vit_feat_ext @st.cache_resource def load_all_assets(): scaler = joblib.load("Main_py/feature_scaler.pkl") mlp_model = joblib.load("Main_py/lightgbm_model.pkl") outlier_detector = joblib.load("Main_py/isolation_forest.pkl") with open("Main_py/label_encoder.pkl", "rb") as f: le = pickle.load(f) with open("Main_py/banana_disease_knowledge_base_DL.json", "r", encoding="utf-8") as f: kb_data_image = {entry["Disease"]: entry for entry in json.load(f)} with open("Main_py/banana_disease_knowledge_base.json", "r", encoding="utf-8") as f: kb_data_text = json.load(f) return scaler, mlp_model, le, kb_data_image, kb_data_text, outlier_detector @st.cache_resource def load_nlp_models(): embedder = SentenceTransformer('sentence-transformers/paraphrase-xlm-r-multilingual-v1') cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') return embedder, cross_encoder # ================== IMAGE DIAGNOSIS ================== def identify_disease_from_image(image_path): try: img = cv2.imread(image_path) if img is None: raise ValueError("Image not loaded.") img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_resized = cv2.resize(img_rgb, (224, 224)) cnn_input = np.expand_dims(img_resized / 255.0, axis=0) vit_input = np.expand_dims(mobilenet_preprocess(img_resized), axis=0) cnn_feat = cnn_feature_extractor.predict(cnn_input, verbose=0) vit_feat = vit_feature_extractor.predict(vit_input, verbose=0) combined_feat = np.concatenate([cnn_feat, vit_feat], axis=1) combined_scaled = scaler.transform(combined_feat) y_pred = mlp_model.predict(combined_scaled) predicted_label = le.inverse_transform(y_pred)[0] confidence = None try: probs = mlp_model.predict_proba(combined_scaled) confidence = np.max(probs) except: probs = None st.image(img_rgb, caption="Uploaded Image", use_column_width=True) st.write(f"**Predicted Disease**: {predicted_label} ({confidence:.2f} confidence)" if confidence else predicted_label) result = { "predicted_disease": predicted_label, "confidence": confidence, "info_available": False, "all_probabilities": probs[0].tolist() if probs is not None else None } normalized_pred = predicted_label.lower().replace(" ", "") for disease in kb_data_image: if normalized_pred in disease.lower().replace(" ", ""): matched = kb_data_image[disease] result["info_available"] = True st.subheader("Image-Based Prediction (Marathi)") st.write(f"**रोग**: {matched['Local_Name']['mr']}") st.write(f"**लक्षणे**: {matched['Symptoms_MR']}") st.write(f"**कारण**: {matched['Cause_MR']}") st.write(f"**कीटकनाशक शिफारस**: {matched['Pesticide_Recommendation_MR']}") st.write(f"**कीटकनाशक**: {matched['Pesticide']}") st.write(f"**परजीवी**: {matched['Pathogen']}") st.write(f"**व्यवस्थापन उपाय**: {matched['Management_Practices_MR']}") break else: st.warning("❌ Disease not found in knowledge base.") return result except Exception as e: st.error(f"Error: {e}") return {"error": str(e), "predicted_disease": None} # ================== TEXT DIAGNOSIS ================== def detect_language(query: str) -> str: try: lang = detect(query) return lang if lang in ["mr", "hi"] else "en" except: return "en" def predict_disease(query: str) -> Dict[str, Any]: lang = detect_language(query) query_emb = embedder.encode([query], normalize_embeddings=True) symptom_key = f"Symptoms_{lang.upper()}" if lang != "en" else "Symptoms" pairs = [[query, entry[symptom_key]] for entry in kb_data_text] scores = cross_encoder.predict(pairs) best_idx = np.argmax(scores) if scores[best_idx] < 0.2: return { "message": { "mr": "हा रोग आमच्या डेटाबेसमध्ये नाही.", "hi": "यह रोग हमारे डेटाबेस में नहीं है।", "en": "This disease is not in our database." }[lang] } entry = kb_data_text[best_idx] return { "Crop": entry["Crop"], "Disease": entry["Local_Name"][lang], "Symptoms": entry[symptom_key], "Cause": entry.get(f"Cause_{lang.upper()}", entry["Cause"]), "Pesticide_Recommendation": entry.get(f"Pesticide_Recommendation_{lang.upper()}", entry["Pesticide_Recommendation"]), "Pesticide": entry["Pesticide"], "Pathogen": entry["Pathogen"], "Management_Practices": entry.get(f"Management_{lang.upper()}", entry["Management_Practices"]) } # ================== UI ================== st.set_page_config(page_title="Banana Disease Detection", layout="centered") st.title("🍌 Banana Disease Detection App") st.write("Detect banana crop diseases using image or symptom query in Marathi, Hindi, or English.") option = st.radio("Choose input method:", ("Image Only", "Text Only", "Both")) # Load all models & assets once cnn_model, vit_model, cnn_feature_extractor, vit_feature_extractor = load_all_models() scaler, mlp_model, le, kb_data_image, kb_data_text, outlier_detector = load_all_assets() embedder, cross_encoder = load_nlp_models() # Image input if option in ["Image Only", "Both"]: st.subheader("📷 Upload Banana Leaf Image") uploaded_image = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if uploaded_image: temp_path = "temp_image.jpg" with open(temp_path, "wb") as f: f.write(uploaded_image.getbuffer()) identify_disease_from_image(temp_path) os.remove(temp_path) # Text input if option in ["Text Only", "Both"]: st.subheader("📝 Enter Symptoms") symptoms = st.text_area("Describe the symptoms (Marathi, Hindi, or English):") if symptoms and st.button("Predict Disease from Text"): result = predict_disease(symptoms) if "message" in result: st.warning(result["message"]) else: st.subheader("Text-Based Prediction") st.write(f"**Crop**: {result['Crop']}") st.write(f"**Disease**: {result['Disease']}") st.write(f"**Symptoms**: {result['Symptoms']}") st.write(f"**Cause**: {result['Cause']}") st.write(f"**Pesticide Recommendation**: {result['Pesticide_Recommendation']}") st.write(f"**Pesticide**: {result['Pesticide']}") st.write(f"**Pathogen**: {result['Pathogen']}") st.write(f"**Management Practices**: {result['Management_Practices']}")