Spaces:
Sleeping
Sleeping
Lucas ARRIESSE
commited on
Commit
·
9e95c26
1
Parent(s):
e0c1af3
Add retry with max bailout
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCrit
|
|
8 |
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
9 |
from litellm.router import Router
|
10 |
from dotenv import load_dotenv
|
|
|
11 |
|
12 |
logging.basicConfig(
|
13 |
level=logging.INFO,
|
@@ -119,6 +120,8 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
|
119 |
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
|
120 |
"""Searches solutions solving the given grouping params using Gemini and grounded on google search"""
|
121 |
|
|
|
|
|
122 |
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
|
123 |
# ================== generate the solution with web grounding
|
124 |
req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
|
@@ -167,7 +170,7 @@ async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchRespons
|
|
167 |
)
|
168 |
return final_sol
|
169 |
|
170 |
-
solutions = await asyncio.gather(*[_search_inner(
|
171 |
logging.info(solutions)
|
172 |
final_solutions = [
|
173 |
sol for sol in solutions if not isinstance(sol, Exception)]
|
|
|
8 |
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
9 |
from litellm.router import Router
|
10 |
from dotenv import load_dotenv
|
11 |
+
from util import retry_until
|
12 |
|
13 |
logging.basicConfig(
|
14 |
level=logging.INFO,
|
|
|
120 |
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
|
121 |
"""Searches solutions solving the given grouping params using Gemini and grounded on google search"""
|
122 |
|
123 |
+
logging.info(f"Searching solutions for categories: {params.categories}")
|
124 |
+
|
125 |
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
|
126 |
# ================== generate the solution with web grounding
|
127 |
req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
|
|
|
170 |
)
|
171 |
return final_sol
|
172 |
|
173 |
+
solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True)
|
174 |
logging.info(solutions)
|
175 |
final_solutions = [
|
176 |
sol for sol in solutions if not isinstance(sol, Exception)]
|
util.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from typing import Awaitable, Callable, TypeVar
|
3 |
+
|
4 |
+
|
5 |
+
T = TypeVar("T")
|
6 |
+
A = TypeVar("A")
|
7 |
+
|
8 |
+
|
9 |
+
async def retry_until(
|
10 |
+
func: Callable[[A], Awaitable[T]],
|
11 |
+
arg: A,
|
12 |
+
predicate: Callable[[T], bool],
|
13 |
+
max_retries: int,
|
14 |
+
) -> T:
|
15 |
+
"""Retries the given async function until the passed in validation predicate returns true."""
|
16 |
+
last_value = await func(arg)
|
17 |
+
for _ in range(max_retries):
|
18 |
+
if predicate(last_value):
|
19 |
+
return last_value
|
20 |
+
last_value = await func(arg)
|
21 |
+
return last_value
|