File size: 3,416 Bytes
f401ee6
 
a9e9e50
6a6e7e5
f401ee6
 
90b0406
f401ee6
 
 
 
 
 
e6542f2
 
 
f401ee6
90b0406
a9e9e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f401ee6
a9e9e50
 
 
 
 
 
 
 
f401ee6
 
6a6e7e5
 
f401ee6
 
 
 
 
 
0fa9783
f401ee6
 
 
 
a9e9e50
6a6e7e5
f401ee6
a9e9e50
 
 
 
f401ee6
a9e9e50
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import json
import torch
import openai
import transformers
from transformers import AutoTokenizer
from langchain.chains import RetrievalQA
from huggingface_hub import login
from langchain import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain import PromptTemplate


with open("credentials.json", "r") as file:
    credentials = json.load(file)
    access_token_read = credentials["access_token_read"]
    openai.api_key = credentials["openai_api_key"]

login(token = access_token_read)


def create_juniper_prompt_template():
    template =  """You are a network engineer from Juniper Networks not a Language Model, use your knowledge, and the some pieces of context (delimited by <ctx></ctx>) to answer the user's question. \n Try to pretend as if you are a member of Juniper Networks.  \nIf you don't know the answer, just say that you don't know, don't try to make up an answer. 
                  Do not indicate that you have access to any context.
                  Use the chat history (delimited by <hs></hs>) to keep track of the conversation.
                  \n----------------\n 
                  <ctx>
                  {context}
                  </ctx>
                  \n----------------\n 
                  ------
                  <hs>
                  {history}
                  </hs>
                  ------
                  {question}
                  Answer:
                """

    juniper_prompt_template = PromptTemplate(input_variables=["history", "context", "question"], template=template)
    return juniper_prompt_template



def create_question_answering_chain(retriever):
    """
    Create a retrieval question answering (QA) chain.

    This function initializes a QA chain that can be used to answer questions based on retrieved documents.
    It uses the Meta's 'LLaMA-2-chat' model for the language model (LLM), and a document retriever for finding
    relevant documents.
    
    Args:
        retriever (obj): The document retriever to use for finding relevant documents.

    Returns:
        qa_chain (obj): The initialized retrieval QA chain.
    """
    # Initialize the tokenizer and the language model.
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token=access_token_read)
    
    pipeline = transformers.pipeline(
        "text-generation",
        model = "meta-llama/Llama-2-7b-chat-hf",
        tokenizer = tokenizer,
        torch_dtype = torch.bfloat16,
        trust_remote_code = True,
        device_map = "auto",
        max_length = 4096,
        do_sample = True,
        top_k = 10,
        num_return_sequences = 1,
        eos_token_id = tokenizer.eos_token_id,
    )
    
    hf_llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0})

    # Initialize the retrieval QA chain with the language model, chain type, document retriever,
    # and a flag indicating whether to return source documents.
    qa_chain = RetrievalQA.from_chain_type(
          llm=hf_llm,
          chain_type='stuff',
          retriever=retriever,
          verbose=False,
          chain_type_kwargs={
              "verbose": False,
              "prompt": create_juniper_prompt_template(),
              "memory": ConversationBufferMemory(
                  memory_key="history",
                  input_key="question")
          }
      )
    
    return qa_chain