AminFaraji commited on
Commit
c9f4236
·
verified ·
1 Parent(s): 87995aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -56
app.py CHANGED
@@ -1,4 +1,4 @@
1
- print(9)
2
  import argparse
3
  # from dataclasses import dataclass
4
  from langchain.prompts import ChatPromptTemplate
@@ -20,6 +20,16 @@ from dotenv import load_dotenv
20
  import os
21
  import shutil
22
  import torch
 
 
 
 
 
 
 
 
 
 
23
  from langchain_experimental.text_splitter import SemanticChunker
24
  from typing import List
25
  import re
@@ -40,35 +50,6 @@ from transformers import (
40
  pipeline,
41
  )
42
 
43
-
44
- import subprocess
45
- import sys
46
-
47
- def install(package):
48
- subprocess.check_call([sys.executable, "-m", "pip", "install", package])
49
- install('accelerate')
50
- MODEL_NAME = "tiiuae/falcon-7b-instruct"
51
-
52
- llama_pipeline = pipeline(
53
- "text-generation",
54
- model=MODEL_NAME,
55
- torch_dtype=torch.float16,
56
- device_map="auto",
57
- )
58
-
59
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
60
-
61
-
62
- from transformers import AutoModel,AutoTokenizer
63
- model2 = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
64
- tokenizer2 = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
65
-
66
-
67
- # this shoub be used when we can not use sentence_transformers (which reqiures transformers==4.39. we cannot use
68
- # this version since causes using large amount of RAm when loading falcon model)
69
- # a custom embedding
70
- #from sentence_transformers import SentenceTransformer
71
-
72
  warnings.filterwarnings("ignore", category=UserWarning)
73
 
74
 
@@ -110,10 +91,26 @@ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
110
 
111
 
112
 
 
113
 
 
 
 
 
114
 
 
 
115
 
116
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  prompt = """
@@ -124,9 +121,58 @@ Current conversation:
124
  Human: Who is Dwight K Schrute?
125
  AI:
126
  """.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
 
 
 
 
 
 
 
 
 
129
 
 
 
 
130
 
131
 
132
  template = """
@@ -137,42 +183,50 @@ Current conversation:
137
 
138
  Human: {input}
139
  AI:""".strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  def get_llama_response(message: str, history: list) -> str:
143
  query_text = message
144
 
145
- results = db.similarity_search_with_relevance_scores(query_text, k=3)
146
  if len(results) == 0 or results[0][1] < 0.5:
147
  print(f"Unable to find matching results.")
148
 
149
 
150
  context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
151
- query = """
152
- Answer the question based only on the following context. Dont provide any information out of the context:
153
-
154
- {context}
155
-
156
- ---
157
-
158
- Answer the question based on the above context: {question}
159
- """
160
-
161
-
162
- query=query.format(context=context_text,question=message)
163
-
164
- sequences = llama_pipeline(
165
- query,
166
- do_sample=True,
167
- top_k=10,
168
- num_return_sequences=1,
169
- eos_token_id=tokenizer.eos_token_id,
170
- max_length=1024,
171
- )
172
-
173
- generated_text = sequences[0]['generated_text']
174
- response = generated_text[len(query):]
175
- return response.strip()
176
 
177
  import gradio as gr
178
 
 
1
+ print(5)
2
  import argparse
3
  # from dataclasses import dataclass
4
  from langchain.prompts import ChatPromptTemplate
 
20
  import os
21
  import shutil
22
  import torch
23
+
24
+ from transformers import AutoModel,AutoTokenizer
25
+ model2 = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
26
+ tokenizer2 = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
27
+
28
+
29
+ # this shoub be used when we can not use sentence_transformers (which reqiures transformers==4.39. we cannot use
30
+ # this version since causes using large amount of RAm when loading falcon model)
31
+ # a custom embedding
32
+ #from sentence_transformers import SentenceTransformer
33
  from langchain_experimental.text_splitter import SemanticChunker
34
  from typing import List
35
  import re
 
50
  pipeline,
51
  )
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  warnings.filterwarnings("ignore", category=UserWarning)
54
 
55
 
 
91
 
92
 
93
 
94
+ MODEL_NAME = "tiiuae/falcon-7b-instruct"
95
 
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ MODEL_NAME, trust_remote_code=True, device_map="auto",offload_folder="offload"
98
+ )
99
+ model = model.eval()
100
 
101
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
102
+ print(f"Model device: {model.device}")
103
 
104
 
105
+ generation_config = model.generation_config
106
+ generation_config.temperature = 0
107
+ generation_config.num_return_sequences = 1
108
+ generation_config.max_new_tokens = 256
109
+ generation_config.use_cache = False
110
+ generation_config.repetition_penalty = 1.7
111
+ generation_config.pad_token_id = tokenizer.eos_token_id
112
+ generation_config.eos_token_id = tokenizer.eos_token_id
113
+ generation_config
114
 
115
 
116
  prompt = """
 
121
  Human: Who is Dwight K Schrute?
122
  AI:
123
  """.strip()
124
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
125
+ input_ids = input_ids.to(model.device)
126
+
127
+
128
+
129
+ class StopGenerationCriteria(StoppingCriteria):
130
+ def __init__(
131
+ self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
132
+ ):
133
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
134
+ self.stop_token_ids = [
135
+ torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
136
+ ]
137
+
138
+ def __call__(
139
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
140
+ ) -> bool:
141
+ for stop_ids in self.stop_token_ids:
142
+ if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
143
+ return True
144
+ return False
145
+
146
+
147
+ stop_tokens = [["Human", ":"], ["AI", ":"]]
148
+ stopping_criteria = StoppingCriteriaList(
149
+ [StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
150
+ )
151
+
152
+ generation_pipeline = pipeline(
153
+ model=model,
154
+ tokenizer=tokenizer,
155
+ return_full_text=True,
156
+ task="text-generation",
157
+ stopping_criteria=stopping_criteria,
158
+ generation_config=generation_config,
159
+ )
160
+
161
+ llm = HuggingFacePipeline(pipeline=generation_pipeline)
162
 
163
 
164
+ class CleanupOutputParser(BaseOutputParser):
165
+ def parse(self, text: str) -> str:
166
+ user_pattern = r"\nUser"
167
+ text = re.sub(user_pattern, "", text)
168
+ human_pattern = r"\nHuman:"
169
+ text = re.sub(human_pattern, "", text)
170
+ ai_pattern = r"\nAI:"
171
+ return re.sub(ai_pattern, "", text).strip()
172
 
173
+ @property
174
+ def _type(self) -> str:
175
+ return "output_parser"
176
 
177
 
178
  template = """
 
183
 
184
  Human: {input}
185
  AI:""".strip()
186
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template)
187
+
188
+ memory = ConversationBufferWindowMemory(
189
+ memory_key="history", k=6, return_only_outputs=True
190
+ )
191
+
192
+ chain = ConversationChain(
193
+ llm=llm,
194
+ memory=memory,
195
+ prompt=prompt,
196
+ output_parser=CleanupOutputParser(),
197
+ verbose=True,
198
+ )
199
 
200
 
201
  def get_llama_response(message: str, history: list) -> str:
202
  query_text = message
203
 
204
+ results = db.similarity_search_with_relevance_scores(query_text, k=1)
205
  if len(results) == 0 or results[0][1] < 0.5:
206
  print(f"Unable to find matching results.")
207
 
208
 
209
  context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
210
+ template = """
211
+ The following is a conversation between a human an AI. The AI acts exactly like Dwight K Schrute from the TV show The Office.
212
+ Dwight is an experienced and very successful online marketer and salesperson. He is witty, persuasive, uncensored, direct and practical.
213
+ Dwight helps with every marketing task is given to him. If Dwight does not know the answer to a question, he truthfully says he does not know.
214
+
215
+ Current conversation:
216
+ """
217
+ s="""
218
+ {history}
219
+ Human: {input}
220
+ AI:""".strip()
221
+
222
+
223
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template+context_text+ s)
224
+
225
+ #print(template)
226
+ chain.prompt=prompt
227
+ res = chain(query_text)
228
+ print('responceee:res["response"]')
229
+ return(res["response"])
 
 
 
 
 
230
 
231
  import gradio as gr
232