Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain.embeddings import Embedding | |
from groq import Groq | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import FAISS | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.llms import OpenAI | |
from langchain.agents import initialize_agent | |
from langchain.agents import Tool | |
# Set up Groq API | |
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
# Define a custom embedding class for Groq | |
class GroqEmbedding(Embedding): | |
def __init__(self, model="groq-embedding-model", api_key=None): | |
self.model = model | |
self.client = Groq(api_key=api_key or os.getenv("GROQ_API_KEY")) | |
def embed_documents(self, texts): | |
# Use Groq's API to generate embeddings | |
embeddings = self.client.embed_documents(texts, model=self.model) | |
return embeddings | |
def embed_query(self, query): | |
# Use Groq's API to generate query embedding | |
return self.client.embed_query(query, model=self.model) | |
# Streamlit App UI | |
st.title("PDF Question-Answering with Groq Embeddings") | |
uploaded_file = st.file_uploader("Upload a PDF", type="pdf") | |
# Process the uploaded PDF | |
if uploaded_file is not None: | |
loader = PyPDFLoader(uploaded_file) | |
documents = loader.load() | |
# Split documents into smaller chunks for better processing | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
split_docs = text_splitter.split_documents(documents) | |
# Create embeddings using Groq | |
embeddings = GroqEmbedding(api_key=os.getenv("GROQ_API_KEY")) | |
# Create a FAISS vector store | |
vector_db = FAISS.from_documents(split_docs, embeddings) | |
# Initialize the retrieval-based QA system | |
qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_db) | |
# User input for querying the PDF content | |
query = st.text_input("Ask a question about the PDF:") | |
if query: | |
result = qa.run(query) | |
st.write("Answer:", result) | |