gr3yshadow's picture
Upload folder using huggingface_hub
dd2bcb8 verified
raw
history blame
3.09 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import shutil
import numpy as np
from PIL import Image
import face_recognition
from scipy.spatial import distance
app = FastAPI()
class ImageToGroup(object):
def __init__(self, filename, path):
self.filename = filename
self.path = path
self.embeddings = self.extract_embeddings()
def extract_embeddings(self):
img = face_recognition.load_image_file(self.path)
face_locations = face_recognition.face_locations(img)
if len(face_locations) == 0:
return [] # No face found in the image
# Generate multiple face encodings with jitter
face_encodings = [face_recognition.face_encodings(img, [face_location], num_jitters=10)[0] for face_location in face_locations]
return face_encodings
def are_similar(self, other_embeddings, threshold=0.6):
# Calculate the Euclidean distance between two embeddings
for other_embedding in other_embeddings:
for self_embedding in self.embeddings:
dist = distance.euclidean(self_embedding, other_embedding)
if dist < threshold:
return True
return False
def main(input_dir, output_dir):
filenames = os.listdir(input_dir)
images_to_group = [ImageToGroup(filename, os.path.join(input_dir, filename)) for filename in filenames]
# Group images into clusters based on face embeddings
grouped_images = {}
for image in images_to_group:
if not image.embeddings:
continue # Skip images with no faces
for embedding in image.embeddings:
found_group = False
for group_key, group_images in grouped_images.items():
if image.are_similar(group_images[0].embeddings): # Compare embeddings using are_similar method
group_images.append(image)
found_group = True
break
if not found_group:
grouped_images[tuple(embedding)] = [image] # Convert numpy.ndarray to a hashable type using tuple()
# Save grouped images
for i, (embedding, group_images) in enumerate(grouped_images.items()):
group_dir = os.path.join(output_dir, f"group_{i+1}") # Add +1 to the index to start from 1
os.makedirs(group_dir, exist_ok=True)
for image in group_images:
image_filename = os.path.basename(image.path)
destination_path = os.path.join(group_dir, image_filename)
shutil.copy(image.path, destination_path)
class GroupImagesRequest(BaseModel):
input_dir: str
output_dir: str
@app.post("/group_images")
async def group_images(request: GroupImagesRequest):
if not os.path.isdir(request.input_dir):
raise HTTPException(status_code=400, detail="Input directory does not exist")
if not os.path.isdir(request.output_dir):
os.makedirs(request.output_dir)
main(request.input_dir, request.output_dir)
return {"message": "Images grouped successfully"}