JirasakJo's picture
Update calendar_rag.py
30fbf64 verified
raw
history blame
38.8 kB
from haystack import *
from haystack.components.generators.openai import OpenAIGenerator
from haystack.components.builders import PromptBuilder
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.utils import Secret
from tenacity import retry, stop_after_attempt, wait_exponential
from pathlib import Path
import hashlib
from datetime import *
from typing import *
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from rouge_score import rouge_scorer
import pandas as pd
from dataclasses import *
import json
import logging
import os
import re
import pickle
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class OpenAIDateParser:
"""Uses OpenAI to parse complex Thai date formats"""
def __init__(self, api_key: str, model: str = "gpt-4o"):
self.generator = OpenAIGenerator(
api_key=Secret.from_token(api_key),
model=model
)
self.prompt_builder = PromptBuilder(
template="""
Parse the following Thai date range into a structured format:
Date: {{date}}
Return in JSON format:
{
"start_date": "YYYY-MM-DD",
"end_date": "YYYY-MM-DD" (if range),
"is_range": true/false
}
Notes:
- Convert Buddhist Era (BE) to CE
- Handle abbreviated Thai months
- Account for date ranges with dashes
- Return null for end_date if it's a single date
Example inputs and outputs:
Input: "จ 8 ก.ค. – จ 19 ส.ค. 67"
Output: {"start_date": "2024-07-08", "end_date": "2024-08-19", "is_range": true}
Input: "15 มกราคม 2567"
Output: {"start_date": "2024-01-15", "end_date": null, "is_range": false}
"""
)
async def parse_date(self, date_str: str) -> Dict[str, Union[str, bool]]:
"""Parse complex Thai date format using OpenAI"""
try:
# Build prompt
result = self.prompt_builder.run(date=date_str)
# Get OpenAI response
response = await self.generator.arun(prompt=result["prompt"])
if not response or not response.get("replies"):
raise ValueError("Empty response from OpenAI")
# Parse JSON response
parsed = json.loads(response["replies"][0])
# Validate the parsed dates
for date_field in ['start_date', 'end_date']:
if parsed.get(date_field):
datetime.strptime(parsed[date_field], '%Y-%m-%d')
return parsed
except Exception as e:
logger.error(f"OpenAI date parsing failed for '{date_str}': {str(e)}")
raise ValueError(f"Could not parse date: {date_str}")
@dataclass
class ValidationResult:
"""Stores the result of a validation check"""
is_valid: bool
errors: List[str]
warnings: List[str]
normalized_data: Dict[str, str]
class ThaiTextPreprocessor:
"""Handles Thai text preprocessing and normalization"""
# Thai character normalization mappings
CHAR_MAP = {
'ํา': 'ำ', # Normalize sara am
'์': '', # Remove yamakkan
'–': '-', # Normalize dashes
'—': '-',
'٫': ',', # Normalize separators
}
@classmethod
def normalize_thai_text(cls, text: str) -> str:
"""Normalize Thai text by applying character mappings and spacing rules"""
if not text:
return text
# Apply character mappings
for old, new in cls.CHAR_MAP.items():
text = text.replace(old, new)
# Normalize whitespace
text = re.sub(r'\s+', ' ', text.strip())
# Normalize Thai numerals if present
thai_digits = '๐๑๒๓๔๕๖๗๘๙'
arabic_digits = '0123456789'
for thai, arabic in zip(thai_digits, arabic_digits):
text = text.replace(thai, arabic)
return text
class CalendarEventValidator:
"""Validates and preprocesses calendar events"""
def __init__(self, openai_api_key: str):
self.preprocessor = ThaiTextPreprocessor()
self.date_parser = OpenAIDateParser(api_key=openai_api_key)
async def validate_event(self, event: 'CalendarEvent') -> ValidationResult:
"""Validate a calendar event and return validation results"""
errors = []
warnings = []
normalized_data = {}
# Validate and normalize date using OpenAI
if event.date:
try:
parsed_date = await self.date_parser.parse_date(event.date)
normalized_data['date'] = parsed_date['start_date']
# If it's a date range, store it in the note
if parsed_date['is_range'] and parsed_date['end_date']:
range_note = f"ถึงวันที่ {parsed_date['end_date']}"
if event.note:
normalized_data['note'] = f"{event.note}; {range_note}"
else:
normalized_data['note'] = range_note
except ValueError as e:
errors.append(f"Invalid date format: {event.date}")
else:
errors.append("Date is required")
# Validate time format if provided
if event.time:
time_pattern = r'^([01]?[0-9]|2[0-3]):([0-5][0-9])$'
if not re.match(time_pattern, event.time):
errors.append(f"Invalid time format: {event.time}")
normalized_data['time'] = event.time
# Validate and normalize activity
if event.activity:
normalized_activity = self.preprocessor.normalize_thai_text(event.activity)
if len(normalized_activity) < 3:
warnings.append("Activity description is very short")
normalized_data['activity'] = normalized_activity
else:
errors.append("Activity is required")
# Validate semester
valid_semesters = {'ภาคต้น', 'ภาคปลาย', 'ภาคฤดูร้อน'}
if event.semester:
normalized_semester = self.preprocessor.normalize_thai_text(event.semester)
if normalized_semester not in valid_semesters:
warnings.append(f"Unusual semester value: {event.semester}")
normalized_data['semester'] = normalized_semester
else:
errors.append("Semester is required")
# Validate event type
valid_types = {'registration', 'deadline', 'examination', 'academic', 'holiday'}
if event.event_type not in valid_types:
errors.append(f"Invalid event type: {event.event_type}")
normalized_data['event_type'] = event.event_type
# Normalize note if present and not already set by date range
if event.note and 'note' not in normalized_data:
normalized_data['note'] = self.preprocessor.normalize_thai_text(event.note)
# Normalize section if present
if event.section:
normalized_data['section'] = self.preprocessor.normalize_thai_text(event.section)
return ValidationResult(
is_valid=len(errors) == 0,
errors=errors,
warnings=warnings,
normalized_data=normalized_data
)
# Update CalendarEvent class to include async validation
@dataclass
class CalendarEvent:
"""Structured representation of a calendar event with validation"""
@staticmethod
def classify_event_type(activity: str) -> str:
"""Classify event type based on activity description"""
activity_lower = activity.lower()
keywords = {
'registration': ['ลงทะเบียน', 'ชําระเงิน', 'ค่าธรรมเนียม', 'เปิดเรียน'],
'deadline': ['วันสุดท้าย', 'กําหนด', 'ภายใน', 'ต้องส่ง'],
'examination': ['สอบ', 'ปริญญานิพนธ์', 'วิทยานิพนธ์', 'สอบปากเปล่า'],
'holiday': ['วันหยุด', 'ชดเชย', 'เทศกาล'],
}
for event_type, terms in keywords.items():
if any(term in activity_lower for term in terms):
return event_type
return 'academic'
date: str
time: str
activity: str
note: str
semester: str
event_type: str
section: Optional[str] = None
async def initialize(self, openai_api_key: str):
"""Asynchronously validate and normalize the event"""
validator = CalendarEventValidator(openai_api_key)
result = await validator.validate_event(self)
if not result.is_valid:
raise ValueError(f"Invalid calendar event: {', '.join(result.errors)}")
# Update with normalized data
for field, value in result.normalized_data.items():
setattr(self, field, value)
# Log any warnings
if result.warnings:
logger.warning(f"Calendar event warnings: {', '.join(result.warnings)}")
def to_searchable_text(self) -> str:
"""Convert event to searchable text format"""
return f"""
ภาคการศึกษา: {self.semester}
ประเภท: {self.event_type}
วันที่: {self.date}
เวลา: {self.time}
กิจกรรม: {self.activity}
หมวดหมู่: {self.section or '-'}
หมายเหตุ: {self.note}
""".strip()
class CacheManager:
"""Manages caching for different components of the RAG pipeline"""
def __init__(self, cache_dir: Path, ttl: int = 3600):
"""
Initialize CacheManager
Args:
cache_dir: Directory to store cache files
ttl: Time-to-live in seconds for cache entries (default: 1 hour)
"""
self.cache_dir = cache_dir
self.ttl = ttl
self.embeddings_cache = self._load_cache("embeddings")
self.query_cache = self._load_cache("queries")
self.document_cache = self._load_cache("documents")
def _generate_key(self, data: Union[str, Dict, Any]) -> str:
"""Generate a unique cache key"""
if isinstance(data, str):
content = data.encode('utf-8')
else:
content = json.dumps(data, sort_keys=True).encode('utf-8')
return hashlib.md5(content).hexdigest()
def _load_cache(self, cache_type: str) -> Dict:
"""Load cache from disk"""
cache_path = self.cache_dir / f"{cache_type}_cache.pkl"
if cache_path.exists():
try:
with open(cache_path, 'rb') as f:
cache = pickle.load(f)
# Clean expired entries
self._clean_expired_entries(cache)
return cache
except Exception as e:
logger.warning(f"Failed to load {cache_type} cache: {e}")
return {}
return {}
def _save_cache(self, cache_type: str, cache_data: Dict):
"""Save cache to disk"""
cache_path = self.cache_dir / f"{cache_type}_cache.pkl"
try:
with open(cache_path, 'wb') as f:
pickle.dump(cache_data, f)
except Exception as e:
logger.error(f"Failed to save {cache_type} cache: {e}")
def _clean_expired_entries(self, cache: Dict):
"""Remove expired cache entries"""
current_time = datetime.now()
expired_keys = [
key for key, (_, timestamp) in cache.items()
if current_time - timestamp > timedelta(seconds=self.ttl)
]
for key in expired_keys:
del cache[key]
def get_embedding_cache(self, text: str) -> Optional[Any]:
"""Get cached embedding for text"""
key = self._generate_key(text)
if key in self.embeddings_cache:
embedding, timestamp = self.embeddings_cache[key]
if datetime.now() - timestamp <= timedelta(seconds=self.ttl):
return embedding
return None
def set_embedding_cache(self, text: str, embedding: Any):
"""Cache embedding for text"""
key = self._generate_key(text)
self.embeddings_cache[key] = (embedding, datetime.now())
self._save_cache("embeddings", self.embeddings_cache)
def get_query_cache(self, query: str) -> Optional[Dict]:
"""Get cached query results"""
key = self._generate_key(query)
if key in self.query_cache:
result, timestamp = self.query_cache[key]
if datetime.now() - timestamp <= timedelta(seconds=self.ttl):
return result
return None
def set_query_cache(self, query: str, result: Dict):
"""Cache query results"""
key = self._generate_key(query)
self.query_cache[key] = (result, datetime.now())
self._save_cache("queries", self.query_cache)
def get_document_cache(self, doc_id: str) -> Optional[Any]:
"""Get cached document"""
if doc_id in self.document_cache:
doc, timestamp = self.document_cache[doc_id]
if datetime.now() - timestamp <= timedelta(seconds=self.ttl):
return doc
return None
def set_document_cache(self, doc_id: str, document: Any):
"""Cache document"""
self.document_cache[doc_id] = (document, datetime.now())
self._save_cache("documents", self.document_cache)
def clear_cache(self, cache_type: Optional[str] = None):
"""Clear specific or all caches"""
if cache_type == "embeddings":
self.embeddings_cache.clear()
self._save_cache("embeddings", self.embeddings_cache)
elif cache_type == "queries":
self.query_cache.clear()
self._save_cache("queries", self.query_cache)
elif cache_type == "documents":
self.document_cache.clear()
self._save_cache("documents", self.document_cache)
else:
self.embeddings_cache.clear()
self.query_cache.clear()
self.document_cache.clear()
for cache_type in ["embeddings", "queries", "documents"]:
self._save_cache(cache_type, {})
@dataclass
class ModelConfig:
"""Configuration for language models and embeddings"""
openai_api_key: str
embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
openai_model: str = "gpt-4o"
temperature: float = 0.7
max_tokens: int = 2000
top_p: float = 0.95
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
@dataclass
class RetrieverConfig:
"""Configuration for document retrieval"""
top_k: int = 5
similarity_threshold: float = 0.7
reranking_enabled: bool = False
reranking_model: Optional[str] = None
filter_duplicates: bool = True
min_document_length: int = 10
@dataclass
class CacheConfig:
"""Configuration for caching behavior"""
enabled: bool = True
cache_dir: Path = field(default_factory=lambda: Path("./cache"))
embeddings_cache_ttl: int = 86400 # 24 hours
query_cache_ttl: int = 3600 # 1 hour
max_cache_size: int = 1000 # entries
cache_cleanup_interval: int = 3600 # 1 hour
@dataclass
class ProcessingConfig:
"""Configuration for data processing"""
batch_size: int = 32
max_retries: int = 3
timeout: int = 30
max_concurrent_requests: int = 5
chunk_size: int = 512
chunk_overlap: int = 50
preprocessing_workers: int = 4
@dataclass
class MonitoringConfig:
"""Configuration for monitoring and logging"""
enable_monitoring: bool = True
log_level: str = "INFO"
metrics_enabled: bool = True
trace_enabled: bool = True
performance_logging: bool = True
slow_query_threshold: float = 5.0 # seconds
health_check_interval: int = 300 # 5 minutes
@dataclass
class LocalizationConfig:
"""Configuration for Thai language handling"""
thai_tokenizer_model: str = "thai-tokenizer"
enable_thai_normalization: bool = True
remove_thai_tones: bool = False
keep_english: bool = True
custom_stopwords: List[str] = field(default_factory=list)
custom_synonyms: Dict[str, List[str]] = field(default_factory=dict)
@dataclass
class PipelineConfig:
"""Main configuration for the RAG pipeline"""
# Model configurations
model: ModelConfig
# Retriever settings
retriever: RetrieverConfig = field(default_factory=RetrieverConfig)
# Cache settings
cache: CacheConfig = field(default_factory=CacheConfig)
# Processing settings
processing: ProcessingConfig = field(default_factory=ProcessingConfig)
# Monitoring settings
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
# Localization settings
localization: LocalizationConfig = field(default_factory=LocalizationConfig)
# Rate limiting
rate_limit_enabled: bool = True
requests_per_minute: int = 60
# System settings
debug_mode: bool = False
development_mode: bool = False
def __post_init__(self):
"""Validate configuration and create necessary directories"""
if not self.model.openai_api_key:
raise ValueError("OpenAI API key is required")
if self.cache.enabled:
self.cache.cache_dir.mkdir(parents=True, exist_ok=True)
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary format"""
return {
"model_config": {
"embedder_model": self.model.embedder_model,
"openai_model": self.model.openai_model,
"temperature": self.model.temperature,
# Add other relevant fields
},
"retriever_config": {
"top_k": self.retriever.top_k,
"similarity_threshold": self.retriever.similarity_threshold,
# Add other relevant fields
},
# Add other configuration sections
}
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> 'PipelineConfig':
"""Create configuration from dictionary"""
model_config = ModelConfig(**config_dict.get("model_config", {}))
retriever_config = RetrieverConfig(**config_dict.get("retriever_config", {}))
# Create other config objects
return cls(
model=model_config,
retriever=retriever_config,
# Add other configuration objects
)
def create_default_config(api_key: str) -> PipelineConfig:
"""Create a default configuration with the given API key"""
model_config = ModelConfig(
openai_api_key=api_key,
embedder_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
return PipelineConfig(
model=model_config,
retriever=RetrieverConfig(),
cache=CacheConfig(),
processing=ProcessingConfig(),
monitoring=MonitoringConfig(),
localization=LocalizationConfig()
)
class CalendarDataProcessor:
"""Process and structure calendar data"""
@staticmethod
def parse_calendar_json(json_data: List[Dict]) -> List[CalendarEvent]:
events = []
for semester_data in json_data:
semester = semester_data['education']
# Process regular schedule events
for event in semester_data.get('schedule', []):
# Check if this is a regular event or a section with details
if 'section' in event and 'details' in event:
# This is a section with details
section = event['section']
for detail in event['details']:
# Extract semester-specific information if available
if 'ภาคต้น' in detail and 'ภาคปลาย' in detail:
# Handle both semesters
semesters = ['ภาคต้น', 'ภาคปลาย']
for sem in semesters:
events.append(CalendarEvent(
date=detail.get(sem, ''),
time='',
activity=detail.get('title', ''),
note=section,
semester=sem,
event_type='deadline',
section=section
))
else:
# Single event
events.append(CalendarEvent(
date=detail.get('date', ''),
time='',
activity=detail.get('title', ''),
note=section,
semester=semester,
event_type='deadline',
section=section
))
else:
# This is a regular event
event_type = CalendarEvent.classify_event_type(event.get('activity', ''))
events.append(CalendarEvent(
date=event.get('date', ''),
time=event.get('time', ''),
activity=event.get('activity', ''),
note=event.get('note', ''),
semester=semester,
event_type=event_type
))
return events
# Update the EnhancedDocumentStore class to use caching
class EnhancedDocumentStore:
"""Enhanced document store with caching capabilities"""
def __init__(self, config: PipelineConfig):
self.store = InMemoryDocumentStore()
self.embedder = SentenceTransformersDocumentEmbedder(
model=config.model.embedder_model
)
self.cache_manager = CacheManager(
cache_dir=config.cache.cache_dir,
ttl=config.cache.embeddings_cache_ttl
)
# Configure for Thai text
self.embedder.warm_up()
self.events = []
self.event_type_index = {}
self.semester_index = {}
def _compute_embedding(self, text: str) -> Any:
"""Compute embedding with caching"""
cached_embedding = self.cache_manager.get_embedding_cache(text)
if cached_embedding is not None:
return cached_embedding
doc = Document(content=text)
embedding = self.embedder.run(documents=[doc])["documents"][0].embedding
self.cache_manager.set_embedding_cache(text, embedding)
return embedding
def add_events(self, events: List[CalendarEvent]):
"""Add events with caching"""
documents = []
for event in events:
# Store event
self.events.append(event)
event_idx = len(self.events) - 1
# Update indices
if event.event_type not in self.event_type_index:
self.event_type_index[event.event_type] = []
self.event_type_index[event.event_type].append(event_idx)
if event.semester not in self.semester_index:
self.semester_index[event.semester] = []
self.semester_index[event.semester].append(event_idx)
# Create document with cached embedding
text = event.to_searchable_text()
embedding = self._compute_embedding(text)
doc = Document(
content=text,
embedding=embedding,
meta={
'event_type': event.event_type,
'semester': event.semester,
'date': event.date
}
)
documents.append(doc)
# Cache document
self.cache_manager.set_document_cache(str(event_idx), doc)
# Store documents
self.store.write_documents(documents)
def search(self,
query: str,
event_type: Optional[str] = None,
semester: Optional[str] = None,
top_k: int = 5) -> List[Document]:
"""Search with query caching"""
# Check cache first
cache_key = json.dumps({
'query': query,
'event_type': event_type,
'semester': semester,
'top_k': top_k
})
cached_results = self.cache_manager.get_query_cache(cache_key)
if cached_results is not None:
return cached_results
# Compute query embedding
query_embedding = self._compute_embedding(query)
# Perform search
retriever = InMemoryEmbeddingRetriever(
document_store=self.store,
top_k=top_k * 2
)
results = retriever.run(query_embedding=query_embedding)["documents"]
# Filter results
filtered_results = []
for doc in results:
if event_type and doc.meta['event_type'] != event_type:
continue
if semester and doc.meta['semester'] != semester:
continue
filtered_results.append(doc)
final_results = filtered_results[:top_k]
# Cache results
self.cache_manager.set_query_cache(cache_key, final_results)
return final_results
class AdvancedQueryProcessor:
"""Process queries with better understanding"""
def __init__(self, config: PipelineConfig):
self.generator = OpenAIGenerator(
api_key=Secret.from_token(config.model.openai_api_key),
model=config.model.openai_model
)
self.prompt_builder = PromptBuilder(
template="""
Analyze this academic calendar query (in Thai):
Query: {{query}}
Determine:
1. The type of information being requested
2. Any specific semester mentioned
3. Key terms to look for
Return as JSON:
{
"event_type": "registration|deadline|examination|academic|holiday",
"semester": "term mentioned or null",
"key_terms": ["up to 3 most important terms"],
"response_format": "list|single|detailed"
}
""")
def process_query(self, query: str) -> Dict[str, Any]:
"""Process and analyze query"""
try:
# Get analysis
result = self.prompt_builder.run(query=query)
response = self.generator.run(prompt=result["prompt"])
# Add validation for empty response
if not response or not response.get("replies") or not response["replies"][0]:
logger.warning("Received empty response from generator")
return self._get_default_analysis(query)
try:
# Parse response with error handling
analysis = json.loads(response["replies"][0])
# Validate required fields
required_fields = ["event_type", "semester", "key_terms", "response_format"]
for field in required_fields:
if field not in analysis:
logger.warning(f"Missing required field: {field}")
return self._get_default_analysis(query)
return {
"original_query": query,
**analysis
}
except json.JSONDecodeError as je:
logger.error(f"JSON parsing failed: {str(je)}")
return self._get_default_analysis(query)
except Exception as e:
logger.error(f"Query processing failed: {str(e)}")
return self._get_default_analysis(query)
def _get_default_analysis(self, query: str) -> Dict[str, Any]:
"""Return default analysis when processing fails"""
logger.info("Returning default analysis")
return {
"original_query": query,
"event_type": None,
"semester": None,
"key_terms": [],
"response_format": "detailed"
}
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting"""
requests_per_minute: int = 60
max_retries: int = 3
base_delay: float = 1.0
max_delay: float = 60.0
timeout: float = 30.0
concurrent_requests: int = 5
class APIError(Exception):
"""Base class for API related errors"""
def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Dict] = None):
super().__init__(message)
self.status_code = status_code
self.response = response
class RateLimitExceededError(APIError):
"""Raised when rate limit is exceeded"""
pass
class OpenAIRateLimiter:
"""Rate limiter with advanced error handling for OpenAI API"""
def __init__(self, config: RateLimitConfig):
self.config = config
self.requests = deque(maxlen=config.requests_per_minute)
self.semaphore = asyncio.Semaphore(config.concurrent_requests)
self.total_requests = 0
self.errors = deque(maxlen=1000) # Store recent errors
self.start_time = datetime.now()
async def acquire(self):
"""Acquire permission to make a request"""
now = time.time()
# Clean old requests
while self.requests and self.requests[0] < now - 60:
self.requests.popleft()
# Check if we're at the limit
if len(self.requests) >= self.config.requests_per_minute:
wait_time = 60 - (now - self.requests[0])
logger.warning(f"Rate limit reached. Waiting {wait_time:.2f} seconds")
await asyncio.sleep(wait_time)
# Add new request timestamp
self.requests.append(now)
self.total_requests += 1
def get_usage_stats(self) -> Dict[str, Any]:
"""Get current usage statistics"""
return {
"total_requests": self.total_requests,
"current_rpm": len(self.requests),
"uptime": (datetime.now() - self.start_time).total_seconds(),
"error_rate": len(self.errors) / self.total_requests if self.total_requests > 0 else 0
}
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
reraise=True
)
async def execute_with_retry(self, func, *args, **kwargs):
"""Execute API call with retry logic"""
try:
async with self.semaphore:
await self.acquire()
return await func(*args, **kwargs)
except Exception as e:
error_info = {
"timestamp": datetime.now(),
"error_type": type(e).__name__,
"message": str(e)
}
self.errors.append(error_info)
if isinstance(e, RateLimitExceededError):
logger.warning("Rate limit exceeded, backing off...")
await asyncio.sleep(self.config.base_delay)
raise
elif "timeout" in str(e).lower():
logger.error(f"Timeout error: {str(e)}")
raise APIError(f"Request timed out after {self.config.timeout} seconds")
else:
logger.error(f"API error: {str(e)}")
raise
class ResponseGenerator:
"""Generate responses with better context utilization"""
def __init__(self, config: PipelineConfig):
self.generator = OpenAIGenerator(
api_key=Secret.from_token(config.model.openai_api_key),
model=config.model.openai_model
)
self.prompt_builder = PromptBuilder(
template="""
You are a helpful academic advisor. Answer the following query using the provided calendar information.
Query: {{query}}
Relevant Calendar Information:
{% for doc in context %}
---
{{doc.content}}
{% endfor %}
Format: {{format}}
Guidelines:
1. Answer in Thai language
2. Be specific about dates and requirements
3. Include relevant notes or conditions
4. Format the response according to the specified format
Provide your response:
""")
def generate_response(self,
query: str,
documents: List[Document],
query_info: Dict[str, Any]) -> str:
"""Generate response using retrieved documents"""
try:
result = self.prompt_builder.run(
query=query,
context=documents,
format=query_info["response_format"]
)
response = self.generator.run(prompt=result["prompt"])
return response["replies"][0]
except Exception as e:
logger.error(f"Response generation failed: {str(e)}")
return "ขออภัย ไม่สามารถประมวลผลคำตอบได้ในขณะนี้"
class AcademicCalendarRAG:
"""Main RAG pipeline for academic calendar queries"""
def __init__(self, config: PipelineConfig):
self.config = config
self.document_store = EnhancedDocumentStore(config)
self.query_processor = AdvancedQueryProcessor(config)
self.response_generator = ResponseGenerator(config)
def load_data(self, json_data: List[Dict]):
"""Load and process calendar data"""
processor = CalendarDataProcessor()
events = processor.parse_calendar_json(json_data)
self.document_store.add_events(events)
def process_query(self, query: str) -> Dict[str, Any]:
"""Process query and generate response"""
try:
# Analyze query
query_info = self.query_processor.process_query(query)
# Retrieve relevant documents
documents = self.document_store.search(
query=query,
event_type=query_info["event_type"],
semester=query_info["semester"],
top_k=self.config.retriever.top_k
)
# Generate response
response = self.response_generator.generate_response(
query=query,
documents=documents,
query_info=query_info
)
return {
"answer": response,
"documents": documents,
"query_info": query_info
}
except Exception as e:
logger.error(f"Query processing failed: {str(e)}")
return {
"answer": "ขออภัย ไม่สามารถประมวลผลคำถามได้ในขณะนี้",
"documents": [],
"query_info": {}
}
# def main():
# """Main function for processing real calendar queries"""
# try:
# # Load API key
# with open("key.txt", "r") as f:
# openai_api_key = f.read().strip()
# # Use create_default_config instead of direct PipelineConfig initialization
# config = create_default_config(openai_api_key)
# # Customize config for Thai academic calendar use case
# config.localization.enable_thai_normalization = True
# config.retriever.top_k = 5 # Adjust based on your needs
# config.model.temperature = 0.3 # Lower temperature for more focused responses
# # Initialize pipeline with enhanced config
# pipeline = AcademicCalendarRAG(config)
# # Load calendar data
# with open("calendar.json", "r", encoding="utf-8") as f:
# calendar_data = json.load(f)
# pipeline.load_data(calendar_data)
# # Real queries to process
# queries = ["นิสิตที่เข้าศึกษาในภาคเรียนที่ 1 ปีการศึกษา 2567 สามารถถอนรายวิชาได้หรือไม่? เพราะเหตุใด?"]
# print("Processing calendar queries...")
# print("=" * 80)
# for query in queries:
# result = pipeline.process_query(query)
# print(f"\nQuery: {query}")
# print(f"Answer: {result['answer']}")
# # # Print retrieved documents for verification
# # print("\nRetrieved Documents:")
# # for i, doc in enumerate(result['documents'], 1):
# # print(f"\nDocument {i}:")
# # print(doc.content)
# # # Print query understanding info
# # print("\nQuery Understanding:")
# # for key, value in result['query_info'].items():
# # print(f"{key}: {value}")
# print("=" * 80)
# except Exception as e:
# logger.error(f"Pipeline execution failed: {str(e)}")
# raise
# if __name__ == "__main__":
# main()