entity-linking / app.py
cyberandy's picture
Update app.py
c3e1350
raw
history blame
3.25 kB
import streamlit as st
from annotated_text import annotated_text
from refined.inference.processor import Refined
# Sidebar
st.sidebar.image("logo-wordlift.png")
# Initiate the model
model_options = {"aida_model", "wikipedia_model_with_numbers"}
selected_model_name = st.sidebar.selectbox("Select the Model", list(model_options))
# Select entity_set
entity_set_options = {"wikidata", "wikipedia"}
selected_entity_set = st.sidebar.selectbox("Select the Entity Set", list(entity_set_options))
@st.cache_resource # πŸ‘ˆ Add the caching decorator
def load_model(model_name, entity_set):
# Load the pretrained model
refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
return refined_model
# Use the cached model
refined_model = load_model(selected_model_name, selected_entity_set)
# Helper functions
def get_wikidata_id(entity_string):
entity_list = entity_string.split("=")
entity_id = str(entity_list[1])
entity_link = "https://www.wikidata.org/wiki/" + entity_id
return {"id": entity_id, "link": entity_link}
# Create the form
with st.form(key='my_form'):
text_input = st.text_area(label='Enter a sentence')
submit_button = st.form_submit_button(label='Analyze')
# Process the text and extract the entities
if text_input:
entities = refined_model.process_text(text_input)
entities_map = {}
entities_link_descriptions = {}
for entity in entities:
single_entity_list = str(entity).strip('][').replace("\'", "").split(', ')
if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]:
entities_map[single_entity_list[0].strip()] = get_wikidata_id(single_entity_list[1])
entities_link_descriptions[single_entity_list[0].strip()] = single_entity_list[2].strip().replace("(", "").replace(")", "")
combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_link_descriptions[k]]) for k in entities_map])
def get_entity_description(entity_string, combined_entity_info_dictionary):
return combined_entity_info_dictionary[entity_string][1]
if submit_button:
# Prepare a list to hold the final output
final_text = []
# Replace each entity in the text with its annotated version
for entity_string, entity_info in entities_map.items():
description = get_entity_description(entity_string, combined_entity_info_dictionary)
entity_annotation = (entity_string, entity_info["id"], "#8ef") # Use the entity ID in the annotation
text_input = text_input.replace(entity_string, f'{{{str(entity_annotation)}}}', 1)
# Split the modified text_input into a list
text_list = text_input.split("{")
for item in text_list:
if "}" in item:
item_list = item.split("}")
final_text.append(eval(item_list[0]))
if len(item_list[1]) > 0:
final_text.append(item_list[1])
else:
final_text.append(item)
# Pass the final_text to the annotated_text function
annotated_text(*final_text)
with st.expander("See annotations"):
st.write(combined_entity_info_dictionary)