legal-ai-actions / pages /6_πŸ”Ž_Find_Demo.py
jmuscatello's picture
Add retrievers
1083f7f
raw
history blame
2.6 kB
import os
import streamlit as st
import streamlit_analytics
from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import BM25Retriever, EmbeddingRetreiver
HF_TOKEN = os.environ.get("HF_TOKEN")
DATA_REPO_ID = "simplexico/cuad-qa-answers"
DATA_FILENAME = "cuad_question_answers.json"
EMBEDDING_MODEL = "prajjwal1/bert-tiny"
if EMBEDDING_MODEL == "prajjwal1/bert-tiny":
EMBEDDING_DIM = 128
else:
EMBEDDING_DIM = 768
streamlit_analytics.start_tracking()
st.set_page_config(
page_title="Find Demo",
page_icon="πŸ”Ž",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'mailto:[email protected]',
'Report a bug': None,
'About': "## This a demo showcasing different Legal AI Actions"
}
)
add_logo_to_sidebar()
st.sidebar.success("πŸ‘† Select a demo above.")
st.title('πŸ”Ž Find Demo')
st.markdown("πŸ— This demo is currently under construction. Please visit back soon.")
@st.cache(allow_output_mutation=True)
def load_dataset():
snapshot_download(repo_id=DATA_REPO_ID, token=HF_TOKEN, local_dir='./', repo_type='dataset')
df = pd.read_json(DATA_FILENAME)
return df
@st.cache(allow_output_mutation=True)
def generate_document_store(df):
"""Create haystack document store using contract clause data
"""
document_dicts = []
for idx, row in df.iterrows():
document_dicts.append(
{
'content': row['answer_text'],
'meta': {'contract_title': row['contract_title'], 'question_id': row['question_id']}
}
)
document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=EMBEDDING_DIM)
document_store.write_documents(document_dicts)
return document_store
@st.cache(allow_output_mutation=True)
def generate_bm25_retriever(document_store):
return BM25Retriever(document_store)
@st.cache(allow_output_mutation=True)
def generate_embeddings(embedding_model, document_store):
embedding_retriever = EmbeddingRetreiver(embedding_model=embedding_model, document_store=document_store)
document_store.update_embeddings(embedding_retriever)
return embedding_retriever
df = load_dataset()
document_store = generate_document_store(df)
bm25_retriever = generate_bm25_retriever(document_store)
embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
add_email_signup_form()
add_footer()
streamlit_analytics.stop_tracking(unsafe_password=os.environ["ANALYTICS_PASSWORD"])