import streamlit as st import numpy as np import cv2 from PIL import Image import torch import torch.nn as nn import torchvision.transforms as transforms from io import BytesIO import os import google.generativeai as genai from vit_model import VisionTransformer # Load class names CLASS_NAMES = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy'] # Configure Google Gemini API GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") genai.configure(api_key=GEMINI_API_KEY) # Load the model DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = VisionTransformer(img_size=128, patch_size=8, num_classes=len(CLASS_NAMES), embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1) model.load_state_dict(torch.load("models/custom_vit.pth", map_location=DEVICE)) model.to(DEVICE) model.eval() # Function to preprocess images for ViT def preprocess_image(image, target_size=(128, 128)): transform = transforms.Compose([ transforms.Resize(target_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) return transform(image).unsqueeze(0).to(DEVICE) # Function for inference using ViT def vit_inference(image): """Predicts crop disease using the custom ViT model.""" input_tensor = preprocess_image(image) with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) predicted_class = torch.argmax(probabilities).item() return CLASS_NAMES[predicted_class], probabilities[predicted_class].item() # Generate response from Gemini AI def generate_gemini_response(disease_name, user_context): """Generate a structured diagnosis using Gemini API.""" try: model = genai.GenerativeModel("gemini-1.5-pro") prompt = f""" You are an expert plant pathologist. The detected crop disease is: {disease_name}. User-provided context: {user_context} Provide a structured analysis including: - Pathogen details - Severity level - Symptoms - Economic impact - Treatment options (short-term and long-term) - Prevention strategies """ response = model.generate_content(prompt) return response.text if response else "No response from Gemini." except Exception as e: return f"Error connecting to Gemini API: {str(e)}" # Initialize Streamlit app st.title("AI-Powered Crop Disease Detection with User Context") # Upload image uploaded_file = st.file_uploader("📤 Upload a plant image", type=["jpg", "jpeg", "png"]) user_context = st.text_area("📝 Additional Context (Optional)", "Describe symptoms, location, or any observations...") if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_container_width=True) # Run ViT inference with st.spinner("Analyzing with Vision Transformer... 🔍"): predicted_class, confidence = vit_inference(image) st.write(f"**Detected Disease:** {predicted_class} (Confidence: {confidence:.2f})") # Connect to Gemini AI for diagnosis with st.spinner("Generating diagnosis with Gemini AI... 💡"): diagnosis = generate_gemini_response(predicted_class, user_context) st.subheader("📋 AI Diagnosis") st.write(diagnosis) # Instructions for users st.markdown(""" --- ### How to Use: 1. Upload an image of a plant leaf with suspected disease. 2. Provide context (optional) about symptoms or concerns. 3. The system detects the disease using AI. 4. Gemini generates a diagnosis with symptoms and treatments. """)