tomascufaro's picture
first commit
64fd6d4
raw
history blame
3.16 kB
from langchain.docstore.document import Document
"""Core Modules s"""
from typing import Union, Optional, List, Sequence
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter, NLTKTextSplitter, CharacterTextSplitter
from langchain.vectorstores.faiss import FAISS
from langchain_community.document_loaders import Docx2txtLoader
from langchain import hub
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import gradio as gr
def doc_to_embeddings(doc:Document, split_mode:str='tiktoken',
chunk_size:int=1000, chunk_overlap:int=5, faiss_save_path:str=None, save_faiss:bool=None):
# Load the PDF file (if the file is a URL, load the PDF file from the URL)
# Split by separator and merge by character count
if split_mode == "character":
# Create a CharacterTextSplitter object
text_splitter = CharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
# Recursively split until below the chunk size limit
elif split_mode == "recursive_character":
# Create a RecursiveCharacterTextSplitter object
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
elif split_mode == "nltk":
# Create a NLTKTextSplitter object
text_splitter = NLTKTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
elif split_mode == "tiktoken":
# Create a CharacterTextSplitter object
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,)
else:
raise ValueError("Please specify the split mode.")
documents = text_splitter.split_documents(doc)
embeddings = OpenAIEmbeddings()
faiss_db = FAISS.from_documents(documents, embeddings)
if save_faiss:
faiss_db.save_local(faiss_save_path)
return faiss_db
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
def wrap_all(file, input_prompt:str):
loader = Docx2txtLoader(file.name)
data = loader.load()
db = doc_to_embeddings(data)
retriever = db.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
llm = ChatOpenAI(model_name="gpt-4", temperature=0)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain.invoke(input_prompt)
# Define the Gradio interface
iface = gr.Interface(
fn=wrap_all,
inputs=[gr.File(type="file", label=".docx file of the interview"), gr.Textbox(label="Enter your inquiry")],
outputs="text",
title="Interviews: QA and summarization",
description="Upload a .docx file with the interview and enter the question you have or ask for a summarization.")
iface.launch()