Update app.py
Browse files
app.py
CHANGED
@@ -43,9 +43,11 @@ def load_model(model_name):
|
|
43 |
|
44 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
45 |
tokenizer.pad_token = tokenizer.eos_token
|
46 |
-
|
|
|
47 |
logger.info(f"Model Loading Time : {time.time() - start_time} .")
|
48 |
|
|
|
49 |
return model, tokenizer
|
50 |
|
51 |
|
@@ -76,6 +78,7 @@ def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'):
|
|
76 |
|
77 |
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
|
78 |
logger.info(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .")
|
|
|
79 |
return db
|
80 |
|
81 |
|
@@ -105,6 +108,7 @@ def fetch_context(db, model, model_name, query, template, use_compressor=True):
|
|
105 |
make sure that returned compressed context is relevant to the query.
|
106 |
"""
|
107 |
if use_compressor:
|
|
|
108 |
if model_name=='llama':
|
109 |
compressor = LLMChainExtractor.from_llm(model)
|
110 |
compressor.llm_chain.prompt.template = template['llama_rag_template']
|
@@ -121,7 +125,7 @@ def fetch_context(db, model, model_name, query, template, use_compressor=True):
|
|
121 |
#logger.info(f"User Query : {query}")
|
122 |
compressed_docs = compression_retriever.get_relevant_documents(query)
|
123 |
#logger.info(f"Retrieved Compressed Docs : {compressed_docs}")
|
124 |
-
|
125 |
return compressed_docs
|
126 |
|
127 |
docs = db.max_marginal_relevance_search(query)
|
@@ -145,6 +149,8 @@ def llm_chain_with_context(model, model_name, query, context, template):
|
|
145 |
"""
|
146 |
formated_context = format_context(context)
|
147 |
# Give a precise answer to the question based on the context. Don't be verbose.
|
|
|
|
|
148 |
if model_name=='llama':
|
149 |
prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['llama_prompt_template'])
|
150 |
llm_chain = LLMChain(llm=model, prompt=prompt_template)
|
@@ -152,8 +158,15 @@ def llm_chain_with_context(model, model_name, query, context, template):
|
|
152 |
elif model_name=='mistral':
|
153 |
prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['prompt_template'])
|
154 |
llm_chain = LLMChain(llm=HF_pipeline_model, prompt=prompt_template)
|
|
|
|
|
|
|
|
|
155 |
|
156 |
output = llm_chain.predict(user_query=query, context=formated_context)
|
|
|
|
|
|
|
157 |
return output
|
158 |
|
159 |
|
@@ -170,6 +183,7 @@ def generate_response(query, model, template):
|
|
170 |
my_bar.progress(0.6, "Generating Answer. Please wait.")
|
171 |
response = llm_chain_with_context(model, model_name, query, context, template)
|
172 |
|
|
|
173 |
logger.info(f"Total Execution Time: {time.time() - start_time}")
|
174 |
|
175 |
my_bar.progress(0.9, "Post Processing. Please wait.")
|
|
|
43 |
|
44 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
45 |
tokenizer.pad_token = tokenizer.eos_token
|
46 |
+
|
47 |
+
print(f"Model Loading Time : {time.time() - start_time}."))
|
48 |
logger.info(f"Model Loading Time : {time.time() - start_time} .")
|
49 |
|
50 |
+
|
51 |
return model, tokenizer
|
52 |
|
53 |
|
|
|
78 |
|
79 |
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
|
80 |
logger.info(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .")
|
81 |
+
print(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .")
|
82 |
return db
|
83 |
|
84 |
|
|
|
108 |
make sure that returned compressed context is relevant to the query.
|
109 |
"""
|
110 |
if use_compressor:
|
111 |
+
start_time = time.time()
|
112 |
if model_name=='llama':
|
113 |
compressor = LLMChainExtractor.from_llm(model)
|
114 |
compressor.llm_chain.prompt.template = template['llama_rag_template']
|
|
|
125 |
#logger.info(f"User Query : {query}")
|
126 |
compressed_docs = compression_retriever.get_relevant_documents(query)
|
127 |
#logger.info(f"Retrieved Compressed Docs : {compressed_docs}")
|
128 |
+
print(f"Compressed context Generation Time: {time.time() - start_time}")")
|
129 |
return compressed_docs
|
130 |
|
131 |
docs = db.max_marginal_relevance_search(query)
|
|
|
149 |
"""
|
150 |
formated_context = format_context(context)
|
151 |
# Give a precise answer to the question based on the context. Don't be verbose.
|
152 |
+
start_chain_time = time.time()
|
153 |
+
|
154 |
if model_name=='llama':
|
155 |
prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['llama_prompt_template'])
|
156 |
llm_chain = LLMChain(llm=model, prompt=prompt_template)
|
|
|
158 |
elif model_name=='mistral':
|
159 |
prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['prompt_template'])
|
160 |
llm_chain = LLMChain(llm=HF_pipeline_model, prompt=prompt_template)
|
161 |
+
|
162 |
+
print(f"LLMChain Setup Time: {time.time() - start_chain_time}")
|
163 |
+
|
164 |
+
start_inference_time = time.time()
|
165 |
|
166 |
output = llm_chain.predict(user_query=query, context=formated_context)
|
167 |
+
|
168 |
+
print(f"LLM Inference Time: {time.time() - start_inference_time}")
|
169 |
+
|
170 |
return output
|
171 |
|
172 |
|
|
|
183 |
my_bar.progress(0.6, "Generating Answer. Please wait.")
|
184 |
response = llm_chain_with_context(model, model_name, query, context, template)
|
185 |
|
186 |
+
print(f"Total Execution Time: {time.time() - start_time}")
|
187 |
logger.info(f"Total Execution Time: {time.time() - start_time}")
|
188 |
|
189 |
my_bar.progress(0.9, "Post Processing. Please wait.")
|