sugiv commited on
Commit
d8f18d9
·
1 Parent(s): 7c25703

Fixing generate and random problem API

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -9,6 +9,7 @@ import jwt
9
  from typing import Dict, Any
10
  import autopep8
11
  import textwrap
 
12
 
13
  from datasets import load_dataset
14
  from fastapi.responses import StreamingResponse
@@ -141,7 +142,7 @@ Here's the complete Python function implementation:
141
  formatted_code = extract_and_format_code(generated_text)
142
  return {"solution": formatted_code}
143
 
144
- def stream_solution(instruction: str, token: str) -> Dict[str, Any]:
145
  if not verify_token(token):
146
  return {"error": "Invalid token"}
147
 
@@ -160,10 +161,16 @@ Here's the complete Python function implementation:
160
  """
161
 
162
  def generate():
 
163
  for chunk in llm(full_prompt, stream=True, **generation_kwargs):
164
- yield chunk["choices"][0]["text"]
165
-
166
- return generate()
 
 
 
 
 
167
 
168
  def random_problem(token: str) -> Dict[str, Any]:
169
  if not verify_token(token):
 
9
  from typing import Dict, Any
10
  import autopep8
11
  import textwrap
12
+ import json
13
 
14
  from datasets import load_dataset
15
  from fastapi.responses import StreamingResponse
 
142
  formatted_code = extract_and_format_code(generated_text)
143
  return {"solution": formatted_code}
144
 
145
+ def stream_solution(instruction: str, token: str):
146
  if not verify_token(token):
147
  return {"error": "Invalid token"}
148
 
 
161
  """
162
 
163
  def generate():
164
+ generated_text = ""
165
  for chunk in llm(full_prompt, stream=True, **generation_kwargs):
166
+ token = chunk["choices"][0]["text"]
167
+ generated_text += token
168
+ yield json.dumps({"token": token, "generated_text": generated_text}) + "\n"
169
+
170
+ formatted_code = extract_and_format_code(generated_text)
171
+ yield json.dumps({"complete": True, "formatted_code": formatted_code}) + "\n"
172
+
173
+ return StreamingResponse(generate(), media_type="application/x-ndjson")
174
 
175
  def random_problem(token: str) -> Dict[str, Any]:
176
  if not verify_token(token):