Spaces:
Sleeping
Sleeping
import asyncio | |
import json | |
import logging | |
import os | |
import sys | |
import uvicorn | |
from fastapi import APIRouter, FastAPI | |
from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse, SolutionSearchV2Request, TechnologyData | |
from jinja2 import Environment, FileSystemLoader, StrictUndefined | |
from litellm.router import Router | |
from dotenv import load_dotenv | |
from util import retry_until | |
from httpx import AsyncClient | |
logging.basicConfig( | |
level=logging.INFO, | |
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
# Load .env files | |
load_dotenv() | |
if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ: | |
logging.error( | |
"No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting") | |
sys.exit(-1) | |
# LiteLLM router | |
llm_router = Router(model_list=[ | |
{ | |
"model_name": "chat", | |
"litellm_params": { | |
"model": os.environ.get("LLM_MODEL"), | |
"api_key": os.environ.get("LLM_API_KEY"), | |
"rpm": 15, | |
"max_parallel_requests": 4, | |
"allowed_fails": 1, | |
"cooldown_time": 60, | |
"max_retries": 10, | |
} | |
} | |
], num_retries=10, retry_after=30) | |
# HTTP client | |
INSIGHT_FINDER_BASE_URL = "https://organizedprogrammers-insight-finder.hf.space/" | |
http_client = AsyncClient(verify=os.environ.get( | |
"NO_SSL", "0") == "1", timeout=None) | |
# Jinja2 environment to load prompt templates | |
prompt_env = Environment(loader=FileSystemLoader( | |
'prompts'), enable_async=True, undefined=StrictUndefined) | |
api = FastAPI(docs_url="/", title="Reqxtract-API", | |
description=open("docs/docs.md").read()) | |
# requirements routes | |
requirements_router = APIRouter(prefix="/reqs", tags=["requirements"]) | |
# solution routes | |
solution_router = APIRouter(prefix="/solution", tags=["solution"]) | |
async def format_prompt(prompt_name: str, **args) -> str: | |
"""Helper to format a prompt""" | |
return await prompt_env.get_template(prompt_name).render_async(args) | |
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse: | |
"""Categorize the given service requirements into categories""" | |
MAX_ATTEMPTS = 5 | |
categories: list[_ReqGroupingCategory] = [] | |
messages = [] | |
# categorize the requirements using their indices | |
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{ | |
"requirements": [rq.model_dump() for rq in params.requirements], | |
"max_n_categories": params.max_n_categories, | |
"response_schema": _ReqGroupingOutput.model_json_schema()}) | |
# add system prompt with requirements | |
messages.append({"role": "user", "content": req_prompt}) | |
# ensure all requirements items are processed | |
for attempt in range(MAX_ATTEMPTS): | |
req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput) | |
output = _ReqGroupingOutput.model_validate_json( | |
req_completion.choices[0].message.content) | |
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category | |
valid_ids_universe = set(range(0, len(params.requirements))) | |
assigned_ids = { | |
req_id for cat in output.categories for req_id in cat.items} | |
# keep only non-hallucinated, valid assigned ids | |
valid_assigned_ids = assigned_ids.intersection(valid_ids_universe) | |
# check for remaining requirements assigned to none of the categories | |
unassigned_ids = valid_ids_universe - valid_assigned_ids | |
if len(unassigned_ids) == 0: | |
categories.extend(output.categories) | |
break | |
else: | |
messages.append(req_completion.choices[0].message) | |
messages.append( | |
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."}) | |
if attempt == MAX_ATTEMPTS - 1: | |
raise Exception("Failed to classify all requirements") | |
# build the final category objects | |
# remove the invalid (likely hallucinated) requirement IDs | |
final_categories = [] | |
for idx, cat in enumerate(output.categories): | |
final_categories.append(ReqGroupingCategory( | |
id=idx, | |
title=cat.title, | |
requirements=[params.requirements[i] | |
for i in cat.items if i < len(params.requirements)] | |
)) | |
return ReqGroupingResponse(categories=final_categories) | |
# ========================================================= Solution Endpoints =========================================================== | |
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse: | |
"""Searches solutions solving the given grouping params using Gemini and grounded on google search""" | |
logging.info(f"Searching solutions for categories: {params.categories}") | |
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel: | |
# ================== generate the solution with web grounding | |
req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{ | |
"category": cat.model_dump(), | |
}) | |
# generate the completion in non-structured mode. | |
# the googleSearch tool enables grounding gemini with google search | |
# this also forces gemini to perform a tool call | |
req_completion = await llm_router.acompletion(model="chat", messages=[ | |
{"role": "user", "content": req_prompt} | |
], tools=[{"googleSearch": {}}], tool_choice="required") | |
# ==================== structure the solution as a json =================================== | |
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{ | |
"solution": req_completion.choices[0].message.content, | |
"response_schema": _SearchedSolutionModel.model_json_schema() | |
}) | |
structured_completion = await llm_router.acompletion(model="chat", messages=[ | |
{"role": "user", "content": structured_prompt} | |
], response_format=_SearchedSolutionModel) | |
solution_model = _SearchedSolutionModel.model_validate_json( | |
structured_completion.choices[0].message.content) | |
# ======================== build the final solution object ================================ | |
sources_metadata = [] | |
# extract the source metadata from the search items, if gemini actually called the tools to search .... and didn't hallucinated | |
try: | |
sources_metadata.extend([{"name": a["web"]["title"], "url": a["web"]["uri"]} | |
for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']]) | |
except KeyError as ke: | |
pass | |
final_sol = SolutionModel( | |
Context="", | |
Requirements=[ | |
cat.requirements[i].requirement for i in solution_model.requirement_ids | |
], | |
Problem_Description=solution_model.problem_description, | |
Solution_Description=solution_model.solution_description, | |
References=sources_metadata, | |
Category_Id=cat.id, | |
) | |
return final_sol | |
solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True) | |
logging.info(solutions) | |
final_solutions = [ | |
sol for sol in solutions if not isinstance(sol, Exception)] | |
return SolutionSearchResponse(solutions=final_solutions) | |
async def search_solutions(params: SolutionSearchV2Request) -> SolutionSearchResponse: | |
"""Searches solutions solving the given grouping params and respecting the user constraints using Gemini and grounded on google search""" | |
logging.info(f"Searching solutions for categories: {params}") | |
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel: | |
# ================== generate the solution with web grounding | |
req_prompt = await prompt_env.get_template("search_solution_v2.txt").render_async(**{ | |
"category": cat.model_dump(), | |
"user_constraints": params.user_constraints, | |
}) | |
# generate the completion in non-structured mode. | |
# the googleSearch tool enables grounding gemini with google search | |
# this also forces gemini to perform a tool call | |
req_completion = await llm_router.acompletion(model="chat", messages=[ | |
{"role": "user", "content": req_prompt} | |
], tools=[{"googleSearch": {}}], tool_choice="required") | |
# ==================== structure the solution as a json =================================== | |
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{ | |
"solution": req_completion.choices[0].message.content, | |
"response_schema": _SearchedSolutionModel.model_json_schema() | |
}) | |
structured_completion = await llm_router.acompletion(model="chat", messages=[ | |
{"role": "user", "content": structured_prompt} | |
], response_format=_SearchedSolutionModel) | |
solution_model = _SearchedSolutionModel.model_validate_json( | |
structured_completion.choices[0].message.content) | |
# ======================== build the final solution object ================================ | |
sources_metadata = [] | |
# extract the source metadata from the search items, if gemini actually called the tools to search .... and didn't hallucinated | |
try: | |
sources_metadata.extend([{"name": a["web"]["title"], "url": a["web"]["uri"]} | |
for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']]) | |
except KeyError as ke: | |
pass | |
final_sol = SolutionModel( | |
Context="", | |
Requirements=[ | |
cat.requirements[i].requirement for i in solution_model.requirement_ids | |
], | |
Problem_Description=solution_model.problem_description, | |
Solution_Description=solution_model.solution_description, | |
References=sources_metadata, | |
Category_Id=cat.id, | |
) | |
return final_sol | |
solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True) | |
logging.info(solutions) | |
final_solutions = [ | |
sol for sol in solutions if not isinstance(sol, Exception)] | |
return SolutionSearchResponse(solutions=final_solutions) | |
# ================================================================================================================= | |
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse: | |
"""Criticize the challenges, weaknesses and limitations of the provided solutions.""" | |
async def __criticize_single(solution: SolutionModel): | |
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{ | |
"solutions": [solution.model_dump()], | |
"response_schema": _SolutionCriticismOutput.model_json_schema() | |
}) | |
req_completion = await llm_router.acompletion( | |
model="chat", | |
messages=[{"role": "user", "content": req_prompt}], | |
response_format=_SolutionCriticismOutput | |
) | |
criticism_out = _SolutionCriticismOutput.model_validate_json( | |
req_completion.choices[0].message.content | |
) | |
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0]) | |
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False) | |
return CritiqueResponse(critiques=critiques) | |
async def refine_solutions(params: CritiqueResponse) -> SolutionSearchResponse: | |
"""Refines the previously critiqued solutions.""" | |
async def __refine_solution(crit: SolutionCriticism): | |
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{ | |
"solution": crit.solution.model_dump(), | |
"criticism": crit.criticism, | |
"response_schema": _RefinedSolutionModel.model_json_schema(), | |
}) | |
req_completion = await llm_router.acompletion(model="chat", messages=[ | |
{"role": "user", "content": req_prompt} | |
], response_format=_RefinedSolutionModel) | |
req_model = _RefinedSolutionModel.model_validate_json( | |
req_completion.choices[0].message.content) | |
# copy previous solution model | |
refined_solution = crit.solution.model_copy(deep=True) | |
refined_solution.Problem_Description = req_model.problem_description | |
refined_solution.Solution_Description = req_model.solution_description | |
return refined_solution | |
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False) | |
return SolutionSearchResponse(solutions=refined_solutions) | |
# ======================================== Solution generation using Insights Finder ================== | |
async def search_solutions_if(req: SolutionSearchV2Request) -> SolutionSearchResponse: | |
async def _search_solution_inner(cat: ReqGroupingCategory): | |
# process requirements into insight finder format | |
fmt_completion = await llm_router.acompletion("chat", messages=[ | |
{ | |
"role": "user", | |
"content": await format_prompt("if/format_requirements.txt", **{ | |
"category": cat.model_dump(), | |
"response_schema": InsightFinderConstraintsList.model_json_schema() | |
}) | |
}], response_format=InsightFinderConstraintsList) | |
fmt_model = InsightFinderConstraintsList.model_validate_json( | |
fmt_completion.choices[0].message.content) | |
out = {'constraints': { | |
cons.title: cons.description for cons in fmt_model.constraints}} | |
# logging.info(out) | |
# fetch technologies from insight finder | |
# translate from a structured output to a dict for insights finder | |
technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(out)) | |
technologies = TechnologyData.model_validate(technologies_req.json()) | |
# =============================================================== synthesize solution using LLM ========================================= | |
format_solution = await llm_router.acompletion("chat", messages=[{ | |
"role": "user", | |
"content": await format_prompt("if/synthesize_solution.txt", **{ | |
"category": cat.model_dump(), | |
"technologies": technologies.model_dump()["technologies"], | |
"user_constraints": None, | |
"response_schema": _SearchedSolutionModel.model_json_schema() | |
})} | |
], response_format=_SearchedSolutionModel) | |
format_solution_model = _SearchedSolutionModel.model_validate_json( | |
format_solution.choices[0].message.content) | |
final_solution = SolutionModel( | |
Context="", | |
Requirements=[ | |
cat.requirements[i].requirement for i in format_solution_model.requirement_ids | |
], | |
Problem_Description=format_solution_model.problem_description, | |
Solution_Description=format_solution_model.solution_description, | |
References=[], | |
Category_Id=cat.id, | |
) | |
# ======================================================================================================================================== | |
return final_solution | |
tasks = await asyncio.gather(*[_search_solution_inner(cat) for cat in req.categories], return_exceptions=True) | |
final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)] | |
return SolutionSearchResponse(solutions=final_solutions) | |
api.include_router(requirements_router) | |
api.include_router(solution_router) | |
uvicorn.run(api, host="0.0.0.0", port=8000) | |