File size: 2,419 Bytes
f7c2fa3
5485d7c
026aeba
e384879
026aeba
 
9bde774
 
e234b58
bd69eee
 
 
026aeba
 
bd69eee
 
5184c29
 
 
5485d7c
 
 
5184c29
 
bd69eee
5184c29
5485d7c
5184c29
 
 
 
5485d7c
5184c29
 
 
 
 
 
 
 
 
026aeba
5184c29
bd69eee
79dcf63
e384879
2889c96
9bde774
 
2889c96
bd69eee
f7c2fa3
5485d7c
 
5184c29
e384879
e234b58
2889c96
5485d7c
e234b58
e384879
 
026aeba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import logging
from config import AppConfig, ConfigConstants
from data.load_dataset import load_data
from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
from retriever.chunk_documents import chunk_documents
from retriever.embed_documents import embed_documents
from generator.initialize_llm import initialize_generation_llm
from generator.initialize_llm import initialize_validation_llm
from app import launch_gradio

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def main():
    logging.info("Starting the RAG pipeline")

    # Dictionary to store chunked documents
    all_chunked_documents = []
    datasets = {}

    # Load multiple datasets
    for data_set_name in ConfigConstants.DATA_SET_NAMES:
        logging.info(f"Loading dataset: {data_set_name}")
        datasets[data_set_name] = load_data(data_set_name)

        # Set chunk size based on dataset name
        chunk_size = ConfigConstants.DEFAULT_CHUNK_SIZE
        if data_set_name == 'cuad':
            chunk_size = 4000  # Custom chunk size for 'cuad'
        
        # Chunk documents
        chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=ConfigConstants.CHUNK_OVERLAP)
        all_chunked_documents.extend(chunked_documents)  # Combine all chunks

    # Access individual datasets
    #for name, dataset in datasets.items():
        #logging.info(f"Loaded {name} with {dataset.num_rows} rows")
    
    # Logging final count
    logging.info(f"Total chunked documents: {len(all_chunked_documents)}")
    
    # Embed the documents
    vector_store = embed_documents(all_chunked_documents)
    logging.info("Documents embedded")
    
     # Initialize the Generation LLM
    gen_llm = initialize_generation_llm(ConfigConstants.GENERATION_MODEL_NAME)

    # Initialize the Validation LLM
    val_llm = initialize_validation_llm(ConfigConstants.VALIDATION_MODEL_NAME)

    #Compute RMSE and AUC-ROC for entire dataset
    #Enable below code for calculation
    #data_set_name = 'covidqa'
    #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
    
    # Launch the Gradio app
    config = AppConfig(vector_store= vector_store, gen_llm = gen_llm, val_llm = val_llm)
    launch_gradio(config)

    logging.info("Finished!!!")

if __name__ == "__main__":
    main()