Sobit's picture
Update app.py
eaaa174 verified
import streamlit as st
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from io import BytesIO
import os
import google.generativeai as genai
from vit_model import VisionTransformer
# Load class names
CLASS_NAMES = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']
# Configure Google Gemini API
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=GEMINI_API_KEY)
# Load the model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer(img_size=128, patch_size=8, num_classes=len(CLASS_NAMES), embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1)
model.load_state_dict(torch.load("models/custom_vit.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()
# Function to preprocess images for ViT
def preprocess_image(image, target_size=(128, 128)):
transform = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
return transform(image).unsqueeze(0).to(DEVICE)
# Function for inference using ViT
def vit_inference(image):
"""Predicts crop disease using the custom ViT model."""
input_tensor = preprocess_image(image)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
predicted_class = torch.argmax(probabilities).item()
return CLASS_NAMES[predicted_class], probabilities[predicted_class].item()
# Generate response from Gemini AI
def generate_gemini_response(disease_name, user_context):
"""Generate a structured diagnosis using Gemini API."""
try:
model = genai.GenerativeModel("gemini-1.5-pro")
prompt = f"""
You are an expert plant pathologist. The detected crop disease is: {disease_name}.
User-provided context: {user_context}
Provide a structured analysis including:
- Pathogen details
- Severity level
- Symptoms
- Economic impact
- Treatment options (short-term and long-term)
- Prevention strategies
"""
response = model.generate_content(prompt)
return response.text if response else "No response from Gemini."
except Exception as e:
return f"Error connecting to Gemini API: {str(e)}"
# Initialize Streamlit app
st.title("AI-Powered Crop Disease Detection with User Context")
# Upload image
uploaded_file = st.file_uploader("πŸ“€ Upload a plant image", type=["jpg", "jpeg", "png"])
user_context = st.text_area("πŸ“ Additional Context (Optional)", "Describe symptoms, location, or any observations...")
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
# Run ViT inference
with st.spinner("Analyzing with Vision Transformer... πŸ”"):
predicted_class, confidence = vit_inference(image)
st.write(f"**Detected Disease:** {predicted_class} (Confidence: {confidence:.2f})")
# Connect to Gemini AI for diagnosis
with st.spinner("Generating diagnosis with Gemini AI... πŸ’‘"):
diagnosis = generate_gemini_response(predicted_class, user_context)
st.subheader("πŸ“‹ AI Diagnosis")
st.write(diagnosis)
# Instructions for users
st.markdown("""
---
### How to Use:
1. Upload an image of a plant leaf with suspected disease.
2. Provide context (optional) about symptoms or concerns.
3. The system detects the disease using AI.
4. Gemini generates a diagnosis with symptoms and treatments.
""")