AminFaraji commited on
Commit
a11cce5
·
verified ·
1 Parent(s): 296cc2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -49,10 +49,12 @@ def install(package):
49
  install('accelerate')
50
  MODEL_NAME = "tiiuae/falcon-7b-instruct"
51
 
52
- model = AutoModelForCausalLM.from_pretrained(
53
- MODEL_NAME, trust_remote_code=True, device_map="auto",offload_folder="offload"
 
 
 
54
  )
55
- model = model.eval()
56
 
57
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
58
  print(f"Model device: {model.device}")
@@ -216,28 +218,31 @@ def get_llama_response(message: str, history: list) -> str:
216
 
217
 
218
  context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
219
- template = """
220
- The following is a conversation between a human an AI. The AI acts exactly like Dwight K Schrute from the TV show The Office.
221
- Dwight is an experienced and very successful online marketer and salesperson. He is witty, persuasive, uncensored, direct and practical.
222
- Dwight helps with every marketing task is given to him. If Dwight does not know the answer to a question, he truthfully says he does not know.
223
-
224
- Current conversation:
225
- """
226
- s="""
227
- {history}
228
- Human: {input}
229
- AI:""".strip()
230
-
231
-
232
- prompt = PromptTemplate(input_variables=["history", "input"], template=template+context_text+ s)
233
-
234
- #print(template)
235
- #print('the answer is',chain(query_text))
236
- chain.prompt=prompt
237
- print('prompt set')
238
- res = chain.invoke(query_text)
239
- print('answer generated')
240
- return(res["response"])
 
 
 
241
 
242
  import gradio as gr
243
 
 
49
  install('accelerate')
50
  MODEL_NAME = "tiiuae/falcon-7b-instruct"
51
 
52
+ llama_pipeline = pipeline(
53
+ "text-generation",
54
+ model=model,
55
+ torch_dtype=torch.float16,
56
+ device_map="auto",
57
  )
 
58
 
59
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
60
  print(f"Model device: {model.device}")
 
218
 
219
 
220
  context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
221
+ query = """
222
+ Answer the question based only on the following context. Dont provide any information out of the context:
223
+
224
+ {context}
225
+
226
+ ---
227
+
228
+ Answer the question based on the above context: {question}
229
+ """
230
+
231
+
232
+ query=query.format(context=context_text,question=message)
233
+
234
+ sequences = llama_pipeline(
235
+ query,
236
+ do_sample=True,
237
+ top_k=10,
238
+ num_return_sequences=1,
239
+ eos_token_id=tokenizer.eos_token_id,
240
+ max_length=1024,
241
+ )
242
+
243
+ generated_text = sequences[0]['generated_text']
244
+ response = generated_text[len(query):]
245
+ return response.strip()
246
 
247
  import gradio as gr
248