ai / entity.py
kevinhug's picture
async correct
2ea584d
raw
history blame
2.81 kB
from pydantic import BaseModel, Field
from typing import List
import instructor
import os
from groq import Groq
# Initialize with API key
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
# Enable instructor patches for Groq client
client = instructor.from_groq(client)
"""
import openai
client = instructor.from_openai(
openai.OpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama",
),
mode=instructor.Mode.JSON,
)
"""
llm = 'llama-3.1-8b-instant' if os.getenv("GROQ_API_KEY") else "qwen2.5"
class Property(BaseModel):
key: str
value: str
resolved_absolute_value: str
class Entity(BaseModel):
id: int = Field(
...,
description="Unique identifier for the entity, used for deduplication, design a scheme allows multiple entities",
)
subquote_string: List[str] = Field(
...,
description="Correctly resolved value of the entity, if the entity is a reference to another entity, this should be the id of the referenced entity, include a few more words before and after the value to allow for some context to be used in the resolution",
)
entity_title: str
properties: List[Property] = Field(
..., description="List of properties of the entity such as date, amount...etc", examples=[Property(key="Amount", value="+200", resolved_absolute_value="300"),
Property(key="Date", value="-5", resolved_absolute_value="2018-09-18")]
)
dependencies: List[int] = Field(
...,
description="List of entity ids that this entity depends or relies on to resolve it",
)
class DocumentExtraction(BaseModel):
entities: List[Entity] = Field(
...,
description="Body of the answer, each fact should be a separate object with a body and a list of sources such as Organization, Agreement Date, Asset...etc",
)
def entity_graph(content) -> DocumentExtraction:
return client.chat.completions.create(
model=llm, #"deepseek-r1", #"gpt-4","llama3.2", #
response_model=DocumentExtraction,
temperature=0.1,
messages=[
{
"role": "system",
"content": "You're world class entities resolution system. Ensure that each entity and its attributes are correctly resolved, meaning duplicates are merged and dependencies are established. Extract and resolve a list of entities from the following document:",
},
{
"role": "user",
"content": content,
},
],
)
def resolve(content):
model = entity_graph(content)
return model.model_dump_json(indent=2)
if __name__=='__main__':
content=""
print(resolve(content))