from fastapi import FastAPI, File, UploadFile, HTTPException import cv2 import numpy as np from PIL import Image import io import base64 from transformers import AutoModelForImageClassification, AutoFeatureExtractor import torch app = FastAPI() # Cargar el modelo preentrenado para clasificación de género model_name = "nateraw/bert-imagenet" # Cambiar a un modelo adecuado si se encuentra uno más específico model = AutoModelForImageClassification.from_pretrained(model_name) feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) @app.post("/detect/") async def detect_face(file: UploadFile = File(...)): try: # Leer y preparar la imagen image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)) img_np = np.array(image) if img_np.shape[2] == 4: img_np = cv2.cvtColor(img_np, cv2.COLOR_BGRA2BGR) # Detección de rostros con OpenCV face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) if len(faces) == 0: raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.") result_data = [] for (x, y, w, h) in faces: # Extraer cada rostro face_img = img_np[y:y+h, x:x+w] face_img_pil = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)).resize((224, 224)) # Clasificar el rostro inputs = feature_extractor(images=face_img_pil, return_tensors="pt") outputs = model(**inputs) predicted_class = outputs.logits.argmax(dim=-1).item() label = model.config.id2label[predicted_class] # Dibujar rectángulo y agregar datos cv2.rectangle(img_np, (x, y), (x+w, y+h), (255, 0, 0), 2) result_data.append({ "coordenadas": [int(x), int(y), int(w), int(h)], "sexo": label }) # Preparar imagen resultante result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)) img_byte_arr = io.BytesIO() result_image.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() # Respuesta return { "message": "Rostros detectados", "resultados": result_data, "imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8') } except Exception as e: raise HTTPException(status_code=500, detail=str(e))