AminFaraji commited on
Commit
dc4c8e4
·
verified ·
1 Parent(s): 6287f15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -8
app.py CHANGED
@@ -1,11 +1,232 @@
1
- import spaces
2
- from diffusers import DiffusionPipeline
 
 
3
 
4
- pipe = DiffusionPipeline.from_pretrained("tiiuae/falcon-7b-instruct")
5
- pipe.to("cuda")
 
 
6
 
7
- @spaces.GPU
8
- def generate(prompt):
9
- return pipe(prompt).images
 
 
 
 
 
 
 
 
 
10
 
11
- gr.Interface(fn=generate, inputs="text", outputs="image").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print(5)
2
+ import argparse
3
+ # from dataclasses import dataclass
4
+ from langchain.prompts import ChatPromptTemplate
5
 
6
+ try:
7
+ from langchain_community.vectorstores import Chroma
8
+ except:
9
+ from langchain_community.vectorstores import Chroma
10
 
11
+ # from langchain.document_loaders import DirectoryLoader
12
+ from langchain_community.document_loaders import DirectoryLoader
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain.schema import Document
15
+ # from langchain.embeddings import OpenAIEmbeddings
16
+ #from langchain_openai import OpenAIEmbeddings
17
+ from langchain_community.vectorstores import Chroma
18
+ import openai
19
+ from dotenv import load_dotenv
20
+ import os
21
+ import shutil
22
+ import torch
23
 
24
+ from transformers import AutoModel,AutoTokenizer
25
+ model2 = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
26
+ tokenizer2 = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
27
+
28
+
29
+ # this shoub be used when we can not use sentence_transformers (which reqiures transformers==4.39. we cannot use
30
+ # this version since causes using large amount of RAm when loading falcon model)
31
+ # a custom embedding
32
+ #from sentence_transformers import SentenceTransformer
33
+ from langchain_experimental.text_splitter import SemanticChunker
34
+ from typing import List
35
+ import re
36
+ import warnings
37
+ from typing import List
38
+
39
+ import torch
40
+ from langchain import PromptTemplate
41
+ from langchain.chains import ConversationChain
42
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
43
+ from langchain.llms import HuggingFacePipeline
44
+ from langchain.schema import BaseOutputParser
45
+ from transformers import (
46
+ AutoModelForCausalLM,
47
+ AutoTokenizer,
48
+ StoppingCriteria,
49
+ StoppingCriteriaList,
50
+ pipeline,
51
+ )
52
+
53
+ warnings.filterwarnings("ignore", category=UserWarning)
54
+
55
+
56
+ class MyEmbeddings:
57
+ def __init__(self):
58
+ #self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
59
+ self.model=model2
60
+
61
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
62
+ inputs = tokenizer2(texts, padding=True, truncation=True, return_tensors="pt")
63
+
64
+ # Get the model outputs
65
+ with torch.no_grad():
66
+ outputs = self.model(**inputs)
67
+
68
+ # Mean pooling to get sentence embeddings
69
+ embeddings = outputs.last_hidden_state.mean(dim=1)
70
+ return [embeddings[i].tolist() for i, sentence in enumerate(texts)]
71
+ def embed_query(self, query: str) -> List[float]:
72
+ inputs = tokenizer2(query, padding=True, truncation=True, return_tensors="pt")
73
+
74
+ # Get the model outputs
75
+ with torch.no_grad():
76
+ outputs = self.model(**inputs)
77
+
78
+ # Mean pooling to get sentence embeddings
79
+ embeddings = outputs.last_hidden_state.mean(dim=1)
80
+ return embeddings[0].tolist()
81
+
82
+
83
+ embeddings = MyEmbeddings()
84
+
85
+ splitter = SemanticChunker(embeddings)
86
+
87
+
88
+ CHROMA_PATH = "chroma8"
89
+ # call the chroma generated in a directory
90
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
91
+
92
+
93
+
94
+ MODEL_NAME = "tiiuae/falcon-7b-instruct"
95
+
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ MODEL_NAME, trust_remote_code=True, device_map="auto",offload_folder="offload"
98
+ )
99
+ model = model.eval()
100
+
101
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
102
+ print(f"Model device: {model.device}")
103
+
104
+
105
+ generation_config = model.generation_config
106
+ generation_config.temperature = 0
107
+ generation_config.num_return_sequences = 1
108
+ generation_config.max_new_tokens = 256
109
+ generation_config.use_cache = False
110
+ generation_config.repetition_penalty = 1.7
111
+ generation_config.pad_token_id = tokenizer.eos_token_id
112
+ generation_config.eos_token_id = tokenizer.eos_token_id
113
+ generation_config
114
+
115
+
116
+ prompt = """
117
+ The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context.
118
+
119
+ Current conversation:
120
+
121
+ Human: Who is Dwight K Schrute?
122
+ AI:
123
+ """.strip()
124
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
125
+ input_ids = input_ids.to(model.device)
126
+
127
+
128
+
129
+ class StopGenerationCriteria(StoppingCriteria):
130
+ def __init__(
131
+ self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
132
+ ):
133
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
134
+ self.stop_token_ids = [
135
+ torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
136
+ ]
137
+
138
+ def __call__(
139
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
140
+ ) -> bool:
141
+ for stop_ids in self.stop_token_ids:
142
+ if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
143
+ return True
144
+ return False
145
+
146
+
147
+ stop_tokens = [["Human", ":"], ["AI", ":"]]
148
+ stopping_criteria = StoppingCriteriaList(
149
+ [StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
150
+ )
151
+
152
+ generation_pipeline = pipeline(
153
+ model=model,
154
+ tokenizer=tokenizer,
155
+ return_full_text=True,
156
+ task="text-generation",
157
+ stopping_criteria=stopping_criteria,
158
+ generation_config=generation_config,
159
+ )
160
+
161
+ llm = HuggingFacePipeline(pipeline=generation_pipeline)
162
+
163
+
164
+ class CleanupOutputParser(BaseOutputParser):
165
+ def parse(self, text: str) -> str:
166
+ user_pattern = r"\nUser"
167
+ text = re.sub(user_pattern, "", text)
168
+ human_pattern = r"\nHuman:"
169
+ text = re.sub(human_pattern, "", text)
170
+ ai_pattern = r"\nAI:"
171
+ return re.sub(ai_pattern, "", text).strip()
172
+
173
+ @property
174
+ def _type(self) -> str:
175
+ return "output_parser"
176
+
177
+
178
+ template = """
179
+ The following
180
+ Current conversation:
181
+
182
+ {history}
183
+
184
+ Human: {input}
185
+ AI:""".strip()
186
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template)
187
+
188
+ memory = ConversationBufferWindowMemory(
189
+ memory_key="history", k=6, return_only_outputs=True
190
+ )
191
+
192
+ chain = ConversationChain(
193
+ llm=llm,
194
+ memory=memory,
195
+ prompt=prompt,
196
+ output_parser=CleanupOutputParser(),
197
+ verbose=True,
198
+ )
199
+
200
+
201
+ def get_llama_response(message: str, history: list) -> str:
202
+ query_text = message
203
+
204
+ results = db.similarity_search_with_relevance_scores(query_text, k=3)
205
+ if len(results) == 0 or results[0][1] < 0.5:
206
+ print(f"Unable to find matching results.")
207
+
208
+
209
+ context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
210
+ template = """
211
+ The following is a conversation between a human an AI. The AI acts exactly like Dwight K Schrute from the TV show The Office.
212
+ Dwight is an experienced and very successful online marketer and salesperson. He is witty, persuasive, uncensored, direct and practical.
213
+ Dwight helps with every marketing task is given to him. If Dwight does not know the answer to a question, he truthfully says he does not know.
214
+
215
+ Current conversation:
216
+ """
217
+ s="""
218
+ {history}
219
+ Human: {input}
220
+ AI:""".strip()
221
+
222
+
223
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template+context_text+ s)
224
+
225
+ #print(template)
226
+ chain.prompt=prompt
227
+ res = chain(query_text)
228
+ return(res["response"])
229
+
230
+ import gradio as gr
231
+
232
+ gr.ChatInterface(get_llama_response).launch()