Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Joshua Sundance Bailey
		
	commited on
		
		
					Commit 
							
							·
						
						5825ff9
	
1
								Parent(s):
							
							547d578
								
callbacks (still not working 100%
Browse files- .pre-commit-config.yaml +8 -12
- langchain-streamlit-demo/app.py +65 -34
- langchain-streamlit-demo/llm_resources.py +12 -13
    	
        .pre-commit-config.yaml
    CHANGED
    
    | @@ -40,24 +40,20 @@ repos: | |
| 40 | 
             
                -   id: trailing-whitespace
         | 
| 41 | 
             
                -   id: mixed-line-ending
         | 
| 42 | 
             
                -   id: requirements-txt-fixer
         | 
| 43 | 
            -
            -   repo: https://github.com/ | 
| 44 | 
            -
                rev:  | 
| 45 | 
             
                hooks:
         | 
| 46 | 
            -
                -   id:  | 
| 47 | 
            -
                    additional_dependencies:
         | 
| 48 | 
            -
                        - types-requests
         | 
| 49 | 
             
            -   repo: https://github.com/asottile/add-trailing-comma
         | 
| 50 | 
             
                rev: v3.1.0
         | 
| 51 | 
             
                hooks:
         | 
| 52 | 
             
                -   id: add-trailing-comma
         | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
            #    hooks:
         | 
| 56 | 
            -
            #    -   id: rm-unneeded-f-str
         | 
| 57 | 
            -
            -   repo: https://github.com/psf/black
         | 
| 58 | 
            -
                rev: 23.9.1
         | 
| 59 | 
             
                hooks:
         | 
| 60 | 
            -
                -   id:  | 
|  | |
|  | |
| 61 | 
             
            -   repo: https://github.com/PyCQA/bandit
         | 
| 62 | 
             
                rev: 1.7.5
         | 
| 63 | 
             
                hooks:
         | 
|  | |
| 40 | 
             
                -   id: trailing-whitespace
         | 
| 41 | 
             
                -   id: mixed-line-ending
         | 
| 42 | 
             
                -   id: requirements-txt-fixer
         | 
| 43 | 
            +
            -   repo: https://github.com/psf/black
         | 
| 44 | 
            +
                rev: 23.9.1
         | 
| 45 | 
             
                hooks:
         | 
| 46 | 
            +
                -   id: black
         | 
|  | |
|  | |
| 47 | 
             
            -   repo: https://github.com/asottile/add-trailing-comma
         | 
| 48 | 
             
                rev: v3.1.0
         | 
| 49 | 
             
                hooks:
         | 
| 50 | 
             
                -   id: add-trailing-comma
         | 
| 51 | 
            +
            -   repo: https://github.com/pre-commit/mirrors-mypy
         | 
| 52 | 
            +
                rev: v1.5.1
         | 
|  | |
|  | |
|  | |
|  | |
| 53 | 
             
                hooks:
         | 
| 54 | 
            +
                -   id: mypy
         | 
| 55 | 
            +
                    additional_dependencies:
         | 
| 56 | 
            +
                        - types-requests
         | 
| 57 | 
             
            -   repo: https://github.com/PyCQA/bandit
         | 
| 58 | 
             
                rev: 1.7.5
         | 
| 59 | 
             
                hooks:
         | 
    	
        langchain-streamlit-demo/app.py
    CHANGED
    
    | @@ -6,6 +6,7 @@ import langsmith.utils | |
| 6 | 
             
            import openai
         | 
| 7 | 
             
            import streamlit as st
         | 
| 8 | 
             
            from langchain.callbacks import StreamlitCallbackHandler
         | 
|  | |
| 9 | 
             
            from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
         | 
| 10 | 
             
            from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
         | 
| 11 | 
             
            from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
         | 
| @@ -20,6 +21,7 @@ from defaults import default_values | |
| 20 | 
             
            from llm_resources import (
         | 
| 21 | 
             
                get_agent,
         | 
| 22 | 
             
                get_llm,
         | 
|  | |
| 23 | 
             
                get_texts_and_multiretriever,
         | 
| 24 | 
             
            )
         | 
| 25 | 
             
            from research_assistant.chain import chain as research_assistant_chain
         | 
| @@ -379,15 +381,6 @@ st.session_state.llm = get_llm( | |
| 379 | 
             
                },
         | 
| 380 | 
             
            )
         | 
| 381 |  | 
| 382 | 
            -
            research_assistant_tool = Tool.from_function(
         | 
| 383 | 
            -
                func=lambda s: research_assistant_chain.invoke({"question": s}),
         | 
| 384 | 
            -
                name="web-research-assistant",
         | 
| 385 | 
            -
                description="this assistant returns a report based on web research",
         | 
| 386 | 
            -
            )
         | 
| 387 | 
            -
             | 
| 388 | 
            -
            TOOLS = [research_assistant_tool]
         | 
| 389 | 
            -
            st.session_state.agent = get_agent(TOOLS, STMEMORY, st.session_state.llm)
         | 
| 390 | 
            -
             | 
| 391 | 
             
            # --- Chat History ---
         | 
| 392 | 
             
            for msg in STMEMORY.messages:
         | 
| 393 | 
             
                st.chat_message(
         | 
| @@ -424,12 +417,16 @@ if st.session_state.llm: | |
| 424 | 
             
                        if st.session_state.ls_tracer:
         | 
| 425 | 
             
                            callbacks.append(st.session_state.ls_tracer)
         | 
| 426 |  | 
| 427 | 
            -
                         | 
| 428 | 
            -
                             | 
| 429 | 
            -
             | 
| 430 | 
            -
             | 
| 431 | 
            -
             | 
| 432 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 433 |  | 
| 434 | 
             
                        use_document_chat = all(
         | 
| 435 | 
             
                            [
         | 
| @@ -439,32 +436,66 @@ if st.session_state.llm: | |
| 439 | 
             
                        )
         | 
| 440 |  | 
| 441 | 
             
                        full_response: Union[str, None] = None
         | 
| 442 | 
            -
             | 
| 443 | 
             
                        # stream_handler = StreamHandler(message_placeholder)
         | 
| 444 | 
             
                        # callbacks.append(stream_handler)
         | 
|  | |
| 445 |  | 
| 446 | 
            -
                         | 
| 447 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 448 |  | 
| 449 | 
            -
             | 
| 450 | 
            -
             | 
| 451 | 
            -
             | 
| 452 | 
            -
             | 
| 453 | 
            -
             | 
| 454 | 
            -
             | 
| 455 | 
            -
             | 
| 456 | 
            -
             | 
| 457 | 
            -
             | 
| 458 | 
            -
             | 
| 459 | 
            -
             | 
| 460 | 
            -
             | 
| 461 | 
            -
             | 
| 462 | 
            -
             | 
| 463 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 464 |  | 
| 465 | 
             
                        # --- LLM call ---
         | 
| 466 | 
             
                        try:
         | 
| 467 | 
            -
                            full_response = st.session_state. | 
|  | |
|  | |
|  | |
| 468 |  | 
| 469 | 
             
                        except (openai.AuthenticationError, anthropic.AuthenticationError):
         | 
| 470 | 
             
                            st.error(
         | 
|  | |
| 6 | 
             
            import openai
         | 
| 7 | 
             
            import streamlit as st
         | 
| 8 | 
             
            from langchain.callbacks import StreamlitCallbackHandler
         | 
| 9 | 
            +
            from langchain.callbacks.base import BaseCallbackHandler
         | 
| 10 | 
             
            from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
         | 
| 11 | 
             
            from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
         | 
| 12 | 
             
            from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
         | 
|  | |
| 21 | 
             
            from llm_resources import (
         | 
| 22 | 
             
                get_agent,
         | 
| 23 | 
             
                get_llm,
         | 
| 24 | 
            +
                get_runnable,
         | 
| 25 | 
             
                get_texts_and_multiretriever,
         | 
| 26 | 
             
            )
         | 
| 27 | 
             
            from research_assistant.chain import chain as research_assistant_chain
         | 
|  | |
| 381 | 
             
                },
         | 
| 382 | 
             
            )
         | 
| 383 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 384 | 
             
            # --- Chat History ---
         | 
| 385 | 
             
            for msg in STMEMORY.messages:
         | 
| 386 | 
             
                st.chat_message(
         | 
|  | |
| 417 | 
             
                        if st.session_state.ls_tracer:
         | 
| 418 | 
             
                            callbacks.append(st.session_state.ls_tracer)
         | 
| 419 |  | 
| 420 | 
            +
                        def get_config(callbacks: list[BaseCallbackHandler]) -> dict[str, Any]:
         | 
| 421 | 
            +
                            config: Dict[str, Any] = dict(
         | 
| 422 | 
            +
                                callbacks=callbacks,
         | 
| 423 | 
            +
                                tags=["Streamlit Chat"],
         | 
| 424 | 
            +
                                verbose=True,
         | 
| 425 | 
            +
                                return_intermediate_steps=True,
         | 
| 426 | 
            +
                            )
         | 
| 427 | 
            +
                            if st.session_state.provider == "Anthropic":
         | 
| 428 | 
            +
                                config["max_concurrency"] = 5
         | 
| 429 | 
            +
                            return config
         | 
| 430 |  | 
| 431 | 
             
                        use_document_chat = all(
         | 
| 432 | 
             
                            [
         | 
|  | |
| 436 | 
             
                        )
         | 
| 437 |  | 
| 438 | 
             
                        full_response: Union[str, None] = None
         | 
|  | |
| 439 | 
             
                        # stream_handler = StreamHandler(message_placeholder)
         | 
| 440 | 
             
                        # callbacks.append(stream_handler)
         | 
| 441 | 
            +
                        message_placeholder = st.empty()
         | 
| 442 |  | 
| 443 | 
            +
                        if st.session_state.provider in ("Azure OpenAI", "OpenAI"):
         | 
| 444 | 
            +
                            st_callback = StreamlitCallbackHandler(st.container())
         | 
| 445 | 
            +
                            callbacks.append(st_callback)
         | 
| 446 | 
            +
                            research_assistant_tool = Tool.from_function(
         | 
| 447 | 
            +
                                func=lambda s: research_assistant_chain.invoke(
         | 
| 448 | 
            +
                                    {"question": s},
         | 
| 449 | 
            +
                                    config=get_config(callbacks),
         | 
| 450 | 
            +
                                ),
         | 
| 451 | 
            +
                                name="web-research-assistant",
         | 
| 452 | 
            +
                                description="this assistant returns a report based on web research",
         | 
| 453 | 
            +
                            )
         | 
| 454 |  | 
| 455 | 
            +
                            TOOLS = [research_assistant_tool]
         | 
| 456 | 
            +
                            if use_document_chat:
         | 
| 457 | 
            +
                                st.session_state.doc_chain = get_runnable(
         | 
| 458 | 
            +
                                    use_document_chat,
         | 
| 459 | 
            +
                                    document_chat_chain_type,
         | 
| 460 | 
            +
                                    st.session_state.llm,
         | 
| 461 | 
            +
                                    st.session_state.retriever,
         | 
| 462 | 
            +
                                    MEMORY,
         | 
| 463 | 
            +
                                    chat_prompt,
         | 
| 464 | 
            +
                                    prompt,
         | 
| 465 | 
            +
                                )
         | 
| 466 | 
            +
                                doc_chain_tool = Tool.from_function(
         | 
| 467 | 
            +
                                    func=lambda s: st.session_state.doc_chain.invoke(
         | 
| 468 | 
            +
                                        s,
         | 
| 469 | 
            +
                                        config=get_config(callbacks),
         | 
| 470 | 
            +
                                    ),
         | 
| 471 | 
            +
                                    name="user-document-chat",
         | 
| 472 | 
            +
                                    description="this assistant returns a response based on the user's custom context. if the user's meaning is unclear, perhaps the answer is here. generally speaking, try this tool before conducting web research.",
         | 
| 473 | 
            +
                                )
         | 
| 474 | 
            +
                                TOOLS = [doc_chain_tool, research_assistant_tool]
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                            st.session_state.chain = get_agent(
         | 
| 477 | 
            +
                                TOOLS,
         | 
| 478 | 
            +
                                STMEMORY,
         | 
| 479 | 
            +
                                st.session_state.llm,
         | 
| 480 | 
            +
                                callbacks,
         | 
| 481 | 
            +
                            )
         | 
| 482 | 
            +
                        else:
         | 
| 483 | 
            +
                            st.session_state.chain = get_runnable(
         | 
| 484 | 
            +
                                use_document_chat,
         | 
| 485 | 
            +
                                document_chat_chain_type,
         | 
| 486 | 
            +
                                st.session_state.llm,
         | 
| 487 | 
            +
                                st.session_state.retriever,
         | 
| 488 | 
            +
                                MEMORY,
         | 
| 489 | 
            +
                                chat_prompt,
         | 
| 490 | 
            +
                                prompt,
         | 
| 491 | 
            +
                            )
         | 
| 492 |  | 
| 493 | 
             
                        # --- LLM call ---
         | 
| 494 | 
             
                        try:
         | 
| 495 | 
            +
                            full_response = st.session_state.chain.invoke(
         | 
| 496 | 
            +
                                prompt,
         | 
| 497 | 
            +
                                config=get_config(callbacks),
         | 
| 498 | 
            +
                            )
         | 
| 499 |  | 
| 500 | 
             
                        except (openai.AuthenticationError, anthropic.AuthenticationError):
         | 
| 501 | 
             
                            st.error(
         | 
    	
        langchain-streamlit-demo/llm_resources.py
    CHANGED
    
    | @@ -3,13 +3,13 @@ from tempfile import NamedTemporaryFile | |
| 3 | 
             
            from typing import Tuple, List, Optional, Dict
         | 
| 4 |  | 
| 5 | 
             
            from langchain.agents import AgentExecutor
         | 
| 6 | 
            -
            from langchain.agents.agent_toolkits import create_retriever_tool
         | 
| 7 | 
             
            from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
         | 
| 8 | 
             
                AgentTokenBufferMemory,
         | 
| 9 | 
             
            )
         | 
| 10 | 
             
            from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
         | 
| 11 | 
             
            from langchain.callbacks.base import BaseCallbackHandler
         | 
| 12 | 
             
            from langchain.chains import LLMChain
         | 
|  | |
| 13 | 
             
            from langchain.chat_models import (
         | 
| 14 | 
             
                AzureChatOpenAI,
         | 
| 15 | 
             
                ChatOpenAI,
         | 
| @@ -18,29 +18,30 @@ from langchain.chat_models import ( | |
| 18 | 
             
            )
         | 
| 19 | 
             
            from langchain.document_loaders import PyPDFLoader
         | 
| 20 | 
             
            from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
         | 
|  | |
| 21 | 
             
            from langchain.prompts import MessagesPlaceholder
         | 
| 22 | 
             
            from langchain.retrievers import EnsembleRetriever
         | 
| 23 | 
             
            from langchain.retrievers.multi_query import MultiQueryRetriever
         | 
| 24 | 
             
            from langchain.retrievers.multi_vector import MultiVectorRetriever
         | 
| 25 | 
             
            from langchain.schema import Document, BaseRetriever
         | 
|  | |
| 26 | 
             
            from langchain.schema.runnable import RunnablePassthrough
         | 
| 27 | 
             
            from langchain.storage import InMemoryStore
         | 
| 28 | 
             
            from langchain.text_splitter import RecursiveCharacterTextSplitter
         | 
|  | |
| 29 | 
             
            from langchain.vectorstores import FAISS
         | 
| 30 | 
             
            from langchain_core.messages import SystemMessage
         | 
| 31 |  | 
| 32 | 
             
            from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
         | 
| 33 | 
             
            from qagen import get_rag_qa_gen_chain
         | 
| 34 | 
             
            from summarize import get_rag_summarization_chain
         | 
| 35 | 
            -
            from langchain.tools.base import BaseTool
         | 
| 36 | 
            -
            from langchain.schema.chat_history import BaseChatMessageHistory
         | 
| 37 | 
            -
            from langchain.llms.base import BaseLLM
         | 
| 38 |  | 
| 39 |  | 
| 40 | 
             
            def get_agent(
         | 
| 41 | 
             
                tools: list[BaseTool],
         | 
| 42 | 
             
                chat_history: BaseChatMessageHistory,
         | 
| 43 | 
             
                llm: BaseLLM,
         | 
|  | |
| 44 | 
             
            ):
         | 
| 45 | 
             
                memory_key = "agent_history"
         | 
| 46 | 
             
                system_message = SystemMessage(
         | 
| @@ -68,6 +69,7 @@ def get_agent( | |
| 68 | 
             
                    memory=agent_memory,
         | 
| 69 | 
             
                    verbose=True,
         | 
| 70 | 
             
                    return_intermediate_steps=True,
         | 
|  | |
| 71 | 
             
                )
         | 
| 72 | 
             
                return (
         | 
| 73 | 
             
                    {"input": RunnablePassthrough()}
         | 
| @@ -84,7 +86,6 @@ def get_runnable( | |
| 84 | 
             
                memory,
         | 
| 85 | 
             
                chat_prompt,
         | 
| 86 | 
             
                summarization_prompt,
         | 
| 87 | 
            -
                chat_history,
         | 
| 88 | 
             
            ):
         | 
| 89 | 
             
                if not use_document_chat:
         | 
| 90 | 
             
                    return LLMChain(
         | 
| @@ -105,14 +106,12 @@ def get_runnable( | |
| 105 | 
             
                        llm,
         | 
| 106 | 
             
                    )
         | 
| 107 | 
             
                else:
         | 
| 108 | 
            -
                     | 
| 109 | 
            -
                         | 
| 110 | 
            -
                         | 
| 111 | 
            -
                         | 
| 112 | 
            -
             | 
| 113 | 
            -
                     | 
| 114 | 
            -
             | 
| 115 | 
            -
                    return get_agent(tools, chat_history, llm)
         | 
| 116 |  | 
| 117 |  | 
| 118 | 
             
            def get_llm(
         | 
|  | |
| 3 | 
             
            from typing import Tuple, List, Optional, Dict
         | 
| 4 |  | 
| 5 | 
             
            from langchain.agents import AgentExecutor
         | 
|  | |
| 6 | 
             
            from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
         | 
| 7 | 
             
                AgentTokenBufferMemory,
         | 
| 8 | 
             
            )
         | 
| 9 | 
             
            from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
         | 
| 10 | 
             
            from langchain.callbacks.base import BaseCallbackHandler
         | 
| 11 | 
             
            from langchain.chains import LLMChain
         | 
| 12 | 
            +
            from langchain.chains import RetrievalQA
         | 
| 13 | 
             
            from langchain.chat_models import (
         | 
| 14 | 
             
                AzureChatOpenAI,
         | 
| 15 | 
             
                ChatOpenAI,
         | 
|  | |
| 18 | 
             
            )
         | 
| 19 | 
             
            from langchain.document_loaders import PyPDFLoader
         | 
| 20 | 
             
            from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
         | 
| 21 | 
            +
            from langchain.llms.base import BaseLLM
         | 
| 22 | 
             
            from langchain.prompts import MessagesPlaceholder
         | 
| 23 | 
             
            from langchain.retrievers import EnsembleRetriever
         | 
| 24 | 
             
            from langchain.retrievers.multi_query import MultiQueryRetriever
         | 
| 25 | 
             
            from langchain.retrievers.multi_vector import MultiVectorRetriever
         | 
| 26 | 
             
            from langchain.schema import Document, BaseRetriever
         | 
| 27 | 
            +
            from langchain.schema.chat_history import BaseChatMessageHistory
         | 
| 28 | 
             
            from langchain.schema.runnable import RunnablePassthrough
         | 
| 29 | 
             
            from langchain.storage import InMemoryStore
         | 
| 30 | 
             
            from langchain.text_splitter import RecursiveCharacterTextSplitter
         | 
| 31 | 
            +
            from langchain.tools.base import BaseTool
         | 
| 32 | 
             
            from langchain.vectorstores import FAISS
         | 
| 33 | 
             
            from langchain_core.messages import SystemMessage
         | 
| 34 |  | 
| 35 | 
             
            from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
         | 
| 36 | 
             
            from qagen import get_rag_qa_gen_chain
         | 
| 37 | 
             
            from summarize import get_rag_summarization_chain
         | 
|  | |
|  | |
|  | |
| 38 |  | 
| 39 |  | 
| 40 | 
             
            def get_agent(
         | 
| 41 | 
             
                tools: list[BaseTool],
         | 
| 42 | 
             
                chat_history: BaseChatMessageHistory,
         | 
| 43 | 
             
                llm: BaseLLM,
         | 
| 44 | 
            +
                callbacks,
         | 
| 45 | 
             
            ):
         | 
| 46 | 
             
                memory_key = "agent_history"
         | 
| 47 | 
             
                system_message = SystemMessage(
         | 
|  | |
| 69 | 
             
                    memory=agent_memory,
         | 
| 70 | 
             
                    verbose=True,
         | 
| 71 | 
             
                    return_intermediate_steps=True,
         | 
| 72 | 
            +
                    callbacks=callbacks,
         | 
| 73 | 
             
                )
         | 
| 74 | 
             
                return (
         | 
| 75 | 
             
                    {"input": RunnablePassthrough()}
         | 
|  | |
| 86 | 
             
                memory,
         | 
| 87 | 
             
                chat_prompt,
         | 
| 88 | 
             
                summarization_prompt,
         | 
|  | |
| 89 | 
             
            ):
         | 
| 90 | 
             
                if not use_document_chat:
         | 
| 91 | 
             
                    return LLMChain(
         | 
|  | |
| 106 | 
             
                        llm,
         | 
| 107 | 
             
                    )
         | 
| 108 | 
             
                else:
         | 
| 109 | 
            +
                    return RetrievalQA.from_chain_type(
         | 
| 110 | 
            +
                        llm=llm,
         | 
| 111 | 
            +
                        chain_type=document_chat_chain_type,
         | 
| 112 | 
            +
                        retriever=retriever,
         | 
| 113 | 
            +
                        output_key="output_text",
         | 
| 114 | 
            +
                    ) | (lambda output: output["output_text"])
         | 
|  | |
|  | |
| 115 |  | 
| 116 |  | 
| 117 | 
             
            def get_llm(
         | 
