|
from pydantic import BaseModel, Field |
|
from typing import List |
|
|
|
|
|
import instructor |
|
import os |
|
|
|
from groq import Groq |
|
|
|
client = Groq(api_key=os.getenv("GROQ_API_KEY")) |
|
|
|
|
|
client = instructor.from_groq(client) |
|
""" |
|
import openai |
|
client = instructor.from_openai( |
|
openai.OpenAI( |
|
base_url="http://localhost:11434/v1", |
|
api_key="ollama", |
|
), |
|
mode=instructor.Mode.JSON, |
|
) |
|
""" |
|
|
|
llm = 'llama-3.1-8b-instant' if os.getenv("GROQ_API_KEY") else "qwen2.5" |
|
|
|
class Property(BaseModel): |
|
key: str |
|
value: str |
|
resolved_absolute_value: str |
|
|
|
|
|
class Entity(BaseModel): |
|
id: int = Field( |
|
..., |
|
description="Unique identifier for the entity, used for deduplication, design a scheme allows multiple entities", |
|
) |
|
subquote_string: List[str] = Field( |
|
..., |
|
description="Correctly resolved value of the entity, if the entity is a reference to another entity, this should be the id of the referenced entity, include a few more words before and after the value to allow for some context to be used in the resolution", |
|
) |
|
entity_title: str |
|
properties: List[Property] = Field( |
|
..., description="List of properties of the entity such as date, amount...etc", examples=[Property(key="Amount", value="+200", resolved_absolute_value="300"), |
|
Property(key="Date", value="-5", resolved_absolute_value="2018-09-18")] |
|
) |
|
dependencies: List[int] = Field( |
|
..., |
|
description="List of entity ids that this entity depends or relies on to resolve it", |
|
) |
|
|
|
|
|
class DocumentExtraction(BaseModel): |
|
entities: List[Entity] = Field( |
|
..., |
|
description="Body of the answer, each fact should be a separate object with a body and a list of sources such as Organization, Agreement Date, Asset...etc", |
|
) |
|
|
|
|
|
def entity_graph(content) -> DocumentExtraction: |
|
return client.chat.completions.create( |
|
model=llm, |
|
response_model=DocumentExtraction, |
|
temperature=0.1, |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": "You're world class entities resolution system. Ensure that each entity and its attributes are correctly resolved, meaning duplicates are merged and dependencies are established. Extract and resolve a list of entities from the following document:", |
|
|
|
}, |
|
{ |
|
"role": "user", |
|
"content": content, |
|
}, |
|
], |
|
) |
|
|
|
def resolve(content): |
|
model = entity_graph(content) |
|
return model.model_dump_json(indent=2) |
|
|
|
if __name__=='__main__': |
|
content="" |
|
print(resolve(content)) |
|
|