File size: 3,328 Bytes
51f2dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import sys
from litellm import acompletion
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateNotFound
from schemas import CreateSearchPlanRequest, CreateSearchPlanResponse, ExtractEntitiesRequest, ExtractEntitiesResponse, ExtractedRelationsResponse
from utils import build_visjs_graph, fmt_prompt
import logging

load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)


LLM_MODEL = os.environ.get('LLM_MODEL', default=None)
LLM_TOKEN = os.environ.get('LLM_TOKEN', default=None)
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', default=None)

if not LLM_MODEL and not LLM_TOKEN:
    logging.error("No LLM_TOKEN and LLM_MODEL were provided.")
    sys.exit(-1)

prompt_env = Environment(loader=FileSystemLoader(
    "prompts"), undefined=StrictUndefined, enable_async=True)

api = FastAPI()


@api.post("/extract_entities")
async def extract_entities(body: ExtractEntitiesRequest):
    """Extract entities from the given input text and return them"""
    # Extract entities from the text
    entities_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[
        {
            "role": "user",
            "content": await fmt_prompt(prompt_env, "ner/extract_entities", **{
                "response_format": ExtractEntitiesResponse.model_json_schema(),
                "input_text": body.content
            })
        }
    ], response_format=ExtractEntitiesResponse)

    extracted_entities = ExtractEntitiesResponse.model_validate_json(
        entities_completion.choices[0].message.content)

    # Extract relationships in a second step
    relations_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[
        {
            "role": "user",
            "content": await fmt_prompt(prompt_env, "ner/extract_relations", **{
                "response_format": ExtractedRelationsResponse.model_json_schema(),
                "input_text": body.content,
                "entities": extracted_entities.entities
            })
        }
    ], response_format=ExtractedRelationsResponse, num_retries=5)

    relation_model = ExtractedRelationsResponse.model_validate_json(
        relations_completion.choices[0].message.content)

    display_lists = build_visjs_graph(
        extracted_entities.entities, relation_model.relations)

    return display_lists


@api.post("/create_search_plan")
async def create_search_plan(body: CreateSearchPlanRequest):
    plan_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[
        {
            "role": "user",
            "content": await fmt_prompt(prompt_env, "search/create_search_plan", **{
                "response_format": CreateSearchPlanResponse.model_json_schema(),
                "user_query": body.query,
            })
        }
    ], response_format=CreateSearchPlanResponse)

    plan_model = CreateSearchPlanResponse.model_validate_json(
        plan_completion.choices[0].message.content)

    return plan_model


api.mount("/", StaticFiles(directory="static", html=True), name="static")