JirasakJo's picture
Create calendar_rag.py
bc852b0 verified
raw
history blame
11.8 kB
from haystack import Pipeline, Document
from haystack.components.generators.openai import OpenAIGenerator
from haystack.components.builders import PromptBuilder
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.utils import Secret
from pathlib import Path
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
import json
import asyncio
from datetime import datetime
import re
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class LocalizationConfig:
"""Configuration for Thai language handling"""
thai_tokenizer_model: str = "thai-tokenizer"
enable_thai_normalization: bool = True
remove_thai_tones: bool = False
keep_english: bool = True
custom_stopwords: List[str] = field(default_factory=list)
custom_synonyms: Dict[str, List[str]] = field(default_factory=dict)
@dataclass
class RetrieverConfig:
"""Configuration for document retrieval"""
top_k: int = 5
similarity_threshold: float = 0.7
filter_duplicates: bool = True
@dataclass
class ModelConfig:
"""Configuration for language models"""
openai_api_key: str
temperature: float = 0.3
max_tokens: int = 2000
model: str = "gpt-4"
embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@dataclass
class PipelineConfig:
"""Main configuration for the RAG pipeline"""
model: ModelConfig
retriever: RetrieverConfig = field(default_factory=RetrieverConfig)
localization: LocalizationConfig = field(default_factory=LocalizationConfig)
def __post_init__(self):
if not self.model.openai_api_key:
raise ValueError("OpenAI API key is required")
class ThaiTextPreprocessor:
"""Thai text preprocessing utilities"""
@staticmethod
def normalize_thai_text(text: str) -> str:
"""Normalize Thai text"""
if not text:
return text
# Normalize whitespace
text = re.sub(r'\s+', ' ', text.strip())
# Normalize Thai numerals
thai_digits = '๐๑๒๓๔๕๖๗๘๙'
arabic_digits = '0123456789'
for thai, arabic in zip(thai_digits, arabic_digits):
text = text.replace(thai, arabic)
return text
class CalendarEvent:
"""Represents an academic calendar event"""
def __init__(self,
date: str,
activity: str,
semester: str,
event_type: str = "academic",
note: str = "",
time: str = "",
section: Optional[str] = None):
self.date = date
self.activity = activity
self.semester = semester
self.event_type = event_type
self.note = note
self.time = time
self.section = section
def to_searchable_text(self) -> str:
"""Convert event to searchable text format"""
return f"""
ภาคการศึกษา: {self.semester}
ประเภท: {self.event_type}
วันที่: {self.date}
เวลา: {self.time or '-'}
กิจกรรม: {self.activity}
หมวดหมู่: {self.section or '-'}
หมายเหตุ: {self.note or '-'}
""".strip()
@staticmethod
def from_dict(data: Dict[str, Any]) -> 'CalendarEvent':
"""Create event from dictionary"""
return CalendarEvent(
date=data.get('date', ''),
activity=data.get('activity', ''),
semester=data.get('semester', ''),
event_type=data.get('event_type', 'academic'),
note=data.get('note', ''),
time=data.get('time', ''),
section=data.get('section')
)
class CalendarRAG:
"""Main RAG pipeline for academic calendar"""
def __init__(self, config: PipelineConfig):
"""Initialize the pipeline with configuration"""
self.config = config
self.document_store = InMemoryDocumentStore()
self.embedder = SentenceTransformersDocumentEmbedder(
model=config.model.embedder_model
)
self.text_preprocessor = ThaiTextPreprocessor()
# Initialize OpenAI components
self.generator = OpenAIGenerator(
api_key=Secret.from_token(config.model.openai_api_key),
model=config.model.model,
temperature=config.model.temperature
)
self.query_analyzer = PromptBuilder(
template="""
วิเคราะห์คำถามเกี่ยวกับปฏิทินการศึกษานี้:
คำถาม: {{query}}
กรุณาระบุ:
1. ประเภทของข้อมูลที่ต้องการ
2. ภาคการศึกษาที่เกี่ยวข้อง
3. คำสำคัญที่ต้องค้นหา
ตอบในรูปแบบ JSON:
{
"event_type": "registration|deadline|examination|academic|holiday",
"semester": "ภาคการศึกษาที่ระบุ หรือ null",
"key_terms": ["คำสำคัญไม่เกิน 3 คำ"]
}
"""
)
self.answer_generator = PromptBuilder(
template="""
คุณเป็นผู้ช่วยให้ข้อมูลปฏิทินการศึกษา กรุณาตอบคำถามต่อไปนี้โดยใช้ข้อมูลที่ให้มา:
คำถาม: {{query}}
ข้อมูลที่เกี่ยวข้อง:
{% for doc in documents %}
---
{{doc.content}}
{% endfor %}
คำแนะนำ:
1. ตอบเป็นภาษาไทย
2. ระบุวันที่และข้อกำหนดให้ชัดเจน
3. รวมหมายเหตุหรือเงื่อนไขที่สำคัญ
"""
)
def load_data(self, calendar_data: List[Dict[str, Any]]) -> None:
"""Load calendar data into the system"""
documents = []
for entry in calendar_data:
# Create calendar event
event = CalendarEvent.from_dict(entry)
# Create searchable document
doc = Document(
content=event.to_searchable_text(),
meta={
"event_type": event.event_type,
"semester": event.semester,
"date": event.date
}
)
documents.append(doc)
# Compute embeddings
embedded_docs = self.embedder.run(documents=documents)["documents"]
# Store documents
self.document_store.write_documents(embedded_docs)
def process_query(self, query: str) -> Dict[str, Any]:
"""Process a calendar query and return results"""
try:
# Analyze query
query_info = self._analyze_query(query)
# Retrieve relevant documents
documents = self._retrieve_documents(
query,
event_type=query_info.get("event_type"),
semester=query_info.get("semester")
)
# Generate answer
answer = self._generate_answer(query, documents)
return {
"answer": answer,
"documents": documents,
"query_info": query_info
}
except Exception as e:
logger.error(f"Query processing failed: {str(e)}")
return {
"answer": "ขออภัย ไม่สามารถประมวลผลคำถามได้ในขณะนี้",
"documents": [],
"query_info": {}
}
def _analyze_query(self, query: str) -> Dict[str, Any]:
"""Analyze and extract information from query"""
try:
# Normalize query
normalized_query = self.text_preprocessor.normalize_thai_text(query)
# Get analysis from OpenAI
prompt_result = self.query_analyzer.run(query=normalized_query)
response = self.generator.run(prompt=prompt_result["prompt"])
if not response or not response.get("replies"):
raise ValueError("Empty response from query analyzer")
analysis = json.loads(response["replies"][0])
analysis["original_query"] = query
return analysis
except Exception as e:
logger.error(f"Query analysis failed: {str(e)}")
return {
"original_query": query,
"event_type": None,
"semester": None,
"key_terms": []
}
def _retrieve_documents(self,
query: str,
event_type: Optional[str] = None,
semester: Optional[str] = None) -> List[Document]:
"""Retrieve relevant documents"""
# Create retriever
retriever = InMemoryEmbeddingRetriever(
document_store=self.document_store,
top_k=self.config.retriever.top_k
)
# Get query embedding
query_doc = Document(content=query)
embedded_query = self.embedder.run(documents=[query_doc])["documents"][0]
# Retrieve documents
results = retriever.run(query_embedding=embedded_query.embedding)["documents"]
# Filter results if needed
filtered_results = []
for doc in results:
if event_type and doc.meta['event_type'] != event_type:
continue
if semester and doc.meta['semester'] != semester:
continue
filtered_results.append(doc)
return filtered_results[:self.config.retriever.top_k]
def _generate_answer(self, query: str, documents: List[Document]) -> str:
"""Generate answer from retrieved documents"""
try:
prompt_result = self.answer_generator.run(
query=query,
documents=documents
)
response = self.generator.run(prompt=prompt_result["prompt"])
if not response or not response.get("replies"):
raise ValueError("Empty response from answer generator")
return response["replies"][0]
except Exception as e:
logger.error(f"Answer generation failed: {str(e)}")
return "ขออภัย ไม่สามารถสร้างคำตอบได้ในขณะนี้"
def create_default_config(api_key: str) -> PipelineConfig:
"""Create default pipeline configuration"""
model_config = ModelConfig(openai_api_key=api_key)
return PipelineConfig(
model=model_config,
retriever=RetrieverConfig(),
localization=LocalizationConfig()
)