HarBat commited on
Commit
6a6e7e5
·
1 Parent(s): f5a475f

Update chain.py

Browse files
Files changed (1) hide show
  1. chain.py +23 -5
chain.py CHANGED
@@ -3,6 +3,10 @@ from langchain import PromptTemplate
3
  from langchain.chat_models import ChatOpenAI
4
  from langchain.chains import RetrievalQA
5
  import openai
 
 
 
 
6
 
7
  openai.api_key = "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"
8
 
@@ -44,16 +48,30 @@ def create_question_answering_chain(retriever):
44
  qa_chain (obj): The initialized retrieval QA chain.
45
  """
46
  # Initialize the OpenAI language model with specified temperature, model name, and API key.
47
- turbo_llm = ChatOpenAI(
48
- temperature=0,
49
- model_name='gpt-3.5-turbo',
50
- openai_api_key = openai.api_key
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
 
 
52
 
53
  # Initialize the retrieval QA chain with the language model, chain type, document retriever,
54
  # and a flag indicating whether to return source documents.
55
  qa_chain = RetrievalQA.from_chain_type(
56
- llm=turbo_llm,
57
  chain_type='stuff',
58
  retriever=retriever,
59
  verbose=False,
 
3
  from langchain.chat_models import ChatOpenAI
4
  from langchain.chains import RetrievalQA
5
  import openai
6
+ from langchain import HuggingFacePipeline
7
+ from transformers import AutoTokenizer
8
+ import transformers
9
+ import torch
10
 
11
  openai.api_key = "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"
12
 
 
48
  qa_chain (obj): The initialized retrieval QA chain.
49
  """
50
  # Initialize the OpenAI language model with specified temperature, model name, and API key.
51
+ model = "meta-llama/Llama-2-7b-chat-hf"
52
+
53
+ tokenizer = AutoTokenizer.from_pretrained(model)
54
+
55
+ pipeline = transformers.pipeline(
56
+ "text-generation", #task
57
+ model=model,
58
+ tokenizer=tokenizer,
59
+ torch_dtype=torch.bfloat16,
60
+ trust_remote_code=True,
61
+ device_map="auto",
62
+ max_length=1000,
63
+ do_sample=True,
64
+ top_k=10,
65
+ num_return_sequences=1,
66
+ eos_token_id=tokenizer.eos_token_id
67
  )
68
+
69
+ llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0})
70
 
71
  # Initialize the retrieval QA chain with the language model, chain type, document retriever,
72
  # and a flag indicating whether to return source documents.
73
  qa_chain = RetrievalQA.from_chain_type(
74
+ llm=llm,
75
  chain_type='stuff',
76
  retriever=retriever,
77
  verbose=False,