File size: 6,578 Bytes
8d7a1e9 8e73b66 8d7a1e9 8e73b66 8d7a1e9 c16e9f5 8d7a1e9 8e73b66 8d7a1e9 8e73b66 8d7a1e9 8e73b66 8d7a1e9 c16e9f5 8d7a1e9 8e73b66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from typing import List, Iterable, Optional
from pydantic import BaseModel, ValidationInfo, model_validator, Field, field_validator
import instructor
import openai
import asyncio
import os
from groq import AsyncGroq
# Initialize with API key
client = AsyncGroq(api_key=os.getenv("GROQ_API_KEY"))
# Enable instructor patches for Groq client
client = instructor.from_groq(client)
"""
import openai
client = instructor.from_openai(
openai.AsyncOpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama",
),
mode=instructor.Mode.JSON,
)
"""
llm = 'llama-3.1-8b-instant' if os.getenv("GROQ_API_KEY") else "qwen2.5" #"gemma3:12b" #"llama3.2" #"deepseek-r1"
class Tag(BaseModel):
chain_of_thought:List[str]= Field(default_factory=list, description="the chain of thought led to the prediction", examples=["Let's think step by step. the customer explicitly mention donation, and there is a tag name with donation, tag the text with donation"])
name: str
id: int= Field(..., description="id for the specific tag")
confidence: float = Field(
default=0.5,
ge=0,
le=1,
description="The confidence of the prediction(id, name) for the text, 0 is low, 1 is high",examples=[0.5,0.1,0.9]
)
@field_validator('confidence', mode="after")
@classmethod
def high_confidence(cls, c:float):
"keep tag with confidence 0.6 or above"
if c < 0.6:
raise ValueError(f"low confidence `{c}` ")
return c
@model_validator(mode="after")
def validate_ids(self, info: ValidationInfo):
context = info.context
if context:
tags: List[Tag] = context.get("tags")
assert self.id in {
tag.id for tag in tags
}, f"Tag ID {self.id} not found in context"
assert self.name.lower() in {
tag.name.lower() for tag in tags
}, f"Tag name {self.name} not found in context"
return self
class TagWithInstructions(Tag):
instructions: str
class TagRequest(BaseModel):
texts: List[str]
tags: List[TagWithInstructions]
from judge import judge_relevance, Judgment
class TagResponse(BaseModel):
texts: List[str]
predictions: List[Optional[List[Tag]]]=Field(...,default_factory=list)
judgment: List[Optional[List[Judgment]]]=Field(...,default_factory=list)
async def judge(self):
for i in range(len(self.texts)):
p=self.predictions[i]
if p:
self.judgment.append(
await asyncio.gather(*[
judge_relevance(
" ".join(t.chain_of_thought),
texts[i],
t.name
) for t in p
])
)
else:
self.judgment.append(None)
sem = asyncio.Semaphore(2)
async def tag_single_request(text: str, tags: List[Tag]) -> Iterable[Tag]:
allowed_tags = [(tag.id, tag.name) for tag in tags]
allowed_tags_str = ", ".join([f"`{tag}`" for tag in allowed_tags])
async with sem:
try:
result = await client.chat.completions.create(
model=llm, #"gpt-4o-mini","deepseek-r1", #"llama3.2"
temperature=0.3,
max_retries=3,
messages=[
{
"role": "system",
"content": """You are a world-class text tagging system for customer feedback in the banking industry to classify banking product/services.
"""
},
{"role": "user", "content": f"""Create minimum multiple Tag according to instruction most appropriate for the following text: `{text}`
### Instruction:
Here are the allowed Tag(id, name), do not use any other Tag than these: {allowed_tags_str}
Tag the name based on fact stated and directly mention in the text. Do not guess the name, Do not tag if tag not mention in the text. Do not use implication.
Calculate the newly created Tag's confidence that Tag fit to the text
For each question, show your step-by-step thinking under 'chain_of_thought' in list of string, then clearly state your final answer under 'name'.
""" },
],
response_model=Iterable[Tag],
validation_context={"tags": tags},
)
return result
except Exception as e:
print(e)
async def tag_request(request: TagRequest) -> TagResponse:
predictions = await asyncio.gather(
*[tag_single_request(text, request.tags) for text in request.texts]
)
pred_dedup=[]
for tags in predictions:
if tags is not None:
dedup=[]
#filter(lambda x: x.confidence > 0.7, tags)
tags_s=sorted(tags, key=lambda x: (x.name, x.confidence))
if len(tags_s)>0:
dedup.append(tags_s[0])
for j in range(1,len(tags_s)):
if tags_s[j-1].name!=tags_s[j].name:
dedup.append(tags_s[j])
pred_dedup.append(dedup)
else:
pred_dedup.append(None)
return TagResponse(
texts=request.texts,
predictions=pred_dedup,
)
tags = [
TagWithInstructions(id=0, name="online", instructions="text related to online banking"),
TagWithInstructions(id=1, name="card", instructions="text related to credit card"),
TagWithInstructions(id=2, name="cars", instructions="auto finance"),
TagWithInstructions(id=3, name="mortgage", instructions="home mortgage"),
TagWithInstructions(id=4, name="insurance", instructions="insurance"),
]
texts = """
"The online portal makes managing my mortgage payments so convenient."
;"RBC offer great mortgage for my home with competitive rate thank you";
"Low interest rate compared to other cards I’ve used. Highly recommend for responsible spenders.";
"The mobile check deposit feature saves me so much time. Banking made easy!";
"Affordable premiums with great coverage. Switched from my old provider and saved!"
"""
def judge_response(response):
response.judge()
def judge(texts):
texts=map(lambda t: t.strip(), texts.split(";"))
request = TagRequest(texts=texts, tags=tags)
response = asyncio.run(tag_request(request))
#print(response.model_dump_json(indent=2))
asyncio.run(response.judge())
#[print(r.model_dump_json(indent=2)) for r in response]
return response.model_dump_json(indent=2)
def bucket(texts):
texts=map(lambda t: t.strip(), texts.split(";"))
request = TagRequest(texts=texts, tags=tags)
response = asyncio.run(tag_request(request))
return response.model_dump_json(indent=2)
if __name__=="__main__":
from pprint import pprint
#print(bucket(texts))
print(judge(texts))
|