File size: 3,988 Bytes
f3fd096
 
 
 
 
 
1e0a2f8
f3fd096
 
 
1e0a2f8
f3fd096
1e0a2f8
f3fd096
 
 
 
 
1e0a2f8
f3fd096
1e0a2f8
 
 
f3fd096
 
1e0a2f8
 
 
f3fd096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from collections import Counter

import graphviz
import penman
from penman.models.noop import NoOpModel
from mbart_amr.data.linearization import linearized2penmanstr
from transformers import LogitsProcessorList

import streamlit as st

from utils import get_resources

model, tokenizer, logitsprocessor = get_resources()

st.title("πŸ“ Parse text into AMR")

text = st.text_input(label="Text to transform (en)")

if text:
    gen_kwargs = {
        "max_length": model.config.max_length,
        "num_beams": model.config.num_beams,
        "logits_processor": LogitsProcessorList([logitsprocessor])
    }

    encoded = tokenizer(text, return_tensors="pt")
    generated = model.generate(**encoded, **gen_kwargs)
    linearized = tokenizer.decode_and_fix(generated)[0]
    penman_str = linearized2penmanstr(linearized)

    try:
        graph = penman.decode(penman_str, model=NoOpModel())
    except Exception as exc:
        st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
                 f" to a valid graph but note that this is invalid Penman.")
        st.code(penman_str)

        with st.expander("Error trace"):
            st.write(exc)
    else:
        visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
                                                 "fontcolor": "white"})

        # Count which names occur multiple times, e.g. t/talk-01 t2/talk-01
        nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"])
        # Generated initial nodenames for each variable, e.g. {"t": "talk-01",  "t2": "talk-01"}
        nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"}

        # Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)",  "t2": "talk-01 (2)"}
        # but only the value occurs more than once
        nodename_str_c = Counter()
        for varname in nodenames:
            nodename = nodenames[varname]
            if nodename_c[nodename] > 1:
                nodename_str_c[nodename] += 1
                nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"

        def get_node_name(item: str):
            return nodenames[item] if item in nodenames else item

        try:
            for triple in graph.triples:
                if triple[1] == ":instance":
                    continue
                else:
                    visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
        except Exception as exc:
            st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
                     " to a valid graph but note that this is probably invalid Penman.")
            st.code(penman_str)
            st.write("The initial linearized output of the model was:")
            st.code(linearized)

            with st.expander("Error trace"):
                st.write(exc)
        else:
            st.subheader("Graph visualization")
            st.graphviz_chart(visualized, use_container_width=True)

            # Download
            img = visualized.pipe(format="png")
            st.download_button("Download graph", img, mime="image/png")

            # Additional info
            st.subheader("Model output and Penman graph")
            st.write("The linearized output of the model (after some post-processing) is:")
            st.code(linearized)
            st.write("When converted into Penman, it looks like this:")
            st.code(penman.encode(graph))


########################
# Information, socials #
########################
st.markdown("## Contact βœ’οΈ")

st.markdown("Would you like  additional functionality in the demo? Or just want to get in touch?"
            " Give me a shout on [Twitter](https://twitter.com/BramVanroy)"
            " or add me on [LinkedIn](https://www.linkedin.com/in/bramvanroy/)!")