ajalisatgi commited on
Commit
73ab43d
·
verified ·
1 Parent(s): 636e240

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -35
app.py CHANGED
@@ -16,22 +16,31 @@ import os
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- # Download all required NLTK data upfront
20
  nltk.download('punkt')
21
  nltk.download('punkt_tab')
22
  nltk.download('averaged_perceptron_tagger')
23
  nltk.download('stopwords')
24
 
25
  # Initialize OpenAI API key
26
- openai.api_key = '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' # Replace with your API key
27
 
28
- # Load the ragbench datasets
 
29
  ragbench = {}
30
- for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']:
31
- ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
32
- logger.info(f"Loaded {dataset}")
33
 
34
- # Initialize with a stronger model
 
 
 
 
 
 
 
 
 
 
35
  model_name = 'sentence-transformers/all-mpnet-base-v2'
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  embedding_model = HuggingFaceEmbeddings(model_name=model_name)
@@ -65,41 +74,28 @@ def chunk_documents_semantic(documents, max_chunk_size=500):
65
  chunks.append(current_chunk.strip())
66
  return chunks
67
 
68
- # Process documents in batches
69
- batch_size = 1000
70
  documents = []
71
- total_processed = 0
72
-
73
- for dataset_name in tqdm(ragbench.keys(), desc="Processing datasets"):
74
- for split in ragbench[dataset_name].keys():
75
- original_documents = ragbench[dataset_name][split]['documents']
76
-
77
- for i in range(0, len(original_documents), batch_size):
78
- batch = original_documents[i:i + batch_size]
79
- chunked_documents = chunk_documents_semantic(batch)
80
- documents.extend([Document(page_content=chunk) for chunk in chunked_documents])
81
-
82
- if len(documents) >= batch_size:
83
- vectordb = Chroma.from_documents(
84
- documents=documents,
85
- embedding=embedding_model,
86
- persist_directory=f'./docs/chroma_{total_processed}'
87
- )
88
- vectordb.persist()
89
- total_processed += len(documents)
90
- documents = []
91
 
92
- # Final vector store
93
- final_vectordb = Chroma(
94
- persist_directory='./docs/chroma_final/',
95
- embedding_function=embedding_model
 
96
  )
 
97
 
98
  def process_query(query, dataset_choice):
99
  try:
100
  logger.info(f"Processing query for {dataset_choice}: {query}")
101
 
102
- relevant_docs = final_vectordb.max_marginal_relevance_search(
103
  query,
104
  k=5,
105
  fetch_k=10
@@ -123,7 +119,7 @@ def process_query(query, dataset_choice):
123
  logger.error(f"Error processing query: {str(e)}")
124
  return f"Error: {str(e)}"
125
 
126
- # Create Gradio interface with dataset selection
127
  demo = gr.Interface(
128
  fn=process_query,
129
  inputs=[
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Download NLTK data
20
  nltk.download('punkt')
21
  nltk.download('punkt_tab')
22
  nltk.download('averaged_perceptron_tagger')
23
  nltk.download('stopwords')
24
 
25
  # Initialize OpenAI API key
26
+ openai.api_key = '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'
27
 
28
+ # Load selected datasets
29
+ logger.info("Starting dataset loading...")
30
  ragbench = {}
31
+ datasets_to_load = ['covidqa', 'hotpotqa', 'pubmedqa']
 
 
32
 
33
+ for dataset in datasets_to_load:
34
+ try:
35
+ ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset, split='train')
36
+ logger.info(f"Successfully loaded {dataset}")
37
+ except Exception as e:
38
+ logger.error(f"Failed to load {dataset}: {e}")
39
+ continue
40
+
41
+ print(f"Loaded {len(ragbench)} datasets successfully")
42
+
43
+ # Initialize embedding model
44
  model_name = 'sentence-transformers/all-mpnet-base-v2'
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
  embedding_model = HuggingFaceEmbeddings(model_name=model_name)
 
74
  chunks.append(current_chunk.strip())
75
  return chunks
76
 
77
+ # Process documents
 
78
  documents = []
79
+ for dataset_name, dataset in ragbench.items():
80
+ logger.info(f"Processing {dataset_name}")
81
+ original_documents = dataset['documents']
82
+ chunked_documents = chunk_documents_semantic(original_documents)
83
+ documents.extend([Document(page_content=chunk) for chunk in chunked_documents])
84
+ logger.info(f"Processed {len(chunked_documents)} chunks from {dataset_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # Initialize vectordb
87
+ vectordb = Chroma.from_documents(
88
+ documents=documents,
89
+ embedding=embedding_model,
90
+ persist_directory='./docs/chroma/'
91
  )
92
+ vectordb.persist()
93
 
94
  def process_query(query, dataset_choice):
95
  try:
96
  logger.info(f"Processing query for {dataset_choice}: {query}")
97
 
98
+ relevant_docs = vectordb.max_marginal_relevance_search(
99
  query,
100
  k=5,
101
  fetch_k=10
 
119
  logger.error(f"Error processing query: {str(e)}")
120
  return f"Error: {str(e)}"
121
 
122
+ # Create Gradio interface
123
  demo = gr.Interface(
124
  fn=process_query,
125
  inputs=[