Asman2010 commited on
Commit
9ba4a65
·
verified ·
1 Parent(s): 9e284b7

Upload 5 files

Browse files
Files changed (5) hide show
  1. .env +1 -0
  2. llama-logo.png +0 -0
  3. llama2_chatbot.py +143 -0
  4. requirements.txt +67 -0
  5. utils.py +28 -0
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ GROQ_API_KEY = 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
llama-logo.png ADDED
llama2_chatbot.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from dotenv import load_dotenv
4
+ load_dotenv()
5
+ from groq import Groq
6
+
7
+ # Load environment variables
8
+ GROQ_API_KEY = os.environ.get('GROQ_API_KEY')
9
+
10
+ PRE_PROMPT = "You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as Assistant."
11
+
12
+ if not GROQ_API_KEY:
13
+ st.warning("Please add your Groq API key to the .env file.")
14
+ st.stop()
15
+
16
+ # Connect to Groq
17
+ client = Groq(api_key=GROQ_API_KEY)
18
+
19
+ #models' endpoints:
20
+ GROQ_MODEL_ENDPOINT70B = os.environ.get(GROQ_API_KEY, model='llama3-70b-8192')
21
+ GROQ_MODEL_ENDPOINT8B = os.environ.get(GROQ_API_KEY, model='llama3-8b-8192')
22
+ PRE_PROMPT = "You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as Assistant."
23
+ #Auth0 for auth
24
+
25
+ # Set up Streamlit app
26
+ st.set_page_config(page_title="LLaMA 3x", page_icon="🦙", layout="wide")
27
+
28
+ def render_app():
29
+
30
+ # reduce font sizes for input text boxes
31
+ custom_css = """
32
+ <style>
33
+ .stTextArea textarea {font-size: 13px;}
34
+ div[data-baseweb="select"] > div {font-size: 13px !important;}
35
+ </style>
36
+ """
37
+ st.markdown(custom_css, unsafe_allow_html=True)
38
+
39
+ #Left sidebar menu
40
+ st.sidebar.header("LLaMA 3x")
41
+
42
+ #Set config for a cleaner menu, footer & background:
43
+ hide_streamlit_style = """
44
+ <style>
45
+ #MainMenu {visibility: hidden;}
46
+ footer {visibility: hidden;}
47
+ </style>
48
+ """
49
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
50
+
51
+ #container for the chat history
52
+ response_container = st.container()
53
+ #container for the user's text input
54
+ container = st.container()
55
+ #Set up/Initialize Session State variables:
56
+ # Set up/Initialize Session State variables
57
+ if 'chat_dialogue' not in st.session_state:
58
+ st.session_state['chat_dialogue'] = []
59
+ if 'temperature' not in st.session_state:
60
+ st.session_state['temperature'] = 0.1
61
+ if 'top_p' not in st.session_state:
62
+ st.session_state['top_p'] = 0.9
63
+ if 'max_seq_len' not in st.session_state:
64
+ st.session_state['max_seq_len'] = 512
65
+ if 'pre_prompt' not in st.session_state:
66
+ st.session_state['pre_prompt'] = PRE_PROMPT
67
+
68
+
69
+ #Dropdown menu to select the model edpoint:
70
+ selected_option = st.sidebar.selectbox('Choose a LLaMA2 model:', ['LLaMA3 70B', 'LLaMA3 8B'], key='model')
71
+ if selected_option == 'LLaMA-3 70B':
72
+ st.session_state['llm'] = GROQ_MODEL_ENDPOINT70B
73
+ else:
74
+ st.session_state['llm'] = GROQ_MODEL_ENDPOINT8B
75
+
76
+ # Model hyperparameters
77
+ st.session_state['temperature'] = st.sidebar.slider('Temperature:', min_value=0.01, max_value=5.0, value=0.1, step=0.01)
78
+ st.session_state['top_p'] = st.sidebar.slider('Top P:', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
79
+ st.session_state['max_seq_len'] = st.sidebar.slider('Max Sequence Length:', min_value=64, max_value=4096, value=2048, step=8)
80
+
81
+
82
+ NEW_P = st.sidebar.text_area('Prompt before the chat starts. Edit here if desired:', PRE_PROMPT, height=60)
83
+ if NEW_P != PRE_PROMPT and NEW_P != "" and NEW_P != None:
84
+ st.session_state['pre_prompt'] = NEW_P + "\n\n"
85
+ else:
86
+ st.session_state['pre_prompt'] = PRE_PROMPT
87
+
88
+ btn_col1, btn_col2 = st.sidebar.columns(2)
89
+
90
+ # Add the "Clear Chat History" button to the sidebar
91
+ def clear_history():
92
+ st.session_state['chat_dialogue'] = []
93
+ clear_chat_history_button = btn_col1.button("Clear History",
94
+ use_container_width=True,
95
+ on_click=clear_history)
96
+
97
+ # add logout button
98
+ def logout():
99
+ del st.session_state['user_info']
100
+ logout_button = btn_col2.button("Logout",
101
+ use_container_width=True,
102
+ on_click=logout)
103
+
104
+ # add links to relevant resources for users to select
105
+ st.sidebar.write(" ")
106
+
107
+ logo1 = 'https://storage.googleapis.com/llama2_release/a16z_logo.png'
108
+ logo2 = 'https://storage.googleapis.com/llama2_release/Screen%20Shot%202023-07-21%20at%2012.34.05%20PM.png'
109
+
110
+ st.sidebar.write(" ")
111
+ st.sidebar.markdown("*Made with ❤️ by Asman. Not associated with Meta Platforms, Inc.*")
112
+
113
+ # Display chat messages from history on app rerun
114
+ for message in st.session_state.chat_dialogue:
115
+ with st.chat_message(message["role"]):
116
+ st.markdown(message["content"])
117
+
118
+ # Accept user input
119
+ if prompt := st.chat_input("Message LLaMA 3x...."):
120
+ # Add user message to chat history
121
+ st.session_state.chat_dialogue.append({"role": "user", "content": prompt})
122
+ # Display user message in chat message container
123
+ with st.chat_message("user"):
124
+ st.markdown(prompt)
125
+
126
+ with st.chat_message("assistant"):
127
+ message_placeholder = st.empty()
128
+ full_response = ""
129
+ messages = [{"role": msg["role"], "content": msg["content"]} for msg in st.session_state.chat_dialogue]
130
+ chat_completion = client.chat.completions.create(
131
+ messages=messages,
132
+ model=selected_option,
133
+ temperature=st.session_state['temperature'],
134
+ top_p=st.session_state['top_p'],
135
+ max_tokens=st.session_state['max_seq_len']
136
+ )
137
+ full_response = chat_completion.choices[0].message.content
138
+ message_placeholder.markdown(full_response)
139
+
140
+ # Add assistant response to chat history
141
+ st.session_state.chat_dialogue.append({"role": "assistant", "content": full_response})
142
+
143
+ render_app()
requirements.txt ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiodns==3.0.0
2
+ aiohttp==3.8.5
3
+ aiosignal==1.3.1
4
+ altair==5.0.1
5
+ async-timeout==4.0.2
6
+ attrs==23.1.0
7
+ blinker==1.6.2
8
+ Brotli==1.0.9
9
+ cachetools==5.3.1
10
+ certifi==2023.5.7
11
+ cffi==1.15.1
12
+ charset-normalizer==3.2.0
13
+ click==8.1.5
14
+ decorator==5.1.1
15
+ ecdsa==0.18.0
16
+ frozenlist==1.4.0
17
+ gitdb==4.0.10
18
+ GitPython==3.1.32
19
+ idna==3.4
20
+ importlib-metadata==6.8.0
21
+ Jinja2==3.1.2
22
+ jsonschema==4.18.3
23
+ jsonschema-specifications==2023.6.1
24
+ markdown-it-py==3.0.0
25
+ MarkupSafe==2.1.3
26
+ mdurl==0.1.2
27
+ multidict==6.0.4
28
+ numpy==1.25.1
29
+ packaging==23.1
30
+ pandas==2.0.3
31
+ Pillow==9.5.0
32
+ protobuf==4.23.4
33
+ pyarrow==12.0.1
34
+ pyasn1==0.5.0
35
+ pycares==4.3.0
36
+ pycparser==2.21
37
+ pydantic==1.10.11
38
+ pydeck==0.8.1b0
39
+ Pygments==2.15.1
40
+ Pympler==1.0.1
41
+ python-dateutil==2.8.2
42
+ python-dotenv==1.0.0
43
+ python-jose==3.3.0
44
+ pytz==2023.3
45
+ pytz-deprecation-shim==0.1.0.post0
46
+ referencing==0.29.1
47
+ replicate==0.8.4
48
+ requests==2.31.0
49
+ rich==13.4.2
50
+ rpds-py==0.8.10
51
+ rsa==4.9
52
+ six==1.16.0
53
+ smmap==5.0.0
54
+ streamlit==1.24.1
55
+ streamlit-auth0-component==0.1.5
56
+ streamlit-chat==0.1.1
57
+ tenacity==8.2.2
58
+ toml==0.10.2
59
+ toolz==0.12.0
60
+ tornado==6.3.2
61
+ typing_extensions==4.7.1
62
+ tzdata==2023.3
63
+ tzlocal==4.3.1
64
+ urllib3==2.0.3
65
+ validators==0.20.0
66
+ yarl==1.9.2
67
+ zipp==3.16.2
utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import replicate
2
+ import time
3
+
4
+ # Initialize debounce variables
5
+ last_call_time = 0
6
+ debounce_interval = 2 # Set the debounce interval (in seconds) to your desired value
7
+
8
+ def debounce_replicate_run(llm, prompt, max_len, temperature, top_p, API_TOKEN):
9
+ global last_call_time
10
+ print("last call time: ", last_call_time)
11
+
12
+ # Get the current time
13
+ current_time = time.time()
14
+
15
+ # Calculate the time elapsed since the last call
16
+ elapsed_time = current_time - last_call_time
17
+
18
+ # Check if the elapsed time is less than the debounce interval
19
+ if elapsed_time < debounce_interval:
20
+ print("Debouncing")
21
+ return "Hello! You are sending requests too fast. Please wait a few seconds before sending another request."
22
+
23
+
24
+ # Update the last call time to the current time
25
+ last_call_time = time.time()
26
+
27
+ output = replicate.run(llm, input={"prompt": prompt + "Assistant: ", "max_length": max_len, "temperature": temperature, "top_p": top_p, "repetition_penalty": 1}, api_token=API_TOKEN)
28
+ return output