Gourisankar Padihary commited on
Commit
b58a992
·
1 Parent(s): 11b4c9f

Fix for other datasets

Browse files
data/load_dataset.py CHANGED
@@ -1,9 +1,9 @@
1
  import logging
2
  from datasets import load_dataset
3
 
4
- def load_data():
5
  logging.info("Loading dataset")
6
- dataset = load_dataset("rungalileo/ragbench", 'covidqa', split="test")
7
  logging.info("Dataset loaded successfully")
8
  logging.info(f"Number of documents found: {dataset.num_rows}")
9
  return dataset
 
1
  import logging
2
  from datasets import load_dataset
3
 
4
+ def load_data(data_set_name):
5
  logging.info("Loading dataset")
6
+ dataset = load_dataset("rungalileo/ragbench", data_set_name, split="test")
7
  logging.info("Dataset loaded successfully")
8
  logging.info(f"Number of documents found: {dataset.num_rows}")
9
  return dataset
generator/compute_metrics.py CHANGED
@@ -32,6 +32,7 @@ def compute_metrics(attributes, total_sentences):
32
 
33
  def get_metrics(attributes, total_sentences):
34
  if attributes.content:
 
35
  result_content = attributes.content # Access the content attribute
36
  # Extract the JSON part from the result_content
37
  json_start = result_content.find("{")
@@ -40,8 +41,6 @@ def get_metrics(attributes, total_sentences):
40
 
41
  try:
42
  result_json = json.loads(json_str)
43
- #print(json.dumps(result_json, indent=2))
44
-
45
  # Compute metrics using the extracted attributes
46
  metrics = compute_metrics(result_json, total_sentences)
47
  print(metrics)
 
32
 
33
  def get_metrics(attributes, total_sentences):
34
  if attributes.content:
35
+ #print(attributes)
36
  result_content = attributes.content # Access the content attribute
37
  # Extract the JSON part from the result_content
38
  json_start = result_content.find("{")
 
41
 
42
  try:
43
  result_json = json.loads(json_str)
 
 
44
  # Compute metrics using the extracted attributes
45
  metrics = compute_metrics(result_json, total_sentences)
46
  print(metrics)
generator/document_utils.py CHANGED
@@ -7,7 +7,7 @@ class Document:
7
 
8
  def apply_sentence_keys_documents(relevant_docs: List[Document]):
9
  result = []
10
- for i, doc in enumerate(relevant_docs):
11
  doc_id = str(i)
12
  title_passage = doc.page_content.split('\nPassage: ')
13
  title = title_passage[0]
@@ -19,7 +19,13 @@ def apply_sentence_keys_documents(relevant_docs: List[Document]):
19
  for j, passage in enumerate(passages):
20
  doc_result.append([f"{doc_id}{chr(98 + j)}", passage])
21
 
22
- result.append(doc_result)
 
 
 
 
 
 
23
 
24
  return result
25
 
 
7
 
8
  def apply_sentence_keys_documents(relevant_docs: List[Document]):
9
  result = []
10
+ '''for i, doc in enumerate(relevant_docs):
11
  doc_id = str(i)
12
  title_passage = doc.page_content.split('\nPassage: ')
13
  title = title_passage[0]
 
19
  for j, passage in enumerate(passages):
20
  doc_result.append([f"{doc_id}{chr(98 + j)}", passage])
21
 
22
+ result.append(doc_result)'''
23
+
24
+ for relevant_doc_index, relevant_doc in enumerate(relevant_docs):
25
+ sentences = []
26
+ for sentence_index, sentence in enumerate(relevant_doc.page_content.split(".")):
27
+ sentences.append([str(relevant_doc_index)+chr(97 + sentence_index), sentence])
28
+ result.append(sentences)
29
 
30
  return result
31
 
generator/initialize_llm.py CHANGED
@@ -1,12 +1,17 @@
 
1
  import os
2
  from langchain_groq import ChatGroq
3
 
4
  def initialize_llm():
5
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
6
- llm = ChatGroq(model="llama3-8b-8192", temperature=0.7)
 
 
7
  return llm
8
 
9
  def initialize_validation_llm():
10
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
11
- llm = ChatGroq(model="llama3-70b-8192", temperature=0.7)
 
 
12
  return llm
 
1
+ import logging
2
  import os
3
  from langchain_groq import ChatGroq
4
 
5
  def initialize_llm():
6
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
7
+ model_name = "llama3-8b-8192"
8
+ llm = ChatGroq(model=model_name, temperature=0.7)
9
+ logging.info(f'Generation LLM {model_name} initialized')
10
  return llm
11
 
12
  def initialize_validation_llm():
13
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
14
+ model_name = "llama-3.1-8b-instant"
15
+ llm = ChatGroq(model=model_name, temperature=0.7)
16
+ logging.info(f'Validation LLM {model_name} initialized')
17
  return llm
main.py CHANGED
@@ -11,13 +11,17 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
11
 
12
  def main():
13
  logging.info("Starting the RAG pipeline")
 
14
 
15
  # Load the dataset
16
- dataset = load_data()
17
  logging.info("Dataset loaded")
18
 
19
  # Chunk the dataset
20
- documents = chunk_documents(dataset)
 
 
 
21
  logging.info("Documents chunked")
22
 
23
  # Embed the documents
@@ -26,17 +30,16 @@ def main():
26
 
27
  # Initialize the Generation LLM
28
  llm = initialize_llm()
29
- logging.info("LLM initialized")
30
 
31
  # Sample question
32
- row_num = 43
33
  sample_question = dataset[row_num]['question']
34
 
35
  # Call generate_metrics for above sample question
36
  generate_metrics(llm, vector_store, sample_question)
37
 
38
  #Compute RMSE and AUC-ROC for entire dataset
39
- compute_rmse_auc_roc_metrics(llm, dataset, vector_store, dataset.num_rows)
40
 
41
  logging.info("Finished!!!")
42
 
 
11
 
12
  def main():
13
  logging.info("Starting the RAG pipeline")
14
+ data_set_name = 'techqa'
15
 
16
  # Load the dataset
17
+ dataset = load_data(data_set_name)
18
  logging.info("Dataset loaded")
19
 
20
  # Chunk the dataset
21
+ chunk_size = 1000 # default value
22
+ if data_set_name == 'cuad':
23
+ chunk_size = 3000
24
+ documents = chunk_documents(dataset, chunk_size)
25
  logging.info("Documents chunked")
26
 
27
  # Embed the documents
 
30
 
31
  # Initialize the Generation LLM
32
  llm = initialize_llm()
 
33
 
34
  # Sample question
35
+ row_num = 10
36
  sample_question = dataset[row_num]['question']
37
 
38
  # Call generate_metrics for above sample question
39
  generate_metrics(llm, vector_store, sample_question)
40
 
41
  #Compute RMSE and AUC-ROC for entire dataset
42
+ #compute_rmse_auc_roc_metrics(llm, dataset, vector_store, 10)
43
 
44
  logging.info("Finished!!!")
45
 
retriever/chunk_documents.py CHANGED
@@ -1,7 +1,7 @@
1
- from langchain.text_splitter import CharacterTextSplitter
2
 
3
  def chunk_documents(dataset, chunk_size=1000, chunk_overlap=200):
4
- text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
5
  documents = []
6
  for data in dataset:
7
  text_list = data['documents']
 
1
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
 
3
  def chunk_documents(dataset, chunk_size=1000, chunk_overlap=200):
4
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
5
  documents = []
6
  for data in dataset:
7
  text_list = data['documents']