codelion commited on
Commit
90fddeb
Β·
verified Β·
1 Parent(s): 63c0a0b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +99 -92
main.py CHANGED
@@ -1,13 +1,13 @@
1
  # main.py
 
2
  import os
3
  import streamlit as st
4
  import anthropic
 
5
 
6
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
8
  from langchain_community.vectorstores import SupabaseVectorStore
9
  from langchain_community.llms import HuggingFaceEndpoint
10
- from langchain_community.vectorstores import SupabaseVectorStore
11
 
12
  from langchain.chains import ConversationalRetrievalChain
13
  from langchain.memory import ConversationBufferMemory
@@ -16,122 +16,129 @@ from supabase import Client, create_client
16
  from streamlit.logger import get_logger
17
  from stats import get_usage, add_usage
18
 
19
- supabase_url = st.secrets.SUPABASE_URL
20
- supabase_key = st.secrets.SUPABASE_KEY
21
- openai_api_key = st.secrets.openai_api_key
 
22
  anthropic_api_key = st.secrets.anthropic_api_key
23
- hf_api_key = st.secrets.hf_api_key
24
- username = st.secrets.username
25
 
26
  supabase: Client = create_client(supabase_url, supabase_key)
27
  logger = get_logger(__name__)
28
 
29
- embeddings = HuggingFaceInferenceAPIEmbeddings(
30
- api_key=hf_api_key,
 
31
  model_name="BAAI/bge-large-en-v1.5",
32
- api_url="https://router.huggingface.co/hf-inference/pipeline/feature-extraction/",
 
33
  )
34
 
35
- if 'chat_history' not in st.session_state:
36
- st.session_state['chat_history'] = []
37
-
38
- vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
39
- memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
 
 
 
 
 
 
 
 
40
 
41
- model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
 
 
 
42
 
43
- temperature = 0.1
44
- max_tokens = 500
45
- stats = str(get_usage(supabase))
 
 
46
 
47
- def response_generator(query):
48
- qa = None
49
- add_usage(supabase, "chat", "prompt" + query, {"model": model, "temperature": temperature})
50
- logger.info('Using HF model %s', model)
51
- # print(st.session_state['max_tokens'])
52
- endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
53
- model_kwargs = {"temperature" : temperature,
54
- "max_new_tokens" : max_tokens,
55
- # "repetition_penalty" : 1.1,
56
- "return_full_text" : False}
57
  hf = HuggingFaceEndpoint(
58
- endpoint_url=endpoint_url,
59
  task="text-generation",
60
  huggingfacehub_api_token=hf_api_key,
61
- model_kwargs=model_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
- qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True)
64
-
65
- # Generate model's response
66
- model_response = qa({"question": query})
67
- logger.info('Result: %s', model_response["answer"])
68
- sources = model_response["source_documents"]
69
- logger.info('Sources: %s', model_response["source_documents"])
70
-
71
- if len(sources) > 0:
72
- response = model_response["answer"]
73
- else:
74
- response = "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
75
-
76
- return response
77
-
78
- # Set the theme
 
 
 
79
  st.set_page_config(
80
  page_title="Securade.ai - Safety Copilot",
81
  page_icon="https://securade.ai/favicon.ico",
82
  layout="centered",
83
  initial_sidebar_state="collapsed",
84
  menu_items={
85
- "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
86
- "Get Help" : "https://securade.ai",
87
- "Report a Bug": "mailto:[email protected]"
88
- }
89
  )
90
 
91
  st.title("πŸ‘·β€β™‚οΈ Safety Copilot 🦺")
 
 
 
 
 
 
 
92
 
93
- st.markdown("Chat with your personal safety assistant about any health & safety related queries. [[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]")
94
- # st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
95
- st.markdown("_"+ stats + " queries answered!_")
96
-
97
- if 'chat_history' not in st.session_state:
98
- st.session_state['chat_history'] = []
99
-
100
- # Display chat messages from history on app rerun
101
- for message in st.session_state.chat_history:
102
- with st.chat_message(message["role"]):
103
- st.markdown(message["content"])
104
-
105
- # Accept user input
106
- if prompt := st.chat_input("Ask a question"):
107
- # print(prompt)
108
- # Add user message to chat history
109
  st.session_state.chat_history.append({"role": "user", "content": prompt})
110
- # Display user message in chat message container
111
  with st.chat_message("user"):
112
  st.markdown(prompt)
113
-
114
- with st.spinner('Safety briefing in progress...'):
115
- response = response_generator(prompt)
116
-
117
- # Display assistant response in chat message container
118
  with st.chat_message("assistant"):
119
- st.markdown(response)
120
- # Add assistant response to chat history
121
- # print(response)
122
- st.session_state.chat_history.append({"role": "assistant", "content": response})
123
-
124
- # query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
125
- # columns = st.columns(2)
126
- # with columns[0]:
127
- # button = st.button("Ask")
128
- # with columns[1]:
129
- # clear_history = st.button("Clear History", type='secondary')
130
-
131
- # st.markdown("---\n\n")
132
-
133
- # if clear_history:
134
- # # Clear memory in Langchain
135
- # memory.clear()
136
- # st.session_state['chat_history'] = []
137
- # st.experimental_rerun()
 
1
  # main.py
2
+
3
  import os
4
  import streamlit as st
5
  import anthropic
6
+ from requests import JSONDecodeError
7
 
 
8
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
9
  from langchain_community.vectorstores import SupabaseVectorStore
10
  from langchain_community.llms import HuggingFaceEndpoint
 
11
 
12
  from langchain.chains import ConversationalRetrievalChain
13
  from langchain.memory import ConversationBufferMemory
 
16
  from streamlit.logger import get_logger
17
  from stats import get_usage, add_usage
18
 
19
+ # ─────── supabase + secrets ────────────────────────────────────────────────────
20
+ supabase_url = st.secrets.SUPABASE_URL
21
+ supabase_key = st.secrets.SUPABASE_KEY
22
+ openai_api_key = st.secrets.openai_api_key
23
  anthropic_api_key = st.secrets.anthropic_api_key
24
+ hf_api_key = st.secrets.hf_api_key
25
+ username = st.secrets.username
26
 
27
  supabase: Client = create_client(supabase_url, supabase_key)
28
  logger = get_logger(__name__)
29
 
30
+ # ─────── embeddings ─────────────────────────────────────────────────────────────
31
+ # Switch to local BGE embeddings (no JSONDecode errors, no HTTP‑batch issues) :contentReference[oaicite:0]{index=0}
32
+ embeddings = HuggingFaceBgeEmbeddings(
33
  model_name="BAAI/bge-large-en-v1.5",
34
+ model_kwargs={"device": "cpu"},
35
+ encode_kwargs={"normalize_embeddings": True},
36
  )
37
 
38
+ # ─────── vector store + memory ─────────────────────────────────────────────────
39
+ vector_store = SupabaseVectorStore(
40
+ client=supabase,
41
+ embedding=embeddings,
42
+ query_name="match_documents",
43
+ table_name="documents",
44
+ )
45
+ memory = ConversationBufferMemory(
46
+ memory_key="chat_history",
47
+ input_key="question",
48
+ output_key="answer",
49
+ return_messages=True,
50
+ )
51
 
52
+ # ─────── LLM setup ──────────────────────────────────────────────────────────────
53
+ model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
54
+ temperature = 0.1
55
+ max_tokens = 500
56
 
57
+ def response_generator(query: str) -> str:
58
+ """Ask the RAG chain to answer `query`, with JSON‑error fallback."""
59
+ # log usage
60
+ add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
61
+ logger.info("Using HF model %s", model)
62
 
63
+ # prepare HF text-generation LLM
 
 
 
 
 
 
 
 
 
64
  hf = HuggingFaceEndpoint(
65
+ endpoint_url=f"https://api-inference.huggingface.co/models/{model}",
66
  task="text-generation",
67
  huggingfacehub_api_token=hf_api_key,
68
+ model_kwargs={
69
+ "temperature": temperature,
70
+ "max_new_tokens": max_tokens,
71
+ "return_full_text": False,
72
+ },
73
+ )
74
+
75
+ # conversational RAG chain
76
+ qa = ConversationalRetrievalChain.from_llm(
77
+ llm=hf,
78
+ retriever=vector_store.as_retriever(
79
+ search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
80
+ ),
81
+ memory=memory,
82
+ verbose=True,
83
+ return_source_documents=True,
84
  )
85
+
86
+ try:
87
+ result = qa({"question": query})
88
+ except JSONDecodeError as e:
89
+ # fallback logging
90
+ logger.error("Embedding JSONDecodeError: %s", e)
91
+ return "Sorry, I had trouble understanding the embedded data. Please try again."
92
+
93
+ answer = result.get("answer", "")
94
+ sources = result.get("source_documents", [])
95
+
96
+ if not sources:
97
+ return (
98
+ "I’m sorry, I don’t have enough information to answer that. "
99
+ "If you have a public data source to add, please email [email protected]."
100
+ )
101
+ return answer
102
+
103
+ # ─────── Streamlit UI ──────────────────────────────────────────────────────────
104
  st.set_page_config(
105
  page_title="Securade.ai - Safety Copilot",
106
  page_icon="https://securade.ai/favicon.ico",
107
  layout="centered",
108
  initial_sidebar_state="collapsed",
109
  menu_items={
110
+ "About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)",
111
+ "Get Help": "https://securade.ai",
112
+ "Report a Bug": "mailto:[email protected]",
113
+ },
114
  )
115
 
116
  st.title("πŸ‘·β€β™‚οΈ Safety Copilot 🦺")
117
+ stats = get_usage(supabase)
118
+ st.markdown(f"_{stats} queries answered!_")
119
+ st.markdown(
120
+ "Chat with your personal safety assistant about any health & safety related queries. "
121
+ "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)"
122
+ "|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
123
+ )
124
 
125
+ if "chat_history" not in st.session_state:
126
+ st.session_state.chat_history = []
127
+
128
+ # show history
129
+ for msg in st.session_state.chat_history:
130
+ with st.chat_message(msg["role"]):
131
+ st.markdown(msg["content"])
132
+
133
+ # new user input
134
+ if prompt := st.chat_input("Ask a question"):
 
 
 
 
 
 
135
  st.session_state.chat_history.append({"role": "user", "content": prompt})
 
136
  with st.chat_message("user"):
137
  st.markdown(prompt)
138
+
139
+ with st.spinner("Safety briefing in progress..."):
140
+ answer = response_generator(prompt)
141
+
 
142
  with st.chat_message("assistant"):
143
+ st.markdown(answer)
144
+ st.session_state.chat_history.append({"role": "assistant", "content": answer})