Spaces:
Running
Running
File size: 3,023 Bytes
bbcf937 59c3f8c bbcf937 0bec8b3 542aecd bbcf937 dedd775 bbcf937 dedd775 320ee5a dedd775 bbcf937 320ee5a c9574f5 320ee5a bbcf937 c9574f5 bbcf937 c9574f5 bbcf937 c9574f5 bbcf937 542aecd 5cb9d08 c9574f5 5cb9d08 542aecd 5cb9d08 542aecd 5cb9d08 c9574f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import streamlit as st
from annotated_text import annotated_text
from refined.inference.processor import Refined
# Sidebar
st.sidebar.image("logo-wordlift.png")
# Initiate the model
model_options = {"aida_model", "wikipedia_model_with_numbers"}
selected_model_name = st.sidebar.selectbox("Select the Model", list(model_options))
@st.cache_resource # 👈 Add the caching decorator
def load_model(model_name):
# Load the pretrained model
refined_model = Refined.from_pretrained(model_name=model_name, entity_set="wikipedia")
return refined_model
# Use the cached model
refined_model = load_model(selected_model_name)
# Helper functions
def get_wikidata_id(entity_string):
entity_list = entity_string.split("=")
entity_id = str(entity_list[1])
entity_link = "https://www.wikidata.org/wiki/" + entity_id
return {"id": entity_id, "link": entity_link}
# Create the form
with st.form(key='my_form'):
text_input = st.text_input(label='Enter a sentence')
submit_button = st.form_submit_button(label='Analyze')
# Process the text and extract the entities
if text_input:
entities = refined_model.process_text(text_input)
entities_map = {}
entities_link_descriptions = {}
for entity in entities:
single_entity_list = str(entity).strip('][').replace("\'", "").split(', ')
if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]:
entities_map[single_entity_list[0].strip()] = get_wikidata_id(single_entity_list[1]).strip()
entities_link_descriptions[single_entity_list[0].strip()] = single_entity_list[2].strip().replace("(", "").replace(")", "")
combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_link_descriptions[k]]) for k in entities_map])
def get_entity_description(entity_string, combined_entity_info_dictionary):
return combined_entity_info_dictionary[entity_string][1]
if submit_button:
# Prepare a list to hold the final output
final_text = []
# Replace each entity in the text with its annotated version
for entity_string, entity_info in entities_map.items():
description = get_entity_description(entity_string, combined_entity_info_dictionary)
entity_annotation = (entity_string, entity_info["id"], "#8ef") # Use the entity ID in the annotation
text_input = text_input.replace(entity_string, f'{{{str(entity_annotation)}}}', 1)
# Split the modified text_input into a list
text_list = text_input.split("{")
for item in text_list:
if "}" in item:
item_list = item.split("}")
final_text.append(eval(item_list[0]))
if len(item_list[1]) > 0:
final_text.append(item_list[1])
else:
final_text.append(item)
# Pass the final_text to the annotated_text function
annotated_text(*final_text)
with st.expander("See annotations"):
st.write(final_text) |