Update calendar_rag.py
Browse files- calendar_rag.py +23 -58
calendar_rag.py
CHANGED
@@ -5,7 +5,6 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder
|
|
5 |
from haystack.components.retrievers.in_memory import *
|
6 |
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
7 |
from haystack.utils import Secret
|
8 |
-
from sentence_transformers import CrossEncoder
|
9 |
from pathlib import Path
|
10 |
import hashlib
|
11 |
from datetime import *
|
@@ -119,17 +118,7 @@ class TuitionFee:
|
|
119 |
event_type: str
|
120 |
regular_fee: RegularFee
|
121 |
late_payment_fee: LatePaymentFee
|
122 |
-
|
123 |
-
class SentenceTransformersCrossEncoder:
|
124 |
-
"""Wrapper for the sentence-transformers CrossEncoder for compatibility with the existing code"""
|
125 |
|
126 |
-
def __init__(self, model_name_or_path: str):
|
127 |
-
"""Initialize the cross-encoder model"""
|
128 |
-
self.model = CrossEncoder(model_name_or_path)
|
129 |
-
|
130 |
-
def predict(self, sentence_pairs: List[Tuple[str, str]]) -> List[float]:
|
131 |
-
"""Predict relevance scores for sentence pairs"""
|
132 |
-
return self.model.predict(sentence_pairs)
|
133 |
|
134 |
class OpenAIDateParser:
|
135 |
"""Uses OpenAI to parse complex Thai date formats"""
|
@@ -425,21 +414,16 @@ class CacheManager:
|
|
425 |
self.document_cache[doc_id] = (document, datetime.now())
|
426 |
self._save_cache("documents", self.document_cache)
|
427 |
|
428 |
-
@dataclass
|
429 |
@dataclass
|
430 |
class ModelConfig:
|
431 |
openai_api_key: str
|
432 |
embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
433 |
openai_model: str = "gpt-4o"
|
434 |
temperature: float = 0.7
|
435 |
-
reranker_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" # Add this
|
436 |
|
437 |
@dataclass
|
438 |
class RetrieverConfig:
|
439 |
top_k: int = 5
|
440 |
-
use_reranking: bool = True # Add this flag
|
441 |
-
top_k_initial: int = 20 # Add this parameter
|
442 |
-
top_k_final: int = 5 # Add this parameter
|
443 |
|
444 |
@dataclass
|
445 |
class CacheConfig:
|
@@ -465,8 +449,7 @@ class PipelineConfig:
|
|
465 |
|
466 |
def create_default_config(api_key: str) -> PipelineConfig:
|
467 |
"""
|
468 |
-
Create a default pipeline configuration with optimized settings for Thai language processing
|
469 |
-
including reranking capabilities.
|
470 |
|
471 |
Args:
|
472 |
api_key (str): OpenAI API key
|
@@ -477,14 +460,10 @@ def create_default_config(api_key: str) -> PipelineConfig:
|
|
477 |
return PipelineConfig(
|
478 |
model=ModelConfig(
|
479 |
openai_api_key=api_key,
|
480 |
-
embedder_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
481 |
temperature=0.3 # Lower temperature for more focused responses
|
482 |
),
|
483 |
retriever=RetrieverConfig(
|
484 |
-
top_k=5
|
485 |
-
use_reranking=True, # Enable reranking
|
486 |
-
top_k_initial=20, # Retrieve more initial documents for reranking
|
487 |
-
top_k_final=5 # Final number of documents after reranking
|
488 |
),
|
489 |
cache=CacheConfig(
|
490 |
enabled=True,
|
@@ -1300,13 +1279,13 @@ class HybridDocumentStore:
|
|
1300 |
return sorted_docs[:top_k]
|
1301 |
|
1302 |
def search_with_reranking(self,
|
1303 |
-
|
1304 |
-
|
1305 |
-
|
1306 |
-
|
1307 |
-
|
1308 |
-
|
1309 |
-
|
1310 |
"""
|
1311 |
Two-stage retrieval with hybrid search followed by cross-encoder reranking
|
1312 |
"""
|
@@ -1402,10 +1381,10 @@ class ResponseGenerator:
|
|
1402 |
5. สำหรับคำถามเกี่ยวกับข้อกำหนดภาษาอังกฤษหรือขั้นตอนการสมัคร ให้อธิบายข้อมูลอย่างละเอียด
|
1403 |
6. ใส่ข้อความ "หากมีข้อสงสัยเพิ่มเติม สามารถสอบถามได้" ท้ายคำตอบเสมอ
|
1404 |
7. คำนึงถึงประวัติการสนทนาและให้คำตอบที่ต่อเนื่องกับบทสนทนาก่อนหน้า
|
1405 |
-
8. หากคำถามอ้างอิงถึงข้อมูลในบทสนทนาก่อนหน้า (เช่น "แล้วอันนั้นล่ะ", "มีอะไรอีกบ้าง", "คำถามก่อนหน้า")
|
1406 |
9. กรณีคำถามมีความไม่ชัดเจน ใช้ประวัติการสนทนาเพื่อเข้าใจบริบทของคำถาม
|
1407 |
|
1408 |
-
|
1409 |
|
1410 |
กรุณาตอบเป็นภาษาไทย:
|
1411 |
"""
|
@@ -1623,7 +1602,7 @@ class AdvancedQueryProcessor:
|
|
1623 |
# First, let's modify the AcademicCalendarRAG class to maintain conversation history
|
1624 |
|
1625 |
class AcademicCalendarRAG:
|
1626 |
-
"""Enhanced RAG system for academic calendar and program information with conversation memory
|
1627 |
|
1628 |
def __init__(self, config: PipelineConfig):
|
1629 |
self.config = config
|
@@ -1681,7 +1660,7 @@ class AcademicCalendarRAG:
|
|
1681 |
raise
|
1682 |
|
1683 |
def process_query(self, query: str, conversation_history=None) -> Dict[str, Any]:
|
1684 |
-
"""Process user query using conversation history and hybrid retrieval
|
1685 |
# Use provided conversation history or the internal history
|
1686 |
if conversation_history is not None:
|
1687 |
self.conversation_history = conversation_history
|
@@ -1717,30 +1696,16 @@ class AcademicCalendarRAG:
|
|
1717 |
|
1718 |
weight_semantic = weight_values[attempt - 1]
|
1719 |
|
1720 |
-
# Get relevant documents using
|
1721 |
logger.info(f"Attempt {attempt}: Searching with weight_semantic={weight_semantic}")
|
1722 |
-
|
1723 |
-
|
1724 |
-
|
1725 |
-
|
1726 |
-
|
1727 |
-
|
1728 |
-
|
1729 |
-
|
1730 |
-
top_k_final=self.config.retriever.top_k_final,
|
1731 |
-
weight_semantic=weight_semantic
|
1732 |
-
)
|
1733 |
-
logger.info(f"Using reranking for retrieval, got {len(documents)} documents")
|
1734 |
-
else:
|
1735 |
-
documents = self.document_store.hybrid_search(
|
1736 |
-
query=query_with_context if attempt == 1 else query,
|
1737 |
-
event_type=query_info.get("event_type"),
|
1738 |
-
detail_type=query_info.get("detail_type"),
|
1739 |
-
semester=query_info.get("semester"),
|
1740 |
-
top_k=self.config.retriever.top_k,
|
1741 |
-
weight_semantic=weight_semantic
|
1742 |
-
)
|
1743 |
-
logger.info(f"Using standard hybrid search, got {len(documents)} documents")
|
1744 |
|
1745 |
# Generate response with conversation context
|
1746 |
response = self.response_generator.generate_response(
|
@@ -1792,7 +1757,7 @@ class AcademicCalendarRAG:
|
|
1792 |
|
1793 |
# # Test queries with different semantic weights
|
1794 |
# queries = ["ค่าเทอมเท่าไหร่","เปิดเรียนวันไหน","ขั้นตอนการสมัครที่สาขานี้มีอะไรบ้าง","ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้","ถ้าจะไปติดต่อมาหลายต้องลง mrt อะไร","มีวิชาหลักเเละวิชาเลือกออะไรบ้าง", "ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง", "ปีที่ 2 เทอม 1 ต้องเรียนอะไรบ้าง"]
|
1795 |
-
# # queries = ["
|
1796 |
# print("=" * 80)
|
1797 |
|
1798 |
# for query in queries:
|
|
|
5 |
from haystack.components.retrievers.in_memory import *
|
6 |
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
7 |
from haystack.utils import Secret
|
|
|
8 |
from pathlib import Path
|
9 |
import hashlib
|
10 |
from datetime import *
|
|
|
118 |
event_type: str
|
119 |
regular_fee: RegularFee
|
120 |
late_payment_fee: LatePaymentFee
|
|
|
|
|
|
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
class OpenAIDateParser:
|
124 |
"""Uses OpenAI to parse complex Thai date formats"""
|
|
|
414 |
self.document_cache[doc_id] = (document, datetime.now())
|
415 |
self._save_cache("documents", self.document_cache)
|
416 |
|
|
|
417 |
@dataclass
|
418 |
class ModelConfig:
|
419 |
openai_api_key: str
|
420 |
embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
421 |
openai_model: str = "gpt-4o"
|
422 |
temperature: float = 0.7
|
|
|
423 |
|
424 |
@dataclass
|
425 |
class RetrieverConfig:
|
426 |
top_k: int = 5
|
|
|
|
|
|
|
427 |
|
428 |
@dataclass
|
429 |
class CacheConfig:
|
|
|
449 |
|
450 |
def create_default_config(api_key: str) -> PipelineConfig:
|
451 |
"""
|
452 |
+
Create a default pipeline configuration with optimized settings for Thai language processing.
|
|
|
453 |
|
454 |
Args:
|
455 |
api_key (str): OpenAI API key
|
|
|
460 |
return PipelineConfig(
|
461 |
model=ModelConfig(
|
462 |
openai_api_key=api_key,
|
|
|
463 |
temperature=0.3 # Lower temperature for more focused responses
|
464 |
),
|
465 |
retriever=RetrieverConfig(
|
466 |
+
top_k=5 # Optimal number of documents to retrieve
|
|
|
|
|
|
|
467 |
),
|
468 |
cache=CacheConfig(
|
469 |
enabled=True,
|
|
|
1279 |
return sorted_docs[:top_k]
|
1280 |
|
1281 |
def search_with_reranking(self,
|
1282 |
+
query: str,
|
1283 |
+
event_type: Optional[str] = None,
|
1284 |
+
detail_type: Optional[str] = None,
|
1285 |
+
semester: Optional[str] = None,
|
1286 |
+
top_k_initial: int = 20,
|
1287 |
+
top_k_final: int = 5,
|
1288 |
+
weight_semantic: float = 0.5) -> List[Document]:
|
1289 |
"""
|
1290 |
Two-stage retrieval with hybrid search followed by cross-encoder reranking
|
1291 |
"""
|
|
|
1381 |
5. สำหรับคำถามเกี่ยวกับข้อกำหนดภาษาอังกฤษหรือขั้นตอนการสมัคร ให้อธิบายข้อมูลอย่างละเอียด
|
1382 |
6. ใส่ข้อความ "หากมีข้อสงสัยเพิ่มเติม สามารถสอบถามได้" ท้ายคำตอบเสมอ
|
1383 |
7. คำนึงถึงประวัติการสนทนาและให้คำตอบที่ต่อเนื่องกับบทสนทนาก่อนหน้า
|
1384 |
+
8. หากคำถามอ้างอิงถึงข้อมูลในบทสนทนาก่อนหน้า (เช่น "แล้วอันนั้นล่ะ", "มีอะไรอีกบ้าง", "คำถามก่อนหน้า") ให้พิจารณาบริบทและตอบคำถามอย่างตรงประเด็น แต่ไม่ต้องแสดงคำถามก่อนหน้าในคำตอบ
|
1385 |
9. กรณีคำถามมีความไม่ชัดเจน ใช้ประวัติการสนทนาเพื่อเข้าใจบริบทของคำถาม
|
1386 |
|
1387 |
+
สำคัญ: ไม่ต้องใส่คำว่า "คำถามก่อนหน้าคือ [คำถามก่อนหน้า] และคำตอบคือ..." ในคำตอบของคุณ ให้ตอบคำถามโดยตรง
|
1388 |
|
1389 |
กรุณาตอบเป็นภาษาไทย:
|
1390 |
"""
|
|
|
1602 |
# First, let's modify the AcademicCalendarRAG class to maintain conversation history
|
1603 |
|
1604 |
class AcademicCalendarRAG:
|
1605 |
+
"""Enhanced RAG system for academic calendar and program information with conversation memory"""
|
1606 |
|
1607 |
def __init__(self, config: PipelineConfig):
|
1608 |
self.config = config
|
|
|
1660 |
raise
|
1661 |
|
1662 |
def process_query(self, query: str, conversation_history=None) -> Dict[str, Any]:
|
1663 |
+
"""Process user query using conversation history and hybrid retrieval."""
|
1664 |
# Use provided conversation history or the internal history
|
1665 |
if conversation_history is not None:
|
1666 |
self.conversation_history = conversation_history
|
|
|
1696 |
|
1697 |
weight_semantic = weight_values[attempt - 1]
|
1698 |
|
1699 |
+
# Get relevant documents using hybrid search
|
1700 |
logger.info(f"Attempt {attempt}: Searching with weight_semantic={weight_semantic}")
|
1701 |
+
documents = self.document_store.hybrid_search(
|
1702 |
+
query=query_with_context if attempt == 1 else query,
|
1703 |
+
event_type=query_info.get("event_type"),
|
1704 |
+
detail_type=query_info.get("detail_type"),
|
1705 |
+
semester=query_info.get("semester"),
|
1706 |
+
top_k=self.config.retriever.top_k,
|
1707 |
+
weight_semantic=weight_semantic
|
1708 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1709 |
|
1710 |
# Generate response with conversation context
|
1711 |
response = self.response_generator.generate_response(
|
|
|
1757 |
|
1758 |
# # Test queries with different semantic weights
|
1759 |
# queries = ["ค่าเทอมเท่าไหร่","เปิดเรียนวันไหน","ขั้นตอนการสมัครที่สาขานี้มีอะไรบ้าง","ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้","ถ้าจะไปติดต่อมาหลายต้องลง mrt อะไร","มีวิชาหลักเเละวิชาเลือกออะไรบ้าง", "ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง", "ปีที่ 2 เทอม 1 ต้องเรียนอะไรบ้าง"]
|
1760 |
+
# # queries = ["ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง"]
|
1761 |
# print("=" * 80)
|
1762 |
|
1763 |
# for query in queries:
|