Gourisankar Padihary
commited on
Commit
·
b58a992
1
Parent(s):
11b4c9f
Fix for other datasets
Browse files- data/load_dataset.py +2 -2
- generator/compute_metrics.py +1 -2
- generator/document_utils.py +8 -2
- generator/initialize_llm.py +7 -2
- main.py +8 -5
- retriever/chunk_documents.py +2 -2
data/load_dataset.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import logging
|
2 |
from datasets import load_dataset
|
3 |
|
4 |
-
def load_data():
|
5 |
logging.info("Loading dataset")
|
6 |
-
dataset = load_dataset("rungalileo/ragbench",
|
7 |
logging.info("Dataset loaded successfully")
|
8 |
logging.info(f"Number of documents found: {dataset.num_rows}")
|
9 |
return dataset
|
|
|
1 |
import logging
|
2 |
from datasets import load_dataset
|
3 |
|
4 |
+
def load_data(data_set_name):
|
5 |
logging.info("Loading dataset")
|
6 |
+
dataset = load_dataset("rungalileo/ragbench", data_set_name, split="test")
|
7 |
logging.info("Dataset loaded successfully")
|
8 |
logging.info(f"Number of documents found: {dataset.num_rows}")
|
9 |
return dataset
|
generator/compute_metrics.py
CHANGED
@@ -32,6 +32,7 @@ def compute_metrics(attributes, total_sentences):
|
|
32 |
|
33 |
def get_metrics(attributes, total_sentences):
|
34 |
if attributes.content:
|
|
|
35 |
result_content = attributes.content # Access the content attribute
|
36 |
# Extract the JSON part from the result_content
|
37 |
json_start = result_content.find("{")
|
@@ -40,8 +41,6 @@ def get_metrics(attributes, total_sentences):
|
|
40 |
|
41 |
try:
|
42 |
result_json = json.loads(json_str)
|
43 |
-
#print(json.dumps(result_json, indent=2))
|
44 |
-
|
45 |
# Compute metrics using the extracted attributes
|
46 |
metrics = compute_metrics(result_json, total_sentences)
|
47 |
print(metrics)
|
|
|
32 |
|
33 |
def get_metrics(attributes, total_sentences):
|
34 |
if attributes.content:
|
35 |
+
#print(attributes)
|
36 |
result_content = attributes.content # Access the content attribute
|
37 |
# Extract the JSON part from the result_content
|
38 |
json_start = result_content.find("{")
|
|
|
41 |
|
42 |
try:
|
43 |
result_json = json.loads(json_str)
|
|
|
|
|
44 |
# Compute metrics using the extracted attributes
|
45 |
metrics = compute_metrics(result_json, total_sentences)
|
46 |
print(metrics)
|
generator/document_utils.py
CHANGED
@@ -7,7 +7,7 @@ class Document:
|
|
7 |
|
8 |
def apply_sentence_keys_documents(relevant_docs: List[Document]):
|
9 |
result = []
|
10 |
-
for i, doc in enumerate(relevant_docs):
|
11 |
doc_id = str(i)
|
12 |
title_passage = doc.page_content.split('\nPassage: ')
|
13 |
title = title_passage[0]
|
@@ -19,7 +19,13 @@ def apply_sentence_keys_documents(relevant_docs: List[Document]):
|
|
19 |
for j, passage in enumerate(passages):
|
20 |
doc_result.append([f"{doc_id}{chr(98 + j)}", passage])
|
21 |
|
22 |
-
result.append(doc_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
return result
|
25 |
|
|
|
7 |
|
8 |
def apply_sentence_keys_documents(relevant_docs: List[Document]):
|
9 |
result = []
|
10 |
+
'''for i, doc in enumerate(relevant_docs):
|
11 |
doc_id = str(i)
|
12 |
title_passage = doc.page_content.split('\nPassage: ')
|
13 |
title = title_passage[0]
|
|
|
19 |
for j, passage in enumerate(passages):
|
20 |
doc_result.append([f"{doc_id}{chr(98 + j)}", passage])
|
21 |
|
22 |
+
result.append(doc_result)'''
|
23 |
+
|
24 |
+
for relevant_doc_index, relevant_doc in enumerate(relevant_docs):
|
25 |
+
sentences = []
|
26 |
+
for sentence_index, sentence in enumerate(relevant_doc.page_content.split(".")):
|
27 |
+
sentences.append([str(relevant_doc_index)+chr(97 + sentence_index), sentence])
|
28 |
+
result.append(sentences)
|
29 |
|
30 |
return result
|
31 |
|
generator/initialize_llm.py
CHANGED
@@ -1,12 +1,17 @@
|
|
|
|
1 |
import os
|
2 |
from langchain_groq import ChatGroq
|
3 |
|
4 |
def initialize_llm():
|
5 |
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
6 |
-
|
|
|
|
|
7 |
return llm
|
8 |
|
9 |
def initialize_validation_llm():
|
10 |
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
11 |
-
|
|
|
|
|
12 |
return llm
|
|
|
1 |
+
import logging
|
2 |
import os
|
3 |
from langchain_groq import ChatGroq
|
4 |
|
5 |
def initialize_llm():
|
6 |
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
7 |
+
model_name = "llama3-8b-8192"
|
8 |
+
llm = ChatGroq(model=model_name, temperature=0.7)
|
9 |
+
logging.info(f'Generation LLM {model_name} initialized')
|
10 |
return llm
|
11 |
|
12 |
def initialize_validation_llm():
|
13 |
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
14 |
+
model_name = "llama-3.1-8b-instant"
|
15 |
+
llm = ChatGroq(model=model_name, temperature=0.7)
|
16 |
+
logging.info(f'Validation LLM {model_name} initialized')
|
17 |
return llm
|
main.py
CHANGED
@@ -11,13 +11,17 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
11 |
|
12 |
def main():
|
13 |
logging.info("Starting the RAG pipeline")
|
|
|
14 |
|
15 |
# Load the dataset
|
16 |
-
dataset = load_data()
|
17 |
logging.info("Dataset loaded")
|
18 |
|
19 |
# Chunk the dataset
|
20 |
-
|
|
|
|
|
|
|
21 |
logging.info("Documents chunked")
|
22 |
|
23 |
# Embed the documents
|
@@ -26,17 +30,16 @@ def main():
|
|
26 |
|
27 |
# Initialize the Generation LLM
|
28 |
llm = initialize_llm()
|
29 |
-
logging.info("LLM initialized")
|
30 |
|
31 |
# Sample question
|
32 |
-
row_num =
|
33 |
sample_question = dataset[row_num]['question']
|
34 |
|
35 |
# Call generate_metrics for above sample question
|
36 |
generate_metrics(llm, vector_store, sample_question)
|
37 |
|
38 |
#Compute RMSE and AUC-ROC for entire dataset
|
39 |
-
compute_rmse_auc_roc_metrics(llm, dataset, vector_store,
|
40 |
|
41 |
logging.info("Finished!!!")
|
42 |
|
|
|
11 |
|
12 |
def main():
|
13 |
logging.info("Starting the RAG pipeline")
|
14 |
+
data_set_name = 'techqa'
|
15 |
|
16 |
# Load the dataset
|
17 |
+
dataset = load_data(data_set_name)
|
18 |
logging.info("Dataset loaded")
|
19 |
|
20 |
# Chunk the dataset
|
21 |
+
chunk_size = 1000 # default value
|
22 |
+
if data_set_name == 'cuad':
|
23 |
+
chunk_size = 3000
|
24 |
+
documents = chunk_documents(dataset, chunk_size)
|
25 |
logging.info("Documents chunked")
|
26 |
|
27 |
# Embed the documents
|
|
|
30 |
|
31 |
# Initialize the Generation LLM
|
32 |
llm = initialize_llm()
|
|
|
33 |
|
34 |
# Sample question
|
35 |
+
row_num = 10
|
36 |
sample_question = dataset[row_num]['question']
|
37 |
|
38 |
# Call generate_metrics for above sample question
|
39 |
generate_metrics(llm, vector_store, sample_question)
|
40 |
|
41 |
#Compute RMSE and AUC-ROC for entire dataset
|
42 |
+
#compute_rmse_auc_roc_metrics(llm, dataset, vector_store, 10)
|
43 |
|
44 |
logging.info("Finished!!!")
|
45 |
|
retriever/chunk_documents.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from langchain.text_splitter import
|
2 |
|
3 |
def chunk_documents(dataset, chunk_size=1000, chunk_overlap=200):
|
4 |
-
text_splitter =
|
5 |
documents = []
|
6 |
for data in dataset:
|
7 |
text_list = data['documents']
|
|
|
1 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
2 |
|
3 |
def chunk_documents(dataset, chunk_size=1000, chunk_overlap=200):
|
4 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
5 |
documents = []
|
6 |
for data in dataset:
|
7 |
text_list = data['documents']
|