sabazo commited on
Commit
d6f22d3
·
unverified ·
2 Parent(s): 4544cde 83d98e9

Merge pull request #49 from almutareb/26-add-a-function-for-query-rewriting

Browse files
config.py CHANGED
@@ -8,15 +8,18 @@ load_dotenv()
8
  SQLITE_FILE_NAME = os.getenv('SOURCES_CACHE')
9
  PERSIST_DIRECTORY = os.getenv('VECTOR_DATABASE_LOCATION')
10
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
11
- SECONDARY_LLM_MODEL = os.getenv("SECONDARY_LLM_MODEL")
12
 
13
 
14
  db = DataBaseHandler()
15
 
16
  db.create_all_tables()
17
 
18
- SECONDARY_LLM = HuggingFaceEndpoint(
19
- repo_id=SECONDARY_LLM_MODEL,
 
 
 
20
  temperature=0.1, # Controls randomness in response generation (lower value means less random)
21
  max_new_tokens=1024, # Maximum number of new tokens to generate in responses
22
  repetition_penalty=1.2, # Penalty for repeating the same words (higher value increases penalty)
 
8
  SQLITE_FILE_NAME = os.getenv('SOURCES_CACHE')
9
  PERSIST_DIRECTORY = os.getenv('VECTOR_DATABASE_LOCATION')
10
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
11
+ SEVEN_B_LLM_MODEL = os.getenv("SEVEN_B_LLM_MODEL")
12
 
13
 
14
  db = DataBaseHandler()
15
 
16
  db.create_all_tables()
17
 
18
+ # This model is used for task that a larger model may not need to do
19
+ # as of currently we have been getting MODEL OVERLOADED errors
20
+ # with huggingface
21
+ SEVEN_B_LLM_MODEL = HuggingFaceEndpoint(
22
+ repo_id=SEVEN_B_LLM_MODEL,
23
  temperature=0.1, # Controls randomness in response generation (lower value means less random)
24
  max_new_tokens=1024, # Maximum number of new tokens to generate in responses
25
  repetition_penalty=1.2, # Penalty for repeating the same words (higher value increases penalty)
example.env CHANGED
@@ -25,3 +25,5 @@ EMBEDDING_MODEL="sentence-transformers/distiluse-base-multilingual-cased-v2"
25
  #EMBEDDING_MODEL="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
26
  LLM_MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
27
  LLM_MODEL_ARGS=
 
 
 
25
  #EMBEDDING_MODEL="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
26
  LLM_MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
27
  LLM_MODEL_ARGS=
28
+
29
+ SEVEN_B_LLM_MODEL="mistralai/Mistral-7B-Instruct-v0.3"
rag_app/chains/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from rag_app.chains.user_response_sentiment_chain import user_response_sentiment_prompt
2
- from rag_app.chains.generate_document_summary import generate_document_summary_prompt
 
 
1
  from rag_app.chains.user_response_sentiment_chain import user_response_sentiment_prompt
2
+ from rag_app.chains.generate_document_summary import generate_document_summary_prompt
3
+ from rag_app.chains.query_rewritten_chain import query_rewritting_prompt
rag_app/chains/query_rewritten_chain.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import PromptTemplate
2
+
3
+
4
+ query_rewritting_template = """
5
+ You will be given a query from a user
6
+ =================
7
+ {user_query}
8
+ ====================
9
+
10
+ You must improve the query to optimize the result
11
+
12
+
13
+ """
14
+
15
+ query_rewritting_prompt = PromptTemplate.from_template(query_rewritting_template)
16
+
rag_app/knowledge_base/utils.py CHANGED
@@ -1,6 +1,6 @@
1
  from langchain_core.documents import Document
2
  from chains import generate_document_summary_prompt
3
- from config import SECONDARY_LLM
4
 
5
 
6
  def generate_document_summaries(
@@ -27,7 +27,7 @@ def generate_document_summaries(
27
 
28
  for doc in new_docs:
29
 
30
- genrate_summary_chain = generate_document_summary_prompt | SECONDARY_LLM
31
  summary = genrate_summary_chain.invoke(
32
  {"document":str(doc.metadata)}
33
  )
 
1
  from langchain_core.documents import Document
2
  from chains import generate_document_summary_prompt
3
+ from config import SEVEN_B_LLM_MODEL
4
 
5
 
6
  def generate_document_summaries(
 
27
 
28
  for doc in new_docs:
29
 
30
+ genrate_summary_chain = generate_document_summary_prompt | SEVEN_B_LLM_MODEL
31
  summary = genrate_summary_chain.invoke(
32
  {"document":str(doc.metadata)}
33
  )
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  langchain
2
  langchain-community
3
- langchain-HuggingFace
4
  langchain-text-splitters
5
  langchain_google_community
6
  beautifulsoup4
@@ -15,5 +15,4 @@ gradio
15
  boto3
16
  rich
17
  sqlmodel
18
- python-dotenv
19
- langchain_huggingface
 
1
  langchain
2
  langchain-community
3
+ langchain-huggingface
4
  langchain-text-splitters
5
  langchain_google_community
6
  beautifulsoup4
 
15
  boto3
16
  rich
17
  sqlmodel
18
+ python-dotenv