|
import streamlit as st |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from io import BytesIO |
|
import os |
|
import google.generativeai as genai |
|
from vit_model import VisionTransformer |
|
|
|
|
|
CLASS_NAMES = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy'] |
|
|
|
|
|
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
genai.configure(api_key=GEMINI_API_KEY) |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = VisionTransformer(img_size=128, patch_size=8, num_classes=len(CLASS_NAMES), embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1) |
|
model.load_state_dict(torch.load("models/custom_vit.pth", map_location=DEVICE)) |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
|
|
def preprocess_image(image, target_size=(128, 128)): |
|
transform = transforms.Compose([ |
|
transforms.Resize(target_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]) |
|
]) |
|
return transform(image).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
def vit_inference(image): |
|
"""Predicts crop disease using the custom ViT model.""" |
|
input_tensor = preprocess_image(image) |
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
predicted_class = torch.argmax(probabilities).item() |
|
return CLASS_NAMES[predicted_class], probabilities[predicted_class].item() |
|
|
|
|
|
def generate_gemini_response(disease_name, user_context): |
|
"""Generate a structured diagnosis using Gemini API.""" |
|
try: |
|
model = genai.GenerativeModel("gemini-1.5-pro") |
|
prompt = f""" |
|
You are an expert plant pathologist. The detected crop disease is: {disease_name}. |
|
|
|
User-provided context: {user_context} |
|
|
|
Provide a structured analysis including: |
|
- Pathogen details |
|
- Severity level |
|
- Symptoms |
|
- Economic impact |
|
- Treatment options (short-term and long-term) |
|
- Prevention strategies |
|
""" |
|
response = model.generate_content(prompt) |
|
return response.text if response else "No response from Gemini." |
|
except Exception as e: |
|
return f"Error connecting to Gemini API: {str(e)}" |
|
|
|
|
|
st.title("AI-Powered Crop Disease Detection with User Context") |
|
|
|
|
|
uploaded_file = st.file_uploader("π€ Upload a plant image", type=["jpg", "jpeg", "png"]) |
|
user_context = st.text_area("π Additional Context (Optional)", "Describe symptoms, location, or any observations...") |
|
|
|
if uploaded_file: |
|
image = Image.open(uploaded_file).convert("RGB") |
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
|
|
|
|
with st.spinner("Analyzing with Vision Transformer... π"): |
|
predicted_class, confidence = vit_inference(image) |
|
|
|
st.write(f"**Detected Disease:** {predicted_class} (Confidence: {confidence:.2f})") |
|
|
|
|
|
with st.spinner("Generating diagnosis with Gemini AI... π‘"): |
|
diagnosis = generate_gemini_response(predicted_class, user_context) |
|
st.subheader("π AI Diagnosis") |
|
st.write(diagnosis) |
|
|
|
|
|
st.markdown(""" |
|
--- |
|
### How to Use: |
|
1. Upload an image of a plant leaf with suspected disease. |
|
2. Provide context (optional) about symptoms or concerns. |
|
3. The system detects the disease using AI. |
|
4. Gemini generates a diagnosis with symptoms and treatments. |
|
""") |
|
|