faces / app.py
koaning's picture
Update app.py
b9bfbae verified
import marimo
__generated_with = "0.12.8"
app = marimo.App()
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
## Face Embeddings of World Leaders
This notebook explores face embeddings using a subset of the **Labeled Faces in the Wild** dataset, focused on public figures. We'll use standard Python and scikit-learn libraries to load the data, embed images, reduce dimensionality, and visualize clustering behavior.
This example builds on a demo from the Marimo gallery using the MNIST dataset. Here, we adapt it to work with a facial recognition dataset of public figures. While facial recognition has limited responsible use cases, this curated subset includes only world leaders β€” a group I feel comfortable experimenting with in a technical context.
We'll start with our imports:
"""
)
return
@app.cell
def _():
from time import time
import matplotlib.pyplot as plt
from scipy.stats import loguniform
from sklearn.datasets import fetch_lfw_people
from sklearn.decomposition import PCA
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
return (
ConfusionMatrixDisplay,
PCA,
RandomizedSearchCV,
SVC,
StandardScaler,
classification_report,
fetch_lfw_people,
loguniform,
plt,
time,
train_test_split,
)
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""We're using `fetch_lfw_people` from `sklearn.datasets` to load a curated subset of the LFW dataset β€” restricted to individuals with at least 70 images, resulting in 7 distinct people and just over 1,200 samples. These happen to be mostly world leaders, which makes the demo both manageable and fun to explore.""")
return
@app.cell
def _(fetch_lfw_people):
lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)
# introspect the images arrays to find the shapes (for plotting)
n_samples, h, w = lfw_people.images.shape
# for machine learning we use the 2 data directly (as relative pixel
# positions info is ignored by this model)
X = lfw_people.data
n_features = X.shape[1]
# the label to predict is the id of the person
Y = lfw_people.target
target_names = lfw_people.target_names
n_classes = target_names.shape[0]
print("Total dataset size:")
print("n_samples: %d" % n_samples)
print("n_features: %d" % n_features)
print("n_classes: %d" % n_classes)
return (
X,
Y,
h,
lfw_people,
n_classes,
n_features,
n_samples,
target_names,
w,
)
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""Next, we embed each face image using a pre-trained FaceNet model (`InceptionResnetV1` trained on `vggface2`). This converts each image into a 512-dimensional vector. Since the original data is grayscale and flattened, we reshape, normalize, and convert it to RGB before feeding it through the model.""")
return
@app.cell
def _(X, h, w):
from facenet_pytorch import InceptionResnetV1
from torchvision import transforms
from PIL import Image
import torch
import numpy as np
# Load FaceNet model
model = InceptionResnetV1(pretrained='vggface2').eval()
# Transform pipeline: grayscale β†’ RGB β†’ resize β†’ normalize
transform = transforms.Compose([
transforms.Resize((160, 160)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
transforms.Normalize([0.5], [0.5])
])
# Embed a single flattened row from X
def embed_flat_row(flat):
img = flat.reshape(h, w)
img = (img * 255).astype(np.uint8)
pil = Image.fromarray(img).convert("L") # grayscale
tensor = transform(pil).unsqueeze(0)
with torch.no_grad():
return model(tensor).squeeze().numpy() # 512-dim
# Generate embeddings for all samples
embeddings = np.array([embed_flat_row(row) for row in X])
return (
Image,
InceptionResnetV1,
embed_flat_row,
embeddings,
model,
np,
torch,
transform,
transforms,
)
@app.cell
def _(mo):
mo.md(r"""Now that we have 512-dimensional embeddings, we reduce them to 2D for visualization. Both t-SNE and UMAP are available here β€” UMAP is active by default, but you can switch to t-SNE by uncommenting the alternate line. This step lets us inspect the structure of the embedding space:""")
return
@app.cell
def _(embeddings):
from sklearn.manifold import TSNE
import umap.umap_ as umap
# X_embedded = TSNE(n_components=2, perplexity=30, random_state=42).fit_transform(embeddings)
X_embedded = umap.UMAP(n_components=2, random_state=42).fit_transform(embeddings)
return TSNE, X_embedded, umap
@app.cell
def _(mo):
mo.md(r"""We wrap the 2D embeddings into a Pandas DataFrame for easier manipulation and plotting. Each row includes x/y coordinates and the associated person ID, which we map to names. We then define a simple Altair scatterplot function to visualize the clustered embeddings by identity.""")
return
@app.cell
def _(X_embedded, Y, target_names):
import pandas as pd
embedding_df = pd.DataFrame({
"x": X_embedded[:, 0],
"y": X_embedded[:, 1],
"person": Y
}).reset_index()
embedding_df["name"] = embedding_df["person"].map(lambda i: target_names[i])
return embedding_df, pd
@app.cell
def _():
import altair as alt
def scatter(df):
return (alt.Chart(df)
.mark_circle()
.encode(
x=alt.X("x:Q"),
y=alt.Y("y:Q"),
color=alt.Color("name:N"),
).properties(width=500, height=300))
return alt, scatter
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""Here's our 2D embedding space of world leader faces! Each point is a facial embedding projected with UMAP and colored by identity. Try selecting a cluster β€” the notebook will automatically reveal the associated images so you can explore what the model β€œthinks” belongs together.""")
return
@app.cell
def _(embedding_df, scatter):
import marimo as mo
chart = mo.ui.altair_chart(scatter(embedding_df))
return chart, mo
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""When you select points in the scatterplot, Marimo automatically passes those indices into this cell. Here, we render a preview of the corresponding face images using `matplotlib`, along with a table of all selected metadata β€” making it easy to inspect clustering quality or outliers at a glance.""")
return
@app.cell
def _(chart, mo):
table = mo.ui.table(chart.value)
return (table,)
@app.cell
def _(X, chart, h, mo, table, w):
def show_images(indices, max_images=6):
import matplotlib.pyplot as plt
indices = indices[:max_images]
images = X.reshape((-1, h, w))[indices]
fig, axes = plt.subplots(1, len(indices))
fig.set_size_inches(12.5, 1.5)
if len(indices) > 1:
for im, ax in zip(images, axes.flat):
ax.imshow(im, cmap="gray")
ax.set_yticks([])
ax.set_xticks([])
else:
axes.imshow(images[0], cmap="gray")
axes.set_yticks([])
axes.set_xticks([])
plt.tight_layout()
return fig
def show_selected():
return (
show_images(list(chart.value["index"]))
if not len(table.value)
else show_images(list(table.value["index"]))
)
mo.hstack([chart, show_selected() if len(chart.value) else ""])
return show_images, show_selected
@app.cell
def _():
return
if __name__ == "__main__":
app.run()