|
|
|
import streamlit as st |
|
import time |
|
import torch |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.llms import HuggingFacePipeline |
|
from langchain.chains import RetrievalQA |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline |
|
import os |
|
|
|
|
|
st.set_page_config(page_title="Bot Soal Jawab BM", page_icon="π²πΎ", layout="centered") |
|
|
|
|
|
INDEX_SAVE_PATH = "faiss_malay_ecommerce_kb_index" |
|
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" |
|
LLM_CHECKPOINT = "google/mt5-small" |
|
ASSISTANT_AVATAR = "π€" |
|
USER_AVATAR = "π€" |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_embeddings_model(): |
|
"""Loads the Sentence Transformer embedding model.""" |
|
print(">> (Cache) Loading embedding model...") |
|
try: |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
embed_model = HuggingFaceEmbeddings( |
|
model_name=EMBEDDING_MODEL_NAME, |
|
model_kwargs={'device': device} |
|
) |
|
print(f">> Embedding model loaded on {device}.") |
|
return embed_model |
|
except Exception as e: |
|
st.error(f"Error loading embedding model: {e}") |
|
st.stop() |
|
|
|
@st.cache_resource |
|
def load_faiss_index(_embeddings): |
|
"""Loads the FAISS index from local path.""" |
|
print(f">> (Cache) Loading FAISS index from: {INDEX_SAVE_PATH}...") |
|
if not _embeddings: |
|
st.error("Cannot load FAISS index without embedding model.") |
|
return None |
|
if not os.path.exists(INDEX_SAVE_PATH): |
|
st.error(f"FAISS index not found at {INDEX_SAVE_PATH}. Pastikan ia wujud hasil dari Notebook Level 2.") |
|
return None |
|
try: |
|
vector_store = FAISS.load_local( |
|
INDEX_SAVE_PATH, |
|
_embeddings, |
|
allow_dangerous_deserialization=True |
|
) |
|
print(f">> FAISS index loaded. Contains {vector_store.index.ntotal} vectors.") |
|
return vector_store |
|
except Exception as e: |
|
st.error(f"Error loading FAISS index: {e}") |
|
return None |
|
|
|
@st.cache_resource |
|
def load_llm_qa_pipeline(): |
|
"""Loads the LLM pipeline for generation.""" |
|
print(f">> (Cache) Loading LLM pipeline: {LLM_CHECKPOINT}...") |
|
try: |
|
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_CHECKPOINT) |
|
llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_CHECKPOINT) |
|
device = 0 if torch.cuda.is_available() else -1 |
|
pipe = pipeline( |
|
"text2text-generation", |
|
model=llm_model, |
|
tokenizer=llm_tokenizer, |
|
max_new_tokens=150, |
|
|
|
device=device |
|
) |
|
|
|
|
|
llm_pipe = HuggingFacePipeline(pipeline=pipe) |
|
print(f">> LLM pipeline loaded on device {device}.") |
|
return llm_pipe |
|
except Exception as e: |
|
st.error(f"Error loading LLM pipeline: {e}") |
|
st.stop() |
|
|
|
|
|
|
|
embeddings_model = load_embeddings_model() |
|
vector_store = load_faiss_index(embeddings_model) |
|
llm_pipeline = load_llm_qa_pipeline() |
|
|
|
|
|
qa_chain = None |
|
if vector_store and llm_pipeline: |
|
try: |
|
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm_pipeline, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
return_source_documents=True |
|
) |
|
print(">> QA Chain ready.") |
|
except Exception as e: |
|
st.error(f"Error creating QA chain: {e}") |
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [ |
|
{"role": "assistant", "avatar": ASSISTANT_AVATAR, "content": "Salam! π Saya Bot Soal Jawab BM. Anda boleh tanya saya soalan berkaitan polisi e-dagang (contoh: Lazada/Shopee) dari pangkalan data saya."} |
|
] |
|
|
|
|
|
|
|
st.title("π²πΎ Bot Soal Jawab Bahasa Melayu (E-Dagang)") |
|
st.caption("Dibangunkan dengan G-v5.6-Go | Streamlit | LangChain | Hugging Face") |
|
st.divider() |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"], avatar=message.get("avatar")): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Masukkan soalan anda di sini..."): |
|
|
|
st.session_state.messages.append({"role": "user", "avatar": USER_AVATAR, "content": prompt}) |
|
with st.chat_message("user", avatar=USER_AVATAR): |
|
st.markdown(prompt) |
|
|
|
|
|
with st.chat_message("assistant", avatar=ASSISTANT_AVATAR): |
|
|
|
if not qa_chain: |
|
st.error("Maaf, sistem RAG tidak bersedia. Sila pastikan FAISS index dimuatkan dengan betul.") |
|
else: |
|
|
|
with st.spinner("Mencari jawapan..."): |
|
try: |
|
start_time = time.time() |
|
|
|
result = qa_chain({"query": prompt}) |
|
end_time = time.time() |
|
|
|
generated_answer = result.get('result', "Maaf, saya tidak dapat menjana jawapan.") |
|
|
|
if "<extra_id_" in generated_answer: |
|
generated_answer = "Maaf, saya tidak pasti jawapannya berdasarkan maklumat yang ada." |
|
|
|
st.markdown(generated_answer) |
|
|
|
|
|
source_docs = result.get('source_documents', []) |
|
if source_docs: |
|
with st.expander("Lihat Sumber Rujukan", expanded=False): |
|
for i, doc in enumerate(source_docs): |
|
source_name = doc.metadata.get('source', f'Sumber {i+1}') |
|
st.info(f"**{source_name}:**\n\n```\n{doc.page_content}\n```") |
|
st.caption(f"Masa mencari: {end_time - start_time:.2f} saat") |
|
else: |
|
st.warning("Tiada sumber rujukan ditemui.") |
|
|
|
except Exception as e: |
|
st.error(f"Ralat semasa memproses RAG: {e}") |
|
|
|
|
|
assistant_response_content = generated_answer |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "avatar": ASSISTANT_AVATAR, "content": assistant_response_content}) |
|
|
|
|
|
|