Spaces:
Running
Running
| import streamlit as st | |
| from annotated_text import annotated_text | |
| from refined.inference.processor import Refined | |
| # Sidebar | |
| st.sidebar.image("logo-wordlift.png") | |
| # Initiate the model | |
| model_options = {"aida_model", "wikipedia_model_with_numbers"} | |
| selected_model_name = st.sidebar.selectbox("Select the Model", list(model_options)) | |
| # Select entity_set | |
| entity_set_options = {"wikidata", "wikipedia"} | |
| selected_entity_set = st.sidebar.selectbox("Select the Entity Set", list(entity_set_options)) | |
| # π Add the caching decorator | |
| def load_model(model_name, entity_set): | |
| # Load the pretrained model | |
| refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set) | |
| return refined_model | |
| # Use the cached model | |
| refined_model = load_model(selected_model_name, selected_entity_set) | |
| # Helper functions | |
| def get_wikidata_id(entity_string): | |
| entity_list = entity_string.split("=") | |
| entity_id = str(entity_list[1]) | |
| entity_link = "https://www.wikidata.org/wiki/" + entity_id | |
| return {"id": entity_id, "link": entity_link} | |
| # Create the form | |
| with st.form(key='my_form'): | |
| text_input = st.text_area(label='Enter a sentence') | |
| submit_button = st.form_submit_button(label='Analyze') | |
| # Process the text and extract the entities | |
| if text_input: | |
| entities = refined_model.process_text(text_input) | |
| entities_map = {} | |
| entities_link_descriptions = {} | |
| for entity in entities: | |
| single_entity_list = str(entity).strip('][').replace("\'", "").split(', ') | |
| if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]: | |
| entities_map[single_entity_list[0].strip()] = get_wikidata_id(single_entity_list[1]) | |
| entities_link_descriptions[single_entity_list[0].strip()] = single_entity_list[2].strip().replace("(", "").replace(")", "") | |
| combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_link_descriptions[k]]) for k in entities_map]) | |
| def get_entity_description(entity_string, combined_entity_info_dictionary): | |
| return combined_entity_info_dictionary[entity_string][1] | |
| if submit_button: | |
| # Prepare a list to hold the final output | |
| final_text = [] | |
| # Replace each entity in the text with its annotated version | |
| for entity_string, entity_info in entities_map.items(): | |
| description = get_entity_description(entity_string, combined_entity_info_dictionary) | |
| entity_annotation = (entity_string, entity_info["id"], "#8ef") # Use the entity ID in the annotation | |
| text_input = text_input.replace(entity_string, f'{{{str(entity_annotation)}}}', 1) | |
| # Split the modified text_input into a list | |
| text_list = text_input.split("{") | |
| for item in text_list: | |
| if "}" in item: | |
| item_list = item.split("}") | |
| final_text.append(eval(item_list[0])) | |
| if len(item_list[1]) > 0: | |
| final_text.append(item_list[1]) | |
| else: | |
| final_text.append(item) | |
| # Pass the final_text to the annotated_text function | |
| annotated_text(*final_text) | |
| with st.expander("See annotations"): | |
| st.write(combined_entity_info_dictionary) |