ajalisatgi commited on
Commit
8dfd657
·
verified ·
1 Parent(s): 1bbf06d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -73
app.py CHANGED
@@ -13,108 +13,67 @@ import nltk
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- # Initialize OpenAI API key
17
- openai.api_key = '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' # Replace with your API key
 
 
 
18
 
19
- # Download NLTK data
20
- nltk.download('punkt')
21
-
22
- # Initialize models and configurations
23
- model_name = 'intfloat/e5-small'
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  embedding_model = HuggingFaceEmbeddings(model_name=model_name)
26
  embedding_model.client.to(device)
27
 
28
- # Initialize Chroma with existing database
29
- vectordb = Chroma(
30
- persist_directory='./docs/chroma/',
31
- embedding_function=embedding_model
32
- )
33
-
34
- def process_query(query):
35
  try:
36
- logger.info(f"Processing query: {query}")
37
 
38
- # Get relevant documents
39
- relevant_docs = vectordb.similarity_search(query, k=30)
40
- context = " ".join([doc.page_content for doc in relevant_docs])
 
 
 
41
 
42
- # Add delay to respect API rate limits
43
- time.sleep(1)
44
 
45
- # Generate response using OpenAI
46
  response = openai.chat.completions.create(
47
  model="gpt-4",
48
  messages=[
49
- {"role": "system", "content": "You are a helpful assistant."},
50
- {"role": "user", "content": f"Given the document: {context}\n\nGenerate a response to the query: {query}"}
51
  ],
52
  max_tokens=300,
53
  temperature=0.7,
54
  )
55
 
56
- answer = response.choices[0].message.content.strip()
57
- logger.info("Successfully generated response")
58
-
59
- # Extract and display metrics
60
- metrics = extract_metrics(query, answer, relevant_docs)
61
-
62
- return answer, metrics
63
 
64
  except Exception as e:
65
  logger.error(f"Error processing query: {str(e)}")
66
- return f"Error: {str(e)}", "Metrics unavailable"
67
 
68
- def extract_metrics(query, response, relevant_docs):
69
- try:
70
- context = " ".join([doc.page_content for doc in relevant_docs])
71
- metrics_prompt = f"""
72
- Question: {query}
73
- Context: {context}
74
- Response: {response}
75
-
76
- Extract metrics for:
77
- - Context Relevance
78
- - Context Utilization
79
- - Completeness
80
- - Response Quality
81
- """
82
-
83
- metrics_response = openai.chat.completions.create(
84
- model="gpt-4",
85
- messages=[{"role": "user", "content": metrics_prompt}],
86
- max_tokens=150,
87
- temperature=0.7,
88
- )
89
-
90
- return metrics_response.choices[0].message.content.strip()
91
- except Exception as e:
92
- return "Metrics calculation failed"
93
-
94
- # Create Gradio interface
95
  demo = gr.Interface(
96
  fn=process_query,
97
  inputs=[
98
- gr.Textbox(
99
- label="Enter your question",
100
- placeholder="Type your question here...",
101
- lines=2
 
102
  )
103
  ],
104
- outputs=[
105
- gr.Textbox(label="Answer", lines=5),
106
- gr.Textbox(label="Metrics", lines=4)
107
- ],
108
- title="RAG-Powered Question Answering System",
109
- description="Ask questions and get answers based on the embedded document knowledge.",
110
  examples=[
111
- ["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?"],
112
- ["In what school district is Governor John R. Rogers High School located?"],
113
- ["Is there a functional neural correlate of individual differences in cardiovascular reactivity?"],
114
- ["How do I select Natural mode?"]
115
  ]
116
  )
117
 
118
- # Launch with debugging enabled
119
  if __name__ == "__main__":
120
  demo.launch(debug=True)
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # Load the ragbench datasets
17
+ ragbench = {}
18
+ for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']:
19
+ ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
20
+ logger.info(f"Loaded {dataset}")
21
 
22
+ # Initialize with a stronger model for better semantic understanding
23
+ model_name = 'sentence-transformers/all-mpnet-base-v2'
 
 
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  embedding_model = HuggingFaceEmbeddings(model_name=model_name)
26
  embedding_model.client.to(device)
27
 
28
+ def process_query(query, dataset_choice):
 
 
 
 
 
 
29
  try:
30
+ logger.info(f"Processing query for {dataset_choice}: {query}")
31
 
32
+ # Get relevant documents specific to the chosen dataset
33
+ relevant_docs = vectordb.max_marginal_relevance_search(
34
+ query,
35
+ k=5, # Top 5 most relevant documents
36
+ fetch_k=10 # Fetch top 10 then select most diverse 5
37
+ )
38
 
39
+ context = " ".join([doc.page_content for doc in relevant_docs])
 
40
 
 
41
  response = openai.chat.completions.create(
42
  model="gpt-4",
43
  messages=[
44
+ {"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."},
45
+ {"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."}
46
  ],
47
  max_tokens=300,
48
  temperature=0.7,
49
  )
50
 
51
+ return response.choices[0].message.content.strip()
 
 
 
 
 
 
52
 
53
  except Exception as e:
54
  logger.error(f"Error processing query: {str(e)}")
55
+ return f"Error: {str(e)}"
56
 
57
+ # Create Gradio interface with dataset selection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  demo = gr.Interface(
59
  fn=process_query,
60
  inputs=[
61
+ gr.Textbox(label="Question", placeholder="Type your question here...", lines=2),
62
+ gr.Dropdown(
63
+ choices=list(ragbench.keys()),
64
+ label="Select Dataset",
65
+ value="hotpotqa"
66
  )
67
  ],
68
+ outputs=gr.Textbox(label="Answer", lines=5),
69
+ title="RagBench Question Answering System",
70
+ description="Ask questions across different RagBench datasets",
 
 
 
71
  examples=[
72
+ ["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"],
73
+ ["In what school district is Governor John R. Rogers High School located?", "hotpotqa"],
74
+ ["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"]
 
75
  ]
76
  )
77
 
 
78
  if __name__ == "__main__":
79
  demo.launch(debug=True)