Spaces:
Running
Running
import marimo | |
__generated_with = "0.12.8" | |
app = marimo.App() | |
def _(mo): | |
mo.md( | |
r""" | |
## Face Embeddings of World Leaders | |
This notebook explores face embeddings using a subset of the **Labeled Faces in the Wild** dataset, focused on public figures. We'll use standard Python and scikit-learn libraries to load the data, embed images, reduce dimensionality, and visualize clustering behavior. | |
This example builds on a demo from the Marimo gallery using the MNIST dataset. Here, we adapt it to work with a facial recognition dataset of public figures. While facial recognition has limited responsible use cases, this curated subset includes only world leaders β a group I feel comfortable experimenting with in a technical context. | |
We'll start with our imports: | |
""" | |
) | |
return | |
def _(): | |
from time import time | |
import matplotlib.pyplot as plt | |
from scipy.stats import loguniform | |
from sklearn.datasets import fetch_lfw_people | |
from sklearn.decomposition import PCA | |
from sklearn.metrics import ConfusionMatrixDisplay, classification_report | |
from sklearn.model_selection import RandomizedSearchCV, train_test_split | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.svm import SVC | |
return ( | |
ConfusionMatrixDisplay, | |
PCA, | |
RandomizedSearchCV, | |
SVC, | |
StandardScaler, | |
classification_report, | |
fetch_lfw_people, | |
loguniform, | |
plt, | |
time, | |
train_test_split, | |
) | |
def _(mo): | |
mo.md(r"""We're using `fetch_lfw_people` from `sklearn.datasets` to load a curated subset of the LFW dataset β restricted to individuals with at least 70 images, resulting in 7 distinct people and just over 1,200 samples. These happen to be mostly world leaders, which makes the demo both manageable and fun to explore.""") | |
return | |
def _(fetch_lfw_people): | |
lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4) | |
# introspect the images arrays to find the shapes (for plotting) | |
n_samples, h, w = lfw_people.images.shape | |
# for machine learning we use the 2 data directly (as relative pixel | |
# positions info is ignored by this model) | |
X = lfw_people.data | |
n_features = X.shape[1] | |
# the label to predict is the id of the person | |
Y = lfw_people.target | |
target_names = lfw_people.target_names | |
n_classes = target_names.shape[0] | |
print("Total dataset size:") | |
print("n_samples: %d" % n_samples) | |
print("n_features: %d" % n_features) | |
print("n_classes: %d" % n_classes) | |
return ( | |
X, | |
Y, | |
h, | |
lfw_people, | |
n_classes, | |
n_features, | |
n_samples, | |
target_names, | |
w, | |
) | |
def _(mo): | |
mo.md(r"""Next, we embed each face image using a pre-trained FaceNet model (`InceptionResnetV1` trained on `vggface2`). This converts each image into a 512-dimensional vector. Since the original data is grayscale and flattened, we reshape, normalize, and convert it to RGB before feeding it through the model.""") | |
return | |
def _(X, h, w): | |
from facenet_pytorch import InceptionResnetV1 | |
from torchvision import transforms | |
from PIL import Image | |
import torch | |
import numpy as np | |
# Load FaceNet model | |
model = InceptionResnetV1(pretrained='vggface2').eval() | |
# Transform pipeline: grayscale β RGB β resize β normalize | |
transform = transforms.Compose([ | |
transforms.Resize((160, 160)), | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x), | |
transforms.Normalize([0.5], [0.5]) | |
]) | |
# Embed a single flattened row from X | |
def embed_flat_row(flat): | |
img = flat.reshape(h, w) | |
img = (img * 255).astype(np.uint8) | |
pil = Image.fromarray(img).convert("L") # grayscale | |
tensor = transform(pil).unsqueeze(0) | |
with torch.no_grad(): | |
return model(tensor).squeeze().numpy() # 512-dim | |
# Generate embeddings for all samples | |
embeddings = np.array([embed_flat_row(row) for row in X]) | |
return ( | |
Image, | |
InceptionResnetV1, | |
embed_flat_row, | |
embeddings, | |
model, | |
np, | |
torch, | |
transform, | |
transforms, | |
) | |
def _(mo): | |
mo.md(r"""Now that we have 512-dimensional embeddings, we reduce them to 2D for visualization. Both t-SNE and UMAP are available here β UMAP is active by default, but you can switch to t-SNE by uncommenting the alternate line. This step lets us inspect the structure of the embedding space:""") | |
return | |
def _(embeddings): | |
from sklearn.manifold import TSNE | |
import umap.umap_ as umap | |
# X_embedded = TSNE(n_components=2, perplexity=30, random_state=42).fit_transform(embeddings) | |
X_embedded = umap.UMAP(n_components=2, random_state=42).fit_transform(embeddings) | |
return TSNE, X_embedded, umap | |
def _(mo): | |
mo.md(r"""We wrap the 2D embeddings into a Pandas DataFrame for easier manipulation and plotting. Each row includes x/y coordinates and the associated person ID, which we map to names. We then define a simple Altair scatterplot function to visualize the clustered embeddings by identity.""") | |
return | |
def _(X_embedded, Y, target_names): | |
import pandas as pd | |
embedding_df = pd.DataFrame({ | |
"x": X_embedded[:, 0], | |
"y": X_embedded[:, 1], | |
"person": Y | |
}).reset_index() | |
embedding_df["name"] = embedding_df["person"].map(lambda i: target_names[i]) | |
return embedding_df, pd | |
def _(): | |
import altair as alt | |
def scatter(df): | |
return (alt.Chart(df) | |
.mark_circle() | |
.encode( | |
x=alt.X("x:Q"), | |
y=alt.Y("y:Q"), | |
color=alt.Color("name:N"), | |
).properties(width=500, height=300)) | |
return alt, scatter | |
def _(mo): | |
mo.md(r"""Here's our 2D embedding space of world leader faces! Each point is a facial embedding projected with UMAP and colored by identity. Try selecting a cluster β the notebook will automatically reveal the associated images so you can explore what the model βthinksβ belongs together.""") | |
return | |
def _(embedding_df, scatter): | |
import marimo as mo | |
chart = mo.ui.altair_chart(scatter(embedding_df)) | |
return chart, mo | |
def _(mo): | |
mo.md(r"""When you select points in the scatterplot, Marimo automatically passes those indices into this cell. Here, we render a preview of the corresponding face images using `matplotlib`, along with a table of all selected metadata β making it easy to inspect clustering quality or outliers at a glance.""") | |
return | |
def _(chart, mo): | |
table = mo.ui.table(chart.value) | |
return (table,) | |
def _(X, chart, h, mo, table, w): | |
def show_images(indices, max_images=6): | |
import matplotlib.pyplot as plt | |
indices = indices[:max_images] | |
images = X.reshape((-1, h, w))[indices] | |
fig, axes = plt.subplots(1, len(indices)) | |
fig.set_size_inches(12.5, 1.5) | |
if len(indices) > 1: | |
for im, ax in zip(images, axes.flat): | |
ax.imshow(im, cmap="gray") | |
ax.set_yticks([]) | |
ax.set_xticks([]) | |
else: | |
axes.imshow(images[0], cmap="gray") | |
axes.set_yticks([]) | |
axes.set_xticks([]) | |
plt.tight_layout() | |
return fig | |
def show_selected(): | |
return ( | |
show_images(list(chart.value["index"])) | |
if not len(table.value) | |
else show_images(list(table.value["index"])) | |
) | |
mo.hstack([chart, show_selected() if len(chart.value) else ""]) | |
return show_images, show_selected | |
def _(): | |
return | |
if __name__ == "__main__": | |
app.run() | |