File size: 7,316 Bytes
63c0a0b
90fddeb
4e00df7
 
63c0a0b
90fddeb
63c0a0b
 
dfd217b
 
4b01151
63c0a0b
dfd217b
 
63c0a0b
4e00df7
dfd217b
63c0a0b
a91d644
90fddeb
 
 
 
2200d67
90fddeb
 
4e00df7
dfd217b
63c0a0b
4e00df7
90fddeb
 
fd65021
78308ba
90fddeb
9c6f575
4e00df7
90fddeb
 
 
 
 
 
 
 
 
 
 
 
 
4e00df7
90fddeb
fd65021
90fddeb
 
63c0a0b
bd57608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90fddeb
 
 
 
 
4e00df7
90fddeb
c15f617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90fddeb
 
 
 
 
 
 
 
 
 
63c0a0b
90fddeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd57608
 
90fddeb
 
 
1ca7761
 
 
 
 
 
90fddeb
 
 
 
1ca7761
cae23e1
1ca7761
90fddeb
 
 
 
 
 
 
4e00df7
90fddeb
 
 
 
 
 
 
 
 
 
dfd217b
 
 
90fddeb
 
 
 
dfd217b
90fddeb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# main.py

import os
import streamlit as st
import anthropic
from requests import JSONDecodeError

from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models import ChatOpenAI

from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage

# ─────── supabase + secrets ────────────────────────────────────────────────────
supabase_url    = st.secrets.SUPABASE_URL
supabase_key    = st.secrets.SUPABASE_KEY
openai_api_key  = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key      = st.secrets.hf_api_key
username        = st.secrets.username

supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)

# ─────── embeddings ─────────────────────────────────────────────────────────────
# Switch to local BGE embeddings (no JSONDecode errors, no HTTP‑batch issues) :contentReference[oaicite:0]{index=0}
embeddings = HuggingFaceBgeEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={"device": "cpu"},
    encode_kwargs={"normalize_embeddings": True}
)
# ─────── vector store + memory ─────────────────────────────────────────────────
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding=embeddings,
    query_name="match_documents",
    table_name="documents",
)
memory = ConversationBufferMemory(
    memory_key="chat_history",
    input_key="question",
    output_key="answer",
    return_messages=True,
)

# ─────── LLM setup ──────────────────────────────────────────────────────────────
model        = "HuggingFaceTB/SmolLM3-3B"
temperature  = 0.1
max_tokens   = 500

import re

def clean_response(answer: str) -> str:
    """Clean up AI response by removing unwanted artifacts and formatting."""
    if not answer:
        return answer
    
    # Remove thinking tags and content
    answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL)
    answer = re.sub(r'<thinking>.*?</thinking>', '', answer, flags=re.DOTALL)
    
    # Remove other common AI response artifacts
    answer = re.sub(r'\[.*?\]', '', answer, flags=re.DOTALL)  # Remove bracketed content
    answer = re.sub(r'\{.*?\}', '', answer, flags=re.DOTALL)  # Remove curly bracketed content
    answer = re.sub(r'```.*?```', '', answer, flags=re.DOTALL)  # Remove code blocks
    answer = re.sub(r'---.*?---', '', answer, flags=re.DOTALL)  # Remove dashed sections
    
    # Remove excessive whitespace and newlines
    answer = re.sub(r'\s+', ' ', answer).strip()
    
    # Remove common AI-generated prefixes/suffixes
    answer = re.sub(r'^(Assistant:|AI:|Grok:)\s*', '', answer, flags=re.IGNORECASE)
    answer = re.sub(r'\s*(Sincerely,.*|Best regards,.*|Regards,.*)$', '', answer, flags=re.IGNORECASE)
    
    return answer
    
def response_generator(query: str) -> str:
    """Ask the RAG chain to answer `query`, with JSON‑error fallback."""
    # log usage
    add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
    logger.info("Using HF model %s", model)

    # prepare HF text-generation LLM
    # hf = HuggingFaceEndpoint(
    #     # endpoint_url=f"https://api-inference.huggingface.co/models/{model}",
    #     endpoint_url=f"https://router.huggingface.co/hf-inference/models/{model}",
    #     task="text-generation",
    #     huggingfacehub_api_token=hf_api_key,
    #     model_kwargs={
    #         "temperature": temperature,
    #         "max_new_tokens": max_tokens,
    #         "return_full_text": False,
    #     },
    # )

    hf = ChatOpenAI(
            base_url=f"https://router.huggingface.co/hf-inference/models/{model}/v1",
            api_key=hf_api_key,
            model=model,
            temperature=temperature,
            max_tokens=max_tokens,
            timeout=30,  # Add timeout
            max_retries=3,  # Built-in retry logic
        )

    # conversational RAG chain
    qa = ConversationalRetrievalChain.from_llm(
        llm=hf,
        retriever=vector_store.as_retriever(
            search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
        ),
        memory=memory,
        verbose=True,
        return_source_documents=True,
    )

    try:
        result = qa({"question": query})
    except JSONDecodeError as e:
        # fallback logging  
        logger.error("Embedding JSONDecodeError: %s", e)
        return "Sorry, I had trouble understanding the embedded data. Please try again."

    answer = result.get("answer", "")
    sources = result.get("source_documents", [])

    if not sources:
        return (
            "I’m sorry, I don’t have enough information to answer that. "
            "If you have a public data source to add, please email [email protected]."
        )

    answer = clean_response(answer)
    return answer

# ─────── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(
    page_title="Securade.ai - Safety Copilot",
    page_icon="https://securade.ai/favicon.ico",
    layout="centered",
    initial_sidebar_state="collapsed",
    menu_items={
        "About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)",
        "Get Help": "https://securade.ai",
        "Report a Bug": "mailto:[email protected]",
    },
)

st.title("πŸ‘·β€β™‚οΈ Safety Copilot 🦺")
stats = get_usage(supabase)
st.markdown(f"_{stats} queries answered!_")
st.markdown(
    "Chat with your personal safety assistant about any health & safety related queries. "
    "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)"
    "|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
)

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# show history
for msg in st.session_state.chat_history:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

# new user input
if prompt := st.chat_input("Ask a question"):
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.spinner("Safety briefing in progress..."):
        answer = response_generator(prompt)

    with st.chat_message("assistant"):
        st.markdown(answer)
    st.session_state.chat_history.append({"role": "assistant", "content": answer})