Joshua Sundance Bailey
add azure embeddings; cleanup
c603886
raw
history blame
5.08 kB
from tempfile import NamedTemporaryFile
from typing import Tuple, List, Optional, Dict
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import RetrievalQA, LLMChain
from langchain.chat_models import (
AzureChatOpenAI,
ChatOpenAI,
ChatAnthropic,
ChatAnyscale,
)
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.schema import Document, BaseRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
from qagen import get_rag_qa_gen_chain
from summarize import get_rag_summarization_chain
def get_runnable(
use_document_chat: bool,
document_chat_chain_type: str,
llm,
retriever,
memory,
chat_prompt,
summarization_prompt,
):
if not use_document_chat:
return LLMChain(
prompt=chat_prompt,
llm=llm,
memory=memory,
) | (lambda output: output["text"])
if document_chat_chain_type == "Q&A Generation":
return get_rag_qa_gen_chain(
retriever,
llm,
)
elif document_chat_chain_type == "Summarization":
return get_rag_summarization_chain(
summarization_prompt,
retriever,
llm,
)
else:
return RetrievalQA.from_chain_type(
llm=llm,
chain_type=document_chat_chain_type,
retriever=retriever,
memory=memory,
output_key="output_text",
) | (lambda output: output["output_text"])
def get_llm(
provider: str,
model: str,
provider_api_key: str,
temperature: float,
max_tokens: int,
azure_available: bool,
azure_dict: dict[str, str],
):
if azure_available and provider == "Azure OpenAI":
return AzureChatOpenAI(
openai_api_base=azure_dict["AZURE_OPENAI_BASE_URL"],
openai_api_version=azure_dict["AZURE_OPENAI_API_VERSION"],
deployment_name=azure_dict["AZURE_OPENAI_DEPLOYMENT_NAME"],
openai_api_key=azure_dict["AZURE_OPENAI_API_KEY"],
openai_api_type="azure",
model_version=azure_dict["AZURE_OPENAI_MODEL_VERSION"],
temperature=temperature,
streaming=True,
max_tokens=max_tokens,
)
elif provider_api_key:
if provider == "OpenAI":
return ChatOpenAI(
model_name=model,
openai_api_key=provider_api_key,
temperature=temperature,
streaming=True,
max_tokens=max_tokens,
)
elif provider == "Anthropic":
return ChatAnthropic(
model=model,
anthropic_api_key=provider_api_key,
temperature=temperature,
streaming=True,
max_tokens_to_sample=max_tokens,
)
elif provider == "Anyscale Endpoints":
return ChatAnyscale(
model_name=model,
anyscale_api_key=provider_api_key,
temperature=temperature,
streaming=True,
max_tokens=max_tokens,
)
return None
def get_texts_and_retriever(
uploaded_file_bytes: bytes,
openai_api_key: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
k: int = DEFAULT_RETRIEVER_K,
azure_kwargs: Optional[Dict[str, str]] = None,
use_azure: bool = False,
) -> Tuple[List[Document], BaseRetriever]:
with NamedTemporaryFile() as temp_file:
temp_file.write(uploaded_file_bytes)
temp_file.seek(0)
loader = PyPDFLoader(temp_file.name)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
texts = text_splitter.split_documents(documents)
embeddings_kwargs = {"openai_api_key": openai_api_key}
if use_azure and azure_kwargs:
embeddings_kwargs.update(azure_kwargs)
embeddings = OpenAIEmbeddings(**embeddings_kwargs)
bm25_retriever = BM25Retriever.from_documents(texts)
bm25_retriever.k = k
faiss_vectorstore = FAISS.from_documents(texts, embeddings)
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever],
weights=[0.5, 0.5],
)
return texts, ensemble_retriever
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)