jeremierostan commited on
Commit
fc75c0c
·
verified ·
1 Parent(s): c692389

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -44
app.py CHANGED
@@ -13,10 +13,10 @@ from langchain.chains import create_retrieval_chain
13
  import os
14
  import markdown2
15
 
16
- # Retrieve API keys from HF secrets
17
- openai_api_key = os.getenv('OPENAI_API_KEY')
18
- groq_api_key = os.getenv('GROQ_API_KEY')
19
- google_api_key = os.getenv('GEMINI_API_KEY')
20
 
21
  # Initialize API clients with the API keys
22
  openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
@@ -25,7 +25,11 @@ gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_ap
25
 
26
  # Function to extract text from PDF
27
  def extract_pdf(pdf_path):
28
- return extract_text(pdf_path)
 
 
 
 
29
 
30
  # Function to split text into chunks
31
  def split_text(text):
@@ -37,33 +41,29 @@ def generate_embeddings(docs):
37
  embeddings = OpenAIEmbeddings(api_key=openai_api_key)
38
  return FAISS.from_documents(docs, embeddings)
39
 
40
- # Function for query preprocessing and simple HyDE-Lite
41
  def preprocess_query(query):
42
  prompt = ChatPromptTemplate.from_template("""
43
- Your role is to optimize user queries for retrieval from official regulation documents about data protection.
44
- Transform the query into a more affirmative, keyword-focused statement.
45
- The transformed query should look like probable related passages in the official documents.
46
-
47
  Query: {query}
48
-
49
- Optimized query:
50
  """)
51
  chain = prompt | openai_client
52
  return chain.invoke({"query": query}).content
53
 
54
  # Function to create RAG chain with Groq
55
- def create_rag_chain():
56
  prompt = ChatPromptTemplate.from_messages([
57
- ("system", "You are an AI assistant helping with data protection related queries. Use the following context from the official regulation documents to answer the user's question:\n\n{context}"),
58
  ("human", "{input}")
59
  ])
60
  document_chain = create_stuff_documents_chain(groq_client, prompt)
61
  return create_retrieval_chain(vector_store.as_retriever(), document_chain)
62
 
63
  # Function for Gemini response with long context
64
- def gemini_response(query):
65
  prompt = ChatPromptTemplate.from_messages([
66
- ("system", "You are an AI assistant helping with data protection related queries. Use the following full content of the official regulation documents to answer the user's question:\n\n{context}"),
67
  ("human", "{input}")
68
  ])
69
  chain = prompt | gemini_client
@@ -72,35 +72,39 @@ def gemini_response(query):
72
  # Function to generate final response
73
  def generate_final_response(query, response1, response2):
74
  prompt = ChatPromptTemplate.from_template("""
75
- You are an AI assistant helping educators understand and implement data protection and compliance with official regulations when using AI.
76
- Your goal is to provide simple, practical explanation of and advice on how to meet these regulatory requirements based on the 2 given responses.
77
- To do so:
78
- 1. Read the user query
79
- 2. Analyze the following two responses. Inspect their content, and highlight their differences. This MUST be done
80
- internally as a hidden state.
81
- 2. Then, use this information to output your own response to the user query, synthesizing the responses all while maintaining their strengths
82
- If the responses differ or contradict each other on important points, include that in your response as this could be a sign of hallucination.
83
- Only output your own final response to the user query.
 
 
 
 
84
  """)
85
  chain = prompt | openai_client
86
  return chain.invoke({"query": query, "response1": response1, "response2": response2}).content
87
 
88
  # Function to process the query
89
  def process_query(user_query):
90
- preprocessed_query = preprocess_query(user_query)
91
- print(f"Original query: {user_query}")
92
- print(f"Preprocessed query: {preprocessed_query}")
93
-
94
- # Get RAG response using Groq with the preprocessed query
95
- rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
96
-
97
- # Get Gemini response with full PDF content and preprocessed query
98
- gemini_resp = gemini_response(preprocessed_query)
99
-
100
- final_response = generate_final_response(user_query, rag_response, gemini_resp)
101
- html_content = markdown_to_html(final_response)
102
-
103
- return rag_response, gemini_resp, html_content
104
 
105
  # Initialize
106
  pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
@@ -113,12 +117,8 @@ for pdf_path in pdf_paths:
113
  all_documents.extend(split_text(extracted_text))
114
 
115
  vector_store = generate_embeddings(all_documents)
116
- rag_chain = create_rag_chain()
117
 
118
- # Function to output the final response as markdown
119
- def markdown_to_html(content):
120
- return markdown2.markdown(content)
121
-
122
  # Gradio interface
123
  iface = gr.Interface(
124
  fn=process_query,
@@ -133,4 +133,5 @@ iface = gr.Interface(
133
  allow_flagging="never"
134
  )
135
 
 
136
  iface.launch()
 
13
  import os
14
  import markdown2
15
 
16
+ # Retrieve API keys from Hugging Face Spaces secrets
17
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
18
+ groq_api_key = os.environ.get('GROQ_API_KEY')
19
+ google_api_key = os.environ.get('GEMINI_API_KEY')
20
 
21
  # Initialize API clients with the API keys
22
  openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
 
25
 
26
  # Function to extract text from PDF
27
  def extract_pdf(pdf_path):
28
+ try:
29
+ return extract_text(pdf_path)
30
+ except Exception as e:
31
+ print(f"Error extracting text from {pdf_path}: {str(e)}")
32
+ return ""
33
 
34
  # Function to split text into chunks
35
  def split_text(text):
 
41
  embeddings = OpenAIEmbeddings(api_key=openai_api_key)
42
  return FAISS.from_documents(docs, embeddings)
43
 
44
+ # Function for query preprocessing
45
  def preprocess_query(query):
46
  prompt = ChatPromptTemplate.from_template("""
47
+ Transform the following query into a more detailed, keyword-rich statement that could appear in official data protection regulation documents:
 
 
 
48
  Query: {query}
49
+ Transformed query:
 
50
  """)
51
  chain = prompt | openai_client
52
  return chain.invoke({"query": query}).content
53
 
54
  # Function to create RAG chain with Groq
55
+ def create_rag_chain(vector_store):
56
  prompt = ChatPromptTemplate.from_messages([
57
+ ("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following context to answer the user's question:\n\n{context}"),
58
  ("human", "{input}")
59
  ])
60
  document_chain = create_stuff_documents_chain(groq_client, prompt)
61
  return create_retrieval_chain(vector_store.as_retriever(), document_chain)
62
 
63
  # Function for Gemini response with long context
64
+ def gemini_response(query, full_pdf_content):
65
  prompt = ChatPromptTemplate.from_messages([
66
+ ("system", "You are an AI assistant helping with data protection and regulation compliance related queries.. Use the following full content of official regulation documents to answer the user's question:\n\n{context}"),
67
  ("human", "{input}")
68
  ])
69
  chain = prompt | gemini_client
 
72
  # Function to generate final response
73
  def generate_final_response(query, response1, response2):
74
  prompt = ChatPromptTemplate.from_template("""
75
+ As an AI assistant specializing in data protection and compliance for educators:
76
+ 1. Analyze the following two AI-generated responses to the user query.
77
+ 2. Synthesize a comprehensive answer that combines the strengths of both responses.
78
+ 3. If the responses contradict each other, highlight this and explain potential reasons.
79
+ 4. Provide practical advice on how to meet regulatory requirements in the context of the user question based on the information given.
80
+
81
+ User Query: {query}
82
+
83
+ Response 1: {response1}
84
+
85
+ Response 2: {response2}
86
+
87
+ Your synthesized response:
88
  """)
89
  chain = prompt | openai_client
90
  return chain.invoke({"query": query, "response1": response1, "response2": response2}).content
91
 
92
  # Function to process the query
93
  def process_query(user_query):
94
+ try:
95
+ preprocessed_query = preprocess_query(user_query)
96
+ print(f"Original query: {user_query}")
97
+ print(f"Preprocessed query: {preprocessed_query}")
98
+
99
+ rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
100
+ gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
101
+ final_response = generate_final_response(user_query, rag_response, gemini_resp)
102
+ html_content = markdown2.markdown(final_response)
103
+
104
+ return rag_response, gemini_resp, html_content
105
+ except Exception as e:
106
+ error_message = f"An error occurred: {str(e)}"
107
+ return error_message, error_message, error_message
108
 
109
  # Initialize
110
  pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
 
117
  all_documents.extend(split_text(extracted_text))
118
 
119
  vector_store = generate_embeddings(all_documents)
120
+ rag_chain = create_rag_chain(vector_store)
121
 
 
 
 
 
122
  # Gradio interface
123
  iface = gr.Interface(
124
  fn=process_query,
 
133
  allow_flagging="never"
134
  )
135
 
136
+ # Launch the interface
137
  iface.launch()