File size: 2,824 Bytes
57cf043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import re
from logging import Logger
from typing import Dict, List, Optional, Tuple

from openai import OpenAI

from common.configuration import LLMConfiguration


class QueryClassification:
    def __init__(self, config: LLMConfiguration, prompt: str, logger: Logger):
        self.config = config
        self.logger = logger
        self.prompt = prompt
        self.pattern = r'\[\d+\]'
        
        # Initialize OpenAI client
        if self.config.base_url is not None:
            self.client = OpenAI(
                base_url=self.config.base_url,
                api_key=os.getenv(self.config.api_key_env)
            )
        else:
            self.client = None

    def query_classification(self, query: str) -> Tuple[str, Optional[Dict], Optional[List]]:
        """
        Classify the query using LLM
        
        Args:
            query: User query to classify

        Returns:
            Tuple containing query type, optional metadata and optional list
        """
        self.logger.info('Query Classification')
        if self.client is None:
            return '[3]', None, None
        for i in range(5):
            try:
                response = self.client.chat.completions.create(
                    model=self.config.model,
                    messages=[
                        {"role": "system", "content": self.prompt},
                        {"role": "user", "content": query}
                    ],
                    temperature=self.config.temperature,
                    top_p=self.config.top_p,
                    frequency_penalty=self.config.frequency_penalty,
                    presence_penalty=self.config.presence_penalty,
                    seed=self.config.seed
                )

                answer_llm = response.choices[0].message.content
                self.logger.info(f'Answer LLM {answer_llm}')

                # Process the response
                if re.search('%%', answer_llm):
                    index = re.search('%%', answer_llm).span()[1]
                    answer_llm = answer_llm[index:]
                if re.search('Конец ответа', answer_llm):
                    index = re.search('Конец ответа', answer_llm).span()[1]
                    answer_llm = answer_llm[:index]

                # Extract query type
                query_type = re.findall(self.pattern, answer_llm)
                if query_type:
                    query_type = query_type[0]
                else:
                    query_type = '[3]'

                return query_type, None, None

            except Exception as e:
                self.logger.error(f"Attempt {i+1} failed: {str(e)}")
                if i == 4:
                    self.logger.error("All attempts failed")
                    return '[3]', None, None