File size: 3,328 Bytes
51f2dc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import sys
from litellm import acompletion
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateNotFound
from schemas import CreateSearchPlanRequest, CreateSearchPlanResponse, ExtractEntitiesRequest, ExtractEntitiesResponse, ExtractedRelationsResponse
from utils import build_visjs_graph, fmt_prompt
import logging
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
LLM_MODEL = os.environ.get('LLM_MODEL', default=None)
LLM_TOKEN = os.environ.get('LLM_TOKEN', default=None)
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', default=None)
if not LLM_MODEL and not LLM_TOKEN:
logging.error("No LLM_TOKEN and LLM_MODEL were provided.")
sys.exit(-1)
prompt_env = Environment(loader=FileSystemLoader(
"prompts"), undefined=StrictUndefined, enable_async=True)
api = FastAPI()
@api.post("/extract_entities")
async def extract_entities(body: ExtractEntitiesRequest):
"""Extract entities from the given input text and return them"""
# Extract entities from the text
entities_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[
{
"role": "user",
"content": await fmt_prompt(prompt_env, "ner/extract_entities", **{
"response_format": ExtractEntitiesResponse.model_json_schema(),
"input_text": body.content
})
}
], response_format=ExtractEntitiesResponse)
extracted_entities = ExtractEntitiesResponse.model_validate_json(
entities_completion.choices[0].message.content)
# Extract relationships in a second step
relations_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[
{
"role": "user",
"content": await fmt_prompt(prompt_env, "ner/extract_relations", **{
"response_format": ExtractedRelationsResponse.model_json_schema(),
"input_text": body.content,
"entities": extracted_entities.entities
})
}
], response_format=ExtractedRelationsResponse, num_retries=5)
relation_model = ExtractedRelationsResponse.model_validate_json(
relations_completion.choices[0].message.content)
display_lists = build_visjs_graph(
extracted_entities.entities, relation_model.relations)
return display_lists
@api.post("/create_search_plan")
async def create_search_plan(body: CreateSearchPlanRequest):
plan_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[
{
"role": "user",
"content": await fmt_prompt(prompt_env, "search/create_search_plan", **{
"response_format": CreateSearchPlanResponse.model_json_schema(),
"user_query": body.query,
})
}
], response_format=CreateSearchPlanResponse)
plan_model = CreateSearchPlanResponse.model_validate_json(
plan_completion.choices[0].message.content)
return plan_model
api.mount("/", StaticFiles(directory="static", html=True), name="static")
|