""" """ from collections import defaultdict import json import os import re from langchain_core.documents import Document from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnablePassthrough from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores.utils import DistanceStrategy from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_together import ChatTogether from langchain_pinecone import PineconeVectorStore import streamlit as st import usage st.set_page_config(layout="wide", page_title="LegisQA") os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"] os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"] os.environ["TOKENIZERS_PARALLELISM"] = "false" SS = st.session_state SEED = 292764 CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118] SPONSOR_PARTIES = ["D", "R", "L", "I"] CONGRESS_GOV_TYPE_MAP = { "hconres": "house-concurrent-resolution", "hjres": "house-joint-resolution", "hr": "house-bill", "hres": "house-resolution", "s": "senate-bill", "sconres": "senate-concurrent-resolution", "sjres": "senate-joint-resolution", "sres": "senate-resolution", } OPENAI_CHAT_MODELS = { "gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}}, "gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}}, } ANTHROPIC_CHAT_MODELS = { "claude-3-haiku-20240307": {"cost": {"pmi": 0.25, "pmo": 1.25}}, "claude-3-5-sonnet-20240620": {"cost": {"pmi": 3.00, "pmo": 15.0}}, "claude-3-opus-20240229": {"cost": {"pmi": 15.0, "pmo": 75.0}}, } TOGETHER_CHAT_MODELS = { "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {"cost": {"pmi": 0.18, "pmo": 0.18}}, "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { "cost": {"pmi": 0.88, "pmo": 0.88} }, "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": { "cost": {"pmi": 5.00, "pmo": 5.00} }, } PROVIDER_MODELS = { "OpenAI": OPENAI_CHAT_MODELS, "Anthropic": ANTHROPIC_CHAT_MODELS, "Together": TOGETHER_CHAT_MODELS, } def get_sponsor_url(bioguide_id: str) -> str: return f"https://bioguide.congress.gov/search/bio/{bioguide_id}" def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str: lt = CONGRESS_GOV_TYPE_MAP[legis_type] return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}" def load_bge_embeddings(): model_name = "BAAI/bge-small-en-v1.5" model_kwargs = {"device": "cpu"} encode_kwargs = {"normalize_embeddings": True} emb_fn = HuggingFaceBgeEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, query_instruction="Represent this question for searching relevant passages: ", ) return emb_fn def load_pinecone_vectorstore(): emb_fn = load_bge_embeddings() vectorstore = PineconeVectorStore( embedding=emb_fn, text_key="text", distance_strategy=DistanceStrategy.COSINE, pinecone_api_key=st.secrets["pinecone_api_key"], index_name=st.secrets["pinecone_index_name"], ) return vectorstore def render_outreach_links(): nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy" nomic_map_name = "us-congressional-legislation-s1024o256nomic-1" nomic_url = f"{nomic_base_url}/{nomic_map_name}/map" hf_url = "https://huggingface.co/hyperdemocracy" pc_url = "https://www.pinecone.io/blog/serverless" together_url = "https://www.together.ai/" st.subheader(":brain: About [hyperdemocracy](https://hyperdemocracy.us)") st.subheader(f":world_map: Visualize [nomic atlas]({nomic_url})") st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})") st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})") st.subheader(f":pancakes: Inference [together.ai]({together_url})") def render_sidebar(): with st.container(border=True): render_outreach_links() def group_docs(docs) -> list[tuple[str, list[Document]]]: doc_grps = defaultdict(list) # create legis_id groups for doc in docs: doc_grps[doc.metadata["legis_id"]].append(doc) # sort docs in each group by start index for legis_id in doc_grps.keys(): doc_grps[legis_id] = sorted( doc_grps[legis_id], key=lambda x: x.metadata["start_index"], ) # sort groups by number of docs doc_grps = sorted( tuple(doc_grps.items()), key=lambda x: -len(x[1]), ) return doc_grps def format_docs(docs: list[Document]) -> str: """JSON grouped""" doc_grps = group_docs(docs) out = [] for legis_id, doc_grp in doc_grps: dd = { "legis_id": doc_grp[0].metadata["legis_id"], "title": doc_grp[0].metadata["title"], "introduced_date": doc_grp[0].metadata["introduced_date"], "sponsor": doc_grp[0].metadata["sponsor_full_name"], "snippets": [doc.page_content for doc in doc_grp], } out.append(dd) return json.dumps(out, indent=4) def escape_markdown(text: str) -> str: MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$" for char in MD_SPECIAL_CHARS: text = text.replace(char, "\\" + char) return text def get_vectorstore_filter(ret_config: dict) -> dict: vs_filter = {} if ret_config["filter_legis_id"] != "": vs_filter["legis_id"] = ret_config["filter_legis_id"] if ret_config["filter_bioguide_id"] != "": vs_filter["sponsor_bioguide_id"] = ret_config["filter_bioguide_id"] vs_filter = { **vs_filter, "congress_num": {"$in": ret_config["filter_congress_nums"]}, } vs_filter = { **vs_filter, "sponsor_party": {"$in": ret_config["filter_sponsor_parties"]}, } return vs_filter def render_doc_grp(legis_id: str, doc_grp: list[Document]): first_doc = doc_grp[0] congress_gov_url = get_congress_gov_url( first_doc.metadata["congress_num"], first_doc.metadata["legis_type"], first_doc.metadata["legis_num"], ) congress_gov_link = f"[congress.gov]({congress_gov_url})" ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format( len(doc_grp), first_doc.metadata["legis_id"], first_doc.metadata["title"], congress_gov_link, first_doc.metadata["sponsor_full_name"], first_doc.metadata["sponsor_bioguide_id"], get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]), ) doc_contents = [ "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content for doc in doc_grp ] with st.expander(ref): st.write(escape_markdown("\n\n...\n\n".join(doc_contents))) def legis_id_to_link(legis_id: str) -> str: congress_num, legis_type, legis_num = legis_id.split("-") return get_congress_gov_url(congress_num, legis_type, legis_num) def legis_id_match_to_link(matchobj): mstring = matchobj.string[matchobj.start() : matchobj.end()] url = legis_id_to_link(mstring) link = f"[{mstring}]({url})" return link def replace_legis_ids_with_urls(text): pattern = "11[345678]-[a-z]+-\d{1,5}" rtext = re.sub(pattern, legis_id_match_to_link, text) return rtext def render_guide(): st.write( """ When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors. ## Disclaimer This is a research project. The RAG technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around, find questions that work and find questions that fail. There is a small monthly budget dedicated to the OpenAI endpoints. Once that is used up each month, queries will no longer work. ## Config Use the `Generative Config` to change LLM parameters. Use the `Retrieval Config` to change the number of chunks retrieved from our congress corpus and to apply various filters to the content before it is retrieved (e.g. filter to a specific set of congresses). Use the `Prompt Config` to try out different document formatting and prompting strategies. """ ) def render_example_queries(): with st.expander("Example Queries"): st.write( """ ``` What are the themes around artificial intelligence? ``` ``` Write a well cited 3 paragraph essay on food insecurity. ``` ``` Create a table summarizing major climate change ideas with columns legis_id, title, idea. ``` ``` Write an action plan to keep social security solvent. ``` ``` Suggest reforms that would benefit the Medicaid program. ``` """ ) def get_generative_config(key_prefix: str) -> dict: output = {} key = "provider" output[key] = st.selectbox( label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}" ) key = "model_name" output[key] = st.selectbox( label=key, options=PROVIDER_MODELS[output["provider"]], key=f"{key_prefix}|{key}", ) key = "temperature" output[key] = st.slider( key, min_value=0.0, max_value=2.0, value=0.0, key=f"{key_prefix}|{key}", ) key = "max_output_tokens" output[key] = st.slider( key, min_value=1024, max_value=2048, key=f"{key_prefix}|{key}", ) key = "top_p" output[key] = st.slider( key, min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|{key}" ) key = "should_escape_markdown" output[key] = st.checkbox( key, value=False, key=f"{key_prefix}|{key}", ) key = "should_add_legis_urls" output[key] = st.checkbox( key, value=True, key=f"{key_prefix}|{key}", ) return output def get_retrieval_config(key_prefix: str) -> dict: output = {} key = "n_ret_docs" output[key] = st.slider( "Number of chunks to retrieve", min_value=1, max_value=32, value=8, key=f"{key_prefix}|{key}", ) key = "filter_legis_id" output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}") key = "filter_bioguide_id" output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}") key = "filter_congress_nums" output[key] = st.multiselect( "Congress Numbers", CONGRESS_NUMBERS, default=CONGRESS_NUMBERS, key=f"{key_prefix}|{key}", ) key = "filter_sponsor_parties" output[key] = st.multiselect( "Sponsor Party", SPONSOR_PARTIES, default=SPONSOR_PARTIES, key=f"{key_prefix}|{key}", ) return output def get_llm(gen_config: dict): match gen_config["provider"]: case "OpenAI": llm = ChatOpenAI( model=gen_config["model_name"], temperature=gen_config["temperature"], api_key=st.secrets["openai_api_key"], top_p=gen_config["top_p"], seed=SEED, max_tokens=gen_config["max_output_tokens"], ) case "Anthropic": llm = ChatAnthropic( model_name=gen_config["model_name"], temperature=gen_config["temperature"], api_key=st.secrets["anthropic_api_key"], top_p=gen_config["top_p"], max_tokens_to_sample=gen_config["max_output_tokens"], ) case "Together": llm = ChatTogether( model=gen_config["model_name"], temperature=gen_config["temperature"], max_tokens=gen_config["max_output_tokens"], top_p=gen_config["top_p"], seed=SEED, api_key=st.secrets["together_api_key"], ) case _: raise ValueError() return llm def create_rag_chain(llm, retriever): QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user. --- Congressional Legislation Excerpts: {context} --- Query: {query}""" prompt = ChatPromptTemplate.from_messages( [ ("human", QUERY_RAG_TEMPLATE), ] ) rag_chain = ( RunnableParallel( { "docs": retriever, "query": RunnablePassthrough(), } ) .assign(context=lambda x: format_docs(x["docs"])) .assign(aimessage=prompt | llm) ) return rag_chain def process_query(gen_config: dict, ret_config: dict, query: str): vectorstore = load_pinecone_vectorstore() llm = get_llm(gen_config) vs_filter = get_vectorstore_filter(ret_config) retriever = vectorstore.as_retriever( search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter}, ) rag_chain = create_rag_chain(llm, retriever) response = rag_chain.invoke(query) return response def display_retrieved_chunks(response): with st.container(border=True): doc_grps = group_docs(response["docs"]) st.write( "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)" ) for legis_id, doc_grp in doc_grps: render_doc_grp(legis_id, doc_grp) def display_response( response, model_info, provider, should_escape_markdown, should_add_legis_urls ): out_display = response["aimessage"].content if should_escape_markdown: out_display = escape_markdown(out_display) if should_add_legis_urls: out_display = replace_legis_ids_with_urls(out_display) with st.container(border=True): st.write("Response") st.info(out_display) usage.display_api_usage(response, model_info, provider) display_retrieved_chunks(response) def render_query_rag_tab(): key_prefix = "query_rag" render_example_queries() with st.form(f"{key_prefix}|query_form"): query = st.text_area( "Enter a query that can be answered with congressional legislation:" ) cols = st.columns(2) with cols[0]: query_submitted = st.form_submit_button("Submit") with cols[1]: status_placeholder = st.empty() col1, col2 = st.columns(2) with col1: with st.expander("Generative Config"): gen_config = get_generative_config(key_prefix) with col2: with st.expander("Retrieval Config"): ret_config = get_retrieval_config(key_prefix) rkey = f"{key_prefix}|response" if query_submitted: with status_placeholder: with st.spinner("generating response"): SS[rkey] = process_query(gen_config, ret_config, query) if response := SS.get(rkey): model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]] display_response( response, model_info, gen_config["provider"], gen_config["should_escape_markdown"], gen_config["should_add_legis_urls"], ) with st.expander("Debug"): st.write(response) def render_query_rag_sbs_tab(): base_key_prefix = "query_rag_sbs" with st.form(f"{base_key_prefix}|query_form"): query = st.text_area( "Enter a query that can be answered with congressional legislation:" ) cols = st.columns(2) with cols[0]: query_submitted = st.form_submit_button("Submit") with cols[1]: status_placeholder = st.empty() grp1a, grp2a = st.columns(2) gen_configs = {} ret_configs = {} with grp1a: st.header("Group 1") key_prefix = f"{base_key_prefix}|grp1" with st.expander("Generative Config"): gen_configs["grp1"] = get_generative_config(key_prefix) with st.expander("Retrieval Config"): ret_configs["grp1"] = get_retrieval_config(key_prefix) with grp2a: st.header("Group 2") key_prefix = f"{base_key_prefix}|grp2" with st.expander("Generative Config"): gen_configs["grp2"] = get_generative_config(key_prefix) with st.expander("Retrieval Config"): ret_configs["grp2"] = get_retrieval_config(key_prefix) grp1b, grp2b = st.columns(2) sbs_cols = {"grp1": grp1b, "grp2": grp2b} grp_names = {"grp1": "Group 1", "grp2": "Group 2"} for post_key_prefix in ["grp1", "grp2"]: with sbs_cols[post_key_prefix]: key_prefix = f"{base_key_prefix}|{post_key_prefix}" rkey = f"{key_prefix}|response" if query_submitted: with status_placeholder: with st.spinner( "generating response for {}".format(grp_names[post_key_prefix]) ): SS[rkey] = process_query( gen_configs[post_key_prefix], ret_configs[post_key_prefix], query, ) if response := SS.get(rkey): model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][ gen_configs[post_key_prefix]["model_name"] ] display_response( response, model_info, gen_configs[post_key_prefix]["provider"], gen_configs[post_key_prefix]["should_escape_markdown"], gen_configs[post_key_prefix]["should_add_legis_urls"], ) def main(): st.title(":classical_building: LegisQA :classical_building:") st.header("Query Congressional Bills") with st.sidebar: render_sidebar() query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs( [ "RAG", "RAG (side-by-side)", "Guide", ] ) with query_rag_tab: render_query_rag_tab() with query_rag_sbs_tab: render_query_rag_sbs_tab() with guide_tab: render_guide() if __name__ == "__main__": main()