arjunguha commited on
Commit
457470f
·
1 Parent(s): a92220b
Files changed (3) hide show
  1. app.py +120 -30
  2. metrics.py +79 -5
  3. puzzles_cleaned.csv +2 -2
app.py CHANGED
@@ -19,7 +19,10 @@ app that displays the following:
19
  Note that not every model has a response for every puzzle.
20
  """
21
  import gradio as gr
22
- from metrics import load_results
 
 
 
23
 
24
 
25
  def get_model_response(prompt_id, model_name):
@@ -39,12 +42,18 @@ def display_puzzle(puzzle_id):
39
  puzzle = conn.sql(query).fetchone()
40
  return puzzle if puzzle else (None, None,None, None, None)
41
 
42
- def display_model_response(puzzle_id, model_name):
43
  response = get_model_response(puzzle_id, model_name)
 
 
44
  split_thoughts = response.split("</think>")
45
  if len(split_thoughts) > 1:
46
- response = split_thoughts[-1].strip()
47
- return "From " + model_name + ":\n" + response if response else "No response from this model."
 
 
 
 
48
 
49
 
50
  conn = load_results()
@@ -108,36 +117,117 @@ model_columns = {
108
  valid_model_indices = list(model_columns.keys())
109
  default_model = model_columns[valid_model_indices[0]]
110
 
111
- def create_interface():
112
- with gr.Blocks() as demo:
113
- # Using "markdown" as the datatype makes Gradio interpret newlines.
114
- puzzle_list = gr.DataFrame(
115
- value=relabelled_df,
116
- datatype=["number", "str", "markdown", *["str"] * len(model_correct_columns)],
117
- # headers=["ID", "Challenge", "Answer", *cleaned_model_names],
118
- )
119
- model_response = gr.Textbox(label="Model Response", interactive=False)
120
- challenge = gr.Textbox(label="Challenge", interactive=False)
121
- answer = gr.Textbox(label="Answer", interactive=False)
122
- explanation = gr.Textbox(label="Explanation", interactive=False)
123
- editors_note = gr.Textbox(label="Editor's Note", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  transcript = gr.Textbox(label="Transcript", interactive=False)
125
-
126
- def update_puzzle(evt: gr.SelectData):
127
- row = evt.index[0]
128
- model_index = evt.index[1]
129
- model_name = model_columns[model_index] if model_index in valid_model_indices else default_model
130
- return (*display_puzzle(row), display_model_response(row, model_name))
131
-
132
- puzzle_list.select(
133
- fn=update_puzzle,
134
- inputs=[],
135
- outputs=[challenge, answer, transcript, explanation, editors_note, model_response]
136
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- demo.launch()
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  if __name__ == "__main__":
142
  create_interface()
143
 
 
19
  Note that not every model has a response for every puzzle.
20
  """
21
  import gradio as gr
22
+ import pandas as pd
23
+ import numpy as np
24
+ from metrics import load_results, accuracy_by_model_and_time
25
+ import metrics
26
 
27
 
28
  def get_model_response(prompt_id, model_name):
 
42
  puzzle = conn.sql(query).fetchone()
43
  return puzzle if puzzle else (None, None,None, None, None)
44
 
45
+ def display_model_response(puzzle_id, model_name, show_thoughts):
46
  response = get_model_response(puzzle_id, model_name)
47
+ if response is None:
48
+ return "No response from this model."
49
  split_thoughts = response.split("</think>")
50
  if len(split_thoughts) > 1:
51
+ if show_thoughts:
52
+ return response.strip()
53
+ else:
54
+ return split_thoughts[-1].strip()
55
+ else:
56
+ return response.strip()
57
 
58
 
59
  conn = load_results()
 
117
  valid_model_indices = list(model_columns.keys())
118
  default_model = model_columns[valid_model_indices[0]]
119
 
120
+ def summary_view():
121
+ accuracy_over_time = accuracy_by_model_and_time(conn).to_df()
122
+ accuracy_over_time["model"] = accuracy_over_time["model"].apply(lambda x: x.replace("completions-", ""))
123
+ # This hack so that Gradio doesn't render a year 2020 as "2,020.0".
124
+ accuracy_over_time["year"] = accuracy_over_time["year"].astype(str)
125
+ accuracy_over_time.rename(columns={"model": "Model", "year": "Year", "accuracy": "Accuracy"}, inplace=True)
126
+ gr.LinePlot(
127
+ accuracy_over_time,
128
+ x="Year",
129
+ y="Accuracy",
130
+ color="Model",
131
+ title="Model Accuracy Over Time",
132
+ y_lim=[0, 1],
133
+ x_label="Year",
134
+ y_label="Accuracy",
135
+ )
136
+
137
+
138
+ def r1_accuracy_by_completion_length():
139
+ r1_completions = metrics.r1_accuracy_by_completion_length(conn).to_df()
140
+ r1_completions["length"] = r1_completions["length"] / 3.2
141
+ r1_completions.rename(columns={"length": "Response Length", "cumulative_correct": "Cumulative Correct"}, inplace=True)
142
+
143
+ gr.LinePlot(
144
+ r1_completions,
145
+ x="Response Length",
146
+ y="Cumulative Correct",
147
+ title="R1 Accuracy by Completion Length",
148
+ x_label="Max Response Length (tokens)",
149
+ y_label="# Correct Answers",
150
+ x_lim=[0, 32_768],
151
+ )
152
+
153
+ def all_challenges_view():
154
+ # Using "markdown" as the datatype makes Gradio interpret newlines.
155
+ puzzle_list = gr.DataFrame(
156
+ value=relabelled_df,
157
+ datatype=["number", "str", "markdown", *["str"] * len(model_correct_columns)],
158
+ # headers=["ID", "Challenge", "Answer", *cleaned_model_names],
159
+ )
160
+ with gr.Row(scale=2):
161
+ model_name = gr.State(value=default_model)
162
+ challenge_id = gr.State(value=0)
163
+ show_thoughts = gr.State(value=False)
164
+ with gr.Column():
165
+ challenge = gr.Textbox(label="Challenge", interactive=False)
166
+ answer = gr.Textbox(label="Answer", interactive=False)
167
+ explanation = gr.Textbox(label="Explanation", interactive=False)
168
+ editors_note = gr.Textbox(label="Editor's Note", interactive=False)
169
+ with gr.Column():
170
+ gr.Checkbox(
171
+ label="Show Thoughts", value=False
172
+ ).change(
173
+ fn=lambda x: x, inputs=[show_thoughts], outputs=[show_thoughts]
174
+ )
175
+ model_response = gr.Textbox(label="Model Response", interactive=False)
176
  transcript = gr.Textbox(label="Transcript", interactive=False)
177
+
178
+ def select_table_item(evt: gr.SelectData):
179
+ model_index = evt.index[1]
180
+ challenge_id = evt.index[0]
181
+ model_name = model_columns[model_index] if model_index in valid_model_indices else default_model
182
+ return (model_name, challenge_id)
183
+
184
+ def update_puzzle(challenge_id: str, model_name: str, show_thoughts: bool):
185
+ return (*display_puzzle(challenge_id),
186
+ gr.Textbox(
187
+ value=display_model_response(challenge_id, model_name, show_thoughts),
188
+ label=model_name
189
+ ))
190
+
191
+ puzzle_list.select(
192
+ fn=select_table_item,
193
+ inputs=[],
194
+ outputs=[model_name, challenge_id]
195
+ )
196
+
197
+ model_name.change(
198
+ fn=update_puzzle,
199
+ inputs=[challenge_id, model_name, show_thoughts],
200
+ outputs=[challenge, answer, transcript, explanation, editors_note, model_response]
201
+ )
202
+
203
+ challenge_id.change(
204
+ fn=update_puzzle,
205
+ inputs=[challenge_id, model_name, show_thoughts],
206
+ outputs=[challenge, answer, transcript, explanation, editors_note, model_response]
207
+ )
208
+
209
+ show_thoughts.change(
210
+ fn=update_puzzle,
211
+ inputs=[challenge_id, model_name, show_thoughts],
212
+ outputs=[challenge, answer, transcript, explanation, editors_note, model_response]
213
+ )
214
+
215
 
 
216
 
217
 
218
+ def create_interface():
219
+ with gr.Blocks() as demo:
220
+ with gr.Tabs():
221
+ with gr.TabItem("All Challenges"):
222
+ all_challenges_view()
223
+ with gr.TabItem("Accuracy by Model"):
224
+ gr.DataFrame(metrics.accuracy_by_model(conn).to_df())
225
+ with gr.TabItem("Accuracy Over Time"):
226
+ summary_view()
227
+ with gr.TabItem("DeepSeek R1 Analysis"):
228
+ r1_accuracy_by_completion_length()
229
+ demo.launch()
230
+
231
  if __name__ == "__main__":
232
  create_interface()
233
 
metrics.py CHANGED
@@ -2,6 +2,7 @@ import re
2
  import duckdb
3
  import textwrap
4
  from typing import List, Tuple
 
5
 
6
  def _parse_answer(text: str) -> List[List[str]]:
7
  """
@@ -55,15 +56,81 @@ def _wrap_text(text: str, width: int) -> str:
55
 
56
  def load_results():
57
  conn = duckdb.connect(":memory:")
58
- conn.execute("ATTACH DATABASE 'results.duckdb' AS results")
59
  conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
60
  conn.create_function("check_answer", _check_answer)
61
  conn.create_function("clip_text", _clip_text)
62
  conn.create_function("wrap_text", _wrap_text)
63
  return conn
64
 
65
- def accuracy_by_model(conn):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  model_accuracies = conn.sql("""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  WITH AnswerCheck AS (
68
  SELECT
69
  results.parent_dir AS model,
@@ -87,8 +154,15 @@ def accuracy_by_model(conn):
87
  AnswerCheck
88
  """)
89
 
90
- print(model_accuracies)
 
 
 
 
 
 
 
 
91
 
92
  if __name__ == "__main__":
93
- conn = load_results()
94
- accuracy_by_model(conn)
 
2
  import duckdb
3
  import textwrap
4
  from typing import List, Tuple
5
+ import argparse
6
 
7
  def _parse_answer(text: str) -> List[List[str]]:
8
  """
 
56
 
57
  def load_results():
58
  conn = duckdb.connect(":memory:")
59
+ conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
60
  conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
61
  conn.create_function("check_answer", _check_answer)
62
  conn.create_function("clip_text", _clip_text)
63
  conn.create_function("wrap_text", _wrap_text)
64
  return conn
65
 
66
+ def r1_accuracy_by_completion_length(conn):
67
+ """
68
+ For the responses from the completions-r1 model:
69
+ 1. We calculate completion length and correctness for each problem.
70
+ 2. We sort by length.
71
+ 3. We compute cumulative number of correct responses.
72
+ """
73
+ # Use CTEs
74
+ r1_completions = conn.sql("""
75
+ WITH LengthsAndCorrectness AS (
76
+ SELECT
77
+ LENGTH(results.completion) AS length,
78
+ CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct
79
+ FROM results.completions results JOIN challenges
80
+ ON results.prompt_id = challenges.ID
81
+ WHERE results.parent_dir = 'completions-r1'
82
+ )
83
+ SELECT
84
+ length,
85
+ COUNT(*) OVER (ORDER BY length) AS cumulative_correct
86
+ FROM LengthsAndCorrectness
87
+ """)
88
+ return r1_completions
89
+
90
+
91
+ def accuracy_by_model_and_time(conn):
92
  model_accuracies = conn.sql("""
93
+ WITH ChallengesWithDates AS (
94
+ SELECT
95
+ ID,
96
+ answer,
97
+ EXTRACT(YEAR FROM CAST(date AS DATE)) AS year
98
+ FROM
99
+ challenges
100
+ ),
101
+ DateAnswerCheck AS (
102
+ SELECT
103
+ results.parent_dir AS model,
104
+ dates.year,
105
+ COUNT(*) AS total,
106
+ SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct
107
+ FROM
108
+ results.completions results
109
+ JOIN
110
+ ChallengesWithDates dates
111
+ ON
112
+ results.prompt_id = dates.ID
113
+ GROUP BY
114
+ results.parent_dir,
115
+ dates.year
116
+ )
117
+ SELECT
118
+ model,
119
+ year,
120
+ total,
121
+ correct,
122
+ ROUND(correct / total, 2) AS accuracy
123
+ FROM
124
+ DateAnswerCheck
125
+ ORDER BY
126
+ model,
127
+ year
128
+ """)
129
+
130
+ return model_accuracies
131
+
132
+ def accuracy_by_model(conn):
133
+ return conn.sql("""
134
  WITH AnswerCheck AS (
135
  SELECT
136
  results.parent_dir AS model,
 
154
  AnswerCheck
155
  """)
156
 
157
+ def main():
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument("--by-model-and-time", action="store_true")
160
+ args = parser.parse_args()
161
+ conn = load_results()
162
+ if args.by_model_and_time:
163
+ print(accuracy_by_model_and_time(conn))
164
+ else:
165
+ print(accuracy_by_model(conn))
166
 
167
  if __name__ == "__main__":
168
+ main()
 
puzzles_cleaned.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7efd3a2897270124ecc8a299b96d14fb54600f3c0faf27b790d8b0312720f3cd
3
- size 1132332
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:257753179c4b2a5be8716ac03da2617c48d9037290cc39b4896ad55304e13337
3
+ size 1119397