File size: 4,396 Bytes
4e00df7
 
 
 
 
 
 
 
 
cae23e1
 
 
 
4e00df7
8a70a7b
 
4e00df7
 
 
 
 
033cc04
4e00df7
 
 
 
 
 
 
 
 
 
0a26e47
4e00df7
 
 
 
 
 
 
 
cae23e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e00df7
cae23e1
 
4e00df7
cae23e1
4e00df7
cae23e1
 
 
 
 
 
 
 
 
 
 
 
 
033cc04
cae23e1
033cc04
cae23e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# main.py
import os
import tempfile

import streamlit as st
from question import chat_with_doc
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.vectorstores import SupabaseVectorStore
from supabase import Client, create_client
from stats import add_usage
from langchain.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

supabase_url = st.secrets.SUPABASE_URL
supabase_key = st.secrets.SUPABASE_KEY
openai_api_key = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key = st.secrets.hf_api_key
supabase: Client = create_client(supabase_url, supabase_key)
self_hosted = st.secrets.self_hosted
username = st.secrets.username

# embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)

embeddings = HuggingFaceInferenceAPIEmbeddings(
    api_key=hf_api_key,
    model_name="BAAI/bge-large-en-v1.5"
)

vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")

models = ["meta-llama/Llama-2-70b-chat-hf", "mistralai/Mixtral-8x7B-Instruct-v0.1"]

if openai_api_key:
    models += ["gpt-3.5-turbo", "gpt-4"]

if anthropic_api_key:
    models += ["claude-v1", "claude-v1.3",
               "claude-instant-v1-100k", "claude-instant-v1.1-100k"]

if 'question' in st.query_params:
    query = st.query_params['question']
    model = "meta-llama/Llama-2-70b-chat-hf"
    temp = 0.1
    max_tokens = 500
    add_usage(supabase, "api", "prompt" + query, {"model": model, "temperature": temp})
    # print(st.session_state['max_tokens'])
    endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
    model_kwargs = {"temperature" : temp,
                    "max_new_tokens" : max_tokens,
                    "return_full_text" : False}
    hf = HuggingFaceEndpoint(
        endpoint_url=endpoint_url,
        task="text-generation",
        huggingfacehub_api_token=hf_api_key,
        model_kwargs=model_kwargs
    )
    memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
    qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.8, "k": 4,"filter": {"user": username}}), memory=memory, return_source_documents=True)
    model_response = qa({"question": query})
    # print( model_response["answer"])
    sources = model_response["source_documents"]
    # print(sources)
    if len(sources) > 0:
        json = {"response": model_response["answer"]}
        st.code(json, language="json")
    else:
        json = {"response": "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."}
        st.code(json, language="json")
    memory.clear()
else:
    # Set the theme
    st.set_page_config(
        page_title="Securade.ai - Safety Copilot",
        page_icon="https://securade.ai/favicon.ico",
        layout="centered",
        initial_sidebar_state="collapsed",
        menu_items={
            "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
            "Get Help" : "https://securade.ai",
            "Report a Bug": "mailto:[email protected]"
        }
    )

    st.title("👷‍♂️ Safety Copilot 🦺")

    st.markdown("Chat with your personal safety assistant about any health & safety related queries.")
    st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")

    st.markdown("---\n\n")

    # Initialize session state variables
    if 'model' not in st.session_state:
        st.session_state['model'] = "meta-llama/Llama-2-70b-chat-hf"
    if 'temperature' not in st.session_state:
        st.session_state['temperature'] = 0.1
    if 'chunk_size' not in st.session_state:
        st.session_state['chunk_size'] = 500
    if 'chunk_overlap' not in st.session_state:
        st.session_state['chunk_overlap'] = 0
    if 'max_tokens' not in st.session_state:
        st.session_state['max_tokens'] = 500
    if 'username' not in st.session_state:
        st.session_state['username'] = username

    chat_with_doc(st.session_state['model'], vector_store, stats_db=supabase)

    st.markdown("---\n\n")