File size: 1,616 Bytes
f7c2fa3 026aeba e384879 026aeba e384879 9bde774 bd69eee 026aeba bd69eee cfb3435 bd69eee 026aeba b58a992 bd69eee 026aeba b58a992 bd69eee 026aeba bd69eee 79dcf63 e384879 9bde774 bd69eee e384879 b58a992 9bde774 026aeba e384879 9bde774 f7c2fa3 9bde774 e384879 026aeba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import logging
from data.load_dataset import load_data
from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
from retriever.chunk_documents import chunk_documents
from retriever.embed_documents import embed_documents
from generator.generate_metrics import generate_metrics
from generator.initialize_llm import initialize_generation_llm
from generator.initialize_llm import initialize_validation_llm
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def main():
logging.info("Starting the RAG pipeline")
data_set_name = 'covidqa'
# Load the dataset
dataset = load_data(data_set_name)
logging.info("Dataset loaded")
# Chunk the dataset
chunk_size = 1000 # default value
if data_set_name == 'cuad':
chunk_size = 3000
documents = chunk_documents(dataset, chunk_size)
logging.info("Documents chunked")
# Embed the documents
vector_store = embed_documents(documents)
logging.info("Documents embedded")
# Initialize the Generation LLM
gen_llm = initialize_generation_llm()
# Initialize the Validation LLM
val_llm = initialize_validation_llm()
# Sample question
row_num = 10
query = dataset[row_num]['question']
# Call generate_metrics for above sample question
generate_metrics(gen_llm, val_llm, vector_store, query)
#Compute RMSE and AUC-ROC for entire dataset
compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
logging.info("Finished!!!")
if __name__ == "__main__":
main() |