Liu Hong Yuan Tom commited on
Commit
c2d1bfd
·
verified ·
1 Parent(s): 1cbf6da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -2
app.py CHANGED
@@ -1,4 +1,268 @@
1
  import streamlit as st
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import json
3
+ from groq import Groq
4
+ from typing import Iterable
5
+ from moa.agent import MOAgent
6
+ from moa.agent.moa import ResponseChunk
7
+ from streamlit_ace import st_ace
8
+ import copy
9
 
10
+ # Default configuration
11
+ default_config = {
12
+ "main_model": "llama3-70b-8192",
13
+ "cycles": 3,
14
+ "layer_agent_config": {}
15
+ }
16
+
17
+ layer_agent_config_def = {
18
+ "layer_agent_1": {
19
+ "system_prompt": "Think through your response step by step. {helper_response}",
20
+ "model_name": "llama3-8b-8192"
21
+ },
22
+ "layer_agent_2": {
23
+ "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
24
+ "model_name": "gemma-7b-it",
25
+ "temperature": 0.7
26
+ },
27
+ "layer_agent_3": {
28
+ "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
29
+ "model_name": "llama3-8b-8192"
30
+ },
31
+
32
+ }
33
+
34
+ # Recommended Configuration
35
+
36
+ rec_config = {
37
+ "main_model": "llama3-70b-8192",
38
+ "cycles": 2,
39
+ "layer_agent_config": {}
40
+ }
41
+
42
+ layer_agent_config_rec = {
43
+ "layer_agent_1": {
44
+ "system_prompt": "Think through your response step by step. {helper_response}",
45
+ "model_name": "llama3-8b-8192",
46
+ "temperature": 0.1
47
+ },
48
+ "layer_agent_2": {
49
+ "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
50
+ "model_name": "llama3-8b-8192",
51
+ "temperature": 0.2
52
+ },
53
+ "layer_agent_3": {
54
+ "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
55
+ "model_name": "llama3-8b-8192",
56
+ "temperature": 0.4
57
+ },
58
+ "layer_agent_4": {
59
+ "system_prompt": "You are an expert planner agent. Create a plan for how to answer the human's query. {helper_response}",
60
+ "model_name": "mixtral-8x7b-32768",
61
+ "temperature": 0.5
62
+ },
63
+ }
64
+
65
+
66
+ def stream_response(messages: Iterable[ResponseChunk]):
67
+ layer_outputs = {}
68
+ for message in messages:
69
+ if message['response_type'] == 'intermediate':
70
+ layer = message['metadata']['layer']
71
+ if layer not in layer_outputs:
72
+ layer_outputs[layer] = []
73
+ layer_outputs[layer].append(message['delta'])
74
+ else:
75
+ # Display accumulated layer outputs
76
+ for layer, outputs in layer_outputs.items():
77
+ st.write(f"Layer {layer}")
78
+ cols = st.columns(len(outputs))
79
+ for i, output in enumerate(outputs):
80
+ with cols[i]:
81
+ st.expander(label=f"Agent {i+1}", expanded=False).write(output)
82
+
83
+ # Clear layer outputs for the next iteration
84
+ layer_outputs = {}
85
+
86
+ # Yield the main agent's output
87
+ yield message['delta']
88
+
89
+ def set_moa_agent(
90
+ main_model: str = default_config['main_model'],
91
+ cycles: int = default_config['cycles'],
92
+ layer_agent_config: dict[dict[str, any]] = copy.deepcopy(layer_agent_config_def),
93
+ main_model_temperature: float = 0.1,
94
+ override: bool = False
95
+ ):
96
+ if override or ("main_model" not in st.session_state):
97
+ st.session_state.main_model = main_model
98
+ else:
99
+ if "main_model" not in st.session_state: st.session_state.main_model = main_model
100
+
101
+ if override or ("cycles" not in st.session_state):
102
+ st.session_state.cycles = cycles
103
+ else:
104
+ if "cycles" not in st.session_state: st.session_state.cycles = cycles
105
+
106
+ if override or ("layer_agent_config" not in st.session_state):
107
+ st.session_state.layer_agent_config = layer_agent_config
108
+ else:
109
+ if "layer_agent_config" not in st.session_state: st.session_state.layer_agent_config = layer_agent_config
110
+
111
+ if override or ("main_temp" not in st.session_state):
112
+ st.session_state.main_temp = main_model_temperature
113
+ else:
114
+ if "main_temp" not in st.session_state: st.session_state.main_temp = main_model_temperature
115
+
116
+ cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config)
117
+
118
+ if override or ("moa_agent" not in st.session_state):
119
+ st.session_state.moa_agent = MOAgent.from_config(
120
+ main_model=st.session_state.main_model,
121
+ cycles=st.session_state.cycles,
122
+ layer_agent_config=cls_ly_conf,
123
+ temperature=st.session_state.main_temp
124
+ )
125
+
126
+ del cls_ly_conf
127
+ del layer_agent_config
128
+
129
+ st.set_page_config(
130
+ page_title="Mixture-Of-Agents Powered by Groq",
131
+ page_icon='static/favicon.ico',
132
+ menu_items={
133
+ 'About': "## Groq Mixture-Of-Agents \n Powered by [Groq](https://groq.com)"
134
+ },
135
+ layout="wide"
136
+ )
137
+
138
+ valid_model_names = [model.id for model in Groq().models.list().data if not model.id.startswith("whisper")]
139
+
140
+ st.markdown("<a href='https://groq.com'><img src='app/static/banner.png' width='500'></a>", unsafe_allow_html=True)
141
+ st.write("---")
142
+
143
+
144
+
145
+ # Initialize session state
146
+ if "messages" not in st.session_state:
147
+ st.session_state.messages = []
148
+
149
+ set_moa_agent()
150
+
151
+ # Sidebar for configuration
152
+ with st.sidebar:
153
+ # config_form = st.form("Agent Configuration", border=False)
154
+ st.title("MOA Configuration")
155
+ with st.form("Agent Configuration", border=False):
156
+ if st.form_submit_button("Use Recommended Config"):
157
+ try:
158
+ set_moa_agent(
159
+ main_model=rec_config['main_model'],
160
+ cycles=rec_config['cycles'],
161
+ layer_agent_config=layer_agent_config_rec,
162
+ override=True
163
+ )
164
+ st.session_state.messages = []
165
+ st.success("Configuration updated successfully!")
166
+ except json.JSONDecodeError:
167
+ st.error("Invalid JSON in Layer Agent Configuration. Please check your input.")
168
+ except Exception as e:
169
+ st.error(f"Error updating configuration: {str(e)}")
170
+ # Main model selection
171
+ new_main_model = st.selectbox(
172
+ "Select Main Model",
173
+ options=valid_model_names,
174
+ index=valid_model_names.index(st.session_state.main_model)
175
+ )
176
+
177
+ # Cycles input
178
+ new_cycles = st.number_input(
179
+ "Number of Layers",
180
+ min_value=1,
181
+ max_value=10,
182
+ value=st.session_state.cycles
183
+ )
184
+
185
+ # Main Model Temperature
186
+ main_temperature = st.number_input(
187
+ label="Main Model Temperature",
188
+ value=0.1,
189
+ min_value=0.0,
190
+ max_value=1.0,
191
+ step=0.1
192
+ )
193
+
194
+ # Layer agent configuration
195
+ tooltip = "Agents in the layer agent configuration run in parallel _per cycle_. Each layer agent supports all initialization parameters of [Langchain's ChatGroq](https://api.python.langchain.com/en/latest/chat_models/langchain_groq.chat_models.ChatGroq.html) class as valid dictionary fields."
196
+ st.markdown("Layer Agent Config", help=tooltip)
197
+ new_layer_agent_config = st_ace(
198
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
199
+ language='json',
200
+ placeholder="Layer Agent Configuration (JSON)",
201
+ show_gutter=False,
202
+ wrap=True,
203
+ auto_update=True
204
+ )
205
+
206
+ if st.form_submit_button("Update Configuration"):
207
+ try:
208
+ new_layer_config = json.loads(new_layer_agent_config)
209
+ set_moa_agent(
210
+ main_model=new_main_model,
211
+ cycles=new_cycles,
212
+ layer_agent_config=new_layer_config,
213
+ main_model_temperature=main_temperature,
214
+ override=True
215
+ )
216
+ st.session_state.messages = []
217
+ st.success("Configuration updated successfully!")
218
+ except json.JSONDecodeError:
219
+ st.error("Invalid JSON in Layer Agent Configuration. Please check your input.")
220
+ except Exception as e:
221
+ st.error(f"Error updating configuration: {str(e)}")
222
+
223
+ st.markdown("---")
224
+ st.markdown("""
225
+ ### Credits
226
+ - MOA: [Together AI](https://www.together.ai/blog/together-moa)
227
+ - LLMs: [Groq](https://groq.com/)
228
+ - Paper: [arXiv:2406.04692](https://arxiv.org/abs/2406.04692)
229
+ """)
230
+
231
+ # Main app layout
232
+ st.header("Mixture of Agents", anchor=False)
233
+ st.write("A demo of the Mixture of Agents architecture proposed by Together AI, Powered by Groq LLMs.")
234
+ st.image("./static/moa_groq.svg", caption="Mixture of Agents Workflow", width=1000)
235
+
236
+ # Display current configuration
237
+ with st.expander("Current MOA Configuration", expanded=False):
238
+ st.markdown(f"**Main Model**: ``{st.session_state.main_model}``")
239
+ st.markdown(f"**Main Model Temperature**: ``{st.session_state.main_temp:.1f}``")
240
+ st.markdown(f"**Layers**: ``{st.session_state.cycles}``")
241
+ st.markdown(f"**Layer Agents Config**:")
242
+ new_layer_agent_config = st_ace(
243
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
244
+ language='json',
245
+ placeholder="Layer Agent Configuration (JSON)",
246
+ show_gutter=False,
247
+ wrap=True,
248
+ readonly=True,
249
+ auto_update=True
250
+ )
251
+
252
+ # Chat interface
253
+ for message in st.session_state.messages:
254
+ with st.chat_message(message["role"]):
255
+ st.markdown(message["content"])
256
+
257
+ if query := st.chat_input("Ask a question"):
258
+ st.session_state.messages.append({"role": "user", "content": query})
259
+ with st.chat_message("user"):
260
+ st.write(query)
261
+
262
+ moa_agent: MOAgent = st.session_state.moa_agent
263
+ with st.chat_message("assistant"):
264
+ message_placeholder = st.empty()
265
+ ast_mess = stream_response(moa_agent.chat(query, output_format='json'))
266
+ response = st.write_stream(ast_mess)
267
+
268
+ st.session_state.messages.append({"role": "assistant", "content": response})