leolaish commited on
Commit
e9f0f5e
·
verified ·
1 Parent(s): 7e90894

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. .github/workflows/update_space.yml +28 -0
  2. README.md +1 -7
  3. app.css +197 -0
  4. app.py +937 -0
  5. requirements.txt +4 -0
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - no
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
  title: MathsPro
3
- emoji: 🏆
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.38.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: MathsPro
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.38.1
 
 
6
  ---
 
 
app.css ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .title-content {
2
+ font-size: 24px;
3
+ font-weight: 600;
4
+ display: flex;
5
+ align-items: center;
6
+ justify-content: center;
7
+ }
8
+
9
+ .sub-title-content {
10
+ display: flex;
11
+ align-items: center;
12
+ justify-content: center;
13
+ font-size: 18px;
14
+ }
15
+
16
+ .main-area {
17
+ border-radius: 10px;
18
+ padding: 20px;
19
+ border: 1px solid rgba(0, 0, 0, 0.15);
20
+ }
21
+
22
+ .probelm-example-container {
23
+ position: relative;
24
+ padding: 38px 12px 12px 12px;
25
+ border: 1px solid rgba(0, 0, 0, 0.15);
26
+ border-radius: 10px;
27
+ }
28
+
29
+ .probelm-example-title-content {
30
+ position: absolute;
31
+ top: 0;
32
+ left: 0;
33
+ border-right: 1px solid rgba(0, 0, 0, 0.15);
34
+ border-bottom: 1px solid rgba(0, 0, 0, 0.15);
35
+ border-radius: 10px 0 10px 0;
36
+ padding: 6px 12px;
37
+ }
38
+
39
+ .probelm-example-another {
40
+ position: absolute;
41
+ top: 0;
42
+ right: 50px;
43
+ border-left: 1px solid rgba(0, 0, 0, 0.15);
44
+ border-bottom: 1px solid rgba(0, 0, 0, 0.15);
45
+ border-top: 0;
46
+ border-radius: 0 0px 0 10px;
47
+ display: flex;
48
+ align-items: center;
49
+ justify-content: center;
50
+ height: 26.89px;
51
+ width: 36px;
52
+ cursor: pointer;
53
+ font-size: 12px;
54
+ background: transparent;
55
+ min-width: auto;
56
+ padding: 0;
57
+ }
58
+
59
+ .probelm-example-another > img {
60
+ height: 16px;
61
+ width: 16px;
62
+ margin-right: 0;
63
+ }
64
+
65
+ .probelm-example-copy {
66
+ position: absolute;
67
+ top: 0;
68
+ right: 0;
69
+ border-left: 1px solid rgba(0, 0, 0, 0.15);
70
+ border-bottom: 1px solid rgba(0, 0, 0, 0.15);
71
+ border-top: 0;
72
+ border-right: 0;
73
+ border-radius: 0 10px 0 0px;
74
+ padding: 4px 12px;
75
+ cursor: pointer;
76
+ font-size: 12px;
77
+ background: transparent;
78
+ min-width: auto;
79
+ }
80
+
81
+ .right {
82
+ border: 1px solid rgba(0, 0, 0, 0.15);
83
+ border-radius: 10px;
84
+ position: relative;
85
+ padding: 40px 14px 14px 14px;
86
+ min-height: 500px;
87
+ }
88
+
89
+ .solution-title-content {
90
+ position: absolute;
91
+ top: 0;
92
+ left: 0;
93
+ border-right: 1px solid rgba(0, 0, 0, 0.15);
94
+ border-bottom: 1px solid rgba(0, 0, 0, 0.15);
95
+ border-radius: 10px 0 10px 0;
96
+ padding: 4px 8px;
97
+ }
98
+
99
+ .solution-content {
100
+ border-radius: 10px;
101
+ border: 1px solid rgba(0, 0, 0, 0.15);
102
+ height: 480px;
103
+ display: flex;
104
+ flex-direction: column-reverse;
105
+ overflow: none !important;
106
+ overflow-y: auto;
107
+ scroll-snap-type: y mandatory;
108
+ }
109
+
110
+ .solution-content .solution-content {
111
+ padding: 12px;
112
+ display: flex;
113
+ scroll-snap-align: end;
114
+ }
115
+
116
+ .run-btn {
117
+ background: linear-gradient(to right, #ce7e53, #bb470b);
118
+ color: white;
119
+ }
120
+
121
+ .running-btn {
122
+ background: linear-gradient(to right, #ce7e53, #bb470b);
123
+ color: white;
124
+ display: flex;
125
+ align-items: center;
126
+ justify-content: center;
127
+ position: relative;
128
+ }
129
+
130
+ .running-btn::before {
131
+ content: "";
132
+ position: absolute;
133
+ width: 22px;
134
+ height: 22px;
135
+ border-radius: 50%;
136
+ border: 3px solid white;
137
+ border-top-color: transparent;
138
+ /* border-right-color: transparent;
139
+ border-bottom-color: transparent; */
140
+ animation: spin 1s linear infinite;
141
+ }
142
+
143
+ @keyframes spin {
144
+ 0% {
145
+ transform: rotate(0deg);
146
+ }
147
+ 100% {
148
+ transform: rotate(360deg);
149
+ }
150
+ }
151
+
152
+ .probelm-input-container {
153
+ border: 1px solid rgba(0, 0, 0, 0.15);
154
+ border-radius: 10px;
155
+ }
156
+
157
+ .probelm-input-container .probelm-input-container {
158
+ border: none;
159
+ }
160
+
161
+ .problem-input-markdown {
162
+ padding: 38px 12px 12px 12px;
163
+ border: 1px solid rgba(0, 0, 0, 0.15);
164
+ border-radius: 10px;
165
+ position: relative;
166
+ max-height: max-content;
167
+ overflow-y: auto;
168
+ border-width: 1px;
169
+ border: 1px solid rgba(0, 0, 0, 0.15);
170
+ }
171
+
172
+ .problem-input-markdown::before {
173
+ content: "Problem rendered";
174
+ border: 1px solid rgba(0, 0, 0, 0.15);
175
+ border-radius: 10px 0px 10px 0px;
176
+ position: absolute;
177
+ left: 0;
178
+ top: -1px;
179
+ border-right: 1px solid rgba(0, 0, 0, 0.15);
180
+ border-bottom: 1px solid rgba(0, 0, 0, 0.15);
181
+ border-left: 0;
182
+ border-top: 0;
183
+ padding: 6px 12px;
184
+ }
185
+
186
+ .problem-input-markdown .problem-input-markdown {
187
+ border: none;
188
+ padding: 0;
189
+ max-height: 280px;
190
+ min-height: 150px;
191
+ }
192
+
193
+ .problem-input-markdown .problem-input-markdown::before {
194
+ content: "";
195
+ border: none;
196
+ display: none;
197
+ }
app.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from dataclasses import dataclass
4
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
5
+ from huggingface_hub import InferenceClient
6
+ import os
7
+ import re
8
+ import subprocess
9
+ import tempfile
10
+ import json
11
+ import datasets
12
+ from datasets import load_dataset
13
+ from datasets import Value, Features
14
+ import random
15
+ import time
16
+ from typing import Tuple, Dict, Any, List
17
+ from sympy import N, simplify
18
+ from sympy.parsing.latex import parse_latex
19
+ #from openai import OpenAI
20
+ import base64
21
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
22
+ from transformers import AutoTokenizer, AutoModelForPreTraining
23
+
24
+ #client = OpenAI(
25
+ # base_url=os.environ.get("SERVER_URL"),
26
+ # api_key=os.environ.get("HF_TOKEN"),
27
+ #)
28
+ client = InferenceClient("mistralai/mathstral-7B-v0.1")
29
+
30
+
31
+
32
+ @dataclass
33
+ class Config:
34
+ debug: bool = False
35
+ push_to_hub: bool = False
36
+ model_id: str = None
37
+ revision: str = None
38
+ system_prompt: str = None
39
+ validation_set: str = None
40
+ is_quantized: bool = False
41
+ restart_on_fail: bool = False
42
+ is_submission: bool = False
43
+ num_samples: int = 1
44
+ num_generations: int = 1
45
+ do_sample: bool = True
46
+ temperature: float = 1.0
47
+ top_p: float = 0.9
48
+ top_k: int = 50
49
+ max_new_tokens: int = 100
50
+ # Load pre-trained Wit Transformer model and tokenizer
51
+ tokenizer = AutoTokenizer.from_pretrained("AnReu/math_pretrained_bert")
52
+ model = AutoModelForPreTraining.from_pretrained("AnReu/math_pretrained_bert")
53
+
54
+ class PythonREPL:
55
+ def __init__(self, timeout=5):
56
+ self.timeout = timeout
57
+
58
+ def execute(self, query: str) -> Tuple[bool, str]:
59
+ query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
60
+ query = query.strip().split("\n")
61
+ if "print(" not in query[-1]:
62
+ if "#" in query[-1]:
63
+ query[-1] = query[-1].split("#")[0]
64
+ query[-1] = "print(" + query[-1] + ")"
65
+ query = "\n".join(query)
66
+
67
+ with tempfile.TemporaryDirectory() as temp_dir:
68
+ temp_file_path = os.path.join(temp_dir, "tmp.py")
69
+
70
+ with open(temp_file_path, "w") as f:
71
+ f.write(query)
72
+
73
+ result = subprocess.run(
74
+ ["python3", temp_file_path],
75
+ capture_output=True,
76
+ check=False,
77
+ text=True,
78
+ timeout=self.timeout,
79
+ )
80
+
81
+ if result.returncode == 0:
82
+ output = result.stdout
83
+ return True, output.strip()
84
+ else:
85
+ error_msg = result.stderr.strip()
86
+ msgs = error_msg.split("\n")
87
+ new_msgs = []
88
+ want_next = False
89
+ for m in msgs:
90
+ if "Traceback" in m:
91
+ new_msgs.append(m)
92
+ elif m == msgs[-1]:
93
+ new_msgs.append(m)
94
+ elif temp_file_path in m:
95
+ st = m.index('"/') + 1 if '"/' in m else 0
96
+ ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
97
+ clr = m[st:ed] if not ed else m[st:]
98
+ m = m.replace(clr, "")
99
+ new_msgs.append(m)
100
+ want_next = True
101
+ elif want_next:
102
+ new_msgs.append(m)
103
+ want_next = False
104
+ error_msg = "\n".join(new_msgs)
105
+ return False, error_msg.strip()
106
+
107
+ def __call__(self, query: str) -> Tuple[bool, str]:
108
+ with ThreadPoolExecutor() as executor:
109
+ future = executor.submit(self.execute, query)
110
+ try:
111
+ return future.result(timeout=self.timeout)
112
+ except TimeoutError:
113
+ return False, f"Timed out after {self.timeout} seconds."
114
+
115
+
116
+ def execute_completion(
117
+ executor: PythonREPL,
118
+ completion: str,
119
+ return_status: bool = False,
120
+ last_code_block: bool = False,
121
+ ) -> str | Tuple[str, bool]:
122
+ # executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code]
123
+ executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
124
+
125
+ if len(executions) == 0: # directly return cot result
126
+ return completion, False if return_status else completion
127
+ else:
128
+ if last_code_block:
129
+ executions = [executions[-1]]
130
+
131
+ # Python
132
+ execution_outputs = []
133
+ successes = []
134
+ for code in executions:
135
+ success = False
136
+
137
+ if "subprocess" in code:
138
+ output = "subprocess is not allowed"
139
+ execution_outputs.append(output)
140
+ successes.append(success)
141
+ continue
142
+
143
+ if "venv" in code:
144
+ output = "venv is not allowed"
145
+ execution_outputs.append(output)
146
+ successes.append(success)
147
+ continue
148
+
149
+ try:
150
+ success, output = executor(code)
151
+ except TimeoutError as e:
152
+ print("time out")
153
+ output = e
154
+
155
+ if not success and not return_status:
156
+ output = ""
157
+
158
+ execution_outputs.append(output)
159
+ successes.append(success)
160
+
161
+ output = str(execution_outputs[-1]).strip()
162
+ success = successes[-1]
163
+
164
+ if return_status:
165
+ return output, success
166
+ else:
167
+ return output
168
+
169
+
170
+ def postprocess_completion(
171
+ text: str, return_status: bool = False, last_code_block=False, timeout=5
172
+ ) -> str | Tuple[str, bool]:
173
+ executor = PythonREPL(timeout=timeout)
174
+
175
+ result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
176
+ del executor
177
+
178
+ return result
179
+
180
+
181
+ def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]:
182
+ return prompt.format(example["prompt"], "{}")
183
+
184
+
185
+ def last_boxed_only_string(string):
186
+ """
187
+ Extracts the last LaTeX boxed or framed expression from a string.
188
+ Args:
189
+ string (str): The input string containing LaTeX expressions.
190
+ Returns:
191
+ str or None: The last boxed or framed expression, if found;
192
+ otherwise, None.
193
+ """
194
+
195
+ idx = string.rfind("\\boxed")
196
+ if idx < 0:
197
+ idx = string.rfind("\\fbox")
198
+ if idx < 0:
199
+ return None
200
+
201
+ i = idx
202
+ right_brace_idx = None
203
+ num_left_braces_open = 0
204
+ while i < len(string):
205
+ if string[i] == "{":
206
+ num_left_braces_open += 1
207
+ if string[i] == "}":
208
+ num_left_braces_open -= 1
209
+ if num_left_braces_open == 0:
210
+ right_brace_idx = i
211
+ break
212
+ i += 1
213
+
214
+ if right_brace_idx is None:
215
+ retval = None
216
+ else:
217
+ retval = string[idx : right_brace_idx + 1]
218
+
219
+ return retval
220
+
221
+
222
+ def remove_boxed(s):
223
+ """
224
+ Removes the LaTeX boxed command, returning the content inside the braces.
225
+ Args:
226
+ s (str): The string containing a LaTeX boxed expression.
227
+ Returns:
228
+ str or None: The content inside the boxed command, if valid;
229
+ otherwise, None.
230
+ """
231
+
232
+ left = "\\boxed{"
233
+ try:
234
+ assert s[: len(left)] == left
235
+ assert s[-1] == "}"
236
+ length = len(left)
237
+ return s[length:-1]
238
+ except Exception:
239
+ return None
240
+
241
+
242
+ def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
243
+ """
244
+ Extracts the answer from a LaTeX boxed expression within
245
+ a prediction string.
246
+ Args:
247
+ pred_str (str): The string containing one or more LaTeX
248
+ boxed expressions.
249
+ strip_double_curly_brace (bool): If True, removes an additional
250
+ layer of braces.
251
+ Returns:
252
+ str or None: The extracted answer, if any; otherwise, None.
253
+ """
254
+
255
+ boxed_str = last_boxed_only_string(pred_str)
256
+ if boxed_str is None:
257
+ return None
258
+ answer = remove_boxed(boxed_str)
259
+ if answer is None:
260
+ return None
261
+ if strip_double_curly_brace:
262
+ match = re.match("^\{(.*)\}$", answer) # noqa: W605
263
+ if match:
264
+ answer = match.group(1)
265
+ return answer
266
+
267
+
268
+ def normalize_final_answer(final_answer: str) -> str:
269
+ """
270
+ Normalizes a final answer string by removing or replacing various LaTeX
271
+ and text elements.
272
+ Args:
273
+ final_answer (str): The answer string to normalize.
274
+ Returns:
275
+ str: The normalized answer string.
276
+ """
277
+
278
+ match = re.search(r"(.*?)Problem:", final_answer, flags=re.S)
279
+ if match:
280
+ final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本
281
+ """Normalize a final answer to a quantitative reasoning question."""
282
+ # final_answer = final_answer.split('=')[-1]
283
+ SUBSTITUTIONS = [
284
+ ("an ", ""),
285
+ ("a ", ""),
286
+ (".$", "$"),
287
+ ("\\$", ""),
288
+ (r"\ ", ""),
289
+ (" ", ""),
290
+ ("mbox", "text"),
291
+ (",\\text{and}", ","),
292
+ ("\\text{and}", ","),
293
+ ("\\text{m}", "\\text{}"),
294
+ ("\\le", "<"),
295
+ ]
296
+ REMOVED_EXPRESSIONS = [
297
+ "square",
298
+ "ways",
299
+ "integers",
300
+ "dollars",
301
+ "mph",
302
+ "inches",
303
+ "ft",
304
+ "hours",
305
+ "km",
306
+ "units",
307
+ "\\ldots",
308
+ "sue",
309
+ "points",
310
+ "feet",
311
+ "minutes",
312
+ "digits",
313
+ "cents",
314
+ "degrees",
315
+ "cm",
316
+ "gm",
317
+ "pounds",
318
+ "meters",
319
+ "meals",
320
+ "edges",
321
+ "students",
322
+ "childrentickets",
323
+ "multiples",
324
+ "\\text{s}",
325
+ "\\text{.}",
326
+ "\\text{\ns}",
327
+ "\\text{}^2",
328
+ "\\text{}^3",
329
+ "\\text{\n}",
330
+ "\\text{}",
331
+ r"\mathrm{th}",
332
+ r"^\circ",
333
+ r"^{\circ}",
334
+ r"\;",
335
+ r",\!",
336
+ "{,}",
337
+ '"',
338
+ "\\dots",
339
+ "\n",
340
+ "\r",
341
+ "\f",
342
+ "\%",
343
+ ]
344
+ for before, after in SUBSTITUTIONS:
345
+ final_answer = final_answer.replace(before, after)
346
+ for expr in REMOVED_EXPRESSIONS:
347
+ final_answer = final_answer.replace(expr, "")
348
+
349
+ # Extract answer that is in LaTeX math, is bold,
350
+ # is surrounded by a box, etc.
351
+ final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
352
+ final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
353
+ final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
354
+ final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
355
+ assert "\n" not in final_answer
356
+ assert "\r" not in final_answer
357
+ assert "\f" not in final_answer
358
+ if len(re.findall(r"finalansweris(.*)", final_answer)) > 0:
359
+ final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1]
360
+
361
+ if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0:
362
+ final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1]
363
+
364
+ if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0:
365
+ final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1]
366
+
367
+ if len(re.findall(r"\$(.*?)\$", final_answer)) > 0:
368
+ final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1]
369
+ final_answer = final_answer.strip()
370
+ if "rac" in final_answer and "\\frac" not in final_answer:
371
+ final_answer = final_answer.replace("rac", "\\frac")
372
+
373
+ final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
374
+ final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
375
+ final_answer = final_answer.replace("$", "")
376
+
377
+ if final_answer.replace(",", "").isdigit():
378
+ final_answer = final_answer.replace(",", "")
379
+
380
+ return final_answer
381
+
382
+
383
+ def naive_parse(answer: str) -> str:
384
+ """
385
+ Extracts and returns the numeric digits from the input string, processing them in reverse order
386
+ until a non-numeric character is encountered after encountering the first numeric character.
387
+
388
+ Args:
389
+ answer (str): The input string to parse.
390
+
391
+ Returns:
392
+ str: A string consisting of the numeric digits extracted from the input, in their original order.
393
+
394
+ Example:
395
+ >>> naive_parse("abc123def")
396
+ '123'
397
+ >>> naive_parse("def456ghi")
398
+ '456'
399
+ >>> naive_parse("no numbers here")
400
+ ''
401
+ """
402
+ out = []
403
+ start = False
404
+ end = False
405
+ for l in reversed(list(answer)):
406
+ if l in "0123456789" and not end:
407
+ start = True
408
+ out.append(l)
409
+ else:
410
+ if start:
411
+ end = True
412
+
413
+ out = reversed(out)
414
+ return "".join(out)
415
+
416
+
417
+ def validate_answer_is_numeric(x: str | int | float) -> int:
418
+ FLOAT_TOLERANCE = 0.2
419
+ try:
420
+ x = round(float(x))
421
+ f = float(x)
422
+ if abs(x - f) > FLOAT_TOLERANCE:
423
+ x = -1
424
+ except Exception:
425
+ x = -1
426
+ return x
427
+
428
+
429
+ def filter_answers(answers: List[str]) -> List[int]:
430
+ formatted_answers = [validate_answer_is_numeric(a) for a in answers]
431
+
432
+ # Filter for non-negative answers
433
+ formatted_answers = [a for a in formatted_answers if a >= 0]
434
+ # Compute modulo
435
+ formatted_answers = [a % 1_000 for a in formatted_answers]
436
+ # less than 2.1 billion or cannot convert to C int (32-bit)
437
+ formatted_answers = [a for a in formatted_answers if a <= 999]
438
+ return formatted_answers
439
+
440
+
441
+ def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
442
+ def do_answers_match(ref_answer: str, model_answer: str) -> bool:
443
+ ref_sympy = parse_latex(ref_answer)
444
+ model_sympy = parse_latex(model_answer)
445
+ diff = simplify(ref_sympy - model_sympy)
446
+ return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False
447
+
448
+ try:
449
+ result = do_answers_match(ref_answer, model_answer)
450
+ return result
451
+ except Exception as e:
452
+ print(e)
453
+ return False
454
+
455
+
456
+ def check_string_match(ref_answer: str, model_answer: str) -> bool:
457
+ try:
458
+ return ref_answer == model_answer
459
+ except Exception as e:
460
+ print(e)
461
+ return False
462
+
463
+
464
+ def check_answer(ref_answer: str, model_answer: str) -> bool:
465
+ # check if strings are the same
466
+ correct = check_string_match(ref_answer, model_answer)
467
+ if correct:
468
+ return True
469
+
470
+ # use the sympy library to check if the expressions are the same
471
+ correct = check_sympy_equivalence(ref_answer, model_answer)
472
+ if correct:
473
+ return True
474
+
475
+ return False
476
+
477
+
478
+ debug = False
479
+ model_id = "athstral-7B-v0.m1"
480
+ revision = "main"
481
+ system_prompt = "{}"
482
+ validation_set = "kaggle-validation-set-medium"
483
+ is_submission = True
484
+ num_samples = 4
485
+ num_generations = 4
486
+ temperature = 0.8
487
+ is_quantized = False
488
+ restart_on_fail = False
489
+ top_p = 1.0
490
+ top_k = 0
491
+ max_new_tokens = 2048
492
+ # Papermill related variables
493
+ push_to_hub = False
494
+ notebook_name = ""
495
+
496
+ config = Config(
497
+ debug=False,
498
+ push_to_hub=False,
499
+ model_id=model_id,
500
+ revision=revision,
501
+ system_prompt=system_prompt,
502
+ validation_set=validation_set,
503
+ is_quantized=is_quantized,
504
+ restart_on_fail=restart_on_fail,
505
+ is_submission=is_submission,
506
+ num_samples=num_samples,
507
+ num_generations=num_generations,
508
+ do_sample=True,
509
+ temperature=temperature,
510
+ top_p=top_p,
511
+ top_k=top_k,
512
+ max_new_tokens=max_new_tokens
513
+ )
514
+
515
+
516
+ print(f"=== Running submission with config ===\n\n{config}")
517
+
518
+
519
+ def parse_data_chunk(data_chunk):
520
+ """
521
+ Parse a given data chunk string into a list of individual data entries.
522
+
523
+ The function splits the input string by the delimiter "data:" and removes any
524
+ leading or trailing whitespace from each resulting chunk. Empty chunks are
525
+ filtered out from the final list.
526
+
527
+ Parameters:
528
+ data_chunk (str): The input string containing data chunks separated by "data:".
529
+
530
+ Returns:
531
+ list: A list of individual data entries with whitespace stripped.
532
+ """
533
+ chunks = data_chunk.split("data:")
534
+ stripped_chunks = map(lambda chunk: chunk.strip(), chunks)
535
+ return [chunk for chunk in stripped_chunks if chunk]
536
+
537
+
538
+ def generate(message, temperature):
539
+ """
540
+ Generates a chat completion response by streaming data from the client chat model.
541
+
542
+ This function streams the response from the client chat model and yields the content
543
+ of the response chunk by chunk. If an error occurs, it yields the error message.
544
+
545
+ Parameters:
546
+ message (str): The input message to be sent to the chat model.
547
+ temperature (float): The sampling temperature to use. Higher values mean the model will take more risks.
548
+
549
+ Yields:
550
+ tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred.
551
+ If no error occurred, the boolean flag will be False and the content will be the response text.
552
+ If an error occurred, the boolean flag will be True and the content will be the error message.
553
+ """
554
+ stream = client.chat.completions.create(
555
+ model="tgi",
556
+ messages=message,
557
+ stream=False,
558
+ max_tokens=1024,
559
+ stop=["```output\n"],
560
+ temperature=temperature,
561
+ #timeout=30,
562
+ )
563
+
564
+ response = stream.response
565
+
566
+ # The reason why the library method is not used here is that if an error occurs,
567
+ # the returned data will not be a stream, and using the official library will result in an error.
568
+ for chunk in response.iter_bytes():
569
+ chunk = chunk.decode("utf-8")
570
+ data_chunks = parse_data_chunk(chunk)
571
+
572
+ try:
573
+ for data_chunk in data_chunks:
574
+ chune_json = json.loads(data_chunk)
575
+
576
+ if "error" in chune_json and chune_json["error"]:
577
+ yield chune_json["error"], True
578
+ break
579
+
580
+ delta = chune_json["choices"][0]["delta"]
581
+ content = delta["content"] if "content" in delta else ""
582
+
583
+ if content != "":
584
+ yield content, False
585
+ except Exception as e:
586
+ print(f"func: generate error occurred\nchunk:{chunk}\nerror:{e}")
587
+ raise e
588
+
589
+
590
+ def get_majority_text(data):
591
+ from collections import Counter
592
+
593
+ # Count the frequency of each answer in model_answers
594
+ answer_counts = Counter(data["model_answers"])
595
+
596
+ # Find the majority response
597
+ majority_response = answer_counts.most_common(1)[0][0]
598
+
599
+ # Find the index of the first occurrence of the majority response
600
+ majority_index = data["model_answers"].index(majority_response)
601
+
602
+ # Return the corresponding text in gen_texts
603
+ return data["gen_texts"][majority_index]
604
+
605
+
606
+ def extract_solution(text):
607
+ # Split the text at "### Solution:"
608
+ parts = text.split("### Solution:", 1)
609
+ if len(parts) > 1:
610
+ # Return everything after "### Solution:"
611
+ return parts[1].strip()
612
+ else:
613
+ # Return an empty string if "### Solution:" is not found
614
+ return ""
615
+
616
+
617
+ def process_code(
618
+ example: Dict[str, Any],
619
+ config: Config,
620
+ restart_on_fail: bool = False,
621
+ last_step: bool = False,
622
+ ) -> Dict[str, Any]:
623
+ gen_text = example["gen_texts"]
624
+ num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL))
625
+
626
+ if num_python_blocks == 0:
627
+ if restart_on_fail:
628
+ print("no code has ever been generated, RESTARTING")
629
+ # reset the text to the original
630
+ example["gen_texts"] = example["text"]
631
+ else:
632
+ print("no code has ever been generated, STOP")
633
+ example["should_prune"] = True
634
+ example["has_code"] = False
635
+ return example
636
+
637
+ if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]):
638
+ num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL))
639
+ if num_output_blocks == 0:
640
+ print("the model hallucinated the code answer")
641
+ example["should_prune"] = True
642
+ return example
643
+
644
+ if "boxed" in gen_text[-100:]:
645
+ try:
646
+ answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:]))
647
+ except Exception:
648
+ answer = "-1"
649
+ else:
650
+ answer = normalize_final_answer(gen_text[-100:])
651
+
652
+ example["model_answers"] = answer
653
+ if not config.is_submission:
654
+ example["corrects"] = check_answer(example["ground_truth"], answer)
655
+ example["should_prune"] = True
656
+ print("Answer is: ", answer, example["ground_truth"], example["corrects"])
657
+ return example
658
+
659
+ if last_step:
660
+ # no point in continuing if we are at the last step
661
+ return example
662
+
663
+ if gen_text[-10:] != "```output\n":
664
+ # something else has gone wrong with the generation
665
+ print("warning: output block not found: ", gen_text[-40:])
666
+ if restart_on_fail:
667
+ example["gen_texts"] = example["text"]
668
+ else:
669
+ example["should_prune"] = True
670
+ return example
671
+
672
+ code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True)
673
+ # add the code result for the next round of generation
674
+ TRUNCATION_LIMIT = 200
675
+ if len(code_result) > TRUNCATION_LIMIT:
676
+ code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)"
677
+ example["gen_texts"] = gen_text + f"{code_result}\n```"
678
+
679
+ return example
680
+
681
+
682
+ def solve_problem(problem, temperature, progress=gr.Progress()):
683
+ """
684
+ yield token: string, stop: bool
685
+ """
686
+ problem = apply_template({"prompt": problem}, prompt=config.system_prompt)
687
+ print(f"Problem: {problem}")
688
+
689
+ sample = {
690
+ "problem": problem, # not used for the submission TODO Remove
691
+ "ground_truth": "unknown", # not used for the submission TODO Remove
692
+ "text": "## Solution:\n",
693
+ "gen_texts": "", # used to store all the generated text
694
+ "should_prune": False,
695
+ "problem_index": -1, # not used for the submission TODO Remove
696
+ "model_answers": "-1",
697
+ "has_code": True,
698
+ "corrects": False, # not used for the submission TODO Remove
699
+ }
700
+
701
+ for step in progress.tqdm(
702
+ range(config.num_generations), desc="Generating candidates"
703
+ ): # Depth of the tree (e.g. 6 steps = 5 code blocks)
704
+
705
+ step_reponse = sample["gen_texts"]
706
+
707
+ messages = [
708
+ {"role": "user", "content": sample["problem"]},
709
+ {"role": "assistant", "content": sample["gen_texts"]},
710
+ ]
711
+
712
+ stop = False
713
+
714
+ for reponse_message, error in generate(messages, temperature):
715
+ if reponse_message is not None:
716
+ step_reponse += reponse_message
717
+ yield step_reponse, False
718
+
719
+ if error:
720
+ stop = True
721
+
722
+ sample["gen_texts"] = step_reponse
723
+
724
+ # TODO: Maybe it should just return the result of running the code
725
+ sample = process_code(
726
+ sample,
727
+ config=config,
728
+ restart_on_fail=config.restart_on_fail,
729
+ last_step=(step == (config.num_generations - 1)),
730
+ )
731
+ sample["gen_texts"] = sample["gen_texts"] + "\n"
732
+
733
+ run_code_reponse = sample["gen_texts"].replace(step_reponse, "")
734
+
735
+ for output_mseeage in run_code_reponse:
736
+ if output_mseeage is not None:
737
+ step_reponse += output_mseeage
738
+ yield step_reponse, False
739
+
740
+ if sample["should_prune"] or stop:
741
+ break
742
+
743
+ yield sample["gen_texts"], True
744
+
745
+ features = Features({
746
+ 'id': Value('int64'),
747
+ 'problem': Value('string'),
748
+ 'answer': Value('string'),
749
+ #'prompt': Value('string'), # Ensure this matches the actual data type of 'prompt' in your dataset
750
+ #'level': Value('string')
751
+ })
752
+
753
+ # Now load the dataset using the defined schema
754
+ example_data = datasets.load_dataset(
755
+ "AI-MO/aimo-validation-math-level-5",
756
+ split="train",
757
+ use_auth_token=os.environ.get("HF_DATASET_TOKEN", None),
758
+ features=features # Pass the schema definition here
759
+ )
760
+
761
+
762
+
763
+ with open( "/teamspace/studios/this_studio/.lightning_studio/math/app.css", "r") as f:
764
+ css = f.read()
765
+
766
+
767
+ latex_delimiters = [
768
+ {"left": "[", "right": "]", "display": True},
769
+ ]
770
+
771
+
772
+ def get_random_problem():
773
+ example = random.choice(list(example_data))
774
+ problem = example["problem"]
775
+ return problem
776
+
777
+
778
+ def update_example_problem():
779
+ problem_example_text = get_random_problem()
780
+ return problem_example_text, problem_example_text
781
+
782
+
783
+ def clear():
784
+ problem_example_text = get_random_problem()
785
+ return "", 0.1, "", problem_example_text, problem_example_text
786
+
787
+
788
+ def preprocess_output(text):
789
+ return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)")
790
+
791
+
792
+ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
793
+ btn_list = []
794
+ problem_input_ele_list = []
795
+
796
+ problem_example_text = get_random_problem()
797
+
798
+ with gr.Row(elem_classes="title"):
799
+ gr.HTML("Math Olympiad Solver", elem_classes="title-content")
800
+
801
+ with gr.Row(elem_classes="sub-title"):
802
+ gr.HTML(
803
+ "<div>Demo of the <a href='https://huggingface.co/AI-MO/NuminaMath-7B-TIR'>Numina-Math-7B-TIR</a>. Example data are drawn randomly from AMC12, year 2022-2023.</div>",
804
+ elem_classes="sub-title-content",
805
+ )
806
+
807
+ with gr.Row(elem_classes="main-area"):
808
+ with gr.Column(scale=1, elem_classes="left"):
809
+ with gr.Row(elem_classes="probelm-example-container"):
810
+ with gr.Blocks(elem_classes="probelm-example-title"):
811
+ gr.HTML("Problem example", elem_classes="probelm-example-title-content")
812
+
813
+ with gr.Blocks(elem_classes="action-container"):
814
+ another_btn = gr.Button(
815
+ "",
816
+ elem_classes="probelm-example-another",
817
+ icon="./static/images/reset.png",
818
+ )
819
+ copy_btn = gr.Button("Copy", elem_classes="probelm-example-copy")
820
+
821
+ problem_example = gr.HTML(
822
+ problem_example_text,
823
+ elem_classes="probelm-example-content",
824
+ )
825
+
826
+ with gr.Row(elem_classes="probelm-input-container"):
827
+ inp = gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True)
828
+ problem_markdown = gr.Markdown(
829
+ visible=False,
830
+ latex_delimiters=[
831
+ {"left": "[", "right": "]", "display": True},
832
+ {"left": "$", "right": "$", "display": False},
833
+ {"left": r"\(", "right": r"\)", "display": False},
834
+ ],
835
+ )
836
+
837
+ inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown])
838
+ problem_input_ele_list.append(inp)
839
+ problem_input_ele_list.append(problem_markdown)
840
+
841
+ with gr.Accordion("Advanced Options", open=False):
842
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature")
843
+
844
+ with gr.Row() as btn_area:
845
+ btn_clear = gr.Button("Clear", elem_classes="clear-btn")
846
+ btn_run = gr.Button("Run", elem_classes="run-btn")
847
+ btn_list.append(btn_clear)
848
+ btn_list.append(btn_run)
849
+
850
+ with gr.Column(scale=1, elem_classes="right"):
851
+ gr.HTML("Solution", elem_classes="solution-title-content")
852
+ out = gr.Markdown(
853
+ elem_classes="solution-content",
854
+ latex_delimiters=[
855
+ {"left": "[", "right": "]", "display": True},
856
+ {"left": "$", "right": "$", "display": False},
857
+ {"left": r"\(", "right": r"\)", "display": False},
858
+ ],
859
+ )
860
+
861
+ problem_example_text_hidden = gr.Markdown(value=problem_example_text, visible=False)
862
+
863
+ def solve_problem_wrapper(inp_text, temperature):
864
+ new_running_btn = gr.Button("", elem_classes="run-btn running-btn")
865
+
866
+ try:
867
+ for after_tokens, stop in solve_problem(inp_text, temperature):
868
+ yield preprocess_output(after_tokens), new_running_btn
869
+
870
+ if stop:
871
+ btn_run = gr.Button("Run", elem_classes="run-btn")
872
+ yield preprocess_output(after_tokens), btn_run
873
+
874
+ except Exception as e:
875
+ raise e
876
+
877
+ def mount_run_btn(btn):
878
+ btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature], outputs=[out, btn_list[1]])
879
+ btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list)
880
+
881
+ def get_run_after_problem_input():
882
+ return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=False), gr.Markdown(
883
+ visible=True,
884
+ latex_delimiters=[
885
+ {"left": "[", "right": "]", "display": True},
886
+ {"left": "$", "right": "$", "display": False},
887
+ ],
888
+ elem_classes="problem-input-markdown",
889
+ )
890
+
891
+ def get_init_problem_input():
892
+ return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True), gr.Markdown(
893
+ visible=False,
894
+ latex_delimiters=[
895
+ {"left": "[", "right": "]", "display": True},
896
+ {"left": "$", "right": "$", "display": False},
897
+ ],
898
+ )
899
+
900
+ copy_btn.click(fn=lambda example: example, inputs=[problem_example_text_hidden], outputs=[inp])
901
+
902
+ btn_clear.click(
903
+ fn=clear,
904
+ inputs=[],
905
+ outputs=[
906
+ inp,
907
+ temperature,
908
+ out,
909
+ problem_example,
910
+ problem_example_text_hidden,
911
+ ],
912
+ )
913
+
914
+ btn_clear.click(get_init_problem_input, None, outputs=problem_input_ele_list)
915
+
916
+ mount_run_btn(btn_run)
917
+
918
+ demo.load(
919
+ update_example_problem,
920
+ inputs=None,
921
+ outputs=[
922
+ problem_example,
923
+ problem_example_text_hidden,
924
+ ],
925
+ )
926
+
927
+ another_btn.click(
928
+ fn=update_example_problem,
929
+ inputs=[],
930
+ outputs=[
931
+ problem_example,
932
+ problem_example_text_hidden,
933
+ ],
934
+ )
935
+
936
+ if __name__ == "__main__":
937
+ demo.queue(default_concurrency_limit=5).launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ openai
2
+ sympy
3
+ transformers
4
+ datasets==2.4.0