Gourisankar Padihary commited on
Commit
0ea6d19
·
1 Parent(s): 9bde774

Changes for techqa data set

Browse files
generator/generate_metrics.py CHANGED
@@ -22,7 +22,7 @@ def generate_metrics(gen_llm, val_llm, vector_store, query):
22
  logging.info(f"Response from LLM: {response}")
23
 
24
  # Add a sleep interval to avoid hitting the rate limit
25
- time.sleep(20) # Adjust the sleep time as needed
26
 
27
  # Step 3: Extract attributes and total sentences for each query
28
  logging.info(f"Extracting attributes through validation LLM")
 
22
  logging.info(f"Response from LLM: {response}")
23
 
24
  # Add a sleep interval to avoid hitting the rate limit
25
+ time.sleep(25) # Adjust the sleep time as needed
26
 
27
  # Step 3: Extract attributes and total sentences for each query
28
  logging.info(f"Extracting attributes through validation LLM")
generator/initialize_llm.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import os
3
  from langchain_groq import ChatGroq
4
 
5
- def initialize_llm():
6
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
7
  model_name = "llama3-8b-8192"
8
  llm = ChatGroq(model=model_name, temperature=0.7)
@@ -11,7 +11,7 @@ def initialize_llm():
11
 
12
  def initialize_validation_llm():
13
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
14
- model_name = "llama-3.1-8b-instant"
15
  llm = ChatGroq(model=model_name, temperature=0.7)
16
  logging.info(f'Validation LLM {model_name} initialized')
17
  return llm
 
2
  import os
3
  from langchain_groq import ChatGroq
4
 
5
+ def initialize_generation_llm():
6
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
7
  model_name = "llama3-8b-8192"
8
  llm = ChatGroq(model=model_name, temperature=0.7)
 
11
 
12
  def initialize_validation_llm():
13
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
14
+ model_name = "llama3-70b-8192"
15
  llm = ChatGroq(model=model_name, temperature=0.7)
16
  logging.info(f'Validation LLM {model_name} initialized')
17
  return llm
main.py CHANGED
@@ -12,7 +12,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
12
 
13
  def main():
14
  logging.info("Starting the RAG pipeline")
15
- data_set_name = 'covidqa'
16
 
17
  # Load the dataset
18
  dataset = load_data(data_set_name)
@@ -36,7 +36,7 @@ def main():
36
  val_llm = initialize_validation_llm()
37
 
38
  # Sample question
39
- row_num = 10
40
  query = dataset[row_num]['question']
41
 
42
  # Call generate_metrics for above sample question
 
12
 
13
  def main():
14
  logging.info("Starting the RAG pipeline")
15
+ data_set_name = 'techqa'
16
 
17
  # Load the dataset
18
  dataset = load_data(data_set_name)
 
36
  val_llm = initialize_validation_llm()
37
 
38
  # Sample question
39
+ row_num = 7
40
  query = dataset[row_num]['question']
41
 
42
  # Call generate_metrics for above sample question
retriever/embed_documents.py CHANGED
@@ -2,6 +2,6 @@ from langchain_huggingface import HuggingFaceEmbeddings
2
  from langchain_community.vectorstores import FAISS
3
 
4
  def embed_documents(documents):
5
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
6
  vector_store = FAISS.from_texts([doc['text'] for doc in documents], embedding_model)
7
  return vector_store
 
2
  from langchain_community.vectorstores import FAISS
3
 
4
  def embed_documents(documents):
5
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-MiniLM-L3-v2")
6
  vector_store = FAISS.from_texts([doc['text'] for doc in documents], embedding_model)
7
  return vector_store