File size: 4,782 Bytes
dd124ec
55d03cf
dd124ec
 
55d03cf
 
dd124ec
 
 
 
 
 
 
 
 
e21ac83
 
 
dd124ec
4413f1a
811e4f7
dd124ec
 
2a53bc3
4bd6367
 
dd124ec
f75d001
 
ef1d02b
a087ad0
f75d001
 
 
b1b6b1b
f75d001
 
 
65ba45d
f75d001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd124ec
 
 
 
 
 
121a5b7
a08237b
dd124ec
a08237b
dd124ec
 
0368090
dd124ec
9aad845
ee50dc7
d4c667e
9aad845
d4c667e
 
 
 
 
0368090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bd6367
b1b6b1b
fe01337
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# set path
import glob, os, sys; sys.path.append('../scripts')

#import helper
import scripts.process as pre
import scripts.clean as clean

#import needed libraries
import seaborn as sns
from pandas import DataFrame
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import pandas as pd 
from sklearn.feature_extraction import _stop_words
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipelines import ExtractiveQAPipeline
from haystack.nodes import FARMReader, TfidfRetriever
import string
from markdown import markdown
from annotated_text import annotation
from tqdm.autonotebook import tqdm
import numpy as np
import tempfile
import logging
logger = logging.getLogger(__name__)

#Haystack Components
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True)

def start_haystack(documents_processed):
    document_store = InMemoryDocumentStore()
    document_store.write_documents(documents_processed)
    retriever = TfidfRetriever(document_store=document_store)
    reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", use_gpu=True) #deepset/roberta-base-squad2
    pipeline = ExtractiveQAPipeline(reader, retriever)
    return pipeline

def ask_question(question,pipeline):
    prediction = pipeline.run(query=question, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}})
    results = []
    for answer in prediction["answers"]:
        answer = answer.to_dict()
        if answer["answer"]:
            results.append(
                {
                    "context": "..." + answer["context"] + "...",
                    "answer": answer["answer"],
                    "relevance": round(answer["score"] * 100, 2),
                    "offset_start_in_doc": answer["offsets_in_document"][0]["start"],
                }
            )
        else:
            results.append(
                {
                    "context": None,
                    "answer": None,
                    "relevance": round(answer["score"] * 100, 2),
                }
            )
    return results

def app():
    with st.container():
        st.markdown("<h1 style='text-align: center; color: black;'> Keyword Search</h1>", unsafe_allow_html=True)
        st.write(' ')
        st.write(' ')

    with st.expander("ℹ️ - About this app", expanded=False):
        st.write("""     
            The *Keyword Search* app is an easy-to-use interface built in Streamlit for doing keyword search in policy document - developed by GIZ Data and the Sustainable Development Solution Network.
            """)

        st.markdown("")

    with st.container():
        question = st.text_input("Please enter your question here, we will look for the answer in the document.",
                                         value="Which extreme weather is a particular risk?",)
        
        if st.button("Find them."):
            document = st.session_state['document']                                 
            pipeline = start_haystack(document)
            
            if  document is not None:
            
                with st.spinner("👑 Performing semantic search on"):#+file.name+"..."):
                    try:
                        msg = 'Asked ' + question
                        logging.info(msg)
                        pipeline = st.session_state['pipeline']
                        results = ask_question(question,pipeline) 
                        st.write('## Top Results')
                        #st.write(results)
                        for count, result in enumerate(results):
                            if result["answer"]:
                                answer, context = result["answer"], result["context"]
                                start_idx = context.find(answer)
                                end_idx = start_idx + len(answer)
                                st.write(
                                    markdown(context[:start_idx] + str(annotation(body=answer, label="ANSWER", background="#964448", color='#ffffff')) + context[end_idx:]),
                                    unsafe_allow_html=True,
                                )
                                st.markdown(f"**Relevance:** {result['relevance']}")
                            else:
                                st.info(
                                    "🤔 &nbsp;&nbsp; Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
                                )
                           
                    except Exception as e:
                        logging.exception(e)
                    
            else:
                st.info("🤔 No document found, please try to upload it at the sidebar!")