jmuscatello commited on
Commit
1083f7f
Β·
1 Parent(s): 1305fc8

Add retrievers

Browse files
Files changed (1) hide show
  1. pages/6_πŸ”Ž_Find_Demo.py +59 -0
pages/6_πŸ”Ž_Find_Demo.py CHANGED
@@ -4,6 +4,18 @@ import streamlit as st
4
  import streamlit_analytics
5
  from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
6
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  streamlit_analytics.start_tracking()
8
 
9
  st.set_page_config(
@@ -24,6 +36,53 @@ st.sidebar.success("πŸ‘† Select a demo above.")
24
  st.title('πŸ”Ž Find Demo')
25
  st.markdown("πŸ— This demo is currently under construction. Please visit back soon.")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  add_email_signup_form()
28
 
29
  add_footer()
 
4
  import streamlit_analytics
5
  from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
6
 
7
+ from haystack.document_stores import InMemoryDocumentStore
8
+ from haystack.nodes import BM25Retriever, EmbeddingRetreiver
9
+
10
+ HF_TOKEN = os.environ.get("HF_TOKEN")
11
+ DATA_REPO_ID = "simplexico/cuad-qa-answers"
12
+ DATA_FILENAME = "cuad_question_answers.json"
13
+ EMBEDDING_MODEL = "prajjwal1/bert-tiny"
14
+ if EMBEDDING_MODEL == "prajjwal1/bert-tiny":
15
+ EMBEDDING_DIM = 128
16
+ else:
17
+ EMBEDDING_DIM = 768
18
+
19
  streamlit_analytics.start_tracking()
20
 
21
  st.set_page_config(
 
36
  st.title('πŸ”Ž Find Demo')
37
  st.markdown("πŸ— This demo is currently under construction. Please visit back soon.")
38
 
39
+ @st.cache(allow_output_mutation=True)
40
+ def load_dataset():
41
+ snapshot_download(repo_id=DATA_REPO_ID, token=HF_TOKEN, local_dir='./', repo_type='dataset')
42
+ df = pd.read_json(DATA_FILENAME)
43
+ return df
44
+
45
+ @st.cache(allow_output_mutation=True)
46
+ def generate_document_store(df):
47
+ """Create haystack document store using contract clause data
48
+ """
49
+ document_dicts = []
50
+
51
+ for idx, row in df.iterrows():
52
+ document_dicts.append(
53
+ {
54
+ 'content': row['answer_text'],
55
+ 'meta': {'contract_title': row['contract_title'], 'question_id': row['question_id']}
56
+ }
57
+ )
58
+
59
+ document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=EMBEDDING_DIM)
60
+
61
+ document_store.write_documents(document_dicts)
62
+
63
+ return document_store
64
+
65
+ @st.cache(allow_output_mutation=True)
66
+ def generate_bm25_retriever(document_store):
67
+ return BM25Retriever(document_store)
68
+
69
+ @st.cache(allow_output_mutation=True)
70
+ def generate_embeddings(embedding_model, document_store):
71
+ embedding_retriever = EmbeddingRetreiver(embedding_model=embedding_model, document_store=document_store)
72
+ document_store.update_embeddings(embedding_retriever)
73
+ return embedding_retriever
74
+
75
+ df = load_dataset()
76
+
77
+ document_store = generate_document_store(df)
78
+
79
+ bm25_retriever = generate_bm25_retriever(document_store)
80
+
81
+ embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
82
+
83
+
84
+
85
+
86
  add_email_signup_form()
87
 
88
  add_footer()