Junior16 commited on
Commit
c0eba86
verified
1 Parent(s): 15e4aa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -23
app.py CHANGED
@@ -4,20 +4,19 @@ import numpy as np
4
  from PIL import Image
5
  import io
6
  import base64
7
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
8
  import torch
9
 
10
  app = FastAPI()
11
 
12
- # Cargar el modelo preentrenado para clasificaci贸n de g茅nero
13
- model_name = "nateraw/bert-imagenet" # Cambiar a un modelo adecuado si se encuentra uno m谩s espec铆fico
14
- model = AutoModelForImageClassification.from_pretrained(model_name)
15
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
16
 
17
  @app.post("/detect/")
18
  async def detect_face(file: UploadFile = File(...)):
19
  try:
20
- # Leer y preparar la imagen
21
  image_bytes = await file.read()
22
  image = Image.open(io.BytesIO(image_bytes))
23
  img_np = np.array(image)
@@ -25,7 +24,7 @@ async def detect_face(file: UploadFile = File(...)):
25
  if img_np.shape[2] == 4:
26
  img_np = cv2.cvtColor(img_np, cv2.COLOR_BGRA2BGR)
27
 
28
- # Detecci贸n de rostros con OpenCV
29
  face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
30
  gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
31
  faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
@@ -33,38 +32,45 @@ async def detect_face(file: UploadFile = File(...)):
33
  if len(faces) == 0:
34
  raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.")
35
 
36
- result_data = []
 
37
  for (x, y, w, h) in faces:
38
- # Extraer cada rostro
39
  face_img = img_np[y:y+h, x:x+w]
40
- face_img_pil = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)).resize((224, 224))
41
 
42
- # Clasificar el rostro
43
- inputs = feature_extractor(images=face_img_pil, return_tensors="pt")
44
- outputs = model(**inputs)
45
- predicted_class = outputs.logits.argmax(dim=-1).item()
46
- label = model.config.id2label[predicted_class]
47
 
48
- # Dibujar rect谩ngulo y agregar datos
 
 
 
49
  cv2.rectangle(img_np, (x, y), (x+w, y+h), (255, 0, 0), 2)
50
- result_data.append({
51
- "coordenadas": [int(x), int(y), int(w), int(h)],
52
- "sexo": label
 
 
53
  })
54
 
55
- # Preparar imagen resultante
56
  result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
57
  img_byte_arr = io.BytesIO()
58
  result_image.save(img_byte_arr, format='JPEG')
59
  img_byte_arr = img_byte_arr.getvalue()
60
 
61
- # Respuesta
62
  return {
63
- "message": "Rostros detectados",
64
- "resultados": result_data,
 
65
  "imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8')
66
  }
67
 
68
  except Exception as e:
69
  raise HTTPException(status_code=500, detail=str(e))
70
 
 
 
4
  from PIL import Image
5
  import io
6
  import base64
7
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
8
  import torch
9
 
10
  app = FastAPI()
11
 
12
+ # Cargar el modelo de clasificaci贸n de edad y el extractor
13
+ model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
14
+ transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
 
15
 
16
  @app.post("/detect/")
17
  async def detect_face(file: UploadFile = File(...)):
18
  try:
19
+ # Leer y procesar la imagen cargada
20
  image_bytes = await file.read()
21
  image = Image.open(io.BytesIO(image_bytes))
22
  img_np = np.array(image)
 
24
  if img_np.shape[2] == 4:
25
  img_np = cv2.cvtColor(img_np, cv2.COLOR_BGRA2BGR)
26
 
27
+ # Cargar el clasificador Haar para detecci贸n de rostros
28
  face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
29
  gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
30
  faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
 
32
  if len(faces) == 0:
33
  raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.")
34
 
35
+ # Procesar cada rostro detectado
36
+ results = []
37
  for (x, y, w, h) in faces:
38
+ # Extraer el rostro de la imagen
39
  face_img = img_np[y:y+h, x:x+w]
40
+ pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
41
 
42
+ # Realizar la predicci贸n de edad
43
+ inputs = transforms(pil_face_img, return_tensors='pt')
44
+ output = model(**inputs)
45
+ proba = output.logits.softmax(1)
46
+ preds = proba.argmax(1)
47
 
48
+ # Asumimos que la predicci贸n est谩 representando un rango de edad (esto puede adaptarse m谩s tarde)
49
+ predicted_age_range = str(preds.item())
50
+
51
+ # Dibujar un rect谩ngulo alrededor del rostro y a帽adir la edad predicha
52
  cv2.rectangle(img_np, (x, y), (x+w, y+h), (255, 0, 0), 2)
53
+ cv2.putText(img_np, f"Edad: {predicted_age_range}", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)
54
+
55
+ results.append({
56
+ "edad_predicha": predicted_age_range,
57
+ "coordenadas_rostro": (x, y, w, h)
58
  })
59
 
60
+ # Convertir la imagen procesada a base64
61
  result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
62
  img_byte_arr = io.BytesIO()
63
  result_image.save(img_byte_arr, format='JPEG')
64
  img_byte_arr = img_byte_arr.getvalue()
65
 
 
66
  return {
67
+ "message": "Rostros detectados y edad predicha",
68
+ "rostros": len(faces),
69
+ "resultados": results,
70
  "imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8')
71
  }
72
 
73
  except Exception as e:
74
  raise HTTPException(status_code=500, detail=str(e))
75
 
76
+