Spaces:
Sleeping
Sleeping
from langchain.docstore.document import Document | |
"""Core Modules s""" | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, NLTKTextSplitter, CharacterTextSplitter | |
from langchain.vectorstores.faiss import FAISS | |
from langchain_community.document_loaders import Docx2txtLoader | |
from langchain import hub | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_community.vectorstores import Chroma | |
import os | |
import gradio as gr | |
import os | |
from typing import List | |
from pydantic import BaseModel | |
from langchain_core.prompts import ChatPromptTemplate | |
from unstructured.partition.pdf import partition_pdf | |
import uuid | |
from langchain.retrievers.multi_vector import MultiVectorRetriever | |
from langchain.storage import InMemoryStore | |
from langchain_community.document_loaders import UnstructuredPDFLoader | |
# The vectorstore to use to index the child chunks | |
vectorstore = Chroma( | |
collection_name="rag_app",embedding_function=OpenAIEmbeddings(api_key=os.environ['OpenAI_APIKEY'])) | |
# The storage layer for the parent documents | |
store = InMemoryStore() | |
id_key = "doc_id" | |
# The retriever (empty to start) | |
retriever = MultiVectorRetriever( | |
vectorstore=vectorstore, | |
docstore=store, | |
id_key=id_key, | |
) | |
def split_text(doc:str, split_mode:str='tiktoken', | |
chunk_size:int=1000, chunk_overlap:int=5, faiss_save_path:str=None, save_faiss:bool=None): | |
# Split by separator and merge by character count | |
if split_mode == "character": | |
# Create a CharacterTextSplitter object | |
text_splitter = CharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
) | |
# Recursively split until below the chunk size limit | |
elif split_mode == "recursive_character": | |
# Create a RecursiveCharacterTextSplitter object | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
) | |
elif split_mode == "nltk": | |
# Create a NLTKTextSplitter object | |
text_splitter = NLTKTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
) | |
elif split_mode == "tiktoken": | |
# Create a CharacterTextSplitter object | |
text_splitter = CharacterTextSplitter.from_tiktoken_encoder( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap,) | |
else: | |
raise ValueError("Please specify the split mode.") | |
documents = text_splitter.split_documents(doc) | |
return documents | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
class Element(BaseModel): | |
type: str | |
text: str | |
def save_documents(Documents): | |
embeddings = OpenAIEmbeddings(openai_api_key=os.environ['OpenAI_APIKEY']) | |
faiss_db = FAISS.from_documents(documents, embeddings) | |
if save_faiss: | |
faiss_db.save_local(faiss_save_path) | |
return faiss_db | |
def save_documents(texts, text_summaries, tables, table_summaries): | |
# Add texts | |
doc_ids = [str(uuid.uuid4()) for _ in texts] | |
summary_texts = [ | |
Document(page_content=s, metadata={id_key: doc_ids[i]}) | |
for i, s in enumerate(text_summaries) | |
] | |
retriever.vectorstore.add_documents(summary_texts) | |
retriever.docstore.mset(list(zip(doc_ids, texts))) | |
# Add tables | |
table_ids = [str(uuid.uuid4()) for _ in tables] | |
summary_tables = [ | |
Document(page_content=s, metadata={id_key: table_ids[i]}) | |
for i, s in enumerate(table_summaries) | |
] | |
retriever.vectorstore.add_documents(summary_tables) | |
retriever.docstore.mset(list(zip(table_ids, tables))) | |
def doc_processing(files: List[bytes]): | |
docs = [] | |
tables = [] | |
for file in files: | |
if file.name.endswith(".pdf"): | |
# Identify file type and process accordingly | |
raw_pdf_elements = partition_pdf( | |
filename=file, | |
extract_images_in_pdf=False, | |
infer_table_structure=True, | |
chunking_strategy="by_title", | |
max_characters=4000, | |
new_after_n_chars=3800, | |
combine_text_under_n_chars=2000, | |
image_output_dir_path='/tmp', # Change this to your desired path | |
) | |
categorized_elements = [] | |
for element in raw_pdf_elements: | |
if "unstructured.documents.elements.Table" in str(type(element)): | |
categorized_elements.append(Element(type="table", text=str(element))) | |
elif "unstructured.documents.elements.CompositeElement" in str(type(element)): | |
categorized_elements.append(Element(type="text", text=str(element))) | |
# Extract text and table elements | |
text_elements = [e for e in categorized_elements if e.type == "text"] | |
table_elements = [e for e in categorized_elements if e.type == "table"] | |
docs.extend(text_elements) | |
tables.extend(table_elements) | |
elif file.name.endswith(".docx"): | |
# Process DOCX file using LangChain Docx2txtLoader | |
loader = Docx2txtLoader(file) | |
data = loader.load() | |
docs.extend(data) | |
# Prompt | |
prompt_text = """You are an assistant tasked with summarizing tables and text. | |
Give a concise summary of the table or text. Table or text chunk: {element} """ | |
prompt = ChatPromptTemplate.from_template(prompt_text) | |
# Summary chain | |
model = ChatOpenAI(temperature=0, model="gpt-4") | |
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser() | |
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5}) | |
text_summaries = summarize_chain.batch(docs, {"max_concurrency": 5}) | |
return docs, tables, text_summaries, table_summaries | |
# Convert the list of document texts to embeddings | |
def wrap_all(files: List[bytes], input_prompt: str): | |
save_documents(doc_processing(files)) | |
# Prompt template | |
template = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. | |
If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. Please cite the text that you are using to base your arguments when it is possible. | |
Question: {question} | |
Context: {context} | |
Answer: | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
# Load the prompt template and the language model | |
#prompt = hub.pull("rlm/rag-prompt") | |
llm = ChatOpenAI(model_name="gpt-4o", openai_api_key=os.environ['OpenAI_APIKEY'], temperature=0) | |
# Create the RAG chain | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
# Invoke the chain with the input prompt | |
return rag_chain.invoke(input_prompt) | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=wrap_all, | |
inputs=[gr.File(type="filepath", label=".docx file of the interview", file_count='multiple'), gr.Textbox(label="Enter your inquiry")], | |
outputs="text", | |
title="Interviews: QA and summarization", | |
description="Upload a .docx file with the interview and enter the question you have or ask for a summarization.") | |
iface.launch() |