|
import os |
|
import sys |
|
from litellm import acompletion |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateNotFound |
|
from schemas import CreateSearchPlanRequest, CreateSearchPlanResponse, ExtractEntitiesRequest, ExtractEntitiesResponse, ExtractedRelationsResponse |
|
from utils import build_visjs_graph, fmt_prompt |
|
import logging |
|
|
|
load_dotenv() |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
|
|
|
|
LLM_MODEL = os.environ.get('LLM_MODEL', default=None) |
|
LLM_TOKEN = os.environ.get('LLM_TOKEN', default=None) |
|
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', default=None) |
|
|
|
if not LLM_MODEL and not LLM_TOKEN: |
|
logging.error("No LLM_TOKEN and LLM_MODEL were provided.") |
|
sys.exit(-1) |
|
|
|
prompt_env = Environment(loader=FileSystemLoader( |
|
"prompts"), undefined=StrictUndefined, enable_async=True) |
|
|
|
api = FastAPI() |
|
|
|
|
|
@api.post("/extract_entities") |
|
async def extract_entities(body: ExtractEntitiesRequest): |
|
"""Extract entities from the given input text and return them""" |
|
|
|
entities_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ |
|
{ |
|
"role": "user", |
|
"content": await fmt_prompt(prompt_env, "ner/extract_entities", **{ |
|
"response_format": ExtractEntitiesResponse.model_json_schema(), |
|
"input_text": body.content |
|
}) |
|
} |
|
], response_format=ExtractEntitiesResponse) |
|
|
|
extracted_entities = ExtractEntitiesResponse.model_validate_json( |
|
entities_completion.choices[0].message.content) |
|
|
|
|
|
relations_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ |
|
{ |
|
"role": "user", |
|
"content": await fmt_prompt(prompt_env, "ner/extract_relations", **{ |
|
"response_format": ExtractedRelationsResponse.model_json_schema(), |
|
"input_text": body.content, |
|
"entities": extracted_entities.entities |
|
}) |
|
} |
|
], response_format=ExtractedRelationsResponse, num_retries=5) |
|
|
|
relation_model = ExtractedRelationsResponse.model_validate_json( |
|
relations_completion.choices[0].message.content) |
|
|
|
display_lists = build_visjs_graph( |
|
extracted_entities.entities, relation_model.relations) |
|
|
|
return display_lists |
|
|
|
|
|
@api.post("/create_search_plan") |
|
async def create_search_plan(body: CreateSearchPlanRequest): |
|
plan_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ |
|
{ |
|
"role": "user", |
|
"content": await fmt_prompt(prompt_env, "search/create_search_plan", **{ |
|
"response_format": CreateSearchPlanResponse.model_json_schema(), |
|
"user_query": body.query, |
|
}) |
|
} |
|
], response_format=CreateSearchPlanResponse) |
|
|
|
plan_model = CreateSearchPlanResponse.model_validate_json( |
|
plan_completion.choices[0].message.content) |
|
|
|
return plan_model |
|
|
|
|
|
api.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|