ChatBOT / chain.py
HarBat's picture
Update chain.py
0fa9783
import json
import torch
import openai
import transformers
from transformers import AutoTokenizer
from langchain.chains import RetrievalQA
from huggingface_hub import login
from langchain import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain import PromptTemplate
with open("credentials.json", "r") as file:
credentials = json.load(file)
access_token_read = credentials["access_token_read"]
openai.api_key = credentials["openai_api_key"]
login(token = access_token_read)
def create_juniper_prompt_template():
template = """You are a network engineer from Juniper Networks not a Language Model, use your knowledge, and the some pieces of context (delimited by <ctx></ctx>) to answer the user's question. \n Try to pretend as if you are a member of Juniper Networks. \nIf you don't know the answer, just say that you don't know, don't try to make up an answer.
Do not indicate that you have access to any context.
Use the chat history (delimited by <hs></hs>) to keep track of the conversation.
\n----------------\n
<ctx>
{context}
</ctx>
\n----------------\n
------
<hs>
{history}
</hs>
------
{question}
Answer:
"""
juniper_prompt_template = PromptTemplate(input_variables=["history", "context", "question"], template=template)
return juniper_prompt_template
def create_question_answering_chain(retriever):
"""
Create a retrieval question answering (QA) chain.
This function initializes a QA chain that can be used to answer questions based on retrieved documents.
It uses the Meta's 'LLaMA-2-chat' model for the language model (LLM), and a document retriever for finding
relevant documents.
Args:
retriever (obj): The document retriever to use for finding relevant documents.
Returns:
qa_chain (obj): The initialized retrieval QA chain.
"""
# Initialize the tokenizer and the language model.
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token=access_token_read)
pipeline = transformers.pipeline(
"text-generation",
model = "meta-llama/Llama-2-7b-chat-hf",
tokenizer = tokenizer,
torch_dtype = torch.bfloat16,
trust_remote_code = True,
device_map = "auto",
max_length = 4096,
do_sample = True,
top_k = 10,
num_return_sequences = 1,
eos_token_id = tokenizer.eos_token_id,
)
hf_llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0})
# Initialize the retrieval QA chain with the language model, chain type, document retriever,
# and a flag indicating whether to return source documents.
qa_chain = RetrievalQA.from_chain_type(
llm=hf_llm,
chain_type='stuff',
retriever=retriever,
verbose=False,
chain_type_kwargs={
"verbose": False,
"prompt": create_juniper_prompt_template(),
"memory": ConversationBufferMemory(
memory_key="history",
input_key="question")
}
)
return qa_chain