cdleong's picture
Update app.py
0e7a69a verified
import gradio as gr
import pandas as pd
from pathlib import Path
from zipfile import ZipFile
import io
import contextlib
import requests
import random
from functools import lru_cache
import plotly.express as px
FORBIDDEN_NAMES ={"Judas",
"Judas Iscariot"
"Maher-shalal-hash-baz",
"Bathsheba",
"Jephthah",
"Jehoshaphat",
"Tiebreaker",
"Boanerges",
"Jezebel",
"Gomorrah",
"Hymenaeus",
"Herod",
"Pilate",
"Doeg",
"Ziph",
"Phygelus",
"Hermogenes",
"Philetus",
"Balaam",
"Achan",
"Caiaphas",
"Pontius",
"Ahab",
"Manasseh",
"Rehoboam",
"Nebuchadnezzar",
"Delilah",
"Lo-ammi",
"Lo-ruhamah",
"Beelzebub",
"Ichabod",
"Saphira",
"Jushab-hesed",
"Benjarman",
"Cain",
"Esau",
"Machiavelli", # found
"Barabbas",
"Sapphira",
"Shur",
"Pontius Pilate"
}
# --- File download & setup ---
def download_file(url: str, dest_path: Path):
if dest_path.exists():
print(f"{dest_path.name} already exists. Skipping download.")
return
print(f"Downloading {url}")
response = requests.get(url)
response.raise_for_status()
with open(dest_path, "wb") as f:
f.write(response.content)
print(f"Saved to {dest_path}")
def extract_names_zip():
zip_path = Path("names.zip")
if not zip_path.exists():
raise FileNotFoundError("names.zip not found. Please upload it manually to the repo.")
with ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(".")
print("Unzipped names.zip")
extract_names_zip()
# Download Bible CSVs if missing
download_file(
"https://raw.githubusercontent.com/BradyStephenson/bible-data/refs/heads/main/BibleData-Person.csv",
Path("BibleData-Person.csv"),
)
download_file(
"https://raw.githubusercontent.com/BradyStephenson/bible-data/refs/heads/main/BibleData-PersonLabel.csv",
Path("BibleData-PersonLabel.csv"),
)
# --- Load datasets ---
ssa_name_txt_files = sorted(Path(".").glob("yob*.txt"))
@lru_cache(maxsize=1)
def load_all_ssa_names():
dfs = []
for f in ssa_name_txt_files:
year = int(f.stem.replace("yob", ""))
df = pd.read_csv(f, names=["name", "sex", "count"])
df["year"] = year
dfs.append(df)
full_df = pd.concat(dfs, ignore_index=True)
return full_df
@lru_cache
def load_ssa_names(min_year=0, max_year=9999):
full_df = load_all_ssa_names()
filtered_df = full_df[(full_df["year"] >= min_year) & (full_df["year"] <= max_year)]
if filtered_df.empty:
return pd.DataFrame(), pd.DataFrame()
agg_df = (
filtered_df
.groupby(["name", "sex"], as_index=False)["count"]
.sum()
.sort_values("count", ascending=False)
)
return filtered_df, agg_df
def load_bible_names():
bible_names_df = pd.read_csv("BibleData-Person.csv")
bible_names_personlabel_df = pd.read_csv("BibleData-PersonLabel.csv")
bible_names_personlabel_df = bible_names_personlabel_df.merge(bible_names_df[["person_id", "sex"]], on="person_id", how="left")
bible_names_personlabel_df = bible_names_personlabel_df[bible_names_personlabel_df["label_type"] == "proper name"]
bible_names_personlabel_df["sex"] = bible_names_personlabel_df["sex"].replace({"male": "M", "female": "F"})
return bible_names_personlabel_df
bible_names_personlabel_df = load_bible_names()
# --- Name generation logic ---
last_names = ["Smith", "Johnson", "Williams", "Taylor", "Brown"]
def get_normal_and_bible(
# ssa_names_aggregated_df,
bible_names_df,
min_length_ssa=3,
max_length_ssa=8,
min_length_bible=3,
max_length_bible=8,
min_year_ssa=0,
max_year_ssa=9999,
ssa_popularity_percentile=(0.95, 1.0),
sex=None,
forbidden_names=None,
ssa_names_col="name",
bible_names_col="english_label",
debug=False,
):
if forbidden_names is None:
forbidden_names = set()
ssa_names_df, ssa_names_aggregated_df = load_ssa_names(min_year=min_year_ssa, max_year=max_year_ssa)
if debug:
print(f"There are {len(ssa_names_aggregated_df)} SSA names from the years {min_year_ssa} to {max_year_ssa}")
filtered_ssa = ssa_names_aggregated_df.copy()
filtered_ssa = filtered_ssa[~filtered_ssa[ssa_names_col].isin(forbidden_names)]
if debug:
print(f"SSA names after FORBIDDEN NAMES filter: {len(filtered_ssa)}")
filtered_ssa = filtered_ssa[
filtered_ssa[ssa_names_col].str.len().between(min_length_ssa, max_length_ssa)
]
if sex:
filtered_ssa = filtered_ssa[filtered_ssa["sex"] == sex]
if debug:
print(f"SSA names after length/sex filter: {len(filtered_ssa)}")
total = len(filtered_ssa)
filtered_ssa = filtered_ssa.sort_values("count")
low, high = ssa_popularity_percentile
idx_start = int(total * low)
idx_end = int(total * high)
filtered_ssa = filtered_ssa.iloc[idx_start:idx_end]
if debug:
print(f"SSA names after popularity percentile slice: {len(filtered_ssa)}")
ssa_name = filtered_ssa.sample(1)[ssa_names_col].values[0]
# ------------
# Bible names
filtered_bible = bible_names_df.copy()
if debug:
print(f"Bible names before filtering: {len(filtered_bible)}")
filtered_bible = filtered_bible[
filtered_bible[bible_names_col].str.len().between(min_length_bible, max_length_bible)
]
if debug:
print(f"Bible names after lengthfiltering: {len(filtered_bible)}")
if sex:
filtered_bible = filtered_bible[filtered_bible["sex"] == sex]
if debug:
print(f"Bible names after sex filtering: {len(filtered_bible)}")
filtered_bible = filtered_bible[~filtered_bible[bible_names_col].isin(forbidden_names)]
if debug:
print(f"Bible names after FORBIDDEN NAMES filtering: {len(filtered_bible)}")
if len(filtered_bible) == 0 or len(filtered_ssa) == 0:
raise ValueError("No valid names found after filtering.")
bible_name = filtered_bible.sample(1)[bible_names_col].values[0]
return ssa_name, bible_name
# -------------------- Plotting ---
import plotly.graph_objects as go
def plot_name_trends_plotly(df, names, start_year=None, end_year=None, logscale=False):
name_df = df[df["name"].isin(names)]
if start_year is not None:
name_df = name_df[name_df["year"] >= start_year]
if end_year is not None:
name_df = name_df[name_df["year"] <= end_year]
if name_df.empty:
raise gr.Error("No data for selected names and year range.")
agg_df = (
name_df.groupby(["year", "name"])["count"]
.sum()
.reset_index()
)
# Build figure manually for better control
fig = go.Figure()
for name in sorted(agg_df["name"].unique()):
sub_df = agg_df[agg_df["name"] == name]
if len(sub_df) > 1:
fig.add_trace(go.Scatter(
x=sub_df["year"],
y=sub_df["count"],
mode="lines+markers",
name=name
))
else:
# Jessca
fig.add_trace(go.Scatter(
x=sub_df["year"],
y=sub_df["count"],
mode="markers",
name=name,
marker=dict(size=10, symbol="circle"),
))
fig.update_layout(
title="Name Usage Over Time",
xaxis_title="Year",
yaxis_title="Count",
height=500,
yaxis_type="log" if logscale else "linear",
)
return fig, agg_df
def plot_from_inputs(name_text, start_year, end_year, logscale):
names = [n.strip() for n in name_text.split(",") if n.strip()]
if not names:
raise gr.Error("Please enter at least one name.")
full_df = load_all_ssa_names()
return plot_name_trends_plotly(full_df, names, start_year, end_year, logscale)
# --- Gradio app ---
def generate_names(n, sex, ssa_min_len, ssa_max_len,
ssa_min_year,
ssa_max_year,
min_bible_len, max_bible_len, pop_low, pop_high, debug_flag, last, forbidden_names_text, bible_names_first_flag):
results = []
debug_output = io.StringIO()
forbidden_names = set(name.strip() for name in forbidden_names_text.split(",") if name.strip())
with contextlib.redirect_stdout(debug_output):
for i in range(n):
try:
normal, bible = get_normal_and_bible(
bible_names_personlabel_df,
min_length_ssa=ssa_min_len,
max_length_ssa=ssa_max_len,
min_year_ssa=ssa_min_year,
max_year_ssa=ssa_max_year,
min_length_bible=min_bible_len,
max_length_bible=max_bible_len,
ssa_popularity_percentile=(pop_low, pop_high),
sex=sex if sex in {"M", "F"} else None,
forbidden_names=forbidden_names,
debug=(i==0 and debug_flag),
)
if last is None:
last = random.choice(last_names)
if bible_names_first_flag:
first = bible
middle = normal
else:
first=normal
middle = bible
results.append(f"{first} {middle} {last}")
except Exception as e:
results.append(f"[Error: {e}]")
return "\n".join(results), debug_output.getvalue()
with gr.Blocks() as demo:
with gr.Tabs():
with gr.Tab("🔀 Generate Names"):
gr.Markdown("# 📜 Random Bible + SSA Name Generator")
with gr.Row():
n_slider = gr.Slider(1, 20, value=5, step=1, label="How many names?")
sex_choice = gr.Radio(["M", "F", "Any"], label="Sex", value="Any")
with gr.Row():
ssa_min_len = gr.Slider(1, 40, value=1, step=1, label="SSA name min length")
ssa_max_len = gr.Slider(1, 40, value=40, step=1, label="SSA name max length")
with gr.Row():
ssa_min_year = gr.Slider(1880, 2024, value=1880, step=1, label="SSA name min year")
ssa_max_year = gr.Slider(1880, 2024, value=2024, step=1, label="SSA name max year")
with gr.Row():
bible_len = gr.Slider(1, 40, value=1, step=1, label="Bible name min length")
bible_max_len = gr.Slider(1, 40, value=40, step=1, label="Bible name max length")
with gr.Row():
pop_low_slider = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="SSA Popularity: Low Percentile")
pop_high_slider = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="SSA Popularity: High Percentile")
with gr.Row():
last_name_input = gr.Textbox(label="Last Name")
with gr.Row():
forbidden_names_input = gr.Textbox(label="FORBIDDEN NAMES (comma-separated)", value=",".join(FORBIDDEN_NAMES))
debug_checkbox = gr.Checkbox(label="Show debug output", value=True)
bible_name_first_checkbox = gr.Checkbox(label="Bible name first?", value=True)
generate_btn = gr.Button("🔀 Generate Names")
output_box = gr.Textbox(label="Generated Names", lines=10)
debug_box = gr.Textbox(label="Debug Output", lines=10)
generate_btn.click(
fn=generate_names,
inputs=[
n_slider,
sex_choice,
ssa_min_len,
ssa_max_len,
ssa_min_year,
ssa_max_year,
bible_len,
bible_max_len,
pop_low_slider,
pop_high_slider,
debug_checkbox,
last_name_input,
forbidden_names_input,
bible_name_first_checkbox,
],
outputs=[output_box, debug_box],
)
with gr.Tab("📈 Name Trends"):
gr.Markdown("# 📈 SSA Name Trends Over Time")
with gr.Row():
trend_names_input = gr.Textbox(label="Name(s) to plot (comma-separated)", placeholder="e.g. Zebediah, Remington, Jessca, Jielle")
with gr.Row():
trend_start_year = gr.Slider(1880, 2024, value=1950, step=1, label="Start Year")
trend_end_year = gr.Slider(1880, 2024, value=2024, step=1, label="End Year")
trend_logscale = gr.Checkbox(label="Log scale?", value=False)
plot_button = gr.Button("📊 Plot Trends")
plot_output = gr.Plot(label="Trend Plot")
table_output = gr.Dataframe(label="Underlying Data")
plot_button.click(
fn=plot_from_inputs,
inputs=[trend_names_input, trend_start_year, trend_end_year, trend_logscale],
outputs=[plot_output,table_output],
)
demo.launch()