|
import gradio as gr |
|
from pdfminer.high_level import extract_text |
|
from langchain_groq import ChatGroq |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.schema import Document |
|
from langchain_openai import ChatOpenAI |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain.chains import create_retrieval_chain |
|
import os |
|
import markdown2 |
|
|
|
|
|
openai_api_key=os.getenv('OPENAI_API_KEY') |
|
groq_api_key=os.getenv('GROQ_API_KEY') |
|
google_api_key=os.getenv('GEMINI') |
|
|
|
|
|
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key) |
|
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key) |
|
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key) |
|
|
|
|
|
def extract_pdf(pdf_path): |
|
return extract_text(pdf_path) |
|
|
|
|
|
def split_text(text): |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
return [Document(page_content=t) for t in splitter.split_text(text)] |
|
|
|
|
|
def generate_embeddings(docs): |
|
embeddings = OpenAIEmbeddings(api_key=openai_api_key) |
|
return FAISS.from_documents(docs, embeddings) |
|
|
|
|
|
def preprocess_query(query): |
|
prompt = ChatPromptTemplate.from_template(""" |
|
Your role is to optimize user queries for retrieval from a GDPR regulation document. |
|
Transform the query into a more affirmative, keyword-focused statement. |
|
The transformed query should look like probable related passages in the official document. |
|
|
|
Query: {query} |
|
|
|
Optimized query: |
|
""") |
|
chain = prompt | openai_client |
|
return chain.invoke({"query": query}).content |
|
|
|
|
|
def create_rag_chain(): |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are an AI assistant helping with GDPR-related queries. Use the following context from the official GDPR regulation document to answer the user's question:\n\n{context}"), |
|
("human", "{input}") |
|
]) |
|
document_chain = create_stuff_documents_chain(groq_client, prompt) |
|
return create_retrieval_chain(vector_store.as_retriever(), document_chain) |
|
|
|
|
|
def gemini_response(query): |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are an AI assistant helping with GDPR-related queries. Use the following full content of the official GDPR regulation document to answer the user's question:\n\n{context}"), |
|
("human", "{input}") |
|
]) |
|
chain = prompt | gemini_client |
|
return chain.invoke({"context": full_pdf_content, "input": query}).content |
|
|
|
|
|
def generate_final_response(response1, response2): |
|
prompt = ChatPromptTemplate.from_template(""" |
|
You are an AI assistant helping educators understand and implement AI data protection and GDPR compliance. |
|
Your goal is to provide simple, practical explanation of and advice on how to meet GDPR requirements based on the given responses. |
|
To do so, analyze the following two responses, combining similar elements and highlighting any differences. This MUST be done |
|
internally as a hidden state. Only output your own final response. |
|
If the responses contradict each other on important points, include that in your response. |
|
""") |
|
chain = prompt | openai_client |
|
return chain.invoke({"response1": response1, "response2": response2}).content |
|
|
|
def markdown_to_html(content): |
|
return markdown2.markdown(content) |
|
|
|
def process_query(user_query): |
|
preprocessed_query = preprocess_query(user_query) |
|
|
|
|
|
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"] |
|
|
|
|
|
gemini_resp = gemini_response(preprocessed_query) |
|
|
|
final_response = generate_final_response(rag_response, gemini_resp) |
|
html_content = markdown_to_html(final_response) |
|
|
|
return rag_response, gemini_resp, html_content |
|
|
|
|
|
GDPR_PDF_PATH = "/content/GDPR.pdf" |
|
full_pdf_content = extract_pdf(GDPR_PDF_PATH) |
|
extracted_text = extract_pdf(GDPR_PDF_PATH) |
|
documents = split_text(extracted_text) |
|
vector_store = generate_embeddings(documents) |
|
rag_chain = create_rag_chain() |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_query, |
|
inputs=gr.Textbox(label="Ask your data protection related question"), |
|
outputs=[ |
|
gr.Textbox(label="RAG Pipeline (Llama3.1) Response"), |
|
gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"), |
|
gr.HTML(label="Final (GPT-4o) Response") |
|
], |
|
title="Data Protection Team", |
|
description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions .", |
|
allow_flagging="never" |
|
) |
|
|
|
iface.launch(debug=True) |