Manel commited on
Commit
9c39b83
·
verified ·
1 Parent(s): 24e6bd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -43,9 +43,11 @@ def load_model(model_name):
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
45
  tokenizer.pad_token = tokenizer.eos_token
46
-
 
47
  logger.info(f"Model Loading Time : {time.time() - start_time} .")
48
 
 
49
  return model, tokenizer
50
 
51
 
@@ -76,6 +78,7 @@ def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'):
76
 
77
  db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
78
  logger.info(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .")
 
79
  return db
80
 
81
 
@@ -105,6 +108,7 @@ def fetch_context(db, model, model_name, query, template, use_compressor=True):
105
  make sure that returned compressed context is relevant to the query.
106
  """
107
  if use_compressor:
 
108
  if model_name=='llama':
109
  compressor = LLMChainExtractor.from_llm(model)
110
  compressor.llm_chain.prompt.template = template['llama_rag_template']
@@ -121,7 +125,7 @@ def fetch_context(db, model, model_name, query, template, use_compressor=True):
121
  #logger.info(f"User Query : {query}")
122
  compressed_docs = compression_retriever.get_relevant_documents(query)
123
  #logger.info(f"Retrieved Compressed Docs : {compressed_docs}")
124
-
125
  return compressed_docs
126
 
127
  docs = db.max_marginal_relevance_search(query)
@@ -145,6 +149,8 @@ def llm_chain_with_context(model, model_name, query, context, template):
145
  """
146
  formated_context = format_context(context)
147
  # Give a precise answer to the question based on the context. Don't be verbose.
 
 
148
  if model_name=='llama':
149
  prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['llama_prompt_template'])
150
  llm_chain = LLMChain(llm=model, prompt=prompt_template)
@@ -152,8 +158,15 @@ def llm_chain_with_context(model, model_name, query, context, template):
152
  elif model_name=='mistral':
153
  prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['prompt_template'])
154
  llm_chain = LLMChain(llm=HF_pipeline_model, prompt=prompt_template)
 
 
 
 
155
 
156
  output = llm_chain.predict(user_query=query, context=formated_context)
 
 
 
157
  return output
158
 
159
 
@@ -170,6 +183,7 @@ def generate_response(query, model, template):
170
  my_bar.progress(0.6, "Generating Answer. Please wait.")
171
  response = llm_chain_with_context(model, model_name, query, context, template)
172
 
 
173
  logger.info(f"Total Execution Time: {time.time() - start_time}")
174
 
175
  my_bar.progress(0.9, "Post Processing. Please wait.")
 
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
45
  tokenizer.pad_token = tokenizer.eos_token
46
+
47
+ print(f"Model Loading Time : {time.time() - start_time}."))
48
  logger.info(f"Model Loading Time : {time.time() - start_time} .")
49
 
50
+
51
  return model, tokenizer
52
 
53
 
 
78
 
79
  db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
80
  logger.info(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .")
81
+ print(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .")
82
  return db
83
 
84
 
 
108
  make sure that returned compressed context is relevant to the query.
109
  """
110
  if use_compressor:
111
+ start_time = time.time()
112
  if model_name=='llama':
113
  compressor = LLMChainExtractor.from_llm(model)
114
  compressor.llm_chain.prompt.template = template['llama_rag_template']
 
125
  #logger.info(f"User Query : {query}")
126
  compressed_docs = compression_retriever.get_relevant_documents(query)
127
  #logger.info(f"Retrieved Compressed Docs : {compressed_docs}")
128
+ print(f"Compressed context Generation Time: {time.time() - start_time}")")
129
  return compressed_docs
130
 
131
  docs = db.max_marginal_relevance_search(query)
 
149
  """
150
  formated_context = format_context(context)
151
  # Give a precise answer to the question based on the context. Don't be verbose.
152
+ start_chain_time = time.time()
153
+
154
  if model_name=='llama':
155
  prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['llama_prompt_template'])
156
  llm_chain = LLMChain(llm=model, prompt=prompt_template)
 
158
  elif model_name=='mistral':
159
  prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['prompt_template'])
160
  llm_chain = LLMChain(llm=HF_pipeline_model, prompt=prompt_template)
161
+
162
+ print(f"LLMChain Setup Time: {time.time() - start_chain_time}")
163
+
164
+ start_inference_time = time.time()
165
 
166
  output = llm_chain.predict(user_query=query, context=formated_context)
167
+
168
+ print(f"LLM Inference Time: {time.time() - start_inference_time}")
169
+
170
  return output
171
 
172
 
 
183
  my_bar.progress(0.6, "Generating Answer. Please wait.")
184
  response = llm_chain_with_context(model, model_name, query, context, template)
185
 
186
+ print(f"Total Execution Time: {time.time() - start_time}")
187
  logger.info(f"Total Execution Time: {time.time() - start_time}")
188
 
189
  my_bar.progress(0.9, "Post Processing. Please wait.")