shukdevdatta123's picture
Update app.py
5393bbb verified
raw
history blame
4.84 kB
import streamlit as st
import openai
import fitz # PyMuPDF
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from io import BytesIO
# Function to extract text from the uploaded PDF file
def extract_pdf_text(pdf_file):
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
text = ""
for page in doc:
text += page.get_text("text")
return text
# Function to get embeddings for the text
def get_embeddings(texts):
response = openai.Embedding.create(
model="text-embedding-ada-002",
input=texts
)
embeddings = [embedding['embedding'] for embedding in response['data']]
return embeddings
# Function to get the most relevant context from the PDF for the query
def get_relevant_context(pdf_text, query, num_contexts=3):
# Split the PDF text into chunks for better matching
pdf_text_chunks = [pdf_text[i:i+1500] for i in range(0, len(pdf_text), 1500)]
# Get embeddings for both the document and the query
pdf_embeddings = get_embeddings(pdf_text_chunks)
query_embedding = get_embeddings([query])[0]
# Compute cosine similarity between query and document chunks
similarities = cosine_similarity([query_embedding], pdf_embeddings)
top_indices = similarities[0].argsort()[-num_contexts:][::-1]
# Combine the top context pieces
relevant_context = " ".join([pdf_text_chunks[i] for i in top_indices])
return relevant_context
# Function to generate a response from GPT-4 chat model
def generate_response(context, question, conversation_history):
messages = conversation_history + [
{"role": "user", "content": f"Context: {context}\nQuestion: {question}"}
]
response = openai.ChatCompletion.create(
model="gpt-4o-mini", # Use the GPT-4 chat model
messages=messages,
max_tokens=1200,
temperature=0.7,
)
answer = response['choices'][0]['message']['content'].strip()
# Append the new answer to the conversation history
conversation_history.append({"role": "assistant", "content": answer})
return answer, conversation_history
# Function to handle irrelevant questions
def is_irrelevant_question(question):
irrelevant_keywords = ["life", "love", "meaning", "future", "philosophy"]
return any(keyword in question.lower() for keyword in irrelevant_keywords)
# Streamlit UI
def main():
st.title("GPT-4 Research Paper Chatbot")
st.write("Ask any question related to the GPT-4 paper, and I'll try to answer it!")
# User input: OpenAI API key
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password")
if openai_api_key:
openai.api_key = openai_api_key
st.success("API Key successfully set!")
# Upload the PDF file
pdf_file = st.file_uploader("Upload GPT-4 Research Paper PDF", type="pdf")
if pdf_file is not None:
# Extract text from the uploaded PDF
pdf_text = extract_pdf_text(pdf_file)
st.write("PDF content loaded successfully!")
# Initialize conversation history (this will persist between interactions)
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = []
# User input: the question they want to ask
question = st.text_input("Ask your question:")
if question:
# Check if the question is irrelevant
if is_irrelevant_question(question):
st.write("Sorry, I don't know the answer to this question. I am an expert on GPT-4 knowledge.")
else:
# Get the most relevant context from the document
relevant_context = get_relevant_context(pdf_text, question)
# Generate the response from GPT-4 chat model
answer, conversation_history = generate_response(relevant_context, question, st.session_state.conversation_history)
# Update the conversation history in session state
st.session_state.conversation_history = conversation_history
# Display the answer
st.write(f"Answer: {answer}")
# End conversation button to reset chat history
if st.button("END CONVERSATION"):
st.session_state.conversation_history = [] # Reset conversation history
st.write("Conversation has been reset. Feel free to ask new questions.")
else:
st.warning("Please upload a PDF file to proceed.")
else:
st.warning("Please enter your OpenAI API Key to use the chatbot.")
if __name__ == "__main__":
main()