|
from langchain.memory import ConversationBufferMemory |
|
from langchain import PromptTemplate |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.chains import RetrievalQA |
|
import openai |
|
from langchain import HuggingFacePipeline |
|
from transformers import AutoTokenizer |
|
import transformers |
|
import torch |
|
|
|
openai.api_key = "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" |
|
|
|
def create_juniper_prompt_template(): |
|
template = """You are a network engineer from Juniper Networks not a Language Model, use your knowledge, and the some pieces of context (delimited by <ctx></ctx>) to answer the user's question. \n Try to pretend as if you are a member of Juniper Networks. \nIf you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
Do not indicate that you have access to any context. |
|
Use the chat history (delimited by <hs></hs>) to keep track of the conversation. |
|
\n----------------\n |
|
<ctx> |
|
{context} |
|
</ctx> |
|
\n----------------\n |
|
------ |
|
<hs> |
|
{history} |
|
</hs> |
|
------ |
|
{question} |
|
Answer: |
|
""" |
|
|
|
juniper_prompt_template = PromptTemplate(input_variables=["history", "context", "question"], template=template) |
|
return juniper_prompt_template |
|
|
|
|
|
|
|
def create_question_answering_chain(retriever): |
|
""" |
|
Create a retrieval question answering (QA) chain. |
|
|
|
This function initializes a QA chain that can be used to answer questions based on retrieved documents. |
|
It uses the OpenAI 'gpt-3.5-turbo' model for the language model (LLM), and a document retriever for finding |
|
relevant documents. |
|
|
|
Args: |
|
retriever (obj): The document retriever to use for finding relevant documents. |
|
|
|
Returns: |
|
qa_chain (obj): The initialized retrieval QA chain. |
|
""" |
|
|
|
model = "meta-llama/Llama-2-7b-chat-hf" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model) |
|
|
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
max_length=1000, |
|
do_sample=True, |
|
top_k=10, |
|
num_return_sequences=1, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0}) |
|
|
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type='stuff', |
|
retriever=retriever, |
|
verbose=False, |
|
chain_type_kwargs={ |
|
"verbose": False, |
|
"prompt": create_juniper_prompt_template(), |
|
"memory": ConversationBufferMemory( |
|
memory_key="history", |
|
input_key="question") |
|
} |
|
) |
|
|
|
|
|
return qa_chain |