jmuscatello commited on
Commit
647ce10
Β·
1 Parent(s): d103fa2

Add results

Browse files
Files changed (1) hide show
  1. pages/6_πŸ”Ž_Find_Demo.py +77 -22
pages/6_πŸ”Ž_Find_Demo.py CHANGED
@@ -1,40 +1,28 @@
1
  import os
2
 
 
 
3
  import streamlit as st
4
  import streamlit_analytics
5
  from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
6
 
 
 
7
  from haystack.document_stores import InMemoryDocumentStore
8
- from haystack.nodes import BM25Retriever, EmbeddingRetreiver
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
  DATA_REPO_ID = "simplexico/cuad-qa-answers"
12
- DATA_FILENAME = "cuad_question_answers.json"
13
  EMBEDDING_MODEL = "prajjwal1/bert-tiny"
14
  if EMBEDDING_MODEL == "prajjwal1/bert-tiny":
15
  EMBEDDING_DIM = 128
16
  else:
17
  EMBEDDING_DIM = 768
18
 
19
- streamlit_analytics.start_tracking()
20
-
21
- st.set_page_config(
22
- page_title="Find Demo",
23
- page_icon="πŸ”Ž",
24
- layout="wide",
25
- initial_sidebar_state="expanded",
26
- menu_items={
27
- 'Get Help': 'mailto:[email protected]',
28
- 'Report a bug': None,
29
- 'About': "## This a demo showcasing different Legal AI Actions"
30
- }
31
- )
32
-
33
- add_logo_to_sidebar()
34
- st.sidebar.success("πŸ‘† Select a demo above.")
35
 
36
- st.title('πŸ”Ž Find Demo')
37
- st.markdown("πŸ— This demo is currently under construction. Please visit back soon.")
38
 
39
  @st.cache(allow_output_mutation=True)
40
  def load_dataset():
@@ -43,7 +31,7 @@ def load_dataset():
43
  return df
44
 
45
  @st.cache(allow_output_mutation=True)
46
- def generate_document_store(df):
47
  """Create haystack document store using contract clause data
48
  """
49
  document_dicts = []
@@ -68,10 +56,53 @@ def generate_bm25_retriever(document_store):
68
 
69
  @st.cache(allow_output_mutation=True)
70
  def generate_embeddings(embedding_model, document_store):
71
- embedding_retriever = EmbeddingRetreiver(embedding_model=embedding_model, document_store=document_store)
72
  document_store.update_embeddings(embedding_retriever)
73
  return embedding_retriever
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  df = load_dataset()
76
 
77
  document_store = generate_document_store(df)
@@ -80,8 +111,32 @@ bm25_retriever = generate_bm25_retriever(document_store)
80
 
81
  embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
82
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
 
 
 
 
 
 
 
84
 
 
 
 
 
 
85
 
86
  add_email_signup_form()
87
 
 
1
  import os
2
 
3
+ import pandas as pd
4
+
5
  import streamlit as st
6
  import streamlit_analytics
7
  from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
8
 
9
+ from huggingface_hub import snapshot_download
10
+
11
  from haystack.document_stores import InMemoryDocumentStore
12
+ from haystack.nodes import BM25Retriever, EmbeddingRetriever
13
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
  DATA_REPO_ID = "simplexico/cuad-qa-answers"
16
+ DATA_FILENAME = "cuad_questions_answers.json"
17
  EMBEDDING_MODEL = "prajjwal1/bert-tiny"
18
  if EMBEDDING_MODEL == "prajjwal1/bert-tiny":
19
  EMBEDDING_DIM = 128
20
  else:
21
  EMBEDDING_DIM = 768
22
 
23
+ EXAMPLE_TEXT = "the governing law is the State of Texas"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ streamlit_analytics.start_tracking()
 
26
 
27
  @st.cache(allow_output_mutation=True)
28
  def load_dataset():
 
31
  return df
32
 
33
  @st.cache(allow_output_mutation=True)
34
+ def generate_document_store(df, dummy=None):
35
  """Create haystack document store using contract clause data
36
  """
37
  document_dicts = []
 
56
 
57
  @st.cache(allow_output_mutation=True)
58
  def generate_embeddings(embedding_model, document_store):
59
+ embedding_retriever = EmbeddingRetriever(embedding_model=embedding_model, document_store=document_store)
60
  document_store.update_embeddings(embedding_retriever)
61
  return embedding_retriever
62
 
63
+ def process_query(query, retriever):
64
+ """Generates dataframe with top ten results"""
65
+ texts = []
66
+ contract_titles = []
67
+ candidate_documents = retriever.retrieve(
68
+ query=query,
69
+ top_k=10,
70
+ )
71
+
72
+ for document in candidate_documents:
73
+ texts.append(document.content)
74
+ contract_titles.append(document.meta["contract_title"])
75
+
76
+ return pd.DataFrame({"Text": texts, "Source Contract": contract_titles})
77
+
78
+ st.set_page_config(
79
+ page_title="Find Demo",
80
+ page_icon="πŸ”Ž",
81
+ layout="wide",
82
+ initial_sidebar_state="expanded",
83
+ menu_items={
84
+ 'Get Help': 'mailto:[email protected]',
85
+ 'Report a bug': None,
86
+ 'About': "## This a demo showcasing different Legal AI Actions"
87
+ }
88
+ )
89
+
90
+ add_logo_to_sidebar()
91
+ st.sidebar.success("πŸ‘† Select a demo above.")
92
+
93
+ st.title('πŸ”Ž Find Demo')
94
+
95
+ st.write("""
96
+ This demo shows how a set of documents can be searched.
97
+ We've set up a database of clauses from a set of open source legal documents.
98
+ These clauses can be searched using **keywords** or using **semantic search**.
99
+ Semantic search leverages an AI model which matches on clauses with a similar meaning to the input text.
100
+ """)
101
+ st.write("**πŸ‘ˆ Enter search query on the left** and hit the button **Find Clauses** to see the demo in action")
102
+
103
+ query = st.sidebar.text_area(label='Enter Searcb Query', value=EXAMPLE_TEXT, height=250)
104
+ button = st.sidebar.button('**Find Clauses**', type='primary', use_container_width=True)
105
+
106
  df = load_dataset()
107
 
108
  document_store = generate_document_store(df)
 
111
 
112
  embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
113
 
114
+ if button:
115
+
116
+ hide_dataframe_row_index = """
117
+ <style>
118
+ .row_heading.level0 {display:none}
119
+ .blank {display:none}
120
+ </style>
121
+ """
122
+
123
+ col1, col2 = st.columns(2)
124
+
125
+ with col1:
126
 
127
+ st.subheader('Keyword Search Results:')
128
+ # Inject CSS with Markdown
129
+ st.markdown(hide_dataframe_row_index, unsafe_allow_html=True)
130
+ df_bm25 = process_query(query, bm25_retriever)
131
+ st.table(df_bm25)
132
+
133
+ with col2:
134
 
135
+ st.subheader('Semantic Search Results:')
136
+ # Inject CSS with Markdown
137
+ st.markdown(hide_dataframe_row_index, unsafe_allow_html=True)
138
+ df_embed = process_query(query, embedding_retriever)
139
+ st.table(df_embed)
140
 
141
  add_email_signup_form()
142