File size: 7,130 Bytes
4e00df7
 
a91d644
 
78308ba
dfd217b
 
 
 
4e00df7
dfd217b
4e00df7
a91d644
 
 
 
 
 
 
 
 
4e00df7
a91d644
dfd217b
4e00df7
a91d644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78308ba
 
4e00df7
 
a91d644
 
 
4e00df7
a91d644
 
 
 
 
 
 
 
 
 
 
 
 
4e00df7
a91d644
d035a6e
3fd401e
dfd217b
a91d644
 
 
 
 
 
 
 
 
4e00df7
dfd217b
a91d644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ca7761
 
 
 
 
 
 
a91d644
 
 
1ca7761
cae23e1
1ca7761
a91d644
 
 
 
 
 
4e00df7
a91d644
dfd217b
 
 
a91d644
 
 
dfd217b
 
 
 
a91d644
dfd217b
 
 
 
a91d644
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import streamlit as st
import logging
from requests.exceptions import JSONDecodeError
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from supabase import Client, create_client
from streamlit.logger import get_logger

# Configure logging
logger = get_logger(__name__)
logging.basicConfig(level=logging.INFO)

# Load secrets
supabase_url = st.secrets["SUPABASE_URL"]
supabase_key = st.secrets["SUPABASE_KEY"]
hf_api_key = st.secrets["hf_api_key"]
username = st.secrets["username"]

# Initialize Supabase client
supabase: Client = create_client(supabase_url, supabase_key)

# Custom HuggingFaceInferenceAPIEmbeddings to handle JSONDecodeError
class CustomHuggingFaceInferenceAPIEmbeddings(HuggingFaceInferenceAPIEmbeddings):
    def embed_query(self, text: str):
        try:
            response = self.client.post(
                json={"inputs": text, "options": {"use_cache": False}},
                task="feature-extraction",
            )
            if response.status_code != 200:
                logger.error(f"API request failed with status {response.status_code}: {response.text}")
                return [0.0] * 384  # Return zero vector of expected dimension
            try:
                embeddings = response.json()
                if not isinstance(embeddings, list) or not embeddings:
                    logger.error(f"Invalid embeddings response: {embeddings}")
                    return [0.0] * 384
                return embeddings[0]
            except JSONDecodeError as e:
                logger.error(f"JSON decode error: {str(e)}, response: {response.text}")
                return [0.0] * 384
        except Exception as e:
            logger.error(f"Error embedding query: {str(e)}")
            return [0.0] * 384

    def embed_documents(self, texts):
        try:
            response = self.client.post(
                json={"inputs": texts, "options": {"use_cache": False}},
                task="feature-extraction",
            )
            if response.status_code != 200:
                logger.error(f"API request failed with status {response.status_code}: {response.text}")
                return [[0.0] * 384 for _ in texts]
            try:
                embeddings = response.json()
                if not isinstance(embeddings, list) or not embeddings:
                    logger.error(f"Invalid embeddings response: {embeddings}")
                    return [[0.0] * 384 for _ in texts]
                return [emb[0] for emb in embeddings]
            except JSONDecodeError as e:
                logger.error(f"JSON decode error: {str(e)}, response: {response.text}")
                return [[0.0] * 384 for _ in texts]
        except Exception as e:
            logger.error(f"Error embedding documents: {str(e)}")
            return [[0.0] * 384 for _ in texts]

# Initialize embeddings
embeddings = CustomHuggingFaceInferenceAPIEmbeddings(
    api_key=hf_api_key,
    model_name="BAAI/bge-large-en-v1.5",
)

# Initialize session state
if "chat_history" not in st.session_state:
    st.session_state["chat_history"] = []

# Initialize vector store and memory
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding=embeddings,
    query_name="match_documents",
    table_name="documents",
)
memory = ConversationBufferMemory(
    memory_key="chat_history",
    input_key="question",
    output_key="answer",
    return_messages=True,
)

# Model configuration
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
temperature = 0.1
max_tokens = 500

# Mock stats function (replace with your actual implementation)
def get_usage(supabase):
    return 100  # Replace with actual logic

def add_usage(supabase, action, prompt, metadata):
    pass  # Replace with actual logic

stats = str(get_usage(supabase))

def response_generator(query):
    try:
        add_usage(supabase, "chat", f"prompt: {query}", {"model": model, "temperature": temperature})
        logger.info("Using HF model %s", model)
        
        endpoint_url = f"https://api-inference.huggingface.co/models/{model}"
        model_kwargs = {
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "return_full_text": False,
        }
        hf = HuggingFaceEndpoint(
            endpoint_url=endpoint_url,
            task="text-generation",
            huggingfacehub_api_token=hf_api_key,
            model_kwargs=model_kwargs,
        )
        qa = ConversationalRetrievalChain.from_llm(
            llm=hf,
            retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}),
            memory=memory,
            verbose=True,
            return_source_documents=True,
        )
        
        # Use invoke instead of deprecated __call__
        model_response = qa.invoke({"question": query})
        logger.info("Result: %s", model_response["answer"])
        sources = model_response["source_documents"]
        logger.info("Sources: %s", sources)

        if sources:
            return model_response["answer"]
        else:
            return "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
    except Exception as e:
        logger.error(f"Error generating response: {str(e)}")
        return "An error occurred while processing your request. Please try again later."

# Streamlit UI
st.set_page_config(
    page_title="Securade.ai - Safety Copilot",
    page_icon="https://securade.ai/favicon.ico",
    layout="centered",
    initial_sidebar_state="collapsed",
    menu_items={
        "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
        "Get Help": "https://securade.ai",
        "Report a Bug": "mailto:[email protected]",
    },
)

st.title("👷‍♂️ Safety Copilot 🦺")
st.markdown(
    "Chat with your personal safety assistant about any health & safety related queries. "
    "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)|"
    "[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
)
st.markdown(f"_{stats} queries answered!_")

# Display chat history
for message in st.session_state.chat_history:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Handle user input
if prompt := st.chat_input("Ask a question"):
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)
    
    with st.spinner("Safety briefing in progress..."):
        response = response_generator(prompt)
    
    with st.chat_message("assistant"):
        st.markdown(response)
    st.session_state.chat_history.append({"role": "assistant", "content": response})