hmrizal commited on
Commit
7fb09f2
·
verified ·
1 Parent(s): c1c3142

update change_model, process_file, create_llm_pipeline, explicit button to change model

Browse files
Files changed (1) hide show
  1. app.py +102 -59
app.py CHANGED
@@ -113,35 +113,45 @@ def initialize_model_once(model_key):
113
 
114
  def create_llm_pipeline(model_key):
115
  """Create a new pipeline using the specified model"""
116
- tokenizer, model, is_t5 = initialize_model_once(model_key)
117
-
118
- # Create appropriate pipeline based on model type
119
- if is_t5:
120
- pipe = pipeline(
121
- "text2text-generation",
122
- model=model,
123
- tokenizer=tokenizer,
124
- max_new_tokens=256,
125
- temperature=0.3,
126
- top_p=0.9,
127
- return_full_text=False,
128
- )
129
- else:
130
- pipe = pipeline(
131
- "text-generation",
132
- model=model,
133
- tokenizer=tokenizer,
134
- max_new_tokens=256,
135
- temperature=0.3,
136
- top_p=0.9,
137
- top_k=30,
138
- repetition_penalty=1.2,
139
- return_full_text=False,
140
- )
141
-
142
- # Wrap pipeline in HuggingFacePipeline for LangChain compatibility
143
- return HuggingFacePipeline(pipeline=pipe)
144
-
 
 
 
 
 
 
 
 
 
 
145
  def create_conversational_chain(db, file_path, model_key):
146
  llm = create_llm_pipeline(model_key)
147
 
@@ -281,14 +291,16 @@ class ChatBot:
281
  def process_file(self, file, model_key=None):
282
  if model_key:
283
  self.model_key = model_key
284
-
285
  if file is None:
286
  return "Mohon upload file CSV terlebih dahulu."
287
 
288
  try:
 
289
  # Handle file from Gradio
290
  file_path = file.name if hasattr(file, 'name') else str(file)
291
  self.csv_file_path = file_path
 
292
 
293
  # Copy to user directory
294
  user_file_path = f"{self.user_dir}/uploaded.csv"
@@ -301,22 +313,25 @@ class ChatBot:
301
  # Save a copy in user directory
302
  df.to_csv(user_file_path, index=False)
303
  self.csv_file_path = user_file_path
 
304
  except Exception as e:
 
305
  return f"Error membaca CSV: {str(e)}"
306
 
307
  # Load document with reduced chunk size for better memory usage
308
  try:
309
- loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={
310
  'delimiter': ','})
311
  data = loader.load()
312
  print(f"Documents loaded: {len(data)}")
313
  except Exception as e:
 
314
  return f"Error loading documents: {str(e)}"
315
 
316
  # Create vector database with optimized settings
317
  try:
318
  db_path = f"{self.user_dir}/db_faiss"
319
-
320
  # Use CPU-friendly embeddings with smaller dimensions
321
  embeddings = HuggingFaceEmbeddings(
322
  model_name='sentence-transformers/all-MiniLM-L6-v2',
@@ -327,13 +342,18 @@ class ChatBot:
327
  db.save_local(db_path)
328
  print(f"Vector database created at {db_path}")
329
  except Exception as e:
 
330
  return f"Error creating vector database: {str(e)}"
331
 
332
  # Create custom chain
333
  try:
 
334
  self.chain = create_conversational_chain(db, self.csv_file_path, self.model_key)
335
- print(f"Chain created successfully using model: {self.model_key}")
336
  except Exception as e:
 
 
 
337
  return f"Error creating chain: {str(e)}"
338
 
339
  # Add basic file info to chat history for context
@@ -348,32 +368,54 @@ class ChatBot:
348
 
349
  def change_model(self, model_key):
350
  """Change the model being used and recreate the chain if necessary"""
351
- if model_key == self.model_key:
352
- return f"Model {model_key} sudah digunakan."
353
-
354
- self.model_key = model_key
355
-
356
- # If we have an active session with a file already loaded, recreate the chain
357
- if self.csv_file_path:
358
- try:
359
- # Load existing database
360
- db_path = f"{self.user_dir}/db_faiss"
361
- embeddings = HuggingFaceEmbeddings(
362
- model_name='sentence-transformers/all-MiniLM-L6-v2',
363
- model_kwargs={'device': 'cpu'}
364
- )
365
-
366
- # Tambahkan flag allow_dangerous_deserialization=True
367
- db = FAISS.load_local(db_path, embeddings, allow_dangerous_deserialization=True)
368
-
369
- # Create new chain with the selected model
370
- self.chain = create_conversational_chain(db, self.csv_file_path, self.model_key)
371
 
372
- return f"Model berhasil diubah ke {model_key}."
373
- except Exception as e:
374
- return f"Error mengubah model: {str(e)}"
375
- else:
376
- return f"Model diubah ke {model_key}. Silakan upload file CSV untuk memulai."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  def chat(self, message, history):
379
  if self.chain is None:
@@ -430,6 +472,7 @@ def create_gradio_interface():
430
  model_info = gr.Markdown(
431
  value=f"**{default_model}**: {MODEL_CONFIG[default_model]['description']}"
432
  )
 
433
 
434
  # Process button AFTER the accordion
435
  process_button = gr.Button("Proses CSV")
@@ -478,7 +521,7 @@ def create_gradio_interface():
478
  result = chatbot.change_model(model_key)
479
  return chatbot, chatbot.chat_history + [(None, result)]
480
 
481
- model_dropdown.change(
482
  fn=handle_model_change,
483
  inputs=[model_dropdown, chatbot_state, session_id],
484
  outputs=[chatbot_state, chatbot_interface]
 
113
 
114
  def create_llm_pipeline(model_key):
115
  """Create a new pipeline using the specified model"""
116
+ try:
117
+ print(f"Creating pipeline for model: {model_key}")
118
+ tokenizer, model, is_t5 = initialize_model_once(model_key)
119
+
120
+ # Create appropriate pipeline based on model type
121
+ if is_t5:
122
+ print("Creating T5 pipeline")
123
+ pipe = pipeline(
124
+ "text2text-generation",
125
+ model=model,
126
+ tokenizer=tokenizer,
127
+ max_new_tokens=256,
128
+ temperature=0.3,
129
+ top_p=0.9,
130
+ return_full_text=False,
131
+ )
132
+ else:
133
+ print("Creating causal LM pipeline")
134
+ pipe = pipeline(
135
+ "text-generation",
136
+ model=model,
137
+ tokenizer=tokenizer,
138
+ max_new_tokens=256,
139
+ temperature=0.3,
140
+ top_p=0.9,
141
+ top_k=30,
142
+ repetition_penalty=1.2,
143
+ return_full_text=False,
144
+ )
145
+
146
+ print("Pipeline created successfully")
147
+ # Wrap pipeline in HuggingFacePipeline for LangChain compatibility
148
+ return HuggingFacePipeline(pipeline=pipe)
149
+ except Exception as e:
150
+ import traceback
151
+ print(f"Error creating pipeline: {str(e)}")
152
+ print(traceback.format_exc())
153
+ raise
154
+
155
  def create_conversational_chain(db, file_path, model_key):
156
  llm = create_llm_pipeline(model_key)
157
 
 
291
  def process_file(self, file, model_key=None):
292
  if model_key:
293
  self.model_key = model_key
294
+
295
  if file is None:
296
  return "Mohon upload file CSV terlebih dahulu."
297
 
298
  try:
299
+ print(f"Processing file using model: {self.model_key}")
300
  # Handle file from Gradio
301
  file_path = file.name if hasattr(file, 'name') else str(file)
302
  self.csv_file_path = file_path
303
+ print(f"CSV file path: {file_path}")
304
 
305
  # Copy to user directory
306
  user_file_path = f"{self.user_dir}/uploaded.csv"
 
313
  # Save a copy in user directory
314
  df.to_csv(user_file_path, index=False)
315
  self.csv_file_path = user_file_path
316
+ print(f"CSV saved to {user_file_path}")
317
  except Exception as e:
318
+ print(f"Error reading CSV: {str(e)}")
319
  return f"Error membaca CSV: {str(e)}"
320
 
321
  # Load document with reduced chunk size for better memory usage
322
  try:
323
+ loader = CSVLoader(file_path=user_file_path, encoding="utf-8", csv_args={
324
  'delimiter': ','})
325
  data = loader.load()
326
  print(f"Documents loaded: {len(data)}")
327
  except Exception as e:
328
+ print(f"Error loading documents: {str(e)}")
329
  return f"Error loading documents: {str(e)}"
330
 
331
  # Create vector database with optimized settings
332
  try:
333
  db_path = f"{self.user_dir}/db_faiss"
334
+
335
  # Use CPU-friendly embeddings with smaller dimensions
336
  embeddings = HuggingFaceEmbeddings(
337
  model_name='sentence-transformers/all-MiniLM-L6-v2',
 
342
  db.save_local(db_path)
343
  print(f"Vector database created at {db_path}")
344
  except Exception as e:
345
+ print(f"Error creating vector database: {str(e)}")
346
  return f"Error creating vector database: {str(e)}"
347
 
348
  # Create custom chain
349
  try:
350
+ print(f"Creating conversation chain with model: {self.model_key}")
351
  self.chain = create_conversational_chain(db, self.csv_file_path, self.model_key)
352
+ print("Chain created successfully")
353
  except Exception as e:
354
+ import traceback
355
+ print(f"Error creating chain: {str(e)}")
356
+ print(traceback.format_exc())
357
  return f"Error creating chain: {str(e)}"
358
 
359
  # Add basic file info to chat history for context
 
368
 
369
  def change_model(self, model_key):
370
  """Change the model being used and recreate the chain if necessary"""
371
+ try:
372
+ if model_key == self.model_key:
373
+ return f"Model {model_key} sudah digunakan."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ print(f"Changing model from {self.model_key} to {model_key}")
376
+ self.model_key = model_key
377
+
378
+ # If we have an active session with a file already loaded, recreate the chain
379
+ if self.csv_file_path and os.path.exists(self.csv_file_path):
380
+ try:
381
+ # Load existing database
382
+ db_path = f"{self.user_dir}/db_faiss"
383
+ if not os.path.exists(db_path):
384
+ return f"Error: Database tidak ditemukan. Silakan upload file CSV kembali."
385
+
386
+ print(f"Loading embeddings from {db_path}")
387
+ embeddings = HuggingFaceEmbeddings(
388
+ model_name='sentence-transformers/all-MiniLM-L6-v2',
389
+ model_kwargs={'device': 'cpu'}
390
+ )
391
+
392
+ # Tambahkan flag allow_dangerous_deserialization=True
393
+ db = FAISS.load_local(db_path, embeddings, allow_dangerous_deserialization=True)
394
+ print(f"FAISS database loaded successfully")
395
+
396
+ # Create new chain with the selected model
397
+ print(f"Creating new conversation chain with {model_key}")
398
+ self.chain = create_conversational_chain(db, self.csv_file_path, self.model_key)
399
+ print(f"Chain created successfully")
400
+
401
+ # Add notification to chat history
402
+ self.chat_history.append(("System", f"Model berhasil diubah ke {model_key}."))
403
+
404
+ return f"Model berhasil diubah ke {model_key}."
405
+ except Exception as e:
406
+ import traceback
407
+ error_trace = traceback.format_exc()
408
+ print(f"Detailed error in change_model: {error_trace}")
409
+ return f"Error mengubah model: {str(e)}"
410
+ else:
411
+ # Just update the model key if no file is loaded yet
412
+ print(f"No CSV file loaded yet, just updating model preference to {model_key}")
413
+ return f"Model diubah ke {model_key}. Silakan upload file CSV untuk memulai."
414
+ except Exception as e:
415
+ import traceback
416
+ error_trace = traceback.format_exc()
417
+ print(f"Unexpected error in change_model: {error_trace}")
418
+ return f"Error tidak terduga saat mengubah model: {str(e)}"
419
 
420
  def chat(self, message, history):
421
  if self.chain is None:
 
472
  model_info = gr.Markdown(
473
  value=f"**{default_model}**: {MODEL_CONFIG[default_model]['description']}"
474
  )
475
+ change_model_button = gr.Button("Terapkan Perubahan Model")
476
 
477
  # Process button AFTER the accordion
478
  process_button = gr.Button("Proses CSV")
 
521
  result = chatbot.change_model(model_key)
522
  return chatbot, chatbot.chat_history + [(None, result)]
523
 
524
+ change_model_button.click(
525
  fn=handle_model_change,
526
  inputs=[model_dropdown, chatbot_state, session_id],
527
  outputs=[chatbot_state, chatbot_interface]