Lucas ARRIESSE commited on
Commit
9e95c26
·
1 Parent(s): e0c1af3

Add retry with max bailout

Browse files
Files changed (2) hide show
  1. app.py +4 -1
  2. util.py +21 -0
app.py CHANGED
@@ -8,6 +8,7 @@ from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCrit
8
  from jinja2 import Environment, FileSystemLoader, StrictUndefined
9
  from litellm.router import Router
10
  from dotenv import load_dotenv
 
11
 
12
  logging.basicConfig(
13
  level=logging.INFO,
@@ -119,6 +120,8 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
119
  async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
120
  """Searches solutions solving the given grouping params using Gemini and grounded on google search"""
121
 
 
 
122
  async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
123
  # ================== generate the solution with web grounding
124
  req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
@@ -167,7 +170,7 @@ async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchRespons
167
  )
168
  return final_sol
169
 
170
- solutions = await asyncio.gather(*[_search_inner(cat) for cat in params.categories], return_exceptions=True)
171
  logging.info(solutions)
172
  final_solutions = [
173
  sol for sol in solutions if not isinstance(sol, Exception)]
 
8
  from jinja2 import Environment, FileSystemLoader, StrictUndefined
9
  from litellm.router import Router
10
  from dotenv import load_dotenv
11
+ from util import retry_until
12
 
13
  logging.basicConfig(
14
  level=logging.INFO,
 
120
  async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
121
  """Searches solutions solving the given grouping params using Gemini and grounded on google search"""
122
 
123
+ logging.info(f"Searching solutions for categories: {params.categories}")
124
+
125
  async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
126
  # ================== generate the solution with web grounding
127
  req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
 
170
  )
171
  return final_sol
172
 
173
+ solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True)
174
  logging.info(solutions)
175
  final_solutions = [
176
  sol for sol in solutions if not isinstance(sol, Exception)]
util.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Awaitable, Callable, TypeVar
3
+
4
+
5
+ T = TypeVar("T")
6
+ A = TypeVar("A")
7
+
8
+
9
+ async def retry_until(
10
+ func: Callable[[A], Awaitable[T]],
11
+ arg: A,
12
+ predicate: Callable[[T], bool],
13
+ max_retries: int,
14
+ ) -> T:
15
+ """Retries the given async function until the passed in validation predicate returns true."""
16
+ last_value = await func(arg)
17
+ for _ in range(max_retries):
18
+ if predicate(last_value):
19
+ return last_value
20
+ last_value = await func(arg)
21
+ return last_value