File size: 1,225 Bytes
51f2dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Dict, List
from jinja2 import Environment

from schemas import ExtractedRelation


def build_visjs_graph(entities: List[str], relations: List[ExtractedRelation]) -> Dict[str, List[Dict]]:
    """Builds a vertex and edge graph for displaying in UI"""

    unique_entities = set(entities)  # maintains order, removes duplicates
    entity_to_id = {entity: idx for idx, entity in enumerate(unique_entities)}
    nodes = [
        {"id": entity_to_id[entity], "label": entity, "title": entity}
        for entity in unique_entities
    ]

    # Create edges list from relations
    edges = []
    for rel in relations:
        start_id = entity_to_id.get(rel.start)
        end_id = entity_to_id.get(rel.to)
        if start_id is not None and end_id is not None:
            edges.append({
                "from": start_id,
                "to": end_id,
                "label": rel.tag,
                "title": rel.description,
                "arrows": "to",
            })

    return {"nodes": nodes, "edges": edges}


async def fmt_prompt(env: Environment, prompt_id: str, **args):
    """Returns a formatted prompt"""
    prompt = env.get_template(prompt_id)
    return await prompt.render_async(args)