EF / app.py
Junior16's picture
Update app.py
15e4aa2 verified
raw
history blame
2.68 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
import cv2
import numpy as np
from PIL import Image
import io
import base64
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
app = FastAPI()
# Cargar el modelo preentrenado para clasificación de género
model_name = "nateraw/bert-imagenet" # Cambiar a un modelo adecuado si se encuentra uno más específico
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
@app.post("/detect/")
async def detect_face(file: UploadFile = File(...)):
try:
# Leer y preparar la imagen
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
img_np = np.array(image)
if img_np.shape[2] == 4:
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGRA2BGR)
# Detección de rostros con OpenCV
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
if len(faces) == 0:
raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.")
result_data = []
for (x, y, w, h) in faces:
# Extraer cada rostro
face_img = img_np[y:y+h, x:x+w]
face_img_pil = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)).resize((224, 224))
# Clasificar el rostro
inputs = feature_extractor(images=face_img_pil, return_tensors="pt")
outputs = model(**inputs)
predicted_class = outputs.logits.argmax(dim=-1).item()
label = model.config.id2label[predicted_class]
# Dibujar rectángulo y agregar datos
cv2.rectangle(img_np, (x, y), (x+w, y+h), (255, 0, 0), 2)
result_data.append({
"coordenadas": [int(x), int(y), int(w), int(h)],
"sexo": label
})
# Preparar imagen resultante
result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
img_byte_arr = io.BytesIO()
result_image.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
# Respuesta
return {
"message": "Rostros detectados",
"resultados": result_data,
"imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8')
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))