JirasakJo commited on
Commit
b16b6fc
·
verified ·
1 Parent(s): 798cabe

Update calendar_rag.py

Browse files
Files changed (1) hide show
  1. calendar_rag.py +23 -58
calendar_rag.py CHANGED
@@ -5,7 +5,6 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder
5
  from haystack.components.retrievers.in_memory import *
6
  from haystack.document_stores.in_memory import InMemoryDocumentStore
7
  from haystack.utils import Secret
8
- from sentence_transformers import CrossEncoder
9
  from pathlib import Path
10
  import hashlib
11
  from datetime import *
@@ -119,17 +118,7 @@ class TuitionFee:
119
  event_type: str
120
  regular_fee: RegularFee
121
  late_payment_fee: LatePaymentFee
122
-
123
- class SentenceTransformersCrossEncoder:
124
- """Wrapper for the sentence-transformers CrossEncoder for compatibility with the existing code"""
125
 
126
- def __init__(self, model_name_or_path: str):
127
- """Initialize the cross-encoder model"""
128
- self.model = CrossEncoder(model_name_or_path)
129
-
130
- def predict(self, sentence_pairs: List[Tuple[str, str]]) -> List[float]:
131
- """Predict relevance scores for sentence pairs"""
132
- return self.model.predict(sentence_pairs)
133
 
134
  class OpenAIDateParser:
135
  """Uses OpenAI to parse complex Thai date formats"""
@@ -425,21 +414,16 @@ class CacheManager:
425
  self.document_cache[doc_id] = (document, datetime.now())
426
  self._save_cache("documents", self.document_cache)
427
 
428
- @dataclass
429
  @dataclass
430
  class ModelConfig:
431
  openai_api_key: str
432
  embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
433
  openai_model: str = "gpt-4o"
434
  temperature: float = 0.7
435
- reranker_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" # Add this
436
 
437
  @dataclass
438
  class RetrieverConfig:
439
  top_k: int = 5
440
- use_reranking: bool = True # Add this flag
441
- top_k_initial: int = 20 # Add this parameter
442
- top_k_final: int = 5 # Add this parameter
443
 
444
  @dataclass
445
  class CacheConfig:
@@ -465,8 +449,7 @@ class PipelineConfig:
465
 
466
  def create_default_config(api_key: str) -> PipelineConfig:
467
  """
468
- Create a default pipeline configuration with optimized settings for Thai language processing,
469
- including reranking capabilities.
470
 
471
  Args:
472
  api_key (str): OpenAI API key
@@ -477,14 +460,10 @@ def create_default_config(api_key: str) -> PipelineConfig:
477
  return PipelineConfig(
478
  model=ModelConfig(
479
  openai_api_key=api_key,
480
- embedder_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
481
  temperature=0.3 # Lower temperature for more focused responses
482
  ),
483
  retriever=RetrieverConfig(
484
- top_k=5, # Optimal number of documents to retrieve
485
- use_reranking=True, # Enable reranking
486
- top_k_initial=20, # Retrieve more initial documents for reranking
487
- top_k_final=5 # Final number of documents after reranking
488
  ),
489
  cache=CacheConfig(
490
  enabled=True,
@@ -1300,13 +1279,13 @@ class HybridDocumentStore:
1300
  return sorted_docs[:top_k]
1301
 
1302
  def search_with_reranking(self,
1303
- query: str,
1304
- event_type: Optional[str] = None,
1305
- detail_type: Optional[str] = None,
1306
- semester: Optional[str] = None,
1307
- top_k_initial: int = 20,
1308
- top_k_final: int = 5,
1309
- weight_semantic: float = 0.5) -> List[Document]:
1310
  """
1311
  Two-stage retrieval with hybrid search followed by cross-encoder reranking
1312
  """
@@ -1402,10 +1381,10 @@ class ResponseGenerator:
1402
  5. สำหรับคำถามเกี่ยวกับข้อกำหนดภาษาอังกฤษหรือขั้นตอนการสมัคร ให้อธิบายข้อมูลอย่างละเอียด
1403
  6. ใส่ข้อความ "หากมีข้อสงสัยเพิ่มเติม สามารถสอบถามได้" ท้ายคำตอบเสมอ
1404
  7. คำนึงถึงประวัติการสนทนาและให้คำตอบที่ต่อเนื่องกับบทสนทนาก่อนหน้า
1405
- 8. หากคำถามอ้างอิงถึงข้อมูลในบทสนทนาก่อนหน้า (เช่น "แล้วอันนั้นล่ะ", "มีอะไรอีกบ้าง", "คำถามก่อนหน้า") ใ���้พิจารณาบริบทและตอบคำถามอย่างตรงประเด็น
1406
  9. กรณีคำถามมีความไม่ชัดเจน ใช้ประวัติการสนทนาเพื่อเข้าใจบริบทของคำถาม
1407
 
1408
- สิ่งสำคัญพิเศษ: หากคำถามอ้างอิงถึงคำถามก่อนหน้า ให้แสดงคำถามก่อนหน้านั้นในคำตอบด้วย เช่น "คำถามก่อนหน้าคือ [คำถามก่อนหน้า] และคำตอบคือ..."
1409
 
1410
  กรุณาตอบเป็นภาษาไทย:
1411
  """
@@ -1623,7 +1602,7 @@ class AdvancedQueryProcessor:
1623
  # First, let's modify the AcademicCalendarRAG class to maintain conversation history
1624
 
1625
  class AcademicCalendarRAG:
1626
- """Enhanced RAG system for academic calendar and program information with conversation memory and reranking"""
1627
 
1628
  def __init__(self, config: PipelineConfig):
1629
  self.config = config
@@ -1681,7 +1660,7 @@ class AcademicCalendarRAG:
1681
  raise
1682
 
1683
  def process_query(self, query: str, conversation_history=None) -> Dict[str, Any]:
1684
- """Process user query using conversation history and hybrid retrieval with reranking."""
1685
  # Use provided conversation history or the internal history
1686
  if conversation_history is not None:
1687
  self.conversation_history = conversation_history
@@ -1717,30 +1696,16 @@ class AcademicCalendarRAG:
1717
 
1718
  weight_semantic = weight_values[attempt - 1]
1719
 
1720
- # Get relevant documents using reranking if enabled
1721
  logger.info(f"Attempt {attempt}: Searching with weight_semantic={weight_semantic}")
1722
-
1723
- if self.config.retriever.use_reranking:
1724
- documents = self.document_store.search_with_reranking(
1725
- query=query_with_context if attempt == 1 else query,
1726
- event_type=query_info.get("event_type"),
1727
- detail_type=query_info.get("detail_type"),
1728
- semester=query_info.get("semester"),
1729
- top_k_initial=self.config.retriever.top_k_initial,
1730
- top_k_final=self.config.retriever.top_k_final,
1731
- weight_semantic=weight_semantic
1732
- )
1733
- logger.info(f"Using reranking for retrieval, got {len(documents)} documents")
1734
- else:
1735
- documents = self.document_store.hybrid_search(
1736
- query=query_with_context if attempt == 1 else query,
1737
- event_type=query_info.get("event_type"),
1738
- detail_type=query_info.get("detail_type"),
1739
- semester=query_info.get("semester"),
1740
- top_k=self.config.retriever.top_k,
1741
- weight_semantic=weight_semantic
1742
- )
1743
- logger.info(f"Using standard hybrid search, got {len(documents)} documents")
1744
 
1745
  # Generate response with conversation context
1746
  response = self.response_generator.generate_response(
@@ -1792,7 +1757,7 @@ class AcademicCalendarRAG:
1792
 
1793
  # # Test queries with different semantic weights
1794
  # queries = ["ค่าเทอมเท่าไหร่","เปิดเรียนวันไหน","ขั้นตอนการสมัครที่สาขานี้มีอะไรบ้าง","ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้","ถ้าจะไปติดต่อมาหลายต้องลง mrt อะไร","มีวิชาหลักเเละวิชาเลือกออะไรบ้าง", "ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง", "ปีที่ 2 เทอม 1 ต้องเรียนอะไรบ้าง"]
1795
- # # queries = ["ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้"]
1796
  # print("=" * 80)
1797
 
1798
  # for query in queries:
 
5
  from haystack.components.retrievers.in_memory import *
6
  from haystack.document_stores.in_memory import InMemoryDocumentStore
7
  from haystack.utils import Secret
 
8
  from pathlib import Path
9
  import hashlib
10
  from datetime import *
 
118
  event_type: str
119
  regular_fee: RegularFee
120
  late_payment_fee: LatePaymentFee
 
 
 
121
 
 
 
 
 
 
 
 
122
 
123
  class OpenAIDateParser:
124
  """Uses OpenAI to parse complex Thai date formats"""
 
414
  self.document_cache[doc_id] = (document, datetime.now())
415
  self._save_cache("documents", self.document_cache)
416
 
 
417
  @dataclass
418
  class ModelConfig:
419
  openai_api_key: str
420
  embedder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
421
  openai_model: str = "gpt-4o"
422
  temperature: float = 0.7
 
423
 
424
  @dataclass
425
  class RetrieverConfig:
426
  top_k: int = 5
 
 
 
427
 
428
  @dataclass
429
  class CacheConfig:
 
449
 
450
  def create_default_config(api_key: str) -> PipelineConfig:
451
  """
452
+ Create a default pipeline configuration with optimized settings for Thai language processing.
 
453
 
454
  Args:
455
  api_key (str): OpenAI API key
 
460
  return PipelineConfig(
461
  model=ModelConfig(
462
  openai_api_key=api_key,
 
463
  temperature=0.3 # Lower temperature for more focused responses
464
  ),
465
  retriever=RetrieverConfig(
466
+ top_k=5 # Optimal number of documents to retrieve
 
 
 
467
  ),
468
  cache=CacheConfig(
469
  enabled=True,
 
1279
  return sorted_docs[:top_k]
1280
 
1281
  def search_with_reranking(self,
1282
+ query: str,
1283
+ event_type: Optional[str] = None,
1284
+ detail_type: Optional[str] = None,
1285
+ semester: Optional[str] = None,
1286
+ top_k_initial: int = 20,
1287
+ top_k_final: int = 5,
1288
+ weight_semantic: float = 0.5) -> List[Document]:
1289
  """
1290
  Two-stage retrieval with hybrid search followed by cross-encoder reranking
1291
  """
 
1381
  5. สำหรับคำถามเกี่ยวกับข้อกำหนดภาษาอังกฤษหรือขั้นตอนการสมัคร ให้อธิบายข้อมูลอย่างละเอียด
1382
  6. ใส่ข้อความ "หากมีข้อสงสัยเพิ่มเติม สามารถสอบถามได้" ท้ายคำตอบเสมอ
1383
  7. คำนึงถึงประวัติการสนทนาและให้คำตอบที่ต่อเนื่องกับบทสนทนาก่อนหน้า
1384
+ 8. หากคำถามอ้างอิงถึงข้อมูลในบทสนทนาก่อนหน้า (เช่น "แล้วอันนั้นล่ะ", "มีอะไรอีกบ้าง", "คำถามก่อนหน้า") ให้พิจารณาบริบทและตอบคำถามอย่างตรงประเด็น แต่ไม่ต้องแสดงคำถามก่อนหน้าในคำตอบ
1385
  9. กรณีคำถามมีความไม่ชัดเจน ใช้ประวัติการสนทนาเพื่อเข้าใจบริบทของคำถาม
1386
 
1387
+ สำคัญ: ไม่ต้องใส่คำว่า "คำถามก่อนหน้าคือ [คำถามก่อนหน้า] และคำตอบคือ..." ในคำตอบของคุณ ให้ตอบคำถามโดยตรง
1388
 
1389
  กรุณาตอบเป็นภาษาไทย:
1390
  """
 
1602
  # First, let's modify the AcademicCalendarRAG class to maintain conversation history
1603
 
1604
  class AcademicCalendarRAG:
1605
+ """Enhanced RAG system for academic calendar and program information with conversation memory"""
1606
 
1607
  def __init__(self, config: PipelineConfig):
1608
  self.config = config
 
1660
  raise
1661
 
1662
  def process_query(self, query: str, conversation_history=None) -> Dict[str, Any]:
1663
+ """Process user query using conversation history and hybrid retrieval."""
1664
  # Use provided conversation history or the internal history
1665
  if conversation_history is not None:
1666
  self.conversation_history = conversation_history
 
1696
 
1697
  weight_semantic = weight_values[attempt - 1]
1698
 
1699
+ # Get relevant documents using hybrid search
1700
  logger.info(f"Attempt {attempt}: Searching with weight_semantic={weight_semantic}")
1701
+ documents = self.document_store.hybrid_search(
1702
+ query=query_with_context if attempt == 1 else query,
1703
+ event_type=query_info.get("event_type"),
1704
+ detail_type=query_info.get("detail_type"),
1705
+ semester=query_info.get("semester"),
1706
+ top_k=self.config.retriever.top_k,
1707
+ weight_semantic=weight_semantic
1708
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1709
 
1710
  # Generate response with conversation context
1711
  response = self.response_generator.generate_response(
 
1757
 
1758
  # # Test queries with different semantic weights
1759
  # queries = ["ค่าเทอมเท่าไหร่","เปิดเรียนวันไหน","ขั้นตอนการสมัครที่สาขานี้มีอะไรบ้าง","ต้องใช้ระดับภาษาอังกฤษเท่าไหร่ในการสมัครเรียนที่นี้","ถ้าจะไปติดต่อมาหลายต้องลง mrt อะไร","มีวิชาหลักเเละวิชาเลือกออะไรบ้าง", "ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง", "ปีที่ 2 เทอม 1 ต้องเรียนอะไรบ้าง"]
1760
+ # # queries = ["ปีที่ 1 เทอม 1 ต้องเรียนอะไรบ้าง"]
1761
  # print("=" * 80)
1762
 
1763
  # for query in queries: