cdleong commited on
Commit
0e7a69a
·
verified ·
1 Parent(s): 2151157

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -10
app.py CHANGED
@@ -223,8 +223,9 @@ def get_normal_and_bible(
223
 
224
 
225
  # -------------------- Plotting ---
 
 
226
  def plot_name_trends_plotly(df, names, start_year=None, end_year=None, logscale=False):
227
- # Filter for selected names
228
  name_df = df[df["name"].isin(names)]
229
  if start_year is not None:
230
  name_df = name_df[name_df["year"] >= start_year]
@@ -234,26 +235,45 @@ def plot_name_trends_plotly(df, names, start_year=None, end_year=None, logscale=
234
  if name_df.empty:
235
  raise gr.Error("No data for selected names and year range.")
236
 
237
- # Aggregate by year and name
238
  agg_df = (
239
  name_df.groupby(["year", "name"])["count"]
240
  .sum()
241
  .reset_index()
242
  )
243
 
244
- fig = px.line(
245
- agg_df,
246
- x="year",
247
- y="count",
248
- color="name",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  title="Name Usage Over Time",
250
- labels={"count": "Count", "year": "Year"},
251
- log_y=logscale,
 
 
252
  )
253
- fig.update_layout(height=500)
254
  return fig, agg_df
255
 
256
 
 
257
  def plot_from_inputs(name_text, start_year, end_year, logscale):
258
  names = [n.strip() for n in name_text.split(",") if n.strip()]
259
  if not names:
 
223
 
224
 
225
  # -------------------- Plotting ---
226
+ import plotly.graph_objects as go
227
+
228
  def plot_name_trends_plotly(df, names, start_year=None, end_year=None, logscale=False):
 
229
  name_df = df[df["name"].isin(names)]
230
  if start_year is not None:
231
  name_df = name_df[name_df["year"] >= start_year]
 
235
  if name_df.empty:
236
  raise gr.Error("No data for selected names and year range.")
237
 
 
238
  agg_df = (
239
  name_df.groupby(["year", "name"])["count"]
240
  .sum()
241
  .reset_index()
242
  )
243
 
244
+ # Build figure manually for better control
245
+ fig = go.Figure()
246
+ for name in sorted(agg_df["name"].unique()):
247
+ sub_df = agg_df[agg_df["name"] == name]
248
+ if len(sub_df) > 1:
249
+ fig.add_trace(go.Scatter(
250
+ x=sub_df["year"],
251
+ y=sub_df["count"],
252
+ mode="lines+markers",
253
+ name=name
254
+ ))
255
+ else:
256
+ # Jessca
257
+ fig.add_trace(go.Scatter(
258
+ x=sub_df["year"],
259
+ y=sub_df["count"],
260
+ mode="markers",
261
+ name=name,
262
+ marker=dict(size=10, symbol="circle"),
263
+ ))
264
+
265
+ fig.update_layout(
266
  title="Name Usage Over Time",
267
+ xaxis_title="Year",
268
+ yaxis_title="Count",
269
+ height=500,
270
+ yaxis_type="log" if logscale else "linear",
271
  )
272
+
273
  return fig, agg_df
274
 
275
 
276
+
277
  def plot_from_inputs(name_text, start_year, end_year, logscale):
278
  names = [n.strip() for n in name_text.split(",") if n.strip()]
279
  if not names: