cdleong commited on
Commit
bbeab9b
·
verified ·
1 Parent(s): ec2762f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -31
app.py CHANGED
@@ -7,7 +7,7 @@ import contextlib
7
  import requests
8
  import random
9
  from functools import lru_cache
10
- import matplotlib.pyplot as plt
11
 
12
  FORBIDDEN_NAMES ={"Judas",
13
  "Judas Iscariot"
@@ -52,29 +52,6 @@ FORBIDDEN_NAMES ={"Judas",
52
  "Pontius Pilate"
53
  }
54
 
55
- # ----------------------------- Plotting
56
-
57
-
58
- def plot_name_trends(df, names, start_year=None, end_year=None, logscale=False):
59
- name_df = df[df["name"].isin(names)]
60
- if start_year is not None:
61
- name_df = name_df[name_df["year"] >= start_year]
62
- if end_year is not None:
63
- name_df = name_df[name_df["year"] <= end_year]
64
-
65
- pivot = name_df.pivot_table(index="year", columns="name", values="count", aggfunc="sum", fill_value=0)
66
-
67
- fig, ax = plt.subplots(figsize=(12, 6))
68
- pivot.plot(ax=ax)
69
- ax.set_title("Name Usage Over Time")
70
- ax.set_xlabel("Year")
71
- ax.set_ylabel("Count")
72
- if logscale:
73
- ax.set_yscale("log")
74
- ax.grid(True)
75
- ax.legend(title="Name")
76
- fig.tight_layout()
77
- return fig
78
 
79
 
80
  # --- File download & setup ---
@@ -243,6 +220,47 @@ def get_normal_and_bible(
243
 
244
  return ssa_name, bible_name
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  # --- Gradio app ---
247
  def generate_names(n, sex, ssa_min_len, ssa_max_len,
248
  ssa_min_year,
@@ -350,13 +368,6 @@ with gr.Blocks() as demo:
350
  plot_button = gr.Button("📊 Plot Trends")
351
  plot_output = gr.Plot(label="Trend Plot")
352
 
353
- def plot_from_inputs(name_text, start_year, end_year, logscale):
354
- names = [n.strip() for n in name_text.split(",") if n.strip()]
355
- if not names:
356
- raise gr.Error("Please enter at least one name.")
357
- full_df = load_all_ssa_names()
358
- return plot_name_trends(full_df, names, start_year, end_year, logscale)
359
-
360
  plot_button.click(
361
  fn=plot_from_inputs,
362
  inputs=[trend_names_input, trend_start_year, trend_end_year, trend_logscale],
 
7
  import requests
8
  import random
9
  from functools import lru_cache
10
+ import plotly.express as px
11
 
12
  FORBIDDEN_NAMES ={"Judas",
13
  "Judas Iscariot"
 
52
  "Pontius Pilate"
53
  }
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  # --- File download & setup ---
 
220
 
221
  return ssa_name, bible_name
222
 
223
+
224
+
225
+ # -------------------- Plotting ---
226
+ def plot_name_trends_plotly(df, names, start_year=None, end_year=None, logscale=False):
227
+ # Filter for selected names
228
+ name_df = df[df["name"].isin(names)]
229
+ if start_year is not None:
230
+ name_df = name_df[name_df["year"] >= start_year]
231
+ if end_year is not None:
232
+ name_df = name_df[name_df["year"] <= end_year]
233
+
234
+ if name_df.empty:
235
+ raise gr.Error("No data for selected names and year range.")
236
+
237
+ # Aggregate by year and name
238
+ agg_df = (
239
+ name_df.groupby(["year", "name"])["count"]
240
+ .sum()
241
+ .reset_index()
242
+ )
243
+
244
+ fig = px.line(
245
+ agg_df,
246
+ x="year",
247
+ y="count",
248
+ color="name",
249
+ title="Name Usage Over Time",
250
+ labels={"count": "Count", "year": "Year"},
251
+ log_y=logscale,
252
+ )
253
+ fig.update_layout(height=500)
254
+ return fig
255
+
256
+
257
+ def plot_from_inputs(name_text, start_year, end_year, logscale):
258
+ names = [n.strip() for n in name_text.split(",") if n.strip()]
259
+ if not names:
260
+ raise gr.Error("Please enter at least one name.")
261
+ full_df = load_all_ssa_names()
262
+ return plot_name_trends_plotly(full_df, names, start_year, end_year, logscale)
263
+
264
  # --- Gradio app ---
265
  def generate_names(n, sex, ssa_min_len, ssa_max_len,
266
  ssa_min_year,
 
368
  plot_button = gr.Button("📊 Plot Trends")
369
  plot_output = gr.Plot(label="Trend Plot")
370
 
 
 
 
 
 
 
 
371
  plot_button.click(
372
  fn=plot_from_inputs,
373
  inputs=[trend_names_input, trend_start_year, trend_end_year, trend_logscale],