ai / classify.py
kevinhug's picture
judge
8e73b66
raw
history blame
6.58 kB
from typing import List, Iterable, Optional
from pydantic import BaseModel, ValidationInfo, model_validator, Field, field_validator
import instructor
import openai
import asyncio
import os
from groq import AsyncGroq
# Initialize with API key
client = AsyncGroq(api_key=os.getenv("GROQ_API_KEY"))
# Enable instructor patches for Groq client
client = instructor.from_groq(client)
"""
import openai
client = instructor.from_openai(
openai.AsyncOpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama",
),
mode=instructor.Mode.JSON,
)
"""
llm = 'llama-3.1-8b-instant' if os.getenv("GROQ_API_KEY") else "qwen2.5" #"gemma3:12b" #"llama3.2" #"deepseek-r1"
class Tag(BaseModel):
chain_of_thought:List[str]= Field(default_factory=list, description="the chain of thought led to the prediction", examples=["Let's think step by step. the customer explicitly mention donation, and there is a tag name with donation, tag the text with donation"])
name: str
id: int= Field(..., description="id for the specific tag")
confidence: float = Field(
default=0.5,
ge=0,
le=1,
description="The confidence of the prediction(id, name) for the text, 0 is low, 1 is high",examples=[0.5,0.1,0.9]
)
@field_validator('confidence', mode="after")
@classmethod
def high_confidence(cls, c:float):
"keep tag with confidence 0.6 or above"
if c < 0.6:
raise ValueError(f"low confidence `{c}` ")
return c
@model_validator(mode="after")
def validate_ids(self, info: ValidationInfo):
context = info.context
if context:
tags: List[Tag] = context.get("tags")
assert self.id in {
tag.id for tag in tags
}, f"Tag ID {self.id} not found in context"
assert self.name.lower() in {
tag.name.lower() for tag in tags
}, f"Tag name {self.name} not found in context"
return self
class TagWithInstructions(Tag):
instructions: str
class TagRequest(BaseModel):
texts: List[str]
tags: List[TagWithInstructions]
from judge import judge_relevance, Judgment
class TagResponse(BaseModel):
texts: List[str]
predictions: List[Optional[List[Tag]]]=Field(...,default_factory=list)
judgment: List[Optional[List[Judgment]]]=Field(...,default_factory=list)
async def judge(self):
for i in range(len(self.texts)):
p=self.predictions[i]
if p:
self.judgment.append(
await asyncio.gather(*[
judge_relevance(
" ".join(t.chain_of_thought),
texts[i],
t.name
) for t in p
])
)
else:
self.judgment.append(None)
sem = asyncio.Semaphore(2)
async def tag_single_request(text: str, tags: List[Tag]) -> Iterable[Tag]:
allowed_tags = [(tag.id, tag.name) for tag in tags]
allowed_tags_str = ", ".join([f"`{tag}`" for tag in allowed_tags])
async with sem:
try:
result = await client.chat.completions.create(
model=llm, #"gpt-4o-mini","deepseek-r1", #"llama3.2"
temperature=0.3,
max_retries=3,
messages=[
{
"role": "system",
"content": """You are a world-class text tagging system for customer feedback in the banking industry to classify banking product/services.
"""
},
{"role": "user", "content": f"""Create minimum multiple Tag according to instruction most appropriate for the following text: `{text}`
### Instruction:
Here are the allowed Tag(id, name), do not use any other Tag than these: {allowed_tags_str}
Tag the name based on fact stated and directly mention in the text. Do not guess the name, Do not tag if tag not mention in the text. Do not use implication.
Calculate the newly created Tag's confidence that Tag fit to the text
For each question, show your step-by-step thinking under 'chain_of_thought' in list of string, then clearly state your final answer under 'name'.
""" },
],
response_model=Iterable[Tag],
validation_context={"tags": tags},
)
return result
except Exception as e:
print(e)
async def tag_request(request: TagRequest) -> TagResponse:
predictions = await asyncio.gather(
*[tag_single_request(text, request.tags) for text in request.texts]
)
pred_dedup=[]
for tags in predictions:
if tags is not None:
dedup=[]
#filter(lambda x: x.confidence > 0.7, tags)
tags_s=sorted(tags, key=lambda x: (x.name, x.confidence))
if len(tags_s)>0:
dedup.append(tags_s[0])
for j in range(1,len(tags_s)):
if tags_s[j-1].name!=tags_s[j].name:
dedup.append(tags_s[j])
pred_dedup.append(dedup)
else:
pred_dedup.append(None)
return TagResponse(
texts=request.texts,
predictions=pred_dedup,
)
tags = [
TagWithInstructions(id=0, name="online", instructions="text related to online banking"),
TagWithInstructions(id=1, name="card", instructions="text related to credit card"),
TagWithInstructions(id=2, name="cars", instructions="auto finance"),
TagWithInstructions(id=3, name="mortgage", instructions="home mortgage"),
TagWithInstructions(id=4, name="insurance", instructions="insurance"),
]
texts = """
"The online portal makes managing my mortgage payments so convenient."
;"RBC offer great mortgage for my home with competitive rate thank you";
"Low interest rate compared to other cards I’ve used. Highly recommend for responsible spenders.";
"The mobile check deposit feature saves me so much time. Banking made easy!";
"Affordable premiums with great coverage. Switched from my old provider and saved!"
"""
def judge_response(response):
response.judge()
def judge(texts):
texts=map(lambda t: t.strip(), texts.split(";"))
request = TagRequest(texts=texts, tags=tags)
response = asyncio.run(tag_request(request))
#print(response.model_dump_json(indent=2))
asyncio.run(response.judge())
#[print(r.model_dump_json(indent=2)) for r in response]
return response.model_dump_json(indent=2)
def bucket(texts):
texts=map(lambda t: t.strip(), texts.split(";"))
request = TagRequest(texts=texts, tags=tags)
response = asyncio.run(tag_request(request))
return response.model_dump_json(indent=2)
if __name__=="__main__":
from pprint import pprint
#print(bucket(texts))
print(judge(texts))