t3_mejorado / app.py
Junior16's picture
Update app.py
1f75b80 verified
raw
history blame
1.54 kB
from fastapi import FastAPI, File, UploadFile
import numpy as np
from PIL import Image
import io
import cv2
from datasets import load_dataset
app = FastAPI()
# Cargar el PlantVillage dataset predefinido
dataset = load_dataset("plant-village")
# Aqu铆 puedes entrenar un modelo o usar uno preentrenado. Para simplificar, vamos a usar el dataset solo para mostrar ejemplos.
# Por ejemplo, podemos ver algunas im谩genes del dataset, pero en producci贸n deber铆as tener un modelo entrenado.
train_data = dataset["train"]
@app.post("/detect_disease/")
async def detect_disease(file: UploadFile = File(...)):
# Leer imagen cargada
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
img_np = np.array(image)
# Convertir la imagen a escala de grises
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 100, 200)
# Aqu铆 realizar铆as la predicci贸n usando tu modelo, en lugar de solo mostrar bordes.
# Simularemos un diagn贸stico simple usando el promedio de los bordes para ilustrar la idea.
disease_detected = "Enfermedad detectada" if np.mean(edges) > 50 else "Saludable"
# Visualizaci贸n de la primera imagen del dataset de ejemplo
example_image = train_data[0]['image'] # Imagen del dataset para ejemplo
example_label = train_data[0]['label'] # Etiqueta de la enfermedad de la imagen de ejemplo
return {
"diagnosis": disease_detected,
"example_image": example_image,
"example_label": example_label
}