Joshua Sundance Bailey
commited on
Commit
·
679726e
1
Parent(s):
64e3f44
parameterize research assistant llms
Browse files
langchain-streamlit-demo/app.py
CHANGED
@@ -26,7 +26,7 @@ from llm_resources import (
|
|
26 |
get_runnable,
|
27 |
get_texts_and_multiretriever,
|
28 |
)
|
29 |
-
from research_assistant.chain import
|
30 |
|
31 |
__version__ = "2.0.1"
|
32 |
|
@@ -367,7 +367,7 @@ with sidebar:
|
|
367 |
|
368 |
|
369 |
# --- LLM Instantiation ---
|
370 |
-
|
371 |
provider=st.session_state.provider,
|
372 |
model=model,
|
373 |
provider_api_key=provider_api_key,
|
@@ -382,6 +382,8 @@ st.session_state.llm = get_llm(
|
|
382 |
"AZURE_OPENAI_MODEL_VERSION": st.session_state.AZURE_OPENAI_MODEL_VERSION,
|
383 |
},
|
384 |
)
|
|
|
|
|
385 |
|
386 |
# --- Chat History ---
|
387 |
for msg in STMEMORY.messages:
|
@@ -448,12 +450,16 @@ if st.session_state.llm:
|
|
448 |
WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
|
449 |
]
|
450 |
if st.session_state.provider in ("Azure OpenAI", "OpenAI"):
|
|
|
|
|
|
|
|
|
451 |
st_callback = StreamlitCallbackHandler(st.container())
|
452 |
callbacks.append(st_callback)
|
453 |
research_assistant_tool = Tool.from_function(
|
454 |
func=lambda s: research_assistant_chain.invoke(
|
455 |
{"question": s},
|
456 |
-
config=get_config(callbacks),
|
457 |
),
|
458 |
name="web-research-assistant",
|
459 |
description="this assistant returns a comprehensive report based on web research. for quick facts, use duckduckgo instead.",
|
@@ -473,7 +479,7 @@ if st.session_state.llm:
|
|
473 |
doc_chain_tool = Tool.from_function(
|
474 |
func=lambda s: st.session_state.doc_chain.invoke(
|
475 |
s,
|
476 |
-
config=get_config(callbacks),
|
477 |
),
|
478 |
name="user-document-chat",
|
479 |
description="this assistant returns a response based on the user's custom context. if the user's meaning is unclear, perhaps the answer is here. generally speaking, try this tool before conducting web research.",
|
|
|
26 |
get_runnable,
|
27 |
get_texts_and_multiretriever,
|
28 |
)
|
29 |
+
from research_assistant.chain import get_chain as get_research_assistant_chain
|
30 |
|
31 |
__version__ = "2.0.1"
|
32 |
|
|
|
367 |
|
368 |
|
369 |
# --- LLM Instantiation ---
|
370 |
+
get_llm_args = dict(
|
371 |
provider=st.session_state.provider,
|
372 |
model=model,
|
373 |
provider_api_key=provider_api_key,
|
|
|
382 |
"AZURE_OPENAI_MODEL_VERSION": st.session_state.AZURE_OPENAI_MODEL_VERSION,
|
383 |
},
|
384 |
)
|
385 |
+
get_llm_args_temp_zero = get_llm_args | {"temperature": 0.0}
|
386 |
+
st.session_state.llm = get_llm(**get_llm_args)
|
387 |
|
388 |
# --- Chat History ---
|
389 |
for msg in STMEMORY.messages:
|
|
|
450 |
WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
|
451 |
]
|
452 |
if st.session_state.provider in ("Azure OpenAI", "OpenAI"):
|
453 |
+
research_assistant_chain = get_research_assistant_chain(
|
454 |
+
search_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
|
455 |
+
writer_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
|
456 |
+
)
|
457 |
st_callback = StreamlitCallbackHandler(st.container())
|
458 |
callbacks.append(st_callback)
|
459 |
research_assistant_tool = Tool.from_function(
|
460 |
func=lambda s: research_assistant_chain.invoke(
|
461 |
{"question": s},
|
462 |
+
# config=get_config(callbacks),
|
463 |
),
|
464 |
name="web-research-assistant",
|
465 |
description="this assistant returns a comprehensive report based on web research. for quick facts, use duckduckgo instead.",
|
|
|
479 |
doc_chain_tool = Tool.from_function(
|
480 |
func=lambda s: st.session_state.doc_chain.invoke(
|
481 |
s,
|
482 |
+
# config=get_config(callbacks),
|
483 |
),
|
484 |
name="user-document-chat",
|
485 |
description="this assistant returns a response based on the user's custom context. if the user's meaning is unclear, perhaps the answer is here. generally speaking, try this tool before conducting web research.",
|
langchain-streamlit-demo/research_assistant/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
from research_assistant.chain import
|
2 |
|
3 |
-
__all__ = ["
|
|
|
1 |
+
from research_assistant.chain import get_chain
|
2 |
|
3 |
+
__all__ = ["get_chain"]
|
langchain-streamlit-demo/research_assistant/chain.py
CHANGED
@@ -1,16 +1,18 @@
|
|
1 |
from langchain_core.pydantic_v1 import BaseModel
|
2 |
from langchain_core.runnables import RunnablePassthrough
|
3 |
|
4 |
-
from research_assistant.search.web import
|
5 |
-
from research_assistant.writer import
|
|
|
|
|
6 |
|
7 |
-
chain_notypes = (
|
8 |
-
RunnablePassthrough().assign(research_summary=search_chain) | writer_chain
|
9 |
-
)
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
class InputType(BaseModel):
|
13 |
-
|
14 |
|
15 |
-
|
16 |
-
chain = chain_notypes.with_types(input_type=InputType)
|
|
|
1 |
from langchain_core.pydantic_v1 import BaseModel
|
2 |
from langchain_core.runnables import RunnablePassthrough
|
3 |
|
4 |
+
from research_assistant.search.web import get_search_chain
|
5 |
+
from research_assistant.writer import get_writer_chain
|
6 |
+
from langchain.llms.base import BaseLLM
|
7 |
+
from langchain.schema.runnable import Runnable
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
+
def get_chain(search_llm: BaseLLM, writer_llm: BaseLLM) -> Runnable:
|
11 |
+
chain_notypes = RunnablePassthrough().assign(
|
12 |
+
research_summary=get_search_chain(search_llm),
|
13 |
+
) | get_writer_chain(writer_llm)
|
14 |
|
15 |
+
class InputType(BaseModel):
|
16 |
+
question: str
|
17 |
|
18 |
+
return chain_notypes.with_types(input_type=InputType)
|
|
langchain-streamlit-demo/research_assistant/search/web.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Any
|
|
3 |
|
4 |
import requests
|
5 |
from bs4 import BeautifulSoup
|
6 |
-
from langchain.
|
7 |
from langchain.prompts import ChatPromptTemplate
|
8 |
from langchain.retrievers.tavily_search_api import TavilySearchAPIRetriever
|
9 |
from langchain.utilities import DuckDuckGoSearchAPIWrapper
|
@@ -130,25 +130,6 @@ Using the above text, answer in short the following question:
|
|
130 |
if the question cannot be answered using the text, imply summarize the text. Include all factual information, numbers, stats etc if available.""" # noqa: E501
|
131 |
SUMMARY_PROMPT = ChatPromptTemplate.from_template(SUMMARY_TEMPLATE)
|
132 |
|
133 |
-
scrape_and_summarize: Runnable[Any, Any] = (
|
134 |
-
RunnableParallel(
|
135 |
-
{
|
136 |
-
"question": lambda x: x["question"],
|
137 |
-
"text": lambda x: scrape_text(x["url"])[:10000],
|
138 |
-
"url": lambda x: x["url"],
|
139 |
-
},
|
140 |
-
)
|
141 |
-
| RunnableParallel(
|
142 |
-
{
|
143 |
-
"summary": SUMMARY_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),
|
144 |
-
"url": lambda x: x["url"],
|
145 |
-
},
|
146 |
-
)
|
147 |
-
| RunnableLambda(lambda x: f"Source Url: {x['url']}\nSummary: {x['summary']}")
|
148 |
-
)
|
149 |
-
|
150 |
-
multi_search = get_links | scrape_and_summarize.map() | (lambda x: "\n".join(x))
|
151 |
-
|
152 |
|
153 |
def load_json(s):
|
154 |
try:
|
@@ -157,24 +138,41 @@ def load_json(s):
|
|
157 |
return {}
|
158 |
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
| (
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
)
|
171 |
-
| search_query
|
172 |
-
)
|
173 |
|
|
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import requests
|
5 |
from bs4 import BeautifulSoup
|
6 |
+
from langchain.llms.base import BaseLLM
|
7 |
from langchain.prompts import ChatPromptTemplate
|
8 |
from langchain.retrievers.tavily_search_api import TavilySearchAPIRetriever
|
9 |
from langchain.utilities import DuckDuckGoSearchAPIWrapper
|
|
|
130 |
if the question cannot be answered using the text, imply summarize the text. Include all factual information, numbers, stats etc if available.""" # noqa: E501
|
131 |
SUMMARY_PROMPT = ChatPromptTemplate.from_template(SUMMARY_TEMPLATE)
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
def load_json(s):
|
135 |
try:
|
|
|
138 |
return {}
|
139 |
|
140 |
|
141 |
+
def get_search_chain(model: BaseLLM) -> Runnable:
|
142 |
+
scrape_and_summarize: Runnable[Any, Any] = (
|
143 |
+
RunnableParallel(
|
144 |
+
{
|
145 |
+
"question": lambda x: x["question"],
|
146 |
+
"text": lambda x: scrape_text(x["url"])[:10000],
|
147 |
+
"url": lambda x: x["url"],
|
148 |
+
},
|
149 |
+
)
|
150 |
+
| RunnableParallel(
|
151 |
+
{
|
152 |
+
"summary": SUMMARY_PROMPT | model | StrOutputParser(),
|
153 |
+
"url": lambda x: x["url"],
|
154 |
+
},
|
155 |
+
)
|
156 |
+
| RunnableLambda(lambda x: f"Source Url: {x['url']}\nSummary: {x['summary']}")
|
157 |
)
|
|
|
|
|
158 |
|
159 |
+
multi_search = get_links | scrape_and_summarize.map() | (lambda x: "\n".join(x))
|
160 |
|
161 |
+
search_query = SEARCH_PROMPT | model | StrOutputParser() | load_json
|
162 |
+
choose_agent = CHOOSE_AGENT_PROMPT | model | StrOutputParser() | load_json
|
163 |
+
|
164 |
+
get_search_queries = (
|
165 |
+
RunnablePassthrough().assign(
|
166 |
+
agent_prompt=RunnableParallel({"task": lambda x: x})
|
167 |
+
| choose_agent
|
168 |
+
| (lambda x: x.get("agent_role_prompt")),
|
169 |
+
)
|
170 |
+
| search_query
|
171 |
+
)
|
172 |
+
|
173 |
+
return (
|
174 |
+
get_search_queries
|
175 |
+
| (lambda x: [{"question": q} for q in x])
|
176 |
+
| multi_search.map()
|
177 |
+
| (lambda x: "\n\n".join(x))
|
178 |
+
)
|
langchain-streamlit-demo/research_assistant/writer.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from langchain.chat_models import ChatOpenAI
|
2 |
from langchain.prompts import ChatPromptTemplate
|
3 |
from langchain_core.output_parsers import StrOutputParser
|
4 |
from langchain_core.runnables import ConfigurableField
|
|
|
|
|
5 |
|
6 |
WRITER_SYSTEM_PROMPT = "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text." # noqa: E501
|
7 |
|
@@ -50,7 +51,6 @@ Use appropriate Markdown syntax to format the outline and ensure readability.
|
|
50 |
|
51 |
Please do your best, this is very important to my career.""" # noqa: E501
|
52 |
|
53 |
-
model = ChatOpenAI(temperature=0)
|
54 |
prompt = ChatPromptTemplate.from_messages(
|
55 |
[
|
56 |
("system", WRITER_SYSTEM_PROMPT),
|
@@ -72,4 +72,7 @@ prompt = ChatPromptTemplate.from_messages(
|
|
72 |
],
|
73 |
),
|
74 |
)
|
75 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from langchain.prompts import ChatPromptTemplate
|
2 |
from langchain_core.output_parsers import StrOutputParser
|
3 |
from langchain_core.runnables import ConfigurableField
|
4 |
+
from langchain.llms.base import BaseLLM
|
5 |
+
from langchain.schema.runnable import Runnable
|
6 |
|
7 |
WRITER_SYSTEM_PROMPT = "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text." # noqa: E501
|
8 |
|
|
|
51 |
|
52 |
Please do your best, this is very important to my career.""" # noqa: E501
|
53 |
|
|
|
54 |
prompt = ChatPromptTemplate.from_messages(
|
55 |
[
|
56 |
("system", WRITER_SYSTEM_PROMPT),
|
|
|
72 |
],
|
73 |
),
|
74 |
)
|
75 |
+
|
76 |
+
|
77 |
+
def get_writer_chain(model: BaseLLM) -> Runnable:
|
78 |
+
return prompt | model | StrOutputParser()
|