|
from tempfile import NamedTemporaryFile |
|
from typing import Tuple, List, Optional, Dict |
|
|
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain.chains import RetrievalQA, LLMChain |
|
from langchain.chat_models import ( |
|
AzureChatOpenAI, |
|
ChatOpenAI, |
|
ChatAnthropic, |
|
ChatAnyscale, |
|
) |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.retrievers import BM25Retriever, EnsembleRetriever |
|
from langchain.schema import Document, BaseRetriever |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import FAISS |
|
|
|
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K |
|
from qagen import get_rag_qa_gen_chain |
|
from summarize import get_rag_summarization_chain |
|
|
|
|
|
def get_runnable( |
|
use_document_chat: bool, |
|
document_chat_chain_type: str, |
|
llm, |
|
retriever, |
|
memory, |
|
chat_prompt, |
|
summarization_prompt, |
|
): |
|
if not use_document_chat: |
|
return LLMChain( |
|
prompt=chat_prompt, |
|
llm=llm, |
|
memory=memory, |
|
) | (lambda output: output["text"]) |
|
|
|
if document_chat_chain_type == "Q&A Generation": |
|
return get_rag_qa_gen_chain( |
|
retriever, |
|
llm, |
|
) |
|
elif document_chat_chain_type == "Summarization": |
|
return get_rag_summarization_chain( |
|
summarization_prompt, |
|
retriever, |
|
llm, |
|
) |
|
else: |
|
return RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type=document_chat_chain_type, |
|
retriever=retriever, |
|
memory=memory, |
|
output_key="output_text", |
|
) | (lambda output: output["output_text"]) |
|
|
|
|
|
def get_llm( |
|
provider: str, |
|
model: str, |
|
provider_api_key: str, |
|
temperature: float, |
|
max_tokens: int, |
|
azure_available: bool, |
|
azure_dict: dict[str, str], |
|
): |
|
if azure_available and provider == "Azure OpenAI": |
|
return AzureChatOpenAI( |
|
openai_api_base=azure_dict["AZURE_OPENAI_BASE_URL"], |
|
openai_api_version=azure_dict["AZURE_OPENAI_API_VERSION"], |
|
deployment_name=azure_dict["AZURE_OPENAI_DEPLOYMENT_NAME"], |
|
openai_api_key=azure_dict["AZURE_OPENAI_API_KEY"], |
|
openai_api_type="azure", |
|
model_version=azure_dict["AZURE_OPENAI_MODEL_VERSION"], |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens=max_tokens, |
|
) |
|
|
|
elif provider_api_key: |
|
if provider == "OpenAI": |
|
return ChatOpenAI( |
|
model_name=model, |
|
openai_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens=max_tokens, |
|
) |
|
|
|
elif provider == "Anthropic": |
|
return ChatAnthropic( |
|
model=model, |
|
anthropic_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens_to_sample=max_tokens, |
|
) |
|
|
|
elif provider == "Anyscale Endpoints": |
|
return ChatAnyscale( |
|
model_name=model, |
|
anyscale_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens=max_tokens, |
|
) |
|
|
|
return None |
|
|
|
|
|
def get_texts_and_retriever( |
|
uploaded_file_bytes: bytes, |
|
openai_api_key: str, |
|
chunk_size: int = DEFAULT_CHUNK_SIZE, |
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, |
|
k: int = DEFAULT_RETRIEVER_K, |
|
azure_kwargs: Optional[Dict[str, str]] = None, |
|
use_azure: bool = False, |
|
) -> Tuple[List[Document], BaseRetriever]: |
|
with NamedTemporaryFile() as temp_file: |
|
temp_file.write(uploaded_file_bytes) |
|
temp_file.seek(0) |
|
|
|
loader = PyPDFLoader(temp_file.name) |
|
documents = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
) |
|
texts = text_splitter.split_documents(documents) |
|
embeddings_kwargs = {"openai_api_key": openai_api_key} |
|
if use_azure and azure_kwargs: |
|
embeddings_kwargs.update(azure_kwargs) |
|
embeddings = OpenAIEmbeddings(**embeddings_kwargs) |
|
|
|
bm25_retriever = BM25Retriever.from_documents(texts) |
|
bm25_retriever.k = k |
|
|
|
faiss_vectorstore = FAISS.from_documents(texts, embeddings) |
|
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k}) |
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[bm25_retriever, faiss_retriever], |
|
weights=[0.5, 0.5], |
|
) |
|
|
|
return texts, ensemble_retriever |
|
|
|
|
|
class StreamHandler(BaseCallbackHandler): |
|
def __init__(self, container, initial_text=""): |
|
self.container = container |
|
self.text = initial_text |
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None: |
|
self.text += token |
|
self.container.markdown(self.text) |
|
|