dlflannery commited on
Commit
1a61bb7
·
verified ·
1 Parent(s): b56832d

Update app.py

Browse files

structured output puzzle and math

Files changed (1) hide show
  1. app.py +68 -6
app.py CHANGED
@@ -1,4 +1,3 @@
1
- from ast import Interactive
2
  import os
3
  import gradio as gr
4
  # import openai
@@ -16,6 +15,10 @@ import base64
16
  import json
17
  from PIL import Image
18
  from io import BytesIO
 
 
 
 
19
 
20
  load_dotenv(override=True)
21
  key = os.getenv('OPENAI_API_KEY')
@@ -42,6 +45,30 @@ client = OpenAI(api_key = key)
42
 
43
  abbrevs = {'St. ' : 'Saint ', 'Mr. ': 'mister ', 'Mrs. ':'mussus ', 'Mr. ':'mister ', 'Ms. ':'mizz '}
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def genUsageStats(do_reset=False):
46
  result = []
47
  ttotal4o_in = 0
@@ -172,8 +199,22 @@ def updatePassword(txt):
172
  password = txt.lower().strip()
173
  return [password, "*********"]
174
 
175
- # def setModel(val):
176
- # return val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  def chat(prompt, user_window, pwd_window, past, response, gptModel, uploaded_image_file=''):
179
  image_gen_model = 'gpt-4o-2024-08-06'
@@ -200,9 +241,20 @@ def chat(prompt, user_window, pwd_window, past, response, gptModel, uploaded_ima
200
  response = f'{log_cnt} log files\n{wav_cnt} .wav files\n{other_cnt} Other files:\n{others}\nlogs: {str(log_list)}'
201
  return [past, response, None, gptModel, uploaded_image_file]
202
  if user_window in unames and pwd_window == pwdList[unames.index(user_window)]:
 
 
 
 
 
 
 
 
203
  past.append({"role":"user", "content":prompt})
204
  gen_image = (uploaded_image_file != '')
205
- if not gen_image:
 
 
 
206
  completion = client.chat.completions.create(model=gptModel,
207
  messages=past)
208
  reporting_model = gptModel
@@ -212,7 +264,17 @@ def chat(prompt, user_window, pwd_window, past, response, gptModel, uploaded_ima
212
  reporting_model = image_gen_model
213
  if not msg == 'ok':
214
  return [past, msg, None, gptModel, uploaded_image_file]
215
- reply = completion.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
216
  tokens_in = completion.usage.prompt_tokens
217
  tokens_out = completion.usage.completion_tokens
218
  tokens = completion.usage.total_tokens
@@ -503,7 +565,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
503
  for s in words_in:
504
  s = s.lstrip('- *@#$%^&_=+-')
505
  if len(s) > 0:
506
- loc = s.index(' ')
507
  if loc > 1:
508
  val = s[0:loc]
509
  isnum = val.replace('.','0').isdecimal()
 
 
1
  import os
2
  import gradio as gr
3
  # import openai
 
15
  import json
16
  from PIL import Image
17
  from io import BytesIO
18
+ from pydantic import BaseModel
19
+ import pprint
20
+ import flatlatex
21
+ lconv = flatlatex.converter()
22
 
23
  load_dotenv(override=True)
24
  key = os.getenv('OPENAI_API_KEY')
 
45
 
46
  abbrevs = {'St. ' : 'Saint ', 'Mr. ': 'mister ', 'Mrs. ':'mussus ', 'Mr. ':'mister ', 'Ms. ':'mizz '}
47
 
48
+ class Step(BaseModel):
49
+ explanation: str
50
+ output: str
51
+
52
+ class MathReasoning(BaseModel):
53
+ steps: list[Step]
54
+ final_answer: str
55
+
56
+
57
+ def solve(prompt, chatType):
58
+ if chatType == 'math':
59
+ instruction = "You are a helpful math tutor. Guide the user through the solution step by step."
60
+ elif chatType == "logic":
61
+ instruction = "you are a helpful tutor expert in logic. Guide the user through the solution step by step"
62
+ completion = client.beta.chat.completions.parse(
63
+ model = 'gpt-4o-2024-08-06',
64
+ messages = [
65
+ {"role": "system", "content": instruction},
66
+ {"role": "user", "content": prompt}
67
+ ],
68
+ response_format=MathReasoning,
69
+ )
70
+ return completion
71
+
72
  def genUsageStats(do_reset=False):
73
  result = []
74
  ttotal4o_in = 0
 
199
  password = txt.lower().strip()
200
  return [password, "*********"]
201
 
202
+ def parse_math(txt):
203
+ ref = 0
204
+ loc = txt.find(r'\\(')
205
+ if loc == -1:
206
+ return txt
207
+ while (True):
208
+ loc2 = txt[ref:].find(r'\\)')
209
+ if loc2 == -1:
210
+ break
211
+ loc = txt[ref:].find(r'\\(')
212
+ if loc > -1:
213
+ loc2 += 2
214
+ frag = lconv.convert(txt[ref:][loc:loc2])
215
+ txt = txt[:loc+ref] + frag + txt[loc2+ref:]
216
+ ref = len(txt[ref:loc]) + len(frag)
217
+ return txt
218
 
219
  def chat(prompt, user_window, pwd_window, past, response, gptModel, uploaded_image_file=''):
220
  image_gen_model = 'gpt-4o-2024-08-06'
 
241
  response = f'{log_cnt} log files\n{wav_cnt} .wav files\n{other_cnt} Other files:\n{others}\nlogs: {str(log_list)}'
242
  return [past, response, None, gptModel, uploaded_image_file]
243
  if user_window in unames and pwd_window == pwdList[unames.index(user_window)]:
244
+ chatType = 'normal'
245
+ prompt = prompt.strip()
246
+ if prompt.startswith('solve'):
247
+ prompt = 'How do I solve ' + prompt[5:] + ' Do not use Latex for math expressions.'
248
+ chatType = 'math'
249
+ elif prompt.startswith('puzzle'):
250
+ chatType = 'logic'
251
+ prompt = prompt[6:]
252
  past.append({"role":"user", "content":prompt})
253
  gen_image = (uploaded_image_file != '')
254
+ if chatType in ['math', 'logic']:
255
+ completion = solve(prompt, chatType)
256
+ reporting_model = image_gen_model
257
+ elif not gen_image:
258
  completion = client.chat.completions.create(model=gptModel,
259
  messages=past)
260
  reporting_model = gptModel
 
264
  reporting_model = image_gen_model
265
  if not msg == 'ok':
266
  return [past, msg, None, gptModel, uploaded_image_file]
267
+ if chatType in ['math', 'logic']:
268
+ dr = completion.choices[0].message.parsed.model_dump()
269
+ reply = pprint.pformat(dr)
270
+ # df = {'final_answer' : parse_math(dr['final_answer'])}
271
+ # df['steps'] = []
272
+ # for x in dr['steps']:
273
+ # df['steps'].append({'explanation': parse_math(x['explanation']), 'output' : parse_math(x['output'])})
274
+
275
+ # reply = pprint.pformat(df)
276
+ else:
277
+ reply = completion.choices[0].message.content
278
  tokens_in = completion.usage.prompt_tokens
279
  tokens_out = completion.usage.completion_tokens
280
  tokens = completion.usage.total_tokens
 
565
  for s in words_in:
566
  s = s.lstrip('- *@#$%^&_=+-')
567
  if len(s) > 0:
568
+ loc = s.find(' ')
569
  if loc > 1:
570
  val = s[0:loc]
571
  isnum = val.replace('.','0').isdecimal()