|
from haystack import * |
|
from haystack.components.generators.openai import OpenAIGenerator |
|
from haystack.components.builders import PromptBuilder |
|
from haystack.components.embedders import SentenceTransformersDocumentEmbedder |
|
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever |
|
from haystack.document_stores.in_memory import InMemoryDocumentStore |
|
from haystack.utils import Secret |
|
from tenacity import retry, stop_after_attempt, wait_exponential |
|
from pathlib import Path |
|
import hashlib |
|
from datetime import * |
|
from typing import * |
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from rouge_score import rouge_scorer |
|
import pandas as pd |
|
from dataclasses import * |
|
import json |
|
import logging |
|
import os |
|
import re |
|
import pickle |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class OpenAIDateParser: |
|
"""Uses OpenAI to parse complex Thai date formats""" |
|
|
|
def __init__(self, api_key: str, model: str = "gpt-4o"): |
|
self.generator = OpenAIGenerator( |
|
api_key=Secret.from_token(api_key), |
|
model=model |
|
) |
|
self.prompt_builder = PromptBuilder( |
|
template=""" |
|
Parse the following Thai date range into a structured format: |
|
Date: {{date}} |
|
|
|
Return in JSON format: |
|
{ |
|
"start_date": "YYYY-MM-DD", |
|
"end_date": "YYYY-MM-DD" (if range), |
|
"is_range": true/false |
|
} |
|
|
|
Notes: |
|
- Convert Buddhist Era (BE) to CE |
|
- Handle abbreviated Thai months |
|
- Account for date ranges with dashes |
|
- Return null for end_date if it's a single date |
|
|
|
Example inputs and outputs: |
|
Input: "จ 8 ก.ค. – จ 19 ส.ค. 67" |
|
Output: {"start_date": "2024-07-08", "end_date": "2024-08-19", "is_range": true} |
|
|
|
Input: "15 มกราคม 2567" |
|
Output: {"start_date": "2024-01-15", "end_date": null, "is_range": false} |
|
""" |
|
) |
|
|
|
async def parse_date(self, date_str: str) -> Dict[str, Union[str, bool]]: |
|
"""Parse complex Thai date format using OpenAI""" |
|
try: |
|
|
|
result = self.prompt_builder.run(date=date_str) |
|
|
|
|
|
response = await self.generator.arun(prompt=result["prompt"]) |
|
|
|
if not response or not response.get("replies"): |
|
raise ValueError("Empty response from OpenAI") |
|
|
|
|
|
parsed = json.loads(response["replies"][0]) |
|
|
|
|
|
for date_field in ['start_date', 'end_date']: |
|
if parsed.get(date_field): |
|
datetime.strptime(parsed[date_field], '%Y-%m-%d') |
|
|
|
return parsed |
|
|
|
except Exception as e: |
|
logger.error(f"OpenAI date parsing failed for '{date_str}': {str(e)}") |
|
raise ValueError(f"Could not parse date: {date_str}") |
|
|
|
@dataclass |
|
class ValidationResult: |
|
"""Stores the result of a validation check""" |
|
is_valid: bool |
|
errors: List[str] |
|
warnings: List[str] |
|
normalized_data: Dict[str, str] |
|
|
|
class ThaiTextPreprocessor: |
|
"""Handles Thai text preprocessing and normalization""" |
|
|
|
|
|
CHAR_MAP = { |
|
'ํา': 'ำ', |
|
'์': '', |
|
'–': '-', |
|
'—': '-', |
|
'٫': ',', |
|
} |
|
|
|
@classmethod |
|
def normalize_thai_text(cls, text: str) -> str: |
|
"""Normalize Thai text by applying character mappings and spacing rules""" |
|
if not text: |
|
return text |
|
|
|
|
|
for old, new in cls.CHAR_MAP.items(): |
|
text = text.replace(old, new) |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text.strip()) |
|
|
|
|
|
thai_digits = '๐๑๒๓๔๕๖๗๘๙' |
|
arabic_digits = '0123456789' |
|
|
|
for thai, arabic in zip(thai_digits, arabic_digits): |
|
text = text.replace(thai, arabic) |
|
|
|
return text |
|
|
|
class CalendarEventValidator: |
|
"""Validates and preprocesses calendar events""" |
|
|
|
def __init__(self, openai_api_key: str): |
|
self.preprocessor = ThaiTextPreprocessor() |
|
self.date_parser = OpenAIDateParser(api_key=openai_api_key) |
|
|
|
async def validate_event(self, event: 'CalendarEvent') -> ValidationResult: |
|
"""Validate a calendar event and return validation results""" |
|
errors = [] |
|
warnings = [] |
|
normalized_data = {} |
|
|
|
|
|
if event.date: |
|
try: |
|
parsed_date = await self.date_parser.parse_date(event.date) |
|
normalized_data['date'] = parsed_date['start_date'] |
|
|
|
|
|
if parsed_date['is_range'] and parsed_date['end_date']: |
|
range_note = f"ถึงวันที่ {parsed_date['end_date']}" |
|
if event.note: |
|
normalized_data['note'] = f"{event.note}; {range_note}" |
|
else: |
|
normalized_data['note'] = range_note |
|
|
|
except ValueError as e: |
|
errors.append(f"Invalid date format: {event.date}") |
|
else: |
|
errors.append("Date is required") |
|
|
|
|
|
if event.time: |
|
time_pattern = r'^([01]?[0-9]|2[0-3]):([0-5][0-9])$' |
|
if not re.match(time_pattern, event.time): |
|
errors.append(f"Invalid time format: {event.time}") |
|
normalized_data['time'] = event.time |
|
|
|
|
|
if event.activity: |
|
normalized_activity = self.preprocessor.normalize_thai_text(event.activity) |
|
if len(normalized_activity) < 3: |
|
warnings.append("Activity description is very short") |
|
normalized_data['activity'] = normalized_activity |
|
else: |
|
errors.append("Activity is required") |
|
|
|
|
|
valid_semesters = {'ภาคต้น', 'ภาคปลาย', 'ภาคฤดูร้อน'} |
|
if event.semester: |
|
normalized_semester = self.preprocessor.normalize_thai_text(event.semester) |
|
if normalized_semester not in valid_semesters: |
|
warnings.append(f"Unusual semester value: {event.semester}") |
|
normalized_data['semester'] = normalized_semester |
|
else: |
|
errors.append("Semester is required") |
|
|
|
|
|
valid_types = {'registration', 'deadline', 'examination', 'academic', 'holiday'} |
|
if event.event_type not in valid_types: |
|
errors.append(f"Invalid event type: {event.event_type}") |
|
normalized_data['event_type'] = event.event_type |
|
|
|
|
|
if event.note and 'note' not in normalized_data: |
|
normalized_data['note'] = self.preprocessor.normalize_thai_text(event.note) |
|
|
|
|
|
if event.section: |
|
normalized_data['section'] = self.preprocessor.normalize_thai_text(event.section) |
|
|
|
return ValidationResult( |
|
is_valid=len(errors) == 0, |
|
errors=errors, |
|
warnings=warnings, |
|
normalized_data=normalized_data |
|
) |
|
|
|
|
|
@dataclass |
|
class CalendarEvent: |
|
"""Structured representation of a calendar event with validation""" |
|
|
|
@staticmethod |
|
def classify_event_type(activity: str) -> str: |
|
"""Classify event type based on activity description""" |
|
activity_lower = activity.lower() |
|
|
|
keywords = { |
|
'registration': ['ลงทะเบียน', 'ชําระเงิน', 'ค่าธรรมเนียม', 'เปิดเรียน'], |
|
'deadline': ['วันสุดท้าย', 'กําหนด', 'ภายใน', 'ต้องส่ง'], |
|
'examination': ['สอบ', 'ปริญญานิพนธ์', 'วิทยานิพนธ์', 'สอบปากเปล่า'], |
|
'holiday': ['วันหยุด', 'ชดเชย', 'เทศกาล'], |
|
} |
|
|
|
for event_type, terms in keywords.items(): |
|
if any(term in activity_lower for term in terms): |
|
return event_type |
|
return 'academic' |
|
date: str |
|
time: str |
|
activity: str |
|
note: str |
|
semester: str |
|
event_type: str |
|
section: Optional[str] = None |
|
|
|
async def initialize(self, openai_api_key: str): |
|
"""Asynchronously validate and normalize the event""" |
|
validator = CalendarEventValidator(openai_api_key) |
|
result = await validator.validate_event(self) |
|
|
|
if not result.is_valid: |
|
raise ValueError(f"Invalid calendar event: {', '.join(result.errors)}") |
|
|
|
|
|
for field, value in result.normalized_data.items(): |
|
setattr(self, field, value) |
|
|
|
|
|
if result.warnings: |
|
logger.warning(f"Calendar event warnings: {', '.join(result.warnings)}") |
|
|
|
def to_searchable_text(self) -> str: |
|
"""Convert event to searchable text format""" |
|
return f""" |
|
ภาคการศึกษา: {self.semester} |
|
ประเภท: {self.event_type} |
|
วันที่: {self.date} |
|
เวลา: {self.time} |
|
กิจกรรม: {self.activity} |
|
หมวดหมู่: {self.section or '-'} |
|
หมายเหตุ: {self.note} |
|
""".strip() |
|
|
|
class CacheManager: |
|
"""Manages caching for different components of the RAG pipeline""" |
|
|
|
def __init__(self, cache_dir: Path, ttl: int = 3600): |
|
""" |
|
Initialize CacheManager |
|
|
|
Args: |
|
cache_dir: Directory to store cache files |
|
ttl: Time-to-live in seconds for cache entries (default: 1 hour) |
|
""" |
|
self.cache_dir = cache_dir |
|
self.ttl = ttl |
|
self.embeddings_cache = self._load_cache("embeddings") |
|
self.query_cache = self._load_cache("queries") |
|
self.document_cache = self._load_cache("documents") |
|
|
|
def _generate_key(self, data: Union[str, Dict, Any]) -> str: |
|
"""Generate a unique cache key""" |
|
if isinstance(data, str): |
|
content = data.encode('utf-8') |
|
else: |
|
content = json.dumps(data, sort_keys=True).encode('utf-8') |
|
return hashlib.md5(content).hexdigest() |
|
|
|
def _load_cache(self, cache_type: str) -> Dict: |
|
"""Load cache from disk""" |
|
cache_path = self.cache_dir / f"{cache_type}_cache.pkl" |
|
if cache_path.exists(): |
|
try: |
|
with open(cache_path, 'rb') as f: |
|
cache = pickle.load(f) |
|
|
|
self._clean_expired_entries(cache) |
|
return cache |
|
except Exception as e: |
|
logger.warning(f"Failed to load {cache_type} cache: {e}") |
|
return {} |
|
return {} |
|
|
|
def _save_cache(self, cache_type: str, cache_data: Dict): |
|
"""Save cache to disk""" |
|
cache_path = self.cache_dir / f"{cache_type}_cache.pkl" |
|
try: |
|
with open(cache_path, 'wb') as f: |
|
pickle.dump(cache_data, f) |
|
except Exception as e: |
|
logger.error(f"Failed to save {cache_type} cache: {e}") |
|
|
|
def _clean_expired_entries(self, cache: Dict): |
|
"""Remove expired cache entries""" |
|
current_time = datetime.now() |
|
expired_keys = [ |
|
key for key, (_, timestamp) in cache.items() |
|
if current_time - timestamp > timedelta(seconds=self.ttl) |
|
] |
|
for key in expired_keys: |
|
del cache[key] |
|
|
|
def get_embedding_cache(self, text: str) -> Optional[Any]: |
|
"""Get cached embedding for text""" |
|
key = self._generate_key(text) |
|
if key in self.embeddings_cache: |
|
embedding, timestamp = self.embeddings_cache[key] |
|
if datetime.now() - timestamp <= timedelta(seconds=self.ttl): |
|
return embedding |
|
return None |
|
|
|
def set_embedding_cache(self, text: str, embedding: Any): |
|
"""Cache embedding for text""" |
|
key = self._generate_key(text) |
|
self.embeddings_cache[key] = (embedding, datetime.now()) |
|
self._save_cache("embeddings", self.embeddings_cache) |
|
|
|
def get_query_cache(self, query: str) -> Optional[Dict]: |
|
"""Get cached query results""" |
|
key = self._generate_key(query) |
|
if key in self.query_cache: |
|
result, timestamp = self.query_cache[key] |
|
if datetime.now() - timestamp <= timedelta(seconds=self.ttl): |
|
return result |
|
return None |
|
|
|
def set_query_cache(self, query: str, result: Dict): |
|
"""Cache query results""" |
|
key = self._generate_key(query) |
|
self.query_cache[key] = (result, datetime.now()) |
|
self._save_cache("queries", self.query_cache) |
|
|
|
def get_document_cache(self, doc_id: str) -> Optional[Any]: |
|
"""Get cached document""" |
|
if doc_id in self.document_cache: |
|
doc, timestamp = self.document_cache[doc_id] |
|
if datetime.now() - timestamp <= timedelta(seconds=self.ttl): |
|
return doc |
|
return None |
|
|
|
def set_document_cache(self, doc_id: str, document: Any): |
|
"""Cache document""" |
|
self.document_cache[doc_id] = (document, datetime.now()) |
|
self._save_cache("documents", self.document_cache) |
|
|
|
def clear_cache(self, cache_type: Optional[str] = None): |
|
"""Clear specific or all caches""" |
|
if cache_type == "embeddings": |
|
self.embeddings_cache.clear() |
|
self._save_cache("embeddings", self.embeddings_cache) |
|
elif cache_type == "queries": |
|
self.query_cache.clear() |
|
self._save_cache("queries", self.query_cache) |
|
elif cache_type == "documents": |
|
self.document_cache.clear() |
|
self._save_cache("documents", self.document_cache) |
|
else: |
|
self.embeddings_cache.clear() |
|
self.query_cache.clear() |
|
self.document_cache.clear() |
|
for cache_type in ["embeddings", "queries", "documents"]: |
|
self._save_cache(cache_type, {}) |
|
|
|
@dataclass |
|
class ModelConfig: |
|
"""Configuration for language models and embeddings""" |
|
openai_api_key: str |
|
embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
openai_model: str = "gpt-4o" |
|
temperature: float = 0.7 |
|
max_tokens: int = 2000 |
|
top_p: float = 0.95 |
|
frequency_penalty: float = 0.0 |
|
presence_penalty: float = 0.0 |
|
|
|
@dataclass |
|
class RetrieverConfig: |
|
"""Configuration for document retrieval""" |
|
top_k: int = 5 |
|
similarity_threshold: float = 0.7 |
|
reranking_enabled: bool = False |
|
reranking_model: Optional[str] = None |
|
filter_duplicates: bool = True |
|
min_document_length: int = 10 |
|
|
|
@dataclass |
|
class CacheConfig: |
|
"""Configuration for caching behavior""" |
|
enabled: bool = True |
|
cache_dir: Path = field(default_factory=lambda: Path("./cache")) |
|
embeddings_cache_ttl: int = 86400 |
|
query_cache_ttl: int = 3600 |
|
max_cache_size: int = 1000 |
|
cache_cleanup_interval: int = 3600 |
|
|
|
@dataclass |
|
class ProcessingConfig: |
|
"""Configuration for data processing""" |
|
batch_size: int = 32 |
|
max_retries: int = 3 |
|
timeout: int = 30 |
|
max_concurrent_requests: int = 5 |
|
chunk_size: int = 512 |
|
chunk_overlap: int = 50 |
|
preprocessing_workers: int = 4 |
|
|
|
@dataclass |
|
class MonitoringConfig: |
|
"""Configuration for monitoring and logging""" |
|
enable_monitoring: bool = True |
|
log_level: str = "INFO" |
|
metrics_enabled: bool = True |
|
trace_enabled: bool = True |
|
performance_logging: bool = True |
|
slow_query_threshold: float = 5.0 |
|
health_check_interval: int = 300 |
|
|
|
@dataclass |
|
class LocalizationConfig: |
|
"""Configuration for Thai language handling""" |
|
thai_tokenizer_model: str = "thai-tokenizer" |
|
enable_thai_normalization: bool = True |
|
remove_thai_tones: bool = False |
|
keep_english: bool = True |
|
custom_stopwords: List[str] = field(default_factory=list) |
|
custom_synonyms: Dict[str, List[str]] = field(default_factory=dict) |
|
|
|
@dataclass |
|
class PipelineConfig: |
|
"""Main configuration for the RAG pipeline""" |
|
|
|
model: ModelConfig |
|
|
|
|
|
retriever: RetrieverConfig = field(default_factory=RetrieverConfig) |
|
|
|
|
|
cache: CacheConfig = field(default_factory=CacheConfig) |
|
|
|
|
|
processing: ProcessingConfig = field(default_factory=ProcessingConfig) |
|
|
|
|
|
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig) |
|
|
|
|
|
localization: LocalizationConfig = field(default_factory=LocalizationConfig) |
|
|
|
|
|
rate_limit_enabled: bool = True |
|
requests_per_minute: int = 60 |
|
|
|
|
|
debug_mode: bool = False |
|
development_mode: bool = False |
|
|
|
def __post_init__(self): |
|
"""Validate configuration and create necessary directories""" |
|
if not self.model.openai_api_key: |
|
raise ValueError("OpenAI API key is required") |
|
|
|
if self.cache.enabled: |
|
self.cache.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
"""Convert configuration to dictionary format""" |
|
return { |
|
"model_config": { |
|
"embedder_model": self.model.embedder_model, |
|
"openai_model": self.model.openai_model, |
|
"temperature": self.model.temperature, |
|
|
|
}, |
|
"retriever_config": { |
|
"top_k": self.retriever.top_k, |
|
"similarity_threshold": self.retriever.similarity_threshold, |
|
|
|
}, |
|
|
|
} |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: Dict[str, Any]) -> 'PipelineConfig': |
|
"""Create configuration from dictionary""" |
|
model_config = ModelConfig(**config_dict.get("model_config", {})) |
|
retriever_config = RetrieverConfig(**config_dict.get("retriever_config", {})) |
|
|
|
|
|
return cls( |
|
model=model_config, |
|
retriever=retriever_config, |
|
|
|
) |
|
|
|
def create_default_config(api_key: str) -> PipelineConfig: |
|
"""Create a default configuration with the given API key""" |
|
model_config = ModelConfig( |
|
openai_api_key=api_key, |
|
embedder_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
) |
|
return PipelineConfig( |
|
model=model_config, |
|
retriever=RetrieverConfig(), |
|
cache=CacheConfig(), |
|
processing=ProcessingConfig(), |
|
monitoring=MonitoringConfig(), |
|
localization=LocalizationConfig() |
|
) |
|
|
|
class CalendarDataProcessor: |
|
"""Process and structure calendar data""" |
|
|
|
@staticmethod |
|
def parse_calendar_json(json_data: List[Dict]) -> List[CalendarEvent]: |
|
events = [] |
|
|
|
for semester_data in json_data: |
|
semester = semester_data['education'] |
|
|
|
|
|
for event in semester_data.get('schedule', []): |
|
|
|
if 'section' in event and 'details' in event: |
|
|
|
section = event['section'] |
|
for detail in event['details']: |
|
|
|
if 'ภาคต้น' in detail and 'ภาคปลาย' in detail: |
|
|
|
semesters = ['ภาคต้น', 'ภาคปลาย'] |
|
for sem in semesters: |
|
events.append(CalendarEvent( |
|
date=detail.get(sem, ''), |
|
time='', |
|
activity=detail.get('title', ''), |
|
note=section, |
|
semester=sem, |
|
event_type='deadline', |
|
section=section |
|
)) |
|
else: |
|
|
|
events.append(CalendarEvent( |
|
date=detail.get('date', ''), |
|
time='', |
|
activity=detail.get('title', ''), |
|
note=section, |
|
semester=semester, |
|
event_type='deadline', |
|
section=section |
|
)) |
|
else: |
|
|
|
event_type = CalendarEvent.classify_event_type(event.get('activity', '')) |
|
events.append(CalendarEvent( |
|
date=event.get('date', ''), |
|
time=event.get('time', ''), |
|
activity=event.get('activity', ''), |
|
note=event.get('note', ''), |
|
semester=semester, |
|
event_type=event_type |
|
)) |
|
|
|
return events |
|
|
|
|
|
class EnhancedDocumentStore: |
|
"""Enhanced document store with caching capabilities""" |
|
|
|
def __init__(self, config: PipelineConfig): |
|
self.store = InMemoryDocumentStore() |
|
self.embedder = SentenceTransformersDocumentEmbedder( |
|
model=config.model.embedder_model |
|
) |
|
self.cache_manager = CacheManager( |
|
cache_dir=config.cache.cache_dir, |
|
ttl=config.cache.embeddings_cache_ttl |
|
) |
|
|
|
|
|
self.embedder.warm_up() |
|
|
|
self.events = [] |
|
self.event_type_index = {} |
|
self.semester_index = {} |
|
|
|
def _compute_embedding(self, text: str) -> Any: |
|
"""Compute embedding with caching""" |
|
cached_embedding = self.cache_manager.get_embedding_cache(text) |
|
if cached_embedding is not None: |
|
return cached_embedding |
|
|
|
doc = Document(content=text) |
|
embedding = self.embedder.run(documents=[doc])["documents"][0].embedding |
|
self.cache_manager.set_embedding_cache(text, embedding) |
|
return embedding |
|
|
|
def add_events(self, events: List[CalendarEvent]): |
|
"""Add events with caching""" |
|
documents = [] |
|
|
|
for event in events: |
|
|
|
self.events.append(event) |
|
event_idx = len(self.events) - 1 |
|
|
|
|
|
if event.event_type not in self.event_type_index: |
|
self.event_type_index[event.event_type] = [] |
|
self.event_type_index[event.event_type].append(event_idx) |
|
|
|
if event.semester not in self.semester_index: |
|
self.semester_index[event.semester] = [] |
|
self.semester_index[event.semester].append(event_idx) |
|
|
|
|
|
text = event.to_searchable_text() |
|
embedding = self._compute_embedding(text) |
|
|
|
doc = Document( |
|
content=text, |
|
embedding=embedding, |
|
meta={ |
|
'event_type': event.event_type, |
|
'semester': event.semester, |
|
'date': event.date |
|
} |
|
) |
|
documents.append(doc) |
|
|
|
|
|
self.cache_manager.set_document_cache(str(event_idx), doc) |
|
|
|
|
|
self.store.write_documents(documents) |
|
|
|
def search(self, |
|
query: str, |
|
event_type: Optional[str] = None, |
|
semester: Optional[str] = None, |
|
top_k: int = 5) -> List[Document]: |
|
"""Search with query caching""" |
|
|
|
cache_key = json.dumps({ |
|
'query': query, |
|
'event_type': event_type, |
|
'semester': semester, |
|
'top_k': top_k |
|
}) |
|
cached_results = self.cache_manager.get_query_cache(cache_key) |
|
if cached_results is not None: |
|
return cached_results |
|
|
|
|
|
query_embedding = self._compute_embedding(query) |
|
|
|
|
|
retriever = InMemoryEmbeddingRetriever( |
|
document_store=self.store, |
|
top_k=top_k * 2 |
|
) |
|
|
|
results = retriever.run(query_embedding=query_embedding)["documents"] |
|
|
|
|
|
filtered_results = [] |
|
for doc in results: |
|
if event_type and doc.meta['event_type'] != event_type: |
|
continue |
|
if semester and doc.meta['semester'] != semester: |
|
continue |
|
filtered_results.append(doc) |
|
|
|
final_results = filtered_results[:top_k] |
|
|
|
|
|
self.cache_manager.set_query_cache(cache_key, final_results) |
|
|
|
return final_results |
|
|
|
class AdvancedQueryProcessor: |
|
"""Process queries with better understanding""" |
|
|
|
def __init__(self, config: PipelineConfig): |
|
self.generator = OpenAIGenerator( |
|
api_key=Secret.from_token(config.model.openai_api_key), |
|
model=config.model.openai_model |
|
) |
|
self.prompt_builder = PromptBuilder( |
|
template=""" |
|
Analyze this academic calendar query (in Thai): |
|
Query: {{query}} |
|
|
|
Determine: |
|
1. The type of information being requested |
|
2. Any specific semester mentioned |
|
3. Key terms to look for |
|
|
|
Return as JSON: |
|
{ |
|
"event_type": "registration|deadline|examination|academic|holiday", |
|
"semester": "term mentioned or null", |
|
"key_terms": ["up to 3 most important terms"], |
|
"response_format": "list|single|detailed" |
|
} |
|
""") |
|
|
|
def process_query(self, query: str) -> Dict[str, Any]: |
|
"""Process and analyze query""" |
|
try: |
|
|
|
result = self.prompt_builder.run(query=query) |
|
response = self.generator.run(prompt=result["prompt"]) |
|
|
|
|
|
if not response or not response.get("replies") or not response["replies"][0]: |
|
logger.warning("Received empty response from generator") |
|
return self._get_default_analysis(query) |
|
|
|
try: |
|
|
|
analysis = json.loads(response["replies"][0]) |
|
|
|
|
|
required_fields = ["event_type", "semester", "key_terms", "response_format"] |
|
for field in required_fields: |
|
if field not in analysis: |
|
logger.warning(f"Missing required field: {field}") |
|
return self._get_default_analysis(query) |
|
|
|
return { |
|
"original_query": query, |
|
**analysis |
|
} |
|
|
|
except json.JSONDecodeError as je: |
|
logger.error(f"JSON parsing failed: {str(je)}") |
|
return self._get_default_analysis(query) |
|
|
|
except Exception as e: |
|
logger.error(f"Query processing failed: {str(e)}") |
|
return self._get_default_analysis(query) |
|
|
|
def _get_default_analysis(self, query: str) -> Dict[str, Any]: |
|
"""Return default analysis when processing fails""" |
|
logger.info("Returning default analysis") |
|
return { |
|
"original_query": query, |
|
"event_type": None, |
|
"semester": None, |
|
"key_terms": [], |
|
"response_format": "detailed" |
|
} |
|
|
|
@dataclass |
|
class RateLimitConfig: |
|
"""Configuration for rate limiting""" |
|
requests_per_minute: int = 60 |
|
max_retries: int = 3 |
|
base_delay: float = 1.0 |
|
max_delay: float = 60.0 |
|
timeout: float = 30.0 |
|
concurrent_requests: int = 5 |
|
|
|
class APIError(Exception): |
|
"""Base class for API related errors""" |
|
def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Dict] = None): |
|
super().__init__(message) |
|
self.status_code = status_code |
|
self.response = response |
|
|
|
class RateLimitExceededError(APIError): |
|
"""Raised when rate limit is exceeded""" |
|
pass |
|
|
|
class OpenAIRateLimiter: |
|
"""Rate limiter with advanced error handling for OpenAI API""" |
|
|
|
def __init__(self, config: RateLimitConfig): |
|
self.config = config |
|
self.requests = deque(maxlen=config.requests_per_minute) |
|
self.semaphore = asyncio.Semaphore(config.concurrent_requests) |
|
self.total_requests = 0 |
|
self.errors = deque(maxlen=1000) |
|
self.start_time = datetime.now() |
|
|
|
async def acquire(self): |
|
"""Acquire permission to make a request""" |
|
now = time.time() |
|
|
|
|
|
while self.requests and self.requests[0] < now - 60: |
|
self.requests.popleft() |
|
|
|
|
|
if len(self.requests) >= self.config.requests_per_minute: |
|
wait_time = 60 - (now - self.requests[0]) |
|
logger.warning(f"Rate limit reached. Waiting {wait_time:.2f} seconds") |
|
await asyncio.sleep(wait_time) |
|
|
|
|
|
self.requests.append(now) |
|
self.total_requests += 1 |
|
|
|
def get_usage_stats(self) -> Dict[str, Any]: |
|
"""Get current usage statistics""" |
|
return { |
|
"total_requests": self.total_requests, |
|
"current_rpm": len(self.requests), |
|
"uptime": (datetime.now() - self.start_time).total_seconds(), |
|
"error_rate": len(self.errors) / self.total_requests if self.total_requests > 0 else 0 |
|
} |
|
|
|
@retry( |
|
stop=stop_after_attempt(3), |
|
wait=wait_exponential(multiplier=1, min=4, max=60), |
|
reraise=True |
|
) |
|
async def execute_with_retry(self, func, *args, **kwargs): |
|
"""Execute API call with retry logic""" |
|
try: |
|
async with self.semaphore: |
|
await self.acquire() |
|
return await func(*args, **kwargs) |
|
|
|
except Exception as e: |
|
error_info = { |
|
"timestamp": datetime.now(), |
|
"error_type": type(e).__name__, |
|
"message": str(e) |
|
} |
|
self.errors.append(error_info) |
|
|
|
if isinstance(e, RateLimitExceededError): |
|
logger.warning("Rate limit exceeded, backing off...") |
|
await asyncio.sleep(self.config.base_delay) |
|
raise |
|
|
|
elif "timeout" in str(e).lower(): |
|
logger.error(f"Timeout error: {str(e)}") |
|
raise APIError(f"Request timed out after {self.config.timeout} seconds") |
|
|
|
else: |
|
logger.error(f"API error: {str(e)}") |
|
raise |
|
|
|
class ResponseGenerator: |
|
"""Generate responses with better context utilization""" |
|
|
|
def __init__(self, config: PipelineConfig): |
|
self.generator = OpenAIGenerator( |
|
api_key=Secret.from_token(config.model.openai_api_key), |
|
model=config.model.openai_model |
|
) |
|
self.prompt_builder = PromptBuilder( |
|
template=""" |
|
You are a helpful academic advisor. Answer the following query using the provided calendar information. |
|
|
|
Query: {{query}} |
|
|
|
Relevant Calendar Information: |
|
{% for doc in context %} |
|
--- |
|
{{doc.content}} |
|
{% endfor %} |
|
|
|
Format: {{format}} |
|
|
|
Guidelines: |
|
1. Answer in Thai language |
|
2. Be specific about dates and requirements |
|
3. Include relevant notes or conditions |
|
4. Format the response according to the specified format |
|
|
|
Provide your response: |
|
""") |
|
|
|
def generate_response(self, |
|
query: str, |
|
documents: List[Document], |
|
query_info: Dict[str, Any]) -> str: |
|
"""Generate response using retrieved documents""" |
|
try: |
|
result = self.prompt_builder.run( |
|
query=query, |
|
context=documents, |
|
format=query_info["response_format"] |
|
) |
|
|
|
response = self.generator.run(prompt=result["prompt"]) |
|
return response["replies"][0] |
|
|
|
except Exception as e: |
|
logger.error(f"Response generation failed: {str(e)}") |
|
return "ขออภัย ไม่สามารถประมวลผลคำตอบได้ในขณะนี้" |
|
|
|
class AcademicCalendarRAG: |
|
"""Main RAG pipeline for academic calendar queries""" |
|
|
|
def __init__(self, config: PipelineConfig): |
|
self.config = config |
|
self.document_store = EnhancedDocumentStore(config) |
|
self.query_processor = AdvancedQueryProcessor(config) |
|
self.response_generator = ResponseGenerator(config) |
|
|
|
def load_data(self, json_data: List[Dict]): |
|
"""Load and process calendar data""" |
|
processor = CalendarDataProcessor() |
|
events = processor.parse_calendar_json(json_data) |
|
self.document_store.add_events(events) |
|
|
|
def process_query(self, query: str) -> Dict[str, Any]: |
|
"""Process query and generate response""" |
|
try: |
|
|
|
query_info = self.query_processor.process_query(query) |
|
|
|
|
|
documents = self.document_store.search( |
|
query=query, |
|
event_type=query_info["event_type"], |
|
semester=query_info["semester"], |
|
top_k=self.config.retriever.top_k |
|
) |
|
|
|
|
|
response = self.response_generator.generate_response( |
|
query=query, |
|
documents=documents, |
|
query_info=query_info |
|
) |
|
|
|
return { |
|
"answer": response, |
|
"documents": documents, |
|
"query_info": query_info |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Query processing failed: {str(e)}") |
|
return { |
|
"answer": "ขออภัย ไม่สามารถประมวลผลคำถามได้ในขณะนี้", |
|
"documents": [], |
|
"query_info": {} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|