Spaces:
Running
Running
import streamlit as st | |
from annotated_text import annotated_text | |
from refined.inference.processor import Refined | |
import nltk | |
nltk.download('punkt') | |
# Sidebar | |
st.sidebar.image("logo-wordlift.png") | |
# Initiate the model | |
model_options = {"aida_model", "wikipedia_model_with_numbers"} | |
selected_model_name = st.sidebar.selectbox("Select the Model", list(model_options)) | |
# π Add the caching decorator | |
def load_model(model_name): | |
# Load the pretrained model | |
refined_model = Refined.from_pretrained(model_name=model_name, entity_set="wikipedia") | |
return refined_model | |
# Use the cached model | |
refined_model = load_model(selected_model_name) | |
# Helper functions | |
def get_wikidata_id(entity_string): | |
entity_list = entity_string.split("=") | |
return "https://www.wikidata.org/wiki/" + str(entity_list[1]) | |
# Create the form | |
with st.form(key='my_form'): | |
text_input = st.text_input(label='Enter a sentence') | |
submit_button = st.form_submit_button(label='Submit') | |
# Process the text and extract the entities | |
if text_input: | |
entities = refined_model.process_text(text_input) | |
entities_map = {} | |
entities_link_descriptions = {} | |
for entity in entities: | |
single_entity_list = str(entity).strip('][').replace("\'", "").split(', ') | |
if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]: | |
entities_map[get_wikidata_id(single_entity_list[1]).strip()] = single_entity_list[0].strip() | |
entities_link_descriptions[get_wikidata_id(single_entity_list[1]).strip()] = single_entity_list[2].strip().replace("(", "").replace(")", "") | |
combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_link_descriptions[k]]) for k in entities_map]) | |
def get_entity_description(entity_link, combined_entity_info_dictionary): | |
return combined_entity_info_dictionary[entity_link][1] | |
annotations = [] | |
for wikidata_link, entity in entities_map.items(): | |
description = get_entity_description(wikidata_link, combined_entity_info_dictionary) | |
annotations.append((entity, description, "#8ef")) | |
# Annotate text with entities | |
if submit_button: | |
# Split the input text into words | |
words = nltk.word_tokenize(text_input) | |
# Prepare a list to hold the final output | |
final_text = [] | |
for word in words: | |
# If the word is an entity, annotate it | |
if word in entities_map.keys(): | |
final_text.append((word, get_entity_description(word, combined_entity_info_dictionary), "#8ef")) | |
# If the word is not an entity, keep it as it is | |
else: | |
final_text.append(word + " ") | |
# Pass the final_text to the annotated_text function | |
annotated_text(final_text) | |
with st.expander("See annotations"): | |
st.write(final_text) | |