Nullpointer-KK commited on
Commit
fe8d256
Β·
verified Β·
1 Parent(s): 187517f

Update scripts/main.py

Browse files
Files changed (1) hide show
  1. scripts/main.py +269 -269
scripts/main.py CHANGED
@@ -1,269 +1,269 @@
1
- pip install llama-index llama-index-llms-openai
2
- import pdb
3
-
4
- import gradio as gr
5
- import logfire
6
- from custom_retriever import CustomRetriever
7
- from llama_index.agent.openai import OpenAIAgent
8
- from llama_index.core.llms import MessageRole
9
- from llama_index.core.memory import ChatSummaryMemoryBuffer
10
- from llama_index.core.tools import RetrieverTool, ToolMetadata
11
- from llama_index.core.vector_stores import (
12
- FilterCondition,
13
- FilterOperator,
14
- MetadataFilter,
15
- MetadataFilters,
16
- )
17
- from llama_index.llms.openai import OpenAI
18
- from prompts import system_message_openai_agent
19
- from setup import (
20
- AVAILABLE_SOURCES,
21
- AVAILABLE_SOURCES_UI,
22
- CONCURRENCY_COUNT,
23
- custom_retriever_all_sources,
24
- )
25
-
26
-
27
- def update_query_engine_tools(selected_sources) -> list[RetrieverTool]:
28
- tools = []
29
- source_mapping: dict[str, tuple[CustomRetriever, str, str]] = {
30
- "All Sources": (
31
- custom_retriever_all_sources,
32
- "all_sources_info",
33
- """Useful tool that contains general information about the field of AI.""",
34
- ),
35
- }
36
-
37
- for source in selected_sources:
38
- if source in source_mapping:
39
- custom_retriever, name, description = source_mapping[source]
40
- tools.append(
41
- RetrieverTool(
42
- retriever=custom_retriever,
43
- metadata=ToolMetadata(
44
- name=name,
45
- description=description,
46
- ),
47
- )
48
- )
49
-
50
- return tools
51
-
52
-
53
- def generate_completion(
54
- query,
55
- history,
56
- sources,
57
- model,
58
- memory,
59
- ):
60
- llm = OpenAI(temperature=1, model=model, max_tokens=None)
61
- client = llm._get_client()
62
- logfire.instrument_openai(client)
63
-
64
- with logfire.span(f"Running query: {query}"):
65
- logfire.info(f"User chosen sources: {sources}")
66
-
67
- memory_chat_list = memory.get()
68
-
69
- if len(memory_chat_list) != 0:
70
- user_index_memory = [
71
- i
72
- for i, msg in enumerate(memory_chat_list)
73
- if msg.role == MessageRole.USER
74
- ]
75
-
76
- user_index_history = [
77
- i for i, msg in enumerate(history) if msg["role"] == "user"
78
- ]
79
-
80
- if len(user_index_memory) > len(user_index_history):
81
- logfire.warn(f"There are more user messages in memory than in history")
82
- user_index_to_remove = user_index_memory[len(user_index_history)]
83
- memory_chat_list = memory_chat_list[:user_index_to_remove]
84
- memory.set(memory_chat_list)
85
-
86
- logfire.info(f"chat_history: {len(memory.get())} {memory.get()}")
87
- logfire.info(f"gradio_history: {len(history)} {history}")
88
-
89
- query_engine_tools: list[RetrieverTool] = update_query_engine_tools(
90
- ["All Sources"]
91
- )
92
-
93
- filter_list = []
94
- source_mapping = {
95
- "Transformers Docs": "transformers",
96
- "PEFT Docs": "peft",
97
- "TRL Docs": "trl",
98
- "LlamaIndex Docs": "llama_index",
99
- "LangChain Docs": "langchain",
100
- "OpenAI Cookbooks": "openai_cookbooks",
101
- "Towards AI Blog": "tai_blog",
102
- "8 Hour Primer": "8-hour_primer",
103
- "Advanced LLM Developer": "llm_developer",
104
- "Python Primer": "python_primer",
105
- }
106
-
107
- for source in sources:
108
- if source in source_mapping:
109
- filter_list.append(
110
- MetadataFilter(
111
- key="source",
112
- operator=FilterOperator.EQ,
113
- value=source_mapping[source],
114
- )
115
- )
116
-
117
- filters = MetadataFilters(
118
- filters=filter_list,
119
- condition=FilterCondition.OR,
120
- )
121
- logfire.info(f"Filters: {filters}")
122
- query_engine_tools[0].retriever._vector_retriever._filters = filters
123
-
124
- # pdb.set_trace()
125
-
126
- agent = OpenAIAgent.from_tools(
127
- llm=llm,
128
- memory=memory,
129
- tools=query_engine_tools,
130
- system_prompt=system_message_openai_agent,
131
- )
132
-
133
- completion = agent.stream_chat(query)
134
-
135
- answer_str = ""
136
- for token in completion.response_gen:
137
- answer_str += token
138
- yield answer_str
139
-
140
- for answer_str in add_sources(answer_str, completion):
141
- yield answer_str
142
-
143
-
144
- def add_sources(answer_str, completion):
145
- if completion is None:
146
- yield answer_str
147
-
148
- formatted_sources = format_sources(completion)
149
- if formatted_sources == "":
150
- yield answer_str
151
-
152
- if formatted_sources != "":
153
- answer_str += "\n\n" + formatted_sources
154
-
155
- yield answer_str
156
-
157
-
158
- def format_sources(completion) -> str:
159
- if len(completion.sources) == 0:
160
- return ""
161
-
162
- # logfire.info(f"Formatting sources: {completion.sources}")
163
-
164
- display_source_to_ui = {
165
- src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
166
- }
167
-
168
- documents_answer_template: str = (
169
- "πŸ“ Here are the sources I used to answer your question:\n{documents}"
170
- )
171
- document_template: str = "[πŸ”— {source}: {title}]({url}), relevance: {score:2.2f}"
172
- all_documents = []
173
- for source in completion.sources: # looping over list[ToolOutput]
174
- if isinstance(source.raw_output, Exception):
175
- logfire.error(f"Error in source output: {source.raw_output}")
176
- # pdb.set_trace()
177
- continue
178
-
179
- if not isinstance(source.raw_output, list):
180
- logfire.warn(f"Unexpected source output type: {type(source.raw_output)}")
181
- continue
182
- for src in source.raw_output: # looping over list[NodeWithScore]
183
- document = document_template.format(
184
- title=src.metadata["title"],
185
- score=src.score,
186
- source=display_source_to_ui.get(
187
- src.metadata["source"], src.metadata["source"]
188
- ),
189
- url=src.metadata["url"],
190
- )
191
- all_documents.append(document)
192
-
193
- if len(all_documents) == 0:
194
- return ""
195
- else:
196
- documents = "\n".join(all_documents)
197
- return documents_answer_template.format(documents=documents)
198
-
199
-
200
- def save_completion(completion, history):
201
- pass
202
-
203
-
204
- def vote(data: gr.LikeData):
205
- pass
206
-
207
-
208
- accordion = gr.Accordion(label="Customize Sources (Click to expand)", open=False)
209
- sources = gr.CheckboxGroup(
210
- AVAILABLE_SOURCES_UI,
211
- label="Sources",
212
- value=[
213
- "Advanced LLM Developer",
214
- "8 Hour Primer",
215
- "Python Primer",
216
- "Towards AI Blog",
217
- "Transformers Docs",
218
- "PEFT Docs",
219
- "TRL Docs",
220
- "LlamaIndex Docs",
221
- "LangChain Docs",
222
- "OpenAI Cookbooks",
223
- ],
224
- interactive=True,
225
- )
226
- model = gr.Dropdown(
227
- [
228
- "gpt-4o-mini",
229
- #Kenny added GPT2
230
- #"gpt2",
231
- ],
232
- label="Model",
233
- value="gpt-4o-mini",
234
- interactive=False,
235
- )
236
-
237
- with gr.Blocks(
238
- title="Towards AI πŸ€–",
239
- analytics_enabled=True,
240
- fill_height=True,
241
- ) as demo:
242
-
243
- memory = gr.State(
244
- lambda: ChatSummaryMemoryBuffer.from_defaults(
245
- token_limit=120000,
246
- )
247
- )
248
- chatbot = gr.Chatbot(
249
- type="messages",
250
- scale=20,
251
- placeholder="<strong>Towards AI πŸ€–: A Question-Answering Bot for anything AI-related</strong><br>",
252
- show_label=False,
253
- show_copy_button=True,
254
- )
255
- chatbot.like(vote, None, None)
256
- gr.ChatInterface(
257
- fn=generate_completion,
258
- type="messages",
259
- chatbot=chatbot,
260
- additional_inputs=[sources, model, memory],
261
- additional_inputs_accordion=accordion,
262
- # fill_height=True,
263
- # fill_width=True,
264
- analytics_enabled=True,
265
- )
266
-
267
- if __name__ == "__main__":
268
- demo.queue(default_concurrency_limit=CONCURRENCY_COUNT)
269
- demo.launch(debug=False, share=False)
 
1
+
2
+ import pdb
3
+
4
+ import gradio as gr
5
+ import logfire
6
+ from custom_retriever import CustomRetriever
7
+ from llama_index.agent.openai import OpenAIAgent
8
+ from llama_index.core.llms import MessageRole
9
+ from llama_index.core.memory import ChatSummaryMemoryBuffer
10
+ from llama_index.core.tools import RetrieverTool, ToolMetadata
11
+ from llama_index.core.vector_stores import (
12
+ FilterCondition,
13
+ FilterOperator,
14
+ MetadataFilter,
15
+ MetadataFilters,
16
+ )
17
+ from llama_index.llms.openai import OpenAI
18
+ from prompts import system_message_openai_agent
19
+ from setup import (
20
+ AVAILABLE_SOURCES,
21
+ AVAILABLE_SOURCES_UI,
22
+ CONCURRENCY_COUNT,
23
+ custom_retriever_all_sources,
24
+ )
25
+
26
+
27
+ def update_query_engine_tools(selected_sources) -> list[RetrieverTool]:
28
+ tools = []
29
+ source_mapping: dict[str, tuple[CustomRetriever, str, str]] = {
30
+ "All Sources": (
31
+ custom_retriever_all_sources,
32
+ "all_sources_info",
33
+ """Useful tool that contains general information about the field of AI.""",
34
+ ),
35
+ }
36
+
37
+ for source in selected_sources:
38
+ if source in source_mapping:
39
+ custom_retriever, name, description = source_mapping[source]
40
+ tools.append(
41
+ RetrieverTool(
42
+ retriever=custom_retriever,
43
+ metadata=ToolMetadata(
44
+ name=name,
45
+ description=description,
46
+ ),
47
+ )
48
+ )
49
+
50
+ return tools
51
+
52
+
53
+ def generate_completion(
54
+ query,
55
+ history,
56
+ sources,
57
+ model,
58
+ memory,
59
+ ):
60
+ llm = OpenAI(temperature=1, model=model, max_tokens=None)
61
+ client = llm._get_client()
62
+ logfire.instrument_openai(client)
63
+
64
+ with logfire.span(f"Running query: {query}"):
65
+ logfire.info(f"User chosen sources: {sources}")
66
+
67
+ memory_chat_list = memory.get()
68
+
69
+ if len(memory_chat_list) != 0:
70
+ user_index_memory = [
71
+ i
72
+ for i, msg in enumerate(memory_chat_list)
73
+ if msg.role == MessageRole.USER
74
+ ]
75
+
76
+ user_index_history = [
77
+ i for i, msg in enumerate(history) if msg["role"] == "user"
78
+ ]
79
+
80
+ if len(user_index_memory) > len(user_index_history):
81
+ logfire.warn(f"There are more user messages in memory than in history")
82
+ user_index_to_remove = user_index_memory[len(user_index_history)]
83
+ memory_chat_list = memory_chat_list[:user_index_to_remove]
84
+ memory.set(memory_chat_list)
85
+
86
+ logfire.info(f"chat_history: {len(memory.get())} {memory.get()}")
87
+ logfire.info(f"gradio_history: {len(history)} {history}")
88
+
89
+ query_engine_tools: list[RetrieverTool] = update_query_engine_tools(
90
+ ["All Sources"]
91
+ )
92
+
93
+ filter_list = []
94
+ source_mapping = {
95
+ "Transformers Docs": "transformers",
96
+ "PEFT Docs": "peft",
97
+ "TRL Docs": "trl",
98
+ "LlamaIndex Docs": "llama_index",
99
+ "LangChain Docs": "langchain",
100
+ "OpenAI Cookbooks": "openai_cookbooks",
101
+ "Towards AI Blog": "tai_blog",
102
+ "8 Hour Primer": "8-hour_primer",
103
+ "Advanced LLM Developer": "llm_developer",
104
+ "Python Primer": "python_primer",
105
+ }
106
+
107
+ for source in sources:
108
+ if source in source_mapping:
109
+ filter_list.append(
110
+ MetadataFilter(
111
+ key="source",
112
+ operator=FilterOperator.EQ,
113
+ value=source_mapping[source],
114
+ )
115
+ )
116
+
117
+ filters = MetadataFilters(
118
+ filters=filter_list,
119
+ condition=FilterCondition.OR,
120
+ )
121
+ logfire.info(f"Filters: {filters}")
122
+ query_engine_tools[0].retriever._vector_retriever._filters = filters
123
+
124
+ # pdb.set_trace()
125
+
126
+ agent = OpenAIAgent.from_tools(
127
+ llm=llm,
128
+ memory=memory,
129
+ tools=query_engine_tools,
130
+ system_prompt=system_message_openai_agent,
131
+ )
132
+
133
+ completion = agent.stream_chat(query)
134
+
135
+ answer_str = ""
136
+ for token in completion.response_gen:
137
+ answer_str += token
138
+ yield answer_str
139
+
140
+ for answer_str in add_sources(answer_str, completion):
141
+ yield answer_str
142
+
143
+
144
+ def add_sources(answer_str, completion):
145
+ if completion is None:
146
+ yield answer_str
147
+
148
+ formatted_sources = format_sources(completion)
149
+ if formatted_sources == "":
150
+ yield answer_str
151
+
152
+ if formatted_sources != "":
153
+ answer_str += "\n\n" + formatted_sources
154
+
155
+ yield answer_str
156
+
157
+
158
+ def format_sources(completion) -> str:
159
+ if len(completion.sources) == 0:
160
+ return ""
161
+
162
+ # logfire.info(f"Formatting sources: {completion.sources}")
163
+
164
+ display_source_to_ui = {
165
+ src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
166
+ }
167
+
168
+ documents_answer_template: str = (
169
+ "πŸ“ Here are the sources I used to answer your question:\n{documents}"
170
+ )
171
+ document_template: str = "[πŸ”— {source}: {title}]({url}), relevance: {score:2.2f}"
172
+ all_documents = []
173
+ for source in completion.sources: # looping over list[ToolOutput]
174
+ if isinstance(source.raw_output, Exception):
175
+ logfire.error(f"Error in source output: {source.raw_output}")
176
+ # pdb.set_trace()
177
+ continue
178
+
179
+ if not isinstance(source.raw_output, list):
180
+ logfire.warn(f"Unexpected source output type: {type(source.raw_output)}")
181
+ continue
182
+ for src in source.raw_output: # looping over list[NodeWithScore]
183
+ document = document_template.format(
184
+ title=src.metadata["title"],
185
+ score=src.score,
186
+ source=display_source_to_ui.get(
187
+ src.metadata["source"], src.metadata["source"]
188
+ ),
189
+ url=src.metadata["url"],
190
+ )
191
+ all_documents.append(document)
192
+
193
+ if len(all_documents) == 0:
194
+ return ""
195
+ else:
196
+ documents = "\n".join(all_documents)
197
+ return documents_answer_template.format(documents=documents)
198
+
199
+
200
+ def save_completion(completion, history):
201
+ pass
202
+
203
+
204
+ def vote(data: gr.LikeData):
205
+ pass
206
+
207
+
208
+ accordion = gr.Accordion(label="Customize Sources (Click to expand)", open=False)
209
+ sources = gr.CheckboxGroup(
210
+ AVAILABLE_SOURCES_UI,
211
+ label="Sources",
212
+ value=[
213
+ "Advanced LLM Developer",
214
+ "8 Hour Primer",
215
+ "Python Primer",
216
+ "Towards AI Blog",
217
+ "Transformers Docs",
218
+ "PEFT Docs",
219
+ "TRL Docs",
220
+ "LlamaIndex Docs",
221
+ "LangChain Docs",
222
+ "OpenAI Cookbooks",
223
+ ],
224
+ interactive=True,
225
+ )
226
+ model = gr.Dropdown(
227
+ [
228
+ "gpt-4o-mini",
229
+ #Kenny added GPT2
230
+ #"gpt2",
231
+ ],
232
+ label="Model",
233
+ value="gpt-4o-mini",
234
+ interactive=False,
235
+ )
236
+
237
+ with gr.Blocks(
238
+ title="Towards AI πŸ€–",
239
+ analytics_enabled=True,
240
+ fill_height=True,
241
+ ) as demo:
242
+
243
+ memory = gr.State(
244
+ lambda: ChatSummaryMemoryBuffer.from_defaults(
245
+ token_limit=120000,
246
+ )
247
+ )
248
+ chatbot = gr.Chatbot(
249
+ type="messages",
250
+ scale=20,
251
+ placeholder="<strong>Towards AI πŸ€–: A Question-Answering Bot for anything AI-related</strong><br>",
252
+ show_label=False,
253
+ show_copy_button=True,
254
+ )
255
+ chatbot.like(vote, None, None)
256
+ gr.ChatInterface(
257
+ fn=generate_completion,
258
+ type="messages",
259
+ chatbot=chatbot,
260
+ additional_inputs=[sources, model, memory],
261
+ additional_inputs_accordion=accordion,
262
+ # fill_height=True,
263
+ # fill_width=True,
264
+ analytics_enabled=True,
265
+ )
266
+
267
+ if __name__ == "__main__":
268
+ demo.queue(default_concurrency_limit=CONCURRENCY_COUNT)
269
+ demo.launch(debug=False, share=False)