import os import streamlit as st import logging from requests.exceptions import JSONDecodeError from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings from langchain_community.vectorstores import SupabaseVectorStore from langchain_community.llms import HuggingFaceEndpoint from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from supabase import Client, create_client from streamlit.logger import get_logger # Configure logging logger = get_logger(__name__) logging.basicConfig(level=logging.INFO) # Load secrets supabase_url = st.secrets["SUPABASE_URL"] supabase_key = st.secrets["SUPABASE_KEY"] hf_api_key = st.secrets["hf_api_key"] username = st.secrets["username"] # Initialize Supabase client supabase: Client = create_client(supabase_url, supabase_key) # Custom HuggingFaceInferenceAPIEmbeddings to handle JSONDecodeError class CustomHuggingFaceInferenceAPIEmbeddings(HuggingFaceInferenceAPIEmbeddings): def embed_query(self, text: str): try: response = self.client.post( json={"inputs": text, "options": {"use_cache": False}}, task="feature-extraction", ) if response.status_code != 200: logger.error(f"API request failed with status {response.status_code}: {response.text}") return [0.0] * 384 # Return zero vector of expected dimension try: embeddings = response.json() if not isinstance(embeddings, list) or not embeddings: logger.error(f"Invalid embeddings response: {embeddings}") return [0.0] * 384 return embeddings[0] except JSONDecodeError as e: logger.error(f"JSON decode error: {str(e)}, response: {response.text}") return [0.0] * 384 except Exception as e: logger.error(f"Error embedding query: {str(e)}") return [0.0] * 384 def embed_documents(self, texts): try: response = self.client.post( json={"inputs": texts, "options": {"use_cache": False}}, task="feature-extraction", ) if response.status_code != 200: logger.error(f"API request failed with status {response.status_code}: {response.text}") return [[0.0] * 384 for _ in texts] try: embeddings = response.json() if not isinstance(embeddings, list) or not embeddings: logger.error(f"Invalid embeddings response: {embeddings}") return [[0.0] * 384 for _ in texts] return [emb[0] for emb in embeddings] except JSONDecodeError as e: logger.error(f"JSON decode error: {str(e)}, response: {response.text}") return [[0.0] * 384 for _ in texts] except Exception as e: logger.error(f"Error embedding documents: {str(e)}") return [[0.0] * 384 for _ in texts] # Initialize embeddings embeddings = CustomHuggingFaceInferenceAPIEmbeddings( api_key=hf_api_key, model_name="BAAI/bge-large-en-v1.5", ) # Initialize session state if "chat_history" not in st.session_state: st.session_state["chat_history"] = [] # Initialize vector store and memory vector_store = SupabaseVectorStore( client=supabase, embedding=embeddings, query_name="match_documents", table_name="documents", ) memory = ConversationBufferMemory( memory_key="chat_history", input_key="question", output_key="answer", return_messages=True, ) # Model configuration model = "mistralai/Mixtral-8x7B-Instruct-v0.1" temperature = 0.1 max_tokens = 500 # Mock stats function (replace with your actual implementation) def get_usage(supabase): return 100 # Replace with actual logic def add_usage(supabase, action, prompt, metadata): pass # Replace with actual logic stats = str(get_usage(supabase)) def response_generator(query): try: add_usage(supabase, "chat", f"prompt: {query}", {"model": model, "temperature": temperature}) logger.info("Using HF model %s", model) endpoint_url = f"https://api-inference.huggingface.co/models/{model}" model_kwargs = { "temperature": temperature, "max_new_tokens": max_tokens, "return_full_text": False, } hf = HuggingFaceEndpoint( endpoint_url=endpoint_url, task="text-generation", huggingfacehub_api_token=hf_api_key, model_kwargs=model_kwargs, ) qa = ConversationalRetrievalChain.from_llm( llm=hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True, ) # Use invoke instead of deprecated __call__ model_response = qa.invoke({"question": query}) logger.info("Result: %s", model_response["answer"]) sources = model_response["source_documents"] logger.info("Sources: %s", sources) if sources: return model_response["answer"] else: return "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email copilot@securade.ai." except Exception as e: logger.error(f"Error generating response: {str(e)}") return "An error occurred while processing your request. Please try again later." # Streamlit UI st.set_page_config( page_title="Securade.ai - Safety Copilot", page_icon="https://securade.ai/favicon.ico", layout="centered", initial_sidebar_state="collapsed", menu_items={ "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)", "Get Help": "https://securade.ai", "Report a Bug": "mailto:hello@securade.ai", }, ) st.title("👷‍♂️ Safety Copilot 🦺") st.markdown( "Chat with your personal safety assistant about any health & safety related queries. " "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)|" "[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]" ) st.markdown(f"_{stats} queries answered!_") # Display chat history for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.markdown(message["content"]) # Handle user input if prompt := st.chat_input("Ask a question"): st.session_state.chat_history.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.spinner("Safety briefing in progress..."): response = response_generator(prompt) with st.chat_message("assistant"): st.markdown(response) st.session_state.chat_history.append({"role": "assistant", "content": response})