|
from haystack import Pipeline, Document |
|
from haystack.components.generators.openai import OpenAIGenerator |
|
from haystack.components.builders import PromptBuilder |
|
from haystack.components.embedders import SentenceTransformersDocumentEmbedder |
|
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever |
|
from haystack.document_stores.in_memory import InMemoryDocumentStore |
|
from haystack.utils import Secret |
|
from pathlib import Path |
|
import logging |
|
from dataclasses import dataclass, field |
|
from typing import List, Dict, Any, Optional |
|
import json |
|
import asyncio |
|
from datetime import datetime |
|
import re |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
@dataclass |
|
class LocalizationConfig: |
|
"""Configuration for Thai language handling""" |
|
thai_tokenizer_model: str = "thai-tokenizer" |
|
enable_thai_normalization: bool = True |
|
remove_thai_tones: bool = False |
|
keep_english: bool = True |
|
custom_stopwords: List[str] = field(default_factory=list) |
|
custom_synonyms: Dict[str, List[str]] = field(default_factory=dict) |
|
|
|
@dataclass |
|
class RetrieverConfig: |
|
"""Configuration for document retrieval""" |
|
top_k: int = 5 |
|
similarity_threshold: float = 0.7 |
|
filter_duplicates: bool = True |
|
|
|
@dataclass |
|
class ModelConfig: |
|
"""Configuration for language models""" |
|
openai_api_key: str |
|
temperature: float = 0.3 |
|
max_tokens: int = 2000 |
|
model: str = "gpt-4" |
|
embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
|
|
@dataclass |
|
class PipelineConfig: |
|
"""Main configuration for the RAG pipeline""" |
|
model: ModelConfig |
|
retriever: RetrieverConfig = field(default_factory=RetrieverConfig) |
|
localization: LocalizationConfig = field(default_factory=LocalizationConfig) |
|
|
|
def __post_init__(self): |
|
if not self.model.openai_api_key: |
|
raise ValueError("OpenAI API key is required") |
|
|
|
class ThaiTextPreprocessor: |
|
"""Thai text preprocessing utilities""" |
|
|
|
@staticmethod |
|
def normalize_thai_text(text: str) -> str: |
|
"""Normalize Thai text""" |
|
if not text: |
|
return text |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text.strip()) |
|
|
|
|
|
thai_digits = '๐๑๒๓๔๕๖๗๘๙' |
|
arabic_digits = '0123456789' |
|
for thai, arabic in zip(thai_digits, arabic_digits): |
|
text = text.replace(thai, arabic) |
|
|
|
return text |
|
|
|
class CalendarEvent: |
|
"""Represents an academic calendar event""" |
|
|
|
def __init__(self, |
|
date: str, |
|
activity: str, |
|
semester: str, |
|
event_type: str = "academic", |
|
note: str = "", |
|
time: str = "", |
|
section: Optional[str] = None): |
|
self.date = date |
|
self.activity = activity |
|
self.semester = semester |
|
self.event_type = event_type |
|
self.note = note |
|
self.time = time |
|
self.section = section |
|
|
|
def to_searchable_text(self) -> str: |
|
"""Convert event to searchable text format""" |
|
return f""" |
|
ภาคการศึกษา: {self.semester} |
|
ประเภท: {self.event_type} |
|
วันที่: {self.date} |
|
เวลา: {self.time or '-'} |
|
กิจกรรม: {self.activity} |
|
หมวดหมู่: {self.section or '-'} |
|
หมายเหตุ: {self.note or '-'} |
|
""".strip() |
|
|
|
@staticmethod |
|
def from_dict(data: Dict[str, Any]) -> 'CalendarEvent': |
|
"""Create event from dictionary""" |
|
return CalendarEvent( |
|
date=data.get('date', ''), |
|
activity=data.get('activity', ''), |
|
semester=data.get('semester', ''), |
|
event_type=data.get('event_type', 'academic'), |
|
note=data.get('note', ''), |
|
time=data.get('time', ''), |
|
section=data.get('section') |
|
) |
|
|
|
class CalendarRAG: |
|
"""Main RAG pipeline for academic calendar""" |
|
|
|
def __init__(self, config: PipelineConfig): |
|
"""Initialize the pipeline with configuration""" |
|
self.config = config |
|
self.document_store = InMemoryDocumentStore() |
|
self.embedder = SentenceTransformersDocumentEmbedder( |
|
model=config.model.embedder_model |
|
) |
|
self.text_preprocessor = ThaiTextPreprocessor() |
|
|
|
|
|
self.generator = OpenAIGenerator( |
|
api_key=Secret.from_token(config.model.openai_api_key), |
|
model=config.model.model, |
|
temperature=config.model.temperature |
|
) |
|
|
|
self.query_analyzer = PromptBuilder( |
|
template=""" |
|
วิเคราะห์คำถามเกี่ยวกับปฏิทินการศึกษานี้: |
|
คำถาม: {{query}} |
|
|
|
กรุณาระบุ: |
|
1. ประเภทของข้อมูลที่ต้องการ |
|
2. ภาคการศึกษาที่เกี่ยวข้อง |
|
3. คำสำคัญที่ต้องค้นหา |
|
|
|
ตอบในรูปแบบ JSON: |
|
{ |
|
"event_type": "registration|deadline|examination|academic|holiday", |
|
"semester": "ภาคการศึกษาที่ระบุ หรือ null", |
|
"key_terms": ["คำสำคัญไม่เกิน 3 คำ"] |
|
} |
|
""" |
|
) |
|
|
|
self.answer_generator = PromptBuilder( |
|
template=""" |
|
คุณเป็นผู้ช่วยให้ข้อมูลปฏิทินการศึกษา กรุณาตอบคำถามต่อไปนี้โดยใช้ข้อมูลที่ให้มา: |
|
|
|
คำถาม: {{query}} |
|
|
|
ข้อมูลที่เกี่ยวข้อง: |
|
{% for doc in documents %} |
|
--- |
|
{{doc.content}} |
|
{% endfor %} |
|
|
|
คำแนะนำ: |
|
1. ตอบเป็นภาษาไทย |
|
2. ระบุวันที่และข้อกำหนดให้ชัดเจน |
|
3. รวมหมายเหตุหรือเงื่อนไขที่สำคัญ |
|
""" |
|
) |
|
|
|
def load_data(self, calendar_data: List[Dict[str, Any]]) -> None: |
|
"""Load calendar data into the system""" |
|
documents = [] |
|
|
|
for entry in calendar_data: |
|
|
|
event = CalendarEvent.from_dict(entry) |
|
|
|
|
|
doc = Document( |
|
content=event.to_searchable_text(), |
|
meta={ |
|
"event_type": event.event_type, |
|
"semester": event.semester, |
|
"date": event.date |
|
} |
|
) |
|
documents.append(doc) |
|
|
|
|
|
embedded_docs = self.embedder.run(documents=documents)["documents"] |
|
|
|
|
|
self.document_store.write_documents(embedded_docs) |
|
|
|
def process_query(self, query: str) -> Dict[str, Any]: |
|
"""Process a calendar query and return results""" |
|
try: |
|
|
|
query_info = self._analyze_query(query) |
|
|
|
|
|
documents = self._retrieve_documents( |
|
query, |
|
event_type=query_info.get("event_type"), |
|
semester=query_info.get("semester") |
|
) |
|
|
|
|
|
answer = self._generate_answer(query, documents) |
|
|
|
return { |
|
"answer": answer, |
|
"documents": documents, |
|
"query_info": query_info |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Query processing failed: {str(e)}") |
|
return { |
|
"answer": "ขออภัย ไม่สามารถประมวลผลคำถามได้ในขณะนี้", |
|
"documents": [], |
|
"query_info": {} |
|
} |
|
|
|
def _analyze_query(self, query: str) -> Dict[str, Any]: |
|
"""Analyze and extract information from query""" |
|
try: |
|
|
|
normalized_query = self.text_preprocessor.normalize_thai_text(query) |
|
|
|
|
|
prompt_result = self.query_analyzer.run(query=normalized_query) |
|
response = self.generator.run(prompt=prompt_result["prompt"]) |
|
|
|
if not response or not response.get("replies"): |
|
raise ValueError("Empty response from query analyzer") |
|
|
|
analysis = json.loads(response["replies"][0]) |
|
analysis["original_query"] = query |
|
|
|
return analysis |
|
|
|
except Exception as e: |
|
logger.error(f"Query analysis failed: {str(e)}") |
|
return { |
|
"original_query": query, |
|
"event_type": None, |
|
"semester": None, |
|
"key_terms": [] |
|
} |
|
|
|
def _retrieve_documents(self, |
|
query: str, |
|
event_type: Optional[str] = None, |
|
semester: Optional[str] = None) -> List[Document]: |
|
"""Retrieve relevant documents""" |
|
|
|
retriever = InMemoryEmbeddingRetriever( |
|
document_store=self.document_store, |
|
top_k=self.config.retriever.top_k |
|
) |
|
|
|
|
|
query_doc = Document(content=query) |
|
embedded_query = self.embedder.run(documents=[query_doc])["documents"][0] |
|
|
|
|
|
results = retriever.run(query_embedding=embedded_query.embedding)["documents"] |
|
|
|
|
|
filtered_results = [] |
|
for doc in results: |
|
if event_type and doc.meta['event_type'] != event_type: |
|
continue |
|
if semester and doc.meta['semester'] != semester: |
|
continue |
|
filtered_results.append(doc) |
|
|
|
return filtered_results[:self.config.retriever.top_k] |
|
|
|
def _generate_answer(self, query: str, documents: List[Document]) -> str: |
|
"""Generate answer from retrieved documents""" |
|
try: |
|
prompt_result = self.answer_generator.run( |
|
query=query, |
|
documents=documents |
|
) |
|
|
|
response = self.generator.run(prompt=prompt_result["prompt"]) |
|
|
|
if not response or not response.get("replies"): |
|
raise ValueError("Empty response from answer generator") |
|
|
|
return response["replies"][0] |
|
|
|
except Exception as e: |
|
logger.error(f"Answer generation failed: {str(e)}") |
|
return "ขออภัย ไม่สามารถสร้างคำตอบได้ในขณะนี้" |
|
|
|
def create_default_config(api_key: str) -> PipelineConfig: |
|
"""Create default pipeline configuration""" |
|
model_config = ModelConfig(openai_api_key=api_key) |
|
return PipelineConfig( |
|
model=model_config, |
|
retriever=RetrieverConfig(), |
|
localization=LocalizationConfig() |
|
) |