Spaces:
Sleeping
Sleeping
import os | |
import re | |
from logging import Logger | |
from typing import Dict, List, Optional, Tuple | |
from openai import OpenAI | |
from common.configuration import LLMConfiguration | |
class QueryClassification: | |
def __init__(self, config: LLMConfiguration, prompt: str, logger: Logger): | |
self.config = config | |
self.logger = logger | |
self.prompt = prompt | |
self.pattern = r'\[\d+\]' | |
# Initialize OpenAI client | |
if self.config.base_url is not None: | |
self.client = OpenAI( | |
base_url=self.config.base_url, | |
api_key=os.getenv(self.config.api_key_env) | |
) | |
else: | |
self.client = None | |
def query_classification(self, query: str) -> Tuple[str, Optional[Dict], Optional[List]]: | |
""" | |
Classify the query using LLM | |
Args: | |
query: User query to classify | |
Returns: | |
Tuple containing query type, optional metadata and optional list | |
""" | |
self.logger.info('Query Classification') | |
if self.client is None: | |
return '[3]', None, None | |
for i in range(5): | |
try: | |
response = self.client.chat.completions.create( | |
model=self.config.model, | |
messages=[ | |
{"role": "system", "content": self.prompt}, | |
{"role": "user", "content": query} | |
], | |
temperature=self.config.temperature, | |
top_p=self.config.top_p, | |
frequency_penalty=self.config.frequency_penalty, | |
presence_penalty=self.config.presence_penalty, | |
seed=self.config.seed | |
) | |
answer_llm = response.choices[0].message.content | |
self.logger.info(f'Answer LLM {answer_llm}') | |
# Process the response | |
if re.search('%%', answer_llm): | |
index = re.search('%%', answer_llm).span()[1] | |
answer_llm = answer_llm[index:] | |
if re.search('Конец ответа', answer_llm): | |
index = re.search('Конец ответа', answer_llm).span()[1] | |
answer_llm = answer_llm[:index] | |
# Extract query type | |
query_type = re.findall(self.pattern, answer_llm) | |
if query_type: | |
query_type = query_type[0] | |
else: | |
query_type = '[3]' | |
return query_type, None, None | |
except Exception as e: | |
self.logger.error(f"Attempt {i+1} failed: {str(e)}") | |
if i == 4: | |
self.logger.error("All attempts failed") | |
return '[3]', None, None | |