|
import json |
|
import torch |
|
import openai |
|
import transformers |
|
from transformers import AutoTokenizer |
|
from langchain.chains import RetrievalQA |
|
from huggingface_hub import login |
|
from langchain import HuggingFacePipeline |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain import PromptTemplate |
|
|
|
|
|
with open("credentials.json", "r") as file: |
|
credentials = json.load(file) |
|
access_token_read = credentials["access_token_read"] |
|
openai.api_key = credentials["openai_api_key"] |
|
|
|
login(token = access_token_read) |
|
|
|
|
|
def create_juniper_prompt_template(): |
|
template = """You are a network engineer from Juniper Networks not a Language Model, use your knowledge, and the some pieces of context (delimited by <ctx></ctx>) to answer the user's question. \n Try to pretend as if you are a member of Juniper Networks. \nIf you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
Do not indicate that you have access to any context. |
|
Use the chat history (delimited by <hs></hs>) to keep track of the conversation. |
|
\n----------------\n |
|
<ctx> |
|
{context} |
|
</ctx> |
|
\n----------------\n |
|
------ |
|
<hs> |
|
{history} |
|
</hs> |
|
------ |
|
{question} |
|
Answer: |
|
""" |
|
|
|
juniper_prompt_template = PromptTemplate(input_variables=["history", "context", "question"], template=template) |
|
return juniper_prompt_template |
|
|
|
|
|
|
|
def create_question_answering_chain(retriever): |
|
""" |
|
Create a retrieval question answering (QA) chain. |
|
|
|
This function initializes a QA chain that can be used to answer questions based on retrieved documents. |
|
It uses the Meta's 'LLaMA-2-chat' model for the language model (LLM), and a document retriever for finding |
|
relevant documents. |
|
|
|
Args: |
|
retriever (obj): The document retriever to use for finding relevant documents. |
|
|
|
Returns: |
|
qa_chain (obj): The initialized retrieval QA chain. |
|
""" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token=access_token_read) |
|
|
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model = "meta-llama/Llama-2-7b-chat-hf", |
|
tokenizer = tokenizer, |
|
torch_dtype = torch.bfloat16, |
|
trust_remote_code = True, |
|
device_map = "auto", |
|
max_length = 4096, |
|
do_sample = True, |
|
top_k = 10, |
|
num_return_sequences = 1, |
|
eos_token_id = tokenizer.eos_token_id, |
|
) |
|
|
|
hf_llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0}) |
|
|
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=hf_llm, |
|
chain_type='stuff', |
|
retriever=retriever, |
|
verbose=False, |
|
chain_type_kwargs={ |
|
"verbose": False, |
|
"prompt": create_juniper_prompt_template(), |
|
"memory": ConversationBufferMemory( |
|
memory_key="history", |
|
input_key="question") |
|
} |
|
) |
|
|
|
return qa_chain |