Sobit's picture
Update app.py
5610138 verified
raw
history blame
3.96 kB
import streamlit as st
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from io import BytesIO
import os
import google.generativeai as genai
from vit_model import VisionTransformer
# Load class names (should match training labels)
CLASS_NAMES = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy'] # Update with actual class names
# Configure Google Gemini API
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=GEMINI_API_KEY)
# Load the model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer(img_size=128, patch_size=8, num_classes=len(CLASS_NAMES), embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1)
model.load_state_dict(torch.load("custom_vit.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()
# Function to preprocess images for ViT
def preprocess_image(image, target_size=(128, 128)):
transform = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # Normalize to match training
])
return transform(image).unsqueeze(0).to(DEVICE)
# Function for inference using ViT
def vit_inference(image):
"""Predicts crop disease using the custom ViT model."""
input_tensor = preprocess_image(image)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
predicted_class = torch.argmax(probabilities).item()
return CLASS_NAMES[predicted_class], probabilities[predicted_class].item()
# Generate response from Gemini AI
def generate_gemini_response(disease_name):
"""Generate a structured diagnosis using Gemini API."""
try:
model = genai.GenerativeModel("gemini-1.5-pro")
prompt = f"""
You are an expert plant pathologist. The detected crop disease is: {disease_name}.
Provide a structured analysis including:
- Pathogen details
- Severity level
- Symptoms
- Economic impact
- Treatment options (short-term and long-term)
- Prevention strategies
"""
response = model.generate_content(prompt)
return response.text if response else "No response from Gemini."
except Exception as e:
return f"Error connecting to Gemini API: {str(e)}"
# Initialize Streamlit app
st.title("🌱 AI-Powered Crop Disease Detection")
# Upload image
uploaded_file = st.file_uploader("πŸ“€ Upload a plant image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Run ViT inference
with st.spinner("Analyzing with Vision Transformer... πŸ”"):
predicted_class, confidence = vit_inference(image)
st.write(f"βœ… **Detected Disease:** {predicted_class} (Confidence: {confidence:.2f})")
# Connect to Gemini AI for diagnosis
with st.spinner("Generating diagnosis with Gemini AI... πŸ’‘"):
diagnosis = generate_gemini_response(predicted_class)
st.subheader("πŸ“‹ AI Diagnosis")
st.write(diagnosis)
# Instructions for users
st.markdown("""
---
### How to Use:
1. Upload an image of a plant leaf with suspected disease.
2. Provide context (optional) about symptoms or concerns.
3. The system detects the disease using AI.
4. Gemini generates a diagnosis with symptoms and treatments.
5. Ask follow-up questions, and the AI will remember previous responses.
""")