Update calendar_rag.py
Browse files- calendar_rag.py +113 -17
calendar_rag.py
CHANGED
@@ -5,6 +5,7 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder
|
|
5 |
from haystack.components.retrievers.in_memory import *
|
6 |
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
7 |
from haystack.utils import Secret
|
|
|
8 |
from pathlib import Path
|
9 |
import hashlib
|
10 |
from datetime import *
|
@@ -118,7 +119,17 @@ class TuitionFee:
|
|
118 |
event_type: str
|
119 |
regular_fee: RegularFee
|
120 |
late_payment_fee: LatePaymentFee
|
|
|
|
|
|
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
class OpenAIDateParser:
|
124 |
"""Uses OpenAI to parse complex Thai date formats"""
|
@@ -414,17 +425,21 @@ class CacheManager:
|
|
414 |
self.document_cache[doc_id] = (document, datetime.now())
|
415 |
self._save_cache("documents", self.document_cache)
|
416 |
|
|
|
417 |
@dataclass
|
418 |
class ModelConfig:
|
419 |
openai_api_key: str
|
420 |
-
|
421 |
-
embedder_model: str = "sentence-transformers/mUSE"
|
422 |
openai_model: str = "gpt-4o"
|
423 |
temperature: float = 0.7
|
|
|
424 |
|
425 |
@dataclass
|
426 |
class RetrieverConfig:
|
427 |
top_k: int = 5
|
|
|
|
|
|
|
428 |
|
429 |
@dataclass
|
430 |
class CacheConfig:
|
@@ -450,7 +465,8 @@ class PipelineConfig:
|
|
450 |
|
451 |
def create_default_config(api_key: str) -> PipelineConfig:
|
452 |
"""
|
453 |
-
Create a default pipeline configuration with optimized settings for Thai language processing
|
|
|
454 |
|
455 |
Args:
|
456 |
api_key (str): OpenAI API key
|
@@ -461,10 +477,14 @@ def create_default_config(api_key: str) -> PipelineConfig:
|
|
461 |
return PipelineConfig(
|
462 |
model=ModelConfig(
|
463 |
openai_api_key=api_key,
|
|
|
464 |
temperature=0.3 # Lower temperature for more focused responses
|
465 |
),
|
466 |
retriever=RetrieverConfig(
|
467 |
-
top_k=5
|
|
|
|
|
|
|
468 |
),
|
469 |
cache=CacheConfig(
|
470 |
enabled=True,
|
@@ -1278,6 +1298,68 @@ class HybridDocumentStore:
|
|
1278 |
)
|
1279 |
|
1280 |
return sorted_docs[:top_k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1281 |
|
1282 |
class ResponseGenerator:
|
1283 |
"""Generate responses with enhanced conversation context awareness"""
|
@@ -1541,7 +1623,7 @@ class AdvancedQueryProcessor:
|
|
1541 |
# First, let's modify the AcademicCalendarRAG class to maintain conversation history
|
1542 |
|
1543 |
class AcademicCalendarRAG:
|
1544 |
-
"""Enhanced RAG system for academic calendar and program information with conversation memory"""
|
1545 |
|
1546 |
def __init__(self, config: PipelineConfig):
|
1547 |
self.config = config
|
@@ -1599,7 +1681,7 @@ class AcademicCalendarRAG:
|
|
1599 |
raise
|
1600 |
|
1601 |
def process_query(self, query: str, conversation_history=None) -> Dict[str, Any]:
|
1602 |
-
"""Process user query using conversation history and hybrid retrieval."""
|
1603 |
# Use provided conversation history or the internal history
|
1604 |
if conversation_history is not None:
|
1605 |
self.conversation_history = conversation_history
|
@@ -1635,16 +1717,30 @@ class AcademicCalendarRAG:
|
|
1635 |
|
1636 |
weight_semantic = weight_values[attempt - 1]
|
1637 |
|
1638 |
-
# Get relevant documents using
|
1639 |
logger.info(f"Attempt {attempt}: Searching with weight_semantic={weight_semantic}")
|
1640 |
-
|
1641 |
-
|
1642 |
-
|
1643 |
-
|
1644 |
-
|
1645 |
-
|
1646 |
-
|
1647 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1648 |
|
1649 |
# Generate response with conversation context
|
1650 |
response = self.response_generator.generate_response(
|
@@ -1695,8 +1791,8 @@ class AcademicCalendarRAG:
|
|
1695 |
# pipeline.load_data(raw_data)
|
1696 |
|
1697 |
# # Test queries with different semantic weights
|
1698 |
-
#
|
1699 |
-
# queries = ["
|
1700 |
# print("=" * 80)
|
1701 |
|
1702 |
# for query in queries:
|
|
|
5 |
from haystack.components.retrievers.in_memory import *
|
6 |
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
7 |
from haystack.utils import Secret
|
8 |
+
from sentence_transformers import CrossEncoder
|
9 |
from pathlib import Path
|
10 |
import hashlib
|
11 |
from datetime import *
|
|
|
119 |
event_type: str
|
120 |
regular_fee: RegularFee
|
121 |
late_payment_fee: LatePaymentFee
|
122 |
+
|
123 |
+
class SentenceTransformersCrossEncoder:
|
124 |
+
"""Wrapper for the sentence-transformers CrossEncoder for compatibility with the existing code"""
|
125 |
|
126 |
+
def __init__(self, model_name_or_path: str):
|
127 |
+
"""Initialize the cross-encoder model"""
|
128 |
+
self.model = CrossEncoder(model_name_or_path)
|
129 |
+
|
130 |
+
def predict(self, sentence_pairs: List[Tuple[str, str]]) -> List[float]:
|
131 |
+
"""Predict relevance scores for sentence pairs"""
|
132 |
+
return self.model.predict(sentence_pairs)
|
133 |
|
134 |
class OpenAIDateParser:
|
135 |
"""Uses OpenAI to parse complex Thai date formats"""
|
|
|
425 |
self.document_cache[doc_id] = (document, datetime.now())
|
426 |
self._save_cache("documents", self.document_cache)
|
427 |
|
428 |
+
@dataclass
|
429 |
@dataclass
|
430 |
class ModelConfig:
|
431 |
openai_api_key: str
|
432 |
+
embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
|
|
433 |
openai_model: str = "gpt-4o"
|
434 |
temperature: float = 0.7
|
435 |
+
reranker_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" # Add this
|
436 |
|
437 |
@dataclass
|
438 |
class RetrieverConfig:
|
439 |
top_k: int = 5
|
440 |
+
use_reranking: bool = True # Add this flag
|
441 |
+
top_k_initial: int = 20 # Add this parameter
|
442 |
+
top_k_final: int = 5 # Add this parameter
|
443 |
|
444 |
@dataclass
|
445 |
class CacheConfig:
|
|
|
465 |
|
466 |
def create_default_config(api_key: str) -> PipelineConfig:
|
467 |
"""
|
468 |
+
Create a default pipeline configuration with optimized settings for Thai language processing,
|
469 |
+
including reranking capabilities.
|
470 |
|
471 |
Args:
|
472 |
api_key (str): OpenAI API key
|
|
|
477 |
return PipelineConfig(
|
478 |
model=ModelConfig(
|
479 |
openai_api_key=api_key,
|
480 |
+
embedder_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
481 |
temperature=0.3 # Lower temperature for more focused responses
|
482 |
),
|
483 |
retriever=RetrieverConfig(
|
484 |
+
top_k=5, # Optimal number of documents to retrieve
|
485 |
+
use_reranking=True, # Enable reranking
|
486 |
+
top_k_initial=20, # Retrieve more initial documents for reranking
|
487 |
+
top_k_final=5 # Final number of documents after reranking
|
488 |
),
|
489 |
cache=CacheConfig(
|
490 |
enabled=True,
|
|
|
1298 |
)
|
1299 |
|
1300 |
return sorted_docs[:top_k]
|
1301 |
+
|
1302 |
+
def search_with_reranking(self,
|
1303 |
+
query: str,
|
1304 |
+
event_type: Optional[str] = None,
|
1305 |
+
detail_type: Optional[str] = None,
|
1306 |
+
semester: Optional[str] = None,
|
1307 |
+
top_k_initial: int = 20,
|
1308 |
+
top_k_final: int = 5,
|
1309 |
+
weight_semantic: float = 0.5) -> List[Document]:
|
1310 |
+
"""
|
1311 |
+
Two-stage retrieval with hybrid search followed by cross-encoder reranking
|
1312 |
+
"""
|
1313 |
+
# Generate cache key for the reranked query
|
1314 |
+
cache_key = json.dumps({
|
1315 |
+
'query': query,
|
1316 |
+
'event_type': event_type,
|
1317 |
+
'semester': semester,
|
1318 |
+
'top_k_initial': top_k_initial,
|
1319 |
+
'top_k_final': top_k_final,
|
1320 |
+
'weight_semantic': weight_semantic,
|
1321 |
+
'reranked': True # Indicate this is a reranked query
|
1322 |
+
})
|
1323 |
+
|
1324 |
+
# Check cache first
|
1325 |
+
cached_results = self.cache_manager.get_query_cache(cache_key)
|
1326 |
+
if cached_results is not None:
|
1327 |
+
return cached_results
|
1328 |
+
|
1329 |
+
# 1. Get larger initial result set
|
1330 |
+
initial_results = self.hybrid_search(
|
1331 |
+
query=query,
|
1332 |
+
event_type=event_type,
|
1333 |
+
detail_type=detail_type,
|
1334 |
+
semester=semester,
|
1335 |
+
top_k=top_k_initial,
|
1336 |
+
weight_semantic=weight_semantic
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
# If we don't have enough initial results, just return what we have
|
1340 |
+
if len(initial_results) <= top_k_final:
|
1341 |
+
return initial_results
|
1342 |
+
|
1343 |
+
try:
|
1344 |
+
# We'll lazily initialize the cross encoder to save memory
|
1345 |
+
cross_encoder = SentenceTransformersCrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
|
1346 |
+
pairs = [(query, doc.content) for doc in initial_results]
|
1347 |
+
scores = cross_encoder.predict(pairs)
|
1348 |
+
|
1349 |
+
for doc, score in zip(initial_results, scores):
|
1350 |
+
doc.score = float(score) # Ensure score is a regular float
|
1351 |
+
|
1352 |
+
reranked_results = sorted(initial_results, key=lambda x: x.score, reverse=True)[:top_k_final]
|
1353 |
+
|
1354 |
+
# Cache the results
|
1355 |
+
self.cache_manager.set_query_cache(cache_key, reranked_results)
|
1356 |
+
|
1357 |
+
return reranked_results
|
1358 |
+
|
1359 |
+
except Exception as e:
|
1360 |
+
logger.error(f"Reranking failed: {str(e)}. Falling back to hybrid search results.")
|
1361 |
+
|
1362 |
+
return initial_results[:top_k_final]
|
1363 |
|
1364 |
class ResponseGenerator:
|
1365 |
"""Generate responses with enhanced conversation context awareness"""
|
|
|
1623 |
# First, let's modify the AcademicCalendarRAG class to maintain conversation history
|
1624 |
|
1625 |
class AcademicCalendarRAG:
|
1626 |
+
"""Enhanced RAG system for academic calendar and program information with conversation memory and reranking"""
|
1627 |
|
1628 |
def __init__(self, config: PipelineConfig):
|
1629 |
self.config = config
|
|
|
1681 |
raise
|
1682 |
|
1683 |
def process_query(self, query: str, conversation_history=None) -> Dict[str, Any]:
|
1684 |
+
"""Process user query using conversation history and hybrid retrieval with reranking."""
|
1685 |
# Use provided conversation history or the internal history
|
1686 |
if conversation_history is not None:
|
1687 |
self.conversation_history = conversation_history
|
|
|
1717 |
|
1718 |
weight_semantic = weight_values[attempt - 1]
|
1719 |
|
1720 |
+
# Get relevant documents using reranking if enabled
|
1721 |
logger.info(f"Attempt {attempt}: Searching with weight_semantic={weight_semantic}")
|
1722 |
+
|
1723 |
+
if self.config.retriever.use_reranking:
|
1724 |
+
documents = self.document_store.search_with_reranking(
|
1725 |
+
query=query_with_context if attempt == 1 else query,
|
1726 |
+
event_type=query_info.get("event_type"),
|
1727 |
+
detail_type=query_info.get("detail_type"),
|
1728 |
+
semester=query_info.get("semester"),
|
1729 |
+
top_k_initial=self.config.retriever.top_k_initial,
|
1730 |
+
top_k_final=self.config.retriever.top_k_final,
|
1731 |
+
weight_semantic=weight_semantic
|
1732 |
+
)
|
1733 |
+
logger.info(f"Using reranking for retrieval, got {len(documents)} documents")
|
1734 |
+
else:
|
1735 |
+
documents = self.document_store.hybrid_search(
|
1736 |
+
query=query_with_context if attempt == 1 else query,
|
1737 |
+
event_type=query_info.get("event_type"),
|
1738 |
+
detail_type=query_info.get("detail_type"),
|
1739 |
+
semester=query_info.get("semester"),
|
1740 |
+
top_k=self.config.retriever.top_k,
|
1741 |
+
weight_semantic=weight_semantic
|
1742 |
+
)
|
1743 |
+
logger.info(f"Using standard hybrid search, got {len(documents)} documents")
|
1744 |
|
1745 |
# Generate response with conversation context
|
1746 |
response = self.response_generator.generate_response(
|
|
|
1791 |
# pipeline.load_data(raw_data)
|
1792 |
|
1793 |
# # Test queries with different semantic weights
|
1794 |
+
# queries = ["ค่าเทอมเท่าไหร่","เปิดเรียนวันไหน","ขั้นตอนการสมัครที่สาขานี้มีอะไรบ้าง","ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้","ถ้าจะไปติดต่อมาหลายต้องลง mrt อะไร","มีวิชาหลักเเละวิชาเลือกออะไรบ้าง", "ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง", "ปีที่ 2 เทอม 1 ต้องเรียนอะไรบ้าง"]
|
1795 |
+
# # queries = ["ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้"]
|
1796 |
# print("=" * 80)
|
1797 |
|
1798 |
# for query in queries:
|