import os import sys from litellm import acompletion from dotenv import load_dotenv from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateNotFound from schemas import CreateSearchPlanRequest, CreateSearchPlanResponse, ExtractEntitiesRequest, ExtractEntitiesResponse, ExtractedRelationsResponse from utils import build_visjs_graph, fmt_prompt import logging load_dotenv() logging.basicConfig( level=logging.INFO, format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) LLM_MODEL = os.environ.get('LLM_MODEL', default=None) LLM_TOKEN = os.environ.get('LLM_TOKEN', default=None) LLM_BASE_URL = os.environ.get('LLM_BASE_URL', default=None) if not LLM_MODEL and not LLM_TOKEN: logging.error("No LLM_TOKEN and LLM_MODEL were provided.") sys.exit(-1) prompt_env = Environment(loader=FileSystemLoader( "prompts"), undefined=StrictUndefined, enable_async=True) api = FastAPI() @api.post("/extract_entities") async def extract_entities(body: ExtractEntitiesRequest): """Extract entities from the given input text and return them""" # Extract entities from the text entities_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ { "role": "user", "content": await fmt_prompt(prompt_env, "ner/extract_entities", **{ "response_format": ExtractEntitiesResponse.model_json_schema(), "input_text": body.content }) } ], response_format=ExtractEntitiesResponse) extracted_entities = ExtractEntitiesResponse.model_validate_json( entities_completion.choices[0].message.content) # Extract relationships in a second step relations_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ { "role": "user", "content": await fmt_prompt(prompt_env, "ner/extract_relations", **{ "response_format": ExtractedRelationsResponse.model_json_schema(), "input_text": body.content, "entities": extracted_entities.entities }) } ], response_format=ExtractedRelationsResponse, num_retries=5) relation_model = ExtractedRelationsResponse.model_validate_json( relations_completion.choices[0].message.content) display_lists = build_visjs_graph( extracted_entities.entities, relation_model.relations) return display_lists @api.post("/create_search_plan") async def create_search_plan(body: CreateSearchPlanRequest): plan_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ { "role": "user", "content": await fmt_prompt(prompt_env, "search/create_search_plan", **{ "response_format": CreateSearchPlanResponse.model_json_schema(), "user_query": body.query, }) } ], response_format=CreateSearchPlanResponse) plan_model = CreateSearchPlanResponse.model_validate_json( plan_completion.choices[0].message.content) return plan_model api.mount("/", StaticFiles(directory="static", html=True), name="static")