Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import numpy as np | |
from PIL import Image | |
import face_recognition | |
from scipy.spatial import distance | |
class ImageToGroup(object): | |
def __init__(self, filename, path): | |
self.filename = filename | |
self.path = path | |
self.embeddings = self.extract_embeddings() | |
def extract_embeddings(self): | |
try: | |
img = Image.open(self.path) | |
img = img.resize((800, 800)) # Resize the image to improve performance | |
img = np.array(img) | |
face_locations = face_recognition.face_locations(img) | |
if len(face_locations) == 0: | |
return [] # No face found in the image | |
# Generate multiple face encodings with jitter | |
face_encodings = [face_recognition.face_encodings(img, [face_location], num_jitters=10)[0] for face_location in face_locations] | |
return face_encodings | |
except Exception as e: | |
print(f"Error extracting embeddings from {self.path}: {e}") | |
return [] | |
def are_similar(self, other_embeddings, threshold=0.6): | |
# Calculate the Euclidean distance between two embeddings | |
for other_embedding in other_embeddings: | |
for self_embedding in self.embeddings: | |
dist = distance.euclidean(self_embedding, other_embedding) | |
if dist < threshold: | |
return True | |
return False | |
def main(input_dir, output_dir): | |
filenames = os.listdir(input_dir) | |
images_to_group = [ImageToGroup(filename, os.path.join(input_dir, filename)) for filename in filenames] | |
# Group images into clusters based on face embeddings | |
grouped_images = {} | |
for image in images_to_group: | |
if not image.embeddings: | |
continue # Skip images with no faces | |
for embedding in image.embeddings: | |
found_group = False | |
for group_key, group_images in grouped_images.items(): | |
if image.are_similar(group_images[0].embeddings): # Compare embeddings using are_similar method | |
group_images.append(image) | |
found_group = True | |
break | |
if not found_group: | |
grouped_images[tuple(embedding)] = [image] # Convert numpy.ndarray to a hashable type using tuple() | |
# Save grouped images | |
for i, (embedding, group_images) in enumerate(grouped_images.items()): | |
group_dir = os.path.join(output_dir, f"group_{i+1}") # Add +1 to the index to start from 1 | |
try: | |
os.makedirs(group_dir, exist_ok=True) | |
for image in group_images: | |
image_filename = os.path.basename(image.path) | |
destination_path = os.path.join(group_dir, image_filename) | |
shutil.copy(image.path, destination_path) | |
except Exception as e: | |
print(f"Error saving images to {group_dir}: {e}") | |
if __name__ == '__main__': | |
import sys | |
input_dir = sys.argv[1] | |
output_dir = sys.argv[2] | |
main(input_dir, output_dir) |