JirasakJo commited on
Commit
bc852b0
·
verified ·
1 Parent(s): 372fdf9

Create calendar_rag.py

Browse files
Files changed (1) hide show
  1. calendar_rag.py +316 -0
calendar_rag.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from haystack import Pipeline, Document
2
+ from haystack.components.generators.openai import OpenAIGenerator
3
+ from haystack.components.builders import PromptBuilder
4
+ from haystack.components.embedders import SentenceTransformersDocumentEmbedder
5
+ from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
6
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
7
+ from haystack.utils import Secret
8
+ from pathlib import Path
9
+ import logging
10
+ from dataclasses import dataclass, field
11
+ from typing import List, Dict, Any, Optional
12
+ import json
13
+ import asyncio
14
+ from datetime import datetime
15
+ import re
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ @dataclass
22
+ class LocalizationConfig:
23
+ """Configuration for Thai language handling"""
24
+ thai_tokenizer_model: str = "thai-tokenizer"
25
+ enable_thai_normalization: bool = True
26
+ remove_thai_tones: bool = False
27
+ keep_english: bool = True
28
+ custom_stopwords: List[str] = field(default_factory=list)
29
+ custom_synonyms: Dict[str, List[str]] = field(default_factory=dict)
30
+
31
+ @dataclass
32
+ class RetrieverConfig:
33
+ """Configuration for document retrieval"""
34
+ top_k: int = 5
35
+ similarity_threshold: float = 0.7
36
+ filter_duplicates: bool = True
37
+
38
+ @dataclass
39
+ class ModelConfig:
40
+ """Configuration for language models"""
41
+ openai_api_key: str
42
+ temperature: float = 0.3
43
+ max_tokens: int = 2000
44
+ model: str = "gpt-4"
45
+ embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
46
+
47
+ @dataclass
48
+ class PipelineConfig:
49
+ """Main configuration for the RAG pipeline"""
50
+ model: ModelConfig
51
+ retriever: RetrieverConfig = field(default_factory=RetrieverConfig)
52
+ localization: LocalizationConfig = field(default_factory=LocalizationConfig)
53
+
54
+ def __post_init__(self):
55
+ if not self.model.openai_api_key:
56
+ raise ValueError("OpenAI API key is required")
57
+
58
+ class ThaiTextPreprocessor:
59
+ """Thai text preprocessing utilities"""
60
+
61
+ @staticmethod
62
+ def normalize_thai_text(text: str) -> str:
63
+ """Normalize Thai text"""
64
+ if not text:
65
+ return text
66
+
67
+ # Normalize whitespace
68
+ text = re.sub(r'\s+', ' ', text.strip())
69
+
70
+ # Normalize Thai numerals
71
+ thai_digits = '๐๑๒๓๔๕๖๗๘๙'
72
+ arabic_digits = '0123456789'
73
+ for thai, arabic in zip(thai_digits, arabic_digits):
74
+ text = text.replace(thai, arabic)
75
+
76
+ return text
77
+
78
+ class CalendarEvent:
79
+ """Represents an academic calendar event"""
80
+
81
+ def __init__(self,
82
+ date: str,
83
+ activity: str,
84
+ semester: str,
85
+ event_type: str = "academic",
86
+ note: str = "",
87
+ time: str = "",
88
+ section: Optional[str] = None):
89
+ self.date = date
90
+ self.activity = activity
91
+ self.semester = semester
92
+ self.event_type = event_type
93
+ self.note = note
94
+ self.time = time
95
+ self.section = section
96
+
97
+ def to_searchable_text(self) -> str:
98
+ """Convert event to searchable text format"""
99
+ return f"""
100
+ ภาคการศึกษา: {self.semester}
101
+ ประเภท: {self.event_type}
102
+ วันที่: {self.date}
103
+ เวลา: {self.time or '-'}
104
+ กิจกรรม: {self.activity}
105
+ หมวดหมู่: {self.section or '-'}
106
+ หมายเหตุ: {self.note or '-'}
107
+ """.strip()
108
+
109
+ @staticmethod
110
+ def from_dict(data: Dict[str, Any]) -> 'CalendarEvent':
111
+ """Create event from dictionary"""
112
+ return CalendarEvent(
113
+ date=data.get('date', ''),
114
+ activity=data.get('activity', ''),
115
+ semester=data.get('semester', ''),
116
+ event_type=data.get('event_type', 'academic'),
117
+ note=data.get('note', ''),
118
+ time=data.get('time', ''),
119
+ section=data.get('section')
120
+ )
121
+
122
+ class CalendarRAG:
123
+ """Main RAG pipeline for academic calendar"""
124
+
125
+ def __init__(self, config: PipelineConfig):
126
+ """Initialize the pipeline with configuration"""
127
+ self.config = config
128
+ self.document_store = InMemoryDocumentStore()
129
+ self.embedder = SentenceTransformersDocumentEmbedder(
130
+ model=config.model.embedder_model
131
+ )
132
+ self.text_preprocessor = ThaiTextPreprocessor()
133
+
134
+ # Initialize OpenAI components
135
+ self.generator = OpenAIGenerator(
136
+ api_key=Secret.from_token(config.model.openai_api_key),
137
+ model=config.model.model,
138
+ temperature=config.model.temperature
139
+ )
140
+
141
+ self.query_analyzer = PromptBuilder(
142
+ template="""
143
+ วิเคราะห์คำถามเกี่ยวกับปฏิทินการศึกษานี้:
144
+ คำถาม: {{query}}
145
+
146
+ กรุณาระบุ:
147
+ 1. ประเภทของข้อมูลที่ต้องการ
148
+ 2. ภาคการศึกษาที่เกี่ยวข้อง
149
+ 3. คำสำคัญที่ต้องค้นหา
150
+
151
+ ตอบในรูปแบบ JSON:
152
+ {
153
+ "event_type": "registration|deadline|examination|academic|holiday",
154
+ "semester": "ภาคการศึกษาที่ระบุ หรือ null",
155
+ "key_terms": ["คำสำคัญไม่เกิน 3 คำ"]
156
+ }
157
+ """
158
+ )
159
+
160
+ self.answer_generator = PromptBuilder(
161
+ template="""
162
+ คุณเป็นผู้ช่วยให้ข้อมูลปฏิทินการศึกษา กรุณาตอบคำถามต่อไปนี้โดยใช้ข้อมูลที่ให้มา:
163
+
164
+ คำถาม: {{query}}
165
+
166
+ ข้อมูลที่เกี่ยวข้อง:
167
+ {% for doc in documents %}
168
+ ---
169
+ {{doc.content}}
170
+ {% endfor %}
171
+
172
+ คำแนะนำ:
173
+ 1. ตอบเป็นภาษาไทย
174
+ 2. ระบุวันที่และข้อกำหนดให้ชัดเจน
175
+ 3. รวมหมายเหตุหรือเงื่อนไขที่สำคัญ
176
+ """
177
+ )
178
+
179
+ def load_data(self, calendar_data: List[Dict[str, Any]]) -> None:
180
+ """Load calendar data into the system"""
181
+ documents = []
182
+
183
+ for entry in calendar_data:
184
+ # Create calendar event
185
+ event = CalendarEvent.from_dict(entry)
186
+
187
+ # Create searchable document
188
+ doc = Document(
189
+ content=event.to_searchable_text(),
190
+ meta={
191
+ "event_type": event.event_type,
192
+ "semester": event.semester,
193
+ "date": event.date
194
+ }
195
+ )
196
+ documents.append(doc)
197
+
198
+ # Compute embeddings
199
+ embedded_docs = self.embedder.run(documents=documents)["documents"]
200
+
201
+ # Store documents
202
+ self.document_store.write_documents(embedded_docs)
203
+
204
+ def process_query(self, query: str) -> Dict[str, Any]:
205
+ """Process a calendar query and return results"""
206
+ try:
207
+ # Analyze query
208
+ query_info = self._analyze_query(query)
209
+
210
+ # Retrieve relevant documents
211
+ documents = self._retrieve_documents(
212
+ query,
213
+ event_type=query_info.get("event_type"),
214
+ semester=query_info.get("semester")
215
+ )
216
+
217
+ # Generate answer
218
+ answer = self._generate_answer(query, documents)
219
+
220
+ return {
221
+ "answer": answer,
222
+ "documents": documents,
223
+ "query_info": query_info
224
+ }
225
+
226
+ except Exception as e:
227
+ logger.error(f"Query processing failed: {str(e)}")
228
+ return {
229
+ "answer": "ขออภัย ไม่สามารถประมวลผลคำถามได้ในขณะนี้",
230
+ "documents": [],
231
+ "query_info": {}
232
+ }
233
+
234
+ def _analyze_query(self, query: str) -> Dict[str, Any]:
235
+ """Analyze and extract information from query"""
236
+ try:
237
+ # Normalize query
238
+ normalized_query = self.text_preprocessor.normalize_thai_text(query)
239
+
240
+ # Get analysis from OpenAI
241
+ prompt_result = self.query_analyzer.run(query=normalized_query)
242
+ response = self.generator.run(prompt=prompt_result["prompt"])
243
+
244
+ if not response or not response.get("replies"):
245
+ raise ValueError("Empty response from query analyzer")
246
+
247
+ analysis = json.loads(response["replies"][0])
248
+ analysis["original_query"] = query
249
+
250
+ return analysis
251
+
252
+ except Exception as e:
253
+ logger.error(f"Query analysis failed: {str(e)}")
254
+ return {
255
+ "original_query": query,
256
+ "event_type": None,
257
+ "semester": None,
258
+ "key_terms": []
259
+ }
260
+
261
+ def _retrieve_documents(self,
262
+ query: str,
263
+ event_type: Optional[str] = None,
264
+ semester: Optional[str] = None) -> List[Document]:
265
+ """Retrieve relevant documents"""
266
+ # Create retriever
267
+ retriever = InMemoryEmbeddingRetriever(
268
+ document_store=self.document_store,
269
+ top_k=self.config.retriever.top_k
270
+ )
271
+
272
+ # Get query embedding
273
+ query_doc = Document(content=query)
274
+ embedded_query = self.embedder.run(documents=[query_doc])["documents"][0]
275
+
276
+ # Retrieve documents
277
+ results = retriever.run(query_embedding=embedded_query.embedding)["documents"]
278
+
279
+ # Filter results if needed
280
+ filtered_results = []
281
+ for doc in results:
282
+ if event_type and doc.meta['event_type'] != event_type:
283
+ continue
284
+ if semester and doc.meta['semester'] != semester:
285
+ continue
286
+ filtered_results.append(doc)
287
+
288
+ return filtered_results[:self.config.retriever.top_k]
289
+
290
+ def _generate_answer(self, query: str, documents: List[Document]) -> str:
291
+ """Generate answer from retrieved documents"""
292
+ try:
293
+ prompt_result = self.answer_generator.run(
294
+ query=query,
295
+ documents=documents
296
+ )
297
+
298
+ response = self.generator.run(prompt=prompt_result["prompt"])
299
+
300
+ if not response or not response.get("replies"):
301
+ raise ValueError("Empty response from answer generator")
302
+
303
+ return response["replies"][0]
304
+
305
+ except Exception as e:
306
+ logger.error(f"Answer generation failed: {str(e)}")
307
+ return "ขออภัย ไม่สามารถสร้างคำตอบได้ในขณะนี้"
308
+
309
+ def create_default_config(api_key: str) -> PipelineConfig:
310
+ """Create default pipeline configuration"""
311
+ model_config = ModelConfig(openai_api_key=api_key)
312
+ return PipelineConfig(
313
+ model=model_config,
314
+ retriever=RetrieverConfig(),
315
+ localization=LocalizationConfig()
316
+ )