File size: 3,416 Bytes
f401ee6 a9e9e50 6a6e7e5 f401ee6 90b0406 f401ee6 e6542f2 f401ee6 90b0406 a9e9e50 f401ee6 a9e9e50 f401ee6 6a6e7e5 f401ee6 0fa9783 f401ee6 a9e9e50 6a6e7e5 f401ee6 a9e9e50 f401ee6 a9e9e50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import json
import torch
import openai
import transformers
from transformers import AutoTokenizer
from langchain.chains import RetrievalQA
from huggingface_hub import login
from langchain import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain import PromptTemplate
with open("credentials.json", "r") as file:
credentials = json.load(file)
access_token_read = credentials["access_token_read"]
openai.api_key = credentials["openai_api_key"]
login(token = access_token_read)
def create_juniper_prompt_template():
template = """You are a network engineer from Juniper Networks not a Language Model, use your knowledge, and the some pieces of context (delimited by <ctx></ctx>) to answer the user's question. \n Try to pretend as if you are a member of Juniper Networks. \nIf you don't know the answer, just say that you don't know, don't try to make up an answer.
Do not indicate that you have access to any context.
Use the chat history (delimited by <hs></hs>) to keep track of the conversation.
\n----------------\n
<ctx>
{context}
</ctx>
\n----------------\n
------
<hs>
{history}
</hs>
------
{question}
Answer:
"""
juniper_prompt_template = PromptTemplate(input_variables=["history", "context", "question"], template=template)
return juniper_prompt_template
def create_question_answering_chain(retriever):
"""
Create a retrieval question answering (QA) chain.
This function initializes a QA chain that can be used to answer questions based on retrieved documents.
It uses the Meta's 'LLaMA-2-chat' model for the language model (LLM), and a document retriever for finding
relevant documents.
Args:
retriever (obj): The document retriever to use for finding relevant documents.
Returns:
qa_chain (obj): The initialized retrieval QA chain.
"""
# Initialize the tokenizer and the language model.
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token=access_token_read)
pipeline = transformers.pipeline(
"text-generation",
model = "meta-llama/Llama-2-7b-chat-hf",
tokenizer = tokenizer,
torch_dtype = torch.bfloat16,
trust_remote_code = True,
device_map = "auto",
max_length = 4096,
do_sample = True,
top_k = 10,
num_return_sequences = 1,
eos_token_id = tokenizer.eos_token_id,
)
hf_llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0})
# Initialize the retrieval QA chain with the language model, chain type, document retriever,
# and a flag indicating whether to return source documents.
qa_chain = RetrievalQA.from_chain_type(
llm=hf_llm,
chain_type='stuff',
retriever=retriever,
verbose=False,
chain_type_kwargs={
"verbose": False,
"prompt": create_juniper_prompt_template(),
"memory": ConversationBufferMemory(
memory_key="history",
input_key="question")
}
)
return qa_chain |