File size: 6,578 Bytes
8d7a1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e73b66
8d7a1e9
 
 
 
8e73b66
8d7a1e9
 
 
 
 
 
c16e9f5
8d7a1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e73b66
8d7a1e9
 
8e73b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d7a1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e73b66
 
 
 
 
 
 
 
 
 
 
 
 
8d7a1e9
c16e9f5
8d7a1e9
 
 
 
 
 
8e73b66
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from typing import List, Iterable, Optional
from pydantic import BaseModel, ValidationInfo, model_validator, Field,  field_validator
import instructor
import openai
import asyncio
import os



from groq import AsyncGroq
# Initialize with API key
client = AsyncGroq(api_key=os.getenv("GROQ_API_KEY"))

# Enable instructor patches for Groq client
client = instructor.from_groq(client)
"""
import openai
client = instructor.from_openai(
    openai.AsyncOpenAI(
        base_url="http://localhost:11434/v1",
        api_key="ollama",
    ),
    mode=instructor.Mode.JSON,
)
"""
llm = 'llama-3.1-8b-instant' if os.getenv("GROQ_API_KEY") else "qwen2.5" #"gemma3:12b" #"llama3.2" #"deepseek-r1"


class Tag(BaseModel):

    chain_of_thought:List[str]= Field(default_factory=list, description="the chain of thought led to the prediction", examples=["Let's think step by step. the customer explicitly mention donation, and there is a tag name with donation, tag the text with donation"])
    name: str
    id: int= Field(..., description="id for the specific tag")
    confidence: float = Field(
        default=0.5,
        ge=0,
        le=1,
        description="The confidence of the prediction(id, name) for the text, 0 is low, 1 is high",examples=[0.5,0.1,0.9]
    )

    @field_validator('confidence', mode="after")
    @classmethod
    def high_confidence(cls, c:float):
      "keep tag with confidence 0.6 or above"
      if c < 0.6:
          raise ValueError(f"low confidence `{c}` ")
      return c

    @model_validator(mode="after")
    def validate_ids(self, info: ValidationInfo):
        context = info.context
        if context:
            tags: List[Tag] = context.get("tags")

            assert self.id in {
                tag.id for tag in tags
            }, f"Tag ID {self.id} not found in context"

            assert self.name.lower() in {
                tag.name.lower() for tag in tags
            }, f"Tag name {self.name} not found in context"
        return self



class TagWithInstructions(Tag):
    instructions: str


class TagRequest(BaseModel):
    texts: List[str]
    tags: List[TagWithInstructions]

from judge import judge_relevance, Judgment
class TagResponse(BaseModel):
    texts: List[str]
    predictions: List[Optional[List[Tag]]]=Field(...,default_factory=list)
    judgment: List[Optional[List[Judgment]]]=Field(...,default_factory=list)

    async def judge(self):
      for i in range(len(self.texts)):
        p=self.predictions[i]
        if p:
          self.judgment.append(
            await asyncio.gather(*[
              judge_relevance(
                " ".join(t.chain_of_thought),
                texts[i],
                t.name
              ) for t in p
            ])
          )
        else:
          self.judgment.append(None)

sem = asyncio.Semaphore(2)

async def tag_single_request(text: str, tags: List[Tag]) -> Iterable[Tag]:
  allowed_tags = [(tag.id, tag.name) for tag in tags]
  allowed_tags_str = ", ".join([f"`{tag}`" for tag in allowed_tags])

  async with sem:

    try:
      result = await client.chat.completions.create(
        model=llm, #"gpt-4o-mini","deepseek-r1", #"llama3.2"
        temperature=0.3,
        max_retries=3,
        messages=[
          {
            "role": "system",
            "content": """You are a world-class text tagging system for customer feedback in the banking industry to classify banking product/services. 
  """
          },
          {"role": "user", "content": f"""Create minimum multiple Tag according to instruction most appropriate for the following text: `{text}`
            ### Instruction:
            Here are the allowed Tag(id, name), do not use any other Tag than these: {allowed_tags_str}
            Tag the name based on fact stated and directly mention in the text. Do not guess the name, Do not tag if tag not mention in the text. Do not use implication.
            Calculate the newly created Tag's confidence that Tag  fit to the text
            
            For each question, show your step-by-step thinking under 'chain_of_thought' in list of string, then clearly state your final answer under 'name'.
            """ },

        ],
        response_model=Iterable[Tag],
        validation_context={"tags": tags},
      )
      return result
    except Exception as e:
      print(e)


async def tag_request(request: TagRequest) -> TagResponse:
  predictions = await asyncio.gather(
    *[tag_single_request(text, request.tags) for text in request.texts]
  )
  pred_dedup=[]
  for tags in predictions:
    if tags is not None:
      dedup=[]
      #filter(lambda x: x.confidence > 0.7, tags)
      tags_s=sorted(tags, key=lambda x: (x.name, x.confidence))
      if len(tags_s)>0:
        dedup.append(tags_s[0])
        for j in range(1,len(tags_s)):
          if tags_s[j-1].name!=tags_s[j].name:
            dedup.append(tags_s[j])

      pred_dedup.append(dedup)
    else:
      pred_dedup.append(None)

  return TagResponse(
    texts=request.texts,
    predictions=pred_dedup,
  )


tags = [
    TagWithInstructions(id=0, name="online", instructions="text related to online banking"),
    TagWithInstructions(id=1, name="card", instructions="text related to credit card"),
    TagWithInstructions(id=2, name="cars", instructions="auto finance"),
    TagWithInstructions(id=3, name="mortgage", instructions="home mortgage"),
    TagWithInstructions(id=4, name="insurance", instructions="insurance"),
]


texts = """
"The online portal makes managing my mortgage payments so convenient."
;"RBC offer great mortgage for my home with competitive rate thank you";
"Low interest rate compared to other cards I’ve used. Highly recommend for responsible spenders.";
"The mobile check deposit feature saves me so much time. Banking made easy!";
"Affordable premiums with great coverage. Switched from my old provider and saved!"
"""

def judge_response(response):
  response.judge()

def judge(texts):
  texts=map(lambda t: t.strip(), texts.split(";"))
  request = TagRequest(texts=texts, tags=tags)

  response = asyncio.run(tag_request(request))
  #print(response.model_dump_json(indent=2))
  asyncio.run(response.judge())
  #[print(r.model_dump_json(indent=2)) for r in response]
  return response.model_dump_json(indent=2)

def bucket(texts):
  texts=map(lambda t: t.strip(), texts.split(";"))
  request = TagRequest(texts=texts, tags=tags)

  response = asyncio.run(tag_request(request))
  return response.model_dump_json(indent=2)

if __name__=="__main__":
  from pprint import pprint
  #print(bucket(texts))
  print(judge(texts))