Spaces:
Running
Running
File size: 6,182 Bytes
05b9456 f3fd096 3d26c4a 1e0a2f8 f3fd096 3d26c4a f3fd096 3d26c4a f3fd096 3d26c4a f3fd096 11f8e5f f3fd096 05b9456 11f8e5f 05b9456 f3fd096 05b9456 f3fd096 05b9456 f3fd096 05b9456 f3fd096 05b9456 f3fd096 0e6dbbe 55fbc57 ded6735 0e6dbbe 55fbc57 05b9456 f3fd096 0e6dbbe f3fd096 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import base64
from collections import Counter
import graphviz
import penman
from mbart_amr.data.linearization import linearized2penmanstr
from penman.models.noop import NoOpModel
import streamlit as st
from transformers import LogitsProcessorList
from utils import get_resources, LANGUAGES, translate
import streamlit as st
st.set_page_config(
page_title="Text-to-AMR demo by Bram Vanroy",
page_icon="π©βπ»"
)
st.title("π©βπ» Multilingual text to AMR α΅α΅α΅α΅")
with st.form("input data"):
text_col, lang_col = st.columns((4, 1))
text = text_col.text_input(label="Input text")
src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0)
submitted = st.form_submit_button("Submit")
error_ct = st.empty()
if submitted:
text = text.strip()
if not text:
error_ct.error("Text cannot be empty!", icon="β οΈ")
else:
error_ct.info("Generating abstract meaning representation (AMR)...", icon="π»")
multilingual = src_lang != "English"
model, tokenizer, logitsprocessor = get_resources(multilingual)
gen_kwargs = {
"max_length": model.config.max_length,
"num_beams": model.config.num_beams,
"logits_processor": LogitsProcessorList([logitsprocessor])
}
linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs)
penman_str = linearized2penmanstr(linearized)
error_ct.empty()
try:
graph = penman.decode(penman_str, model=NoOpModel())
except Exception as exc:
st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
f" to a valid graph but note that this is invalid Penman.")
st.code(penman_str)
with st.expander("Error trace"):
st.write(exc)
else:
visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
"fontcolor": "white"})
# Count which names occur multiple times, e.g. t/talk-01 t2/talk-01
nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"])
# Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"}
nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"}
# Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"}
# but only the value occurs more than once
nodename_str_c = Counter()
for varname in nodenames:
nodename = nodenames[varname]
if nodename_c[nodename] > 1:
nodename_str_c[nodename] += 1
nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"
def get_node_name(item: str):
return nodenames[item] if item in nodenames else item
try:
for triple in graph.triples:
if triple[1] == ":instance":
continue
else:
visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
except Exception as exc:
st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
" to a valid graph but note that this is probably invalid Penman.")
st.code(penman_str)
st.write("The initial linearized output of the model was:")
st.code(linearized)
with st.expander("Error trace"):
st.write(exc)
else:
st.subheader("Graph visualization")
st.graphviz_chart(visualized, use_container_width=True)
# Download link
def create_download_link(img_bytes: bytes):
encoded = base64.b64encode(img_bytes).decode("utf-8")
return f'<a href="data:image/png;charset=utf-8;base64,{encoded}" download="amr-graph.png">Download graph</a>'
img = visualized.pipe(format="png")
st.markdown(create_download_link(img), unsafe_allow_html=True)
# Additional info
st.subheader("Model output and Penman graph")
st.write("The linearized output of the model (after some post-processing) is:")
st.code(linearized)
st.write("When converted into Penman, it looks like this:")
st.code(penman.encode(graph))
########################
# Information, socials #
########################
st.header("SignON π€")
st.markdown("""
<div style="display: flex">
<img style="margin-right: 1em" alt="SignON logo" src="https://signon-project.eu/wp-content/uploads/2021/05/SignOn_Favicon_500x500px.png" width=64 height=64>
<p><a href="https://signon-project.eu/" target="_blank" title="SignON homepage">SignON</a> aims to bridge the
communication gap between deaf, hard-of-hearing and hearing people through an accessible translation service.
This service will translate between languages and modalities with particular attention for sign languages.</p>
</div>""", unsafe_allow_html=True)
st.markdown("""[Abstract meaning representation](https://aclanthology.org/W13-2322/) (AMR)
is a semantic framework to describe meaning relations of sentences as graphs. In the SignON project, AMR is used as
an interlingua to translate between modalities and languages. To this end, I built MBART models for the task of
generating linearized AMR representations from an input sentence, which is show-cased in this demo.
""")
st.header("Contact βοΈ")
st.markdown("Would you like additional functionality in the demo, do you have questions, or just want to get in touch?"
" Give me a shout on [Twitter](https://twitter.com/BramVanroy)"
" or add me on [LinkedIn](https://www.linkedin.com/in/bramvanroy/)!")
|