JirasakJo commited on
Commit
798cabe
·
verified ·
1 Parent(s): d42b14e

Update calendar_rag.py

Browse files
Files changed (1) hide show
  1. calendar_rag.py +113 -17
calendar_rag.py CHANGED
@@ -5,6 +5,7 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder
5
  from haystack.components.retrievers.in_memory import *
6
  from haystack.document_stores.in_memory import InMemoryDocumentStore
7
  from haystack.utils import Secret
 
8
  from pathlib import Path
9
  import hashlib
10
  from datetime import *
@@ -118,7 +119,17 @@ class TuitionFee:
118
  event_type: str
119
  regular_fee: RegularFee
120
  late_payment_fee: LatePaymentFee
 
 
 
121
 
 
 
 
 
 
 
 
122
 
123
  class OpenAIDateParser:
124
  """Uses OpenAI to parse complex Thai date formats"""
@@ -414,17 +425,21 @@ class CacheManager:
414
  self.document_cache[doc_id] = (document, datetime.now())
415
  self._save_cache("documents", self.document_cache)
416
 
 
417
  @dataclass
418
  class ModelConfig:
419
  openai_api_key: str
420
- # embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
421
- embedder_model: str = "sentence-transformers/mUSE"
422
  openai_model: str = "gpt-4o"
423
  temperature: float = 0.7
 
424
 
425
  @dataclass
426
  class RetrieverConfig:
427
  top_k: int = 5
 
 
 
428
 
429
  @dataclass
430
  class CacheConfig:
@@ -450,7 +465,8 @@ class PipelineConfig:
450
 
451
  def create_default_config(api_key: str) -> PipelineConfig:
452
  """
453
- Create a default pipeline configuration with optimized settings for Thai language processing.
 
454
 
455
  Args:
456
  api_key (str): OpenAI API key
@@ -461,10 +477,14 @@ def create_default_config(api_key: str) -> PipelineConfig:
461
  return PipelineConfig(
462
  model=ModelConfig(
463
  openai_api_key=api_key,
 
464
  temperature=0.3 # Lower temperature for more focused responses
465
  ),
466
  retriever=RetrieverConfig(
467
- top_k=5 # Optimal number of documents to retrieve
 
 
 
468
  ),
469
  cache=CacheConfig(
470
  enabled=True,
@@ -1278,6 +1298,68 @@ class HybridDocumentStore:
1278
  )
1279
 
1280
  return sorted_docs[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1281
 
1282
  class ResponseGenerator:
1283
  """Generate responses with enhanced conversation context awareness"""
@@ -1541,7 +1623,7 @@ class AdvancedQueryProcessor:
1541
  # First, let's modify the AcademicCalendarRAG class to maintain conversation history
1542
 
1543
  class AcademicCalendarRAG:
1544
- """Enhanced RAG system for academic calendar and program information with conversation memory"""
1545
 
1546
  def __init__(self, config: PipelineConfig):
1547
  self.config = config
@@ -1599,7 +1681,7 @@ class AcademicCalendarRAG:
1599
  raise
1600
 
1601
  def process_query(self, query: str, conversation_history=None) -> Dict[str, Any]:
1602
- """Process user query using conversation history and hybrid retrieval."""
1603
  # Use provided conversation history or the internal history
1604
  if conversation_history is not None:
1605
  self.conversation_history = conversation_history
@@ -1635,16 +1717,30 @@ class AcademicCalendarRAG:
1635
 
1636
  weight_semantic = weight_values[attempt - 1]
1637
 
1638
- # Get relevant documents using hybrid search
1639
  logger.info(f"Attempt {attempt}: Searching with weight_semantic={weight_semantic}")
1640
- documents = self.document_store.hybrid_search(
1641
- query=query_with_context if attempt == 1 else query,
1642
- event_type=query_info.get("event_type"),
1643
- detail_type=query_info.get("detail_type"),
1644
- semester=query_info.get("semester"),
1645
- top_k=self.config.retriever.top_k,
1646
- weight_semantic=weight_semantic
1647
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1648
 
1649
  # Generate response with conversation context
1650
  response = self.response_generator.generate_response(
@@ -1695,8 +1791,8 @@ class AcademicCalendarRAG:
1695
  # pipeline.load_data(raw_data)
1696
 
1697
  # # Test queries with different semantic weights
1698
- # # queries = ["ค่าเทอมเท่าไหร่","เปิดเรียนวันไหน","ขั้นตอนการสมัครที่สาขานี้มีอะไรบ้าง","ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้","ถ้าจะไปติดต่อมาหลายต้องลง mrt อะไร","มีวิชาหลักเเละวิชาเลือกออะไรบ้าง", "ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง", "ปีที่ 2 เทอม 1 ต้องเรียนอะไรบ้าง"]
1699
- # queries = ["ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง"]
1700
  # print("=" * 80)
1701
 
1702
  # for query in queries:
 
5
  from haystack.components.retrievers.in_memory import *
6
  from haystack.document_stores.in_memory import InMemoryDocumentStore
7
  from haystack.utils import Secret
8
+ from sentence_transformers import CrossEncoder
9
  from pathlib import Path
10
  import hashlib
11
  from datetime import *
 
119
  event_type: str
120
  regular_fee: RegularFee
121
  late_payment_fee: LatePaymentFee
122
+
123
+ class SentenceTransformersCrossEncoder:
124
+ """Wrapper for the sentence-transformers CrossEncoder for compatibility with the existing code"""
125
 
126
+ def __init__(self, model_name_or_path: str):
127
+ """Initialize the cross-encoder model"""
128
+ self.model = CrossEncoder(model_name_or_path)
129
+
130
+ def predict(self, sentence_pairs: List[Tuple[str, str]]) -> List[float]:
131
+ """Predict relevance scores for sentence pairs"""
132
+ return self.model.predict(sentence_pairs)
133
 
134
  class OpenAIDateParser:
135
  """Uses OpenAI to parse complex Thai date formats"""
 
425
  self.document_cache[doc_id] = (document, datetime.now())
426
  self._save_cache("documents", self.document_cache)
427
 
428
+ @dataclass
429
  @dataclass
430
  class ModelConfig:
431
  openai_api_key: str
432
+ embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
 
433
  openai_model: str = "gpt-4o"
434
  temperature: float = 0.7
435
+ reranker_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" # Add this
436
 
437
  @dataclass
438
  class RetrieverConfig:
439
  top_k: int = 5
440
+ use_reranking: bool = True # Add this flag
441
+ top_k_initial: int = 20 # Add this parameter
442
+ top_k_final: int = 5 # Add this parameter
443
 
444
  @dataclass
445
  class CacheConfig:
 
465
 
466
  def create_default_config(api_key: str) -> PipelineConfig:
467
  """
468
+ Create a default pipeline configuration with optimized settings for Thai language processing,
469
+ including reranking capabilities.
470
 
471
  Args:
472
  api_key (str): OpenAI API key
 
477
  return PipelineConfig(
478
  model=ModelConfig(
479
  openai_api_key=api_key,
480
+ embedder_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
481
  temperature=0.3 # Lower temperature for more focused responses
482
  ),
483
  retriever=RetrieverConfig(
484
+ top_k=5, # Optimal number of documents to retrieve
485
+ use_reranking=True, # Enable reranking
486
+ top_k_initial=20, # Retrieve more initial documents for reranking
487
+ top_k_final=5 # Final number of documents after reranking
488
  ),
489
  cache=CacheConfig(
490
  enabled=True,
 
1298
  )
1299
 
1300
  return sorted_docs[:top_k]
1301
+
1302
+ def search_with_reranking(self,
1303
+ query: str,
1304
+ event_type: Optional[str] = None,
1305
+ detail_type: Optional[str] = None,
1306
+ semester: Optional[str] = None,
1307
+ top_k_initial: int = 20,
1308
+ top_k_final: int = 5,
1309
+ weight_semantic: float = 0.5) -> List[Document]:
1310
+ """
1311
+ Two-stage retrieval with hybrid search followed by cross-encoder reranking
1312
+ """
1313
+ # Generate cache key for the reranked query
1314
+ cache_key = json.dumps({
1315
+ 'query': query,
1316
+ 'event_type': event_type,
1317
+ 'semester': semester,
1318
+ 'top_k_initial': top_k_initial,
1319
+ 'top_k_final': top_k_final,
1320
+ 'weight_semantic': weight_semantic,
1321
+ 'reranked': True # Indicate this is a reranked query
1322
+ })
1323
+
1324
+ # Check cache first
1325
+ cached_results = self.cache_manager.get_query_cache(cache_key)
1326
+ if cached_results is not None:
1327
+ return cached_results
1328
+
1329
+ # 1. Get larger initial result set
1330
+ initial_results = self.hybrid_search(
1331
+ query=query,
1332
+ event_type=event_type,
1333
+ detail_type=detail_type,
1334
+ semester=semester,
1335
+ top_k=top_k_initial,
1336
+ weight_semantic=weight_semantic
1337
+ )
1338
+
1339
+ # If we don't have enough initial results, just return what we have
1340
+ if len(initial_results) <= top_k_final:
1341
+ return initial_results
1342
+
1343
+ try:
1344
+ # We'll lazily initialize the cross encoder to save memory
1345
+ cross_encoder = SentenceTransformersCrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
1346
+ pairs = [(query, doc.content) for doc in initial_results]
1347
+ scores = cross_encoder.predict(pairs)
1348
+
1349
+ for doc, score in zip(initial_results, scores):
1350
+ doc.score = float(score) # Ensure score is a regular float
1351
+
1352
+ reranked_results = sorted(initial_results, key=lambda x: x.score, reverse=True)[:top_k_final]
1353
+
1354
+ # Cache the results
1355
+ self.cache_manager.set_query_cache(cache_key, reranked_results)
1356
+
1357
+ return reranked_results
1358
+
1359
+ except Exception as e:
1360
+ logger.error(f"Reranking failed: {str(e)}. Falling back to hybrid search results.")
1361
+
1362
+ return initial_results[:top_k_final]
1363
 
1364
  class ResponseGenerator:
1365
  """Generate responses with enhanced conversation context awareness"""
 
1623
  # First, let's modify the AcademicCalendarRAG class to maintain conversation history
1624
 
1625
  class AcademicCalendarRAG:
1626
+ """Enhanced RAG system for academic calendar and program information with conversation memory and reranking"""
1627
 
1628
  def __init__(self, config: PipelineConfig):
1629
  self.config = config
 
1681
  raise
1682
 
1683
  def process_query(self, query: str, conversation_history=None) -> Dict[str, Any]:
1684
+ """Process user query using conversation history and hybrid retrieval with reranking."""
1685
  # Use provided conversation history or the internal history
1686
  if conversation_history is not None:
1687
  self.conversation_history = conversation_history
 
1717
 
1718
  weight_semantic = weight_values[attempt - 1]
1719
 
1720
+ # Get relevant documents using reranking if enabled
1721
  logger.info(f"Attempt {attempt}: Searching with weight_semantic={weight_semantic}")
1722
+
1723
+ if self.config.retriever.use_reranking:
1724
+ documents = self.document_store.search_with_reranking(
1725
+ query=query_with_context if attempt == 1 else query,
1726
+ event_type=query_info.get("event_type"),
1727
+ detail_type=query_info.get("detail_type"),
1728
+ semester=query_info.get("semester"),
1729
+ top_k_initial=self.config.retriever.top_k_initial,
1730
+ top_k_final=self.config.retriever.top_k_final,
1731
+ weight_semantic=weight_semantic
1732
+ )
1733
+ logger.info(f"Using reranking for retrieval, got {len(documents)} documents")
1734
+ else:
1735
+ documents = self.document_store.hybrid_search(
1736
+ query=query_with_context if attempt == 1 else query,
1737
+ event_type=query_info.get("event_type"),
1738
+ detail_type=query_info.get("detail_type"),
1739
+ semester=query_info.get("semester"),
1740
+ top_k=self.config.retriever.top_k,
1741
+ weight_semantic=weight_semantic
1742
+ )
1743
+ logger.info(f"Using standard hybrid search, got {len(documents)} documents")
1744
 
1745
  # Generate response with conversation context
1746
  response = self.response_generator.generate_response(
 
1791
  # pipeline.load_data(raw_data)
1792
 
1793
  # # Test queries with different semantic weights
1794
+ # queries = ["ค่าเทอมเท่าไหร่","เปิดเรียนวันไหน","ขั้นตอนการสมัครที่สาขานี้มีอะไรบ้าง","ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้","ถ้าจะไปติดต่อมาหลายต้องลง mrt อะไร","มีวิชาหลักเเละวิชาเลือกออะไรบ้าง", "ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง", "ปีที่ 2 เทอม 1 ต้องเรียนอะไรบ้าง"]
1795
+ # # queries = ["ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้"]
1796
  # print("=" * 80)
1797
 
1798
  # for query in queries: