|
from typing import Dict, List |
|
from jinja2 import Environment |
|
|
|
from schemas import ExtractedRelation |
|
|
|
|
|
def build_visjs_graph(entities: List[str], relations: List[ExtractedRelation]) -> Dict[str, List[Dict]]: |
|
"""Builds a vertex and edge graph for displaying in UI""" |
|
|
|
unique_entities = set(entities) |
|
entity_to_id = {entity: idx for idx, entity in enumerate(unique_entities)} |
|
nodes = [ |
|
{"id": entity_to_id[entity], "label": entity, "title": entity} |
|
for entity in unique_entities |
|
] |
|
|
|
|
|
edges = [] |
|
for rel in relations: |
|
start_id = entity_to_id.get(rel.start) |
|
end_id = entity_to_id.get(rel.to) |
|
if start_id is not None and end_id is not None: |
|
edges.append({ |
|
"from": start_id, |
|
"to": end_id, |
|
"label": rel.tag, |
|
"title": rel.description, |
|
"arrows": "to", |
|
}) |
|
|
|
return {"nodes": nodes, "edges": edges} |
|
|
|
|
|
async def fmt_prompt(env: Environment, prompt_id: str, **args): |
|
"""Returns a formatted prompt""" |
|
prompt = env.get_template(prompt_id) |
|
return await prompt.render_async(args) |
|
|