File size: 3,245 Bytes
bbcf937
59c3f8c
bbcf937
 
0bec8b3
542aecd
bbcf937
 
 
dedd775
bbcf937
c3e1350
 
 
 
dedd775
c3e1350
dedd775
c3e1350
dedd775
 
320ee5a
c3e1350
bbcf937
320ee5a
 
c9574f5
 
 
 
320ee5a
bbcf937
 
3dac3c5
c9574f5
bbcf937
 
 
 
 
 
 
 
 
 
f81a6a4
c9574f5
bbcf937
 
 
c9574f5
 
bbcf937
 
542aecd
 
 
5cb9d08
c9574f5
 
 
 
5cb9d08
 
 
 
 
 
 
 
 
 
542aecd
5cb9d08
542aecd
 
5cb9d08
c9574f5
49703d7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import streamlit as st
from annotated_text import annotated_text
from refined.inference.processor import Refined

# Sidebar
st.sidebar.image("logo-wordlift.png")

# Initiate the model
model_options = {"aida_model", "wikipedia_model_with_numbers"}
selected_model_name = st.sidebar.selectbox("Select the Model", list(model_options))

# Select entity_set
entity_set_options = {"wikidata", "wikipedia"}
selected_entity_set = st.sidebar.selectbox("Select the Entity Set", list(entity_set_options))

@st.cache_resource  # 👈 Add the caching decorator
def load_model(model_name, entity_set):
    # Load the pretrained model
    refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
    return refined_model

# Use the cached model
refined_model = load_model(selected_model_name, selected_entity_set)

# Helper functions
def get_wikidata_id(entity_string):
    entity_list = entity_string.split("=")
    entity_id = str(entity_list[1])
    entity_link = "https://www.wikidata.org/wiki/" + entity_id
    return {"id": entity_id, "link": entity_link}

# Create the form
with st.form(key='my_form'):
    text_input = st.text_area(label='Enter a sentence')
    submit_button = st.form_submit_button(label='Analyze')

# Process the text and extract the entities
if text_input:
    entities = refined_model.process_text(text_input)

    entities_map = {}
    entities_link_descriptions = {}
    for entity in entities:
        single_entity_list = str(entity).strip('][').replace("\'", "").split(', ')
        if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]: 
            entities_map[single_entity_list[0].strip()] = get_wikidata_id(single_entity_list[1])
            entities_link_descriptions[single_entity_list[0].strip()] = single_entity_list[2].strip().replace("(", "").replace(")", "")

    combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_link_descriptions[k]]) for k in entities_map])

    def get_entity_description(entity_string, combined_entity_info_dictionary):
        return combined_entity_info_dictionary[entity_string][1]

    if submit_button:
        # Prepare a list to hold the final output
        final_text = []

        # Replace each entity in the text with its annotated version
        for entity_string, entity_info in entities_map.items():
            description = get_entity_description(entity_string, combined_entity_info_dictionary)
            entity_annotation = (entity_string, entity_info["id"], "#8ef")  # Use the entity ID in the annotation
            text_input = text_input.replace(entity_string, f'{{{str(entity_annotation)}}}', 1)

        # Split the modified text_input into a list
        text_list = text_input.split("{")

        for item in text_list:
            if "}" in item:
                item_list = item.split("}")
                final_text.append(eval(item_list[0]))
                if len(item_list[1]) > 0:
                    final_text.append(item_list[1])
            else:
                final_text.append(item)

        # Pass the final_text to the annotated_text function
        annotated_text(*final_text)

        with st.expander("See annotations"):
            st.write(combined_entity_info_dictionary)