from fastapi import FastAPI, HTTPException from pydantic import BaseModel import os import shutil import numpy as np from PIL import Image import face_recognition from scipy.spatial import distance app = FastAPI() class ImageToGroup(object): def __init__(self, filename, path): self.filename = filename self.path = path self.embeddings = self.extract_embeddings() def extract_embeddings(self): img = face_recognition.load_image_file(self.path) face_locations = face_recognition.face_locations(img) if len(face_locations) == 0: return [] # No face found in the image # Generate multiple face encodings with jitter face_encodings = [face_recognition.face_encodings(img, [face_location], num_jitters=10)[0] for face_location in face_locations] return face_encodings def are_similar(self, other_embeddings, threshold=0.6): # Calculate the Euclidean distance between two embeddings for other_embedding in other_embeddings: for self_embedding in self.embeddings: dist = distance.euclidean(self_embedding, other_embedding) if dist < threshold: return True return False def main(input_dir, output_dir): filenames = os.listdir(input_dir) images_to_group = [ImageToGroup(filename, os.path.join(input_dir, filename)) for filename in filenames] # Group images into clusters based on face embeddings grouped_images = {} for image in images_to_group: if not image.embeddings: continue # Skip images with no faces for embedding in image.embeddings: found_group = False for group_key, group_images in grouped_images.items(): if image.are_similar(group_images[0].embeddings): # Compare embeddings using are_similar method group_images.append(image) found_group = True break if not found_group: grouped_images[tuple(embedding)] = [image] # Convert numpy.ndarray to a hashable type using tuple() # Save grouped images for i, (embedding, group_images) in enumerate(grouped_images.items()): group_dir = os.path.join(output_dir, f"group_{i+1}") # Add +1 to the index to start from 1 os.makedirs(group_dir, exist_ok=True) for image in group_images: image_filename = os.path.basename(image.path) destination_path = os.path.join(group_dir, image_filename) shutil.copy(image.path, destination_path) class GroupImagesRequest(BaseModel): input_dir: str output_dir: str @app.post("/group_images") async def group_images(request: GroupImagesRequest): if not os.path.isdir(request.input_dir): raise HTTPException(status_code=400, detail="Input directory does not exist") if not os.path.isdir(request.output_dir): os.makedirs(request.output_dir) main(request.input_dir, request.output_dir) return {"message": "Images grouped successfully"}