Update chain.py
Browse files
chain.py
CHANGED
@@ -3,6 +3,10 @@ from langchain import PromptTemplate
|
|
3 |
from langchain.chat_models import ChatOpenAI
|
4 |
from langchain.chains import RetrievalQA
|
5 |
import openai
|
|
|
|
|
|
|
|
|
6 |
|
7 |
openai.api_key = "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"
|
8 |
|
@@ -44,16 +48,30 @@ def create_question_answering_chain(retriever):
|
|
44 |
qa_chain (obj): The initialized retrieval QA chain.
|
45 |
"""
|
46 |
# Initialize the OpenAI language model with specified temperature, model name, and API key.
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
)
|
|
|
|
|
52 |
|
53 |
# Initialize the retrieval QA chain with the language model, chain type, document retriever,
|
54 |
# and a flag indicating whether to return source documents.
|
55 |
qa_chain = RetrievalQA.from_chain_type(
|
56 |
-
llm=
|
57 |
chain_type='stuff',
|
58 |
retriever=retriever,
|
59 |
verbose=False,
|
|
|
3 |
from langchain.chat_models import ChatOpenAI
|
4 |
from langchain.chains import RetrievalQA
|
5 |
import openai
|
6 |
+
from langchain import HuggingFacePipeline
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
import transformers
|
9 |
+
import torch
|
10 |
|
11 |
openai.api_key = "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"
|
12 |
|
|
|
48 |
qa_chain (obj): The initialized retrieval QA chain.
|
49 |
"""
|
50 |
# Initialize the OpenAI language model with specified temperature, model name, and API key.
|
51 |
+
model = "meta-llama/Llama-2-7b-chat-hf"
|
52 |
+
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained(model)
|
54 |
+
|
55 |
+
pipeline = transformers.pipeline(
|
56 |
+
"text-generation", #task
|
57 |
+
model=model,
|
58 |
+
tokenizer=tokenizer,
|
59 |
+
torch_dtype=torch.bfloat16,
|
60 |
+
trust_remote_code=True,
|
61 |
+
device_map="auto",
|
62 |
+
max_length=1000,
|
63 |
+
do_sample=True,
|
64 |
+
top_k=10,
|
65 |
+
num_return_sequences=1,
|
66 |
+
eos_token_id=tokenizer.eos_token_id
|
67 |
)
|
68 |
+
|
69 |
+
llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0})
|
70 |
|
71 |
# Initialize the retrieval QA chain with the language model, chain type, document retriever,
|
72 |
# and a flag indicating whether to return source documents.
|
73 |
qa_chain = RetrievalQA.from_chain_type(
|
74 |
+
llm=llm,
|
75 |
chain_type='stuff',
|
76 |
retriever=retriever,
|
77 |
verbose=False,
|