import streamlit as st import numpy as np import cv2 from PIL import Image import torch import torch.nn as nn import torchvision.transforms as transforms from io import BytesIO import os import google.generativeai as genai from vit_model import VisionTransformer # Load class names (should match training labels) CLASS_NAMES = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy'] # Update with actual class names # Configure Google Gemini API GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") genai.configure(api_key=GEMINI_API_KEY) # Load the model DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = VisionTransformer(img_size=128, patch_size=8, num_classes=len(CLASS_NAMES), embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1) model.load_state_dict(torch.load("models/custom_vit.pth", map_location=DEVICE)) model.to(DEVICE) model.eval() # Function to preprocess images for ViT def preprocess_image(image, target_size=(128, 128)): transform = transforms.Compose([ transforms.Resize(target_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) # Normalize to match training ]) return transform(image).unsqueeze(0).to(DEVICE) # Function for inference using ViT def vit_inference(image): """Predicts crop disease using the custom ViT model.""" input_tensor = preprocess_image(image) with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) predicted_class = torch.argmax(probabilities).item() return CLASS_NAMES[predicted_class], probabilities[predicted_class].item() # Generate response from Gemini AI def generate_gemini_response(disease_name): """Generate a structured diagnosis using Gemini API.""" try: model = genai.GenerativeModel("gemini-1.5-pro") prompt = f""" You are an expert plant pathologist. The detected crop disease is: {disease_name}. Provide a structured analysis including: - Pathogen details - Severity level - Symptoms - Economic impact - Treatment options (short-term and long-term) - Prevention strategies """ response = model.generate_content(prompt) return response.text if response else "No response from Gemini." except Exception as e: return f"Error connecting to Gemini API: {str(e)}" # Initialize Streamlit app st.title("🌱 AI-Powered Crop Disease Detection") # Upload image uploaded_file = st.file_uploader("📤 Upload a plant image", type=["jpg", "jpeg", "png"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_column_width=True) # Run ViT inference with st.spinner("Analyzing with Vision Transformer... 🔍"): predicted_class, confidence = vit_inference(image) st.write(f"✅ **Detected Disease:** {predicted_class} (Confidence: {confidence:.2f})") # Connect to Gemini AI for diagnosis with st.spinner("Generating diagnosis with Gemini AI... 💡"): diagnosis = generate_gemini_response(predicted_class) st.subheader("📋 AI Diagnosis") st.write(diagnosis) # Instructions for users st.markdown(""" --- ### How to Use: 1. Upload an image of a plant leaf with suspected disease. 2. Provide context (optional) about symptoms or concerns. 3. The system detects the disease using AI. 4. Gemini generates a diagnosis with symptoms and treatments. 5. Ask follow-up questions, and the AI will remember previous responses. """)