Spaces:
Build error
Build error
File size: 5,963 Bytes
47c0211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import nltk
from collections import Counter
from weakly_supervised_parser.tree.evaluate import tree_to_spans
class Tree(object):
def __init__(self, label, children, word):
self.label = label
self.children = children
self.word = word
def __str__(self):
return self.linearize()
def linearize(self):
if not self.children:
return f"({self.label} {self.word})"
return f"({self.label} {' '.join(c.linearize() for c in self.children)})"
def spans(self, start=0):
if not self.children:
return [(start, start + 1)]
span_list = []
position = start
for c in self.children:
cspans = c.spans(start=position)
span_list.extend(cspans)
position = cspans[0][1]
return [(start, position)] + span_list
def spans_labels(self, start=0):
if not self.children:
return [(start, start + 1, self.label)]
span_list = []
position = start
for c in self.children:
cspans = c.spans_labels(start=position)
span_list.extend(cspans)
position = cspans[0][1]
return [(start, position, self.label)] + span_list
def extract_sentence(sentence):
t = nltk.Tree.fromstring(sentence)
return " ".join(item[0] for item in t.pos())
def get_constituents(sample_string, want_spans_mapping=False, whole_sentence=True, labels=False):
t = nltk.Tree.fromstring(sample_string)
if want_spans_mapping:
spans = tree_to_spans(t, keep_labels=True)
return dict(Counter(item[1] for item in spans))
spans = tree_to_spans(t, keep_labels=True)
sentence = extract_sentence(sample_string).split()
labeled_consituents_lst = []
constituents = []
for span in spans:
labeled_consituents = {}
labeled_consituents["labels"] = span[0]
i, j = span[1][0], span[1][1]
constituents.append(" ".join(sentence[i:j]))
labeled_consituents["constituent"] = " ".join(sentence[i:j])
labeled_consituents_lst.append(labeled_consituents)
# Add original sentence
if whole_sentence:
constituents = constituents + [" ".join(sentence)]
if labels:
return labeled_consituents_lst
return constituents
def get_distituents(sample_string):
sentence = extract_sentence(sample_string).split()
def get_all_combinations(sentence):
L = sentence.split()
N = len(L)
out = []
for n in range(2, N):
for i in range(N - n + 1):
out.append((i, i + n))
return out
combinations = get_all_combinations(extract_sentence(sample_string))
constituents = list(get_constituents(sample_string, want_spans_mapping=True).keys())
spans = [item for item in combinations if item not in constituents]
distituents = []
for span in spans:
i, j = span[0], span[1]
distituents.append(" ".join(sentence[i:j]))
return distituents
def get_leaves(tree):
if not tree.children:
return [tree]
leaves = []
for c in tree.children:
leaves.extend(get_leaves(c))
return leaves
def unlinearize(string):
"""
(TOP (S (NP (PRP He)) (VP (VBD was) (ADJP (JJ right))) (. .)))
"""
tokens = string.replace("(", " ( ").replace(")", " ) ").split()
def read_tree(start):
if tokens[start + 2] != "(":
return Tree(tokens[start + 1], None, tokens[start + 2]), start + 4
i = start + 2
children = []
while tokens[i] != ")":
tree, i = read_tree(i)
children.append(tree)
return Tree(tokens[start + 1], children, None), i + 1
tree, _ = read_tree(0)
return tree
def recall_by_label(gold_standard, best_parse):
correct = {}
total = {}
for tree1, tree2 in zip(gold_standard, best_parse):
try:
leaves1, leaves2 = get_leaves(tree1["tree"]), get_leaves(tree2["tree"])
for l1, l2 in zip(leaves1, leaves2):
assert l1.word.lower() == l2.word.lower(), f"{l1.word} =/= {l2.word}"
spanlabels = tree1["tree"].spans_labels()
spans = tree2["tree"].spans()
for (i, j, label) in spanlabels:
if j - i != 1:
if label not in correct:
correct[label] = 0
total[label] = 0
if (i, j) in spans:
correct[label] += 1
total[label] += 1
except Exception as e:
print(e)
acc = {}
for label in total.keys():
acc[label] = correct[label] / total[label]
return acc
def label_recall_output(gold_standard, best_parse):
best_parse_trees = []
gold_standard_trees = []
for t1, t2 in zip(gold_standard, best_parse):
gold_standard_trees.append({"tree": unlinearize(t1)})
best_parse_trees.append({"tree": unlinearize(t2)})
dct = recall_by_label(gold_standard=gold_standard_trees, best_parse=best_parse_trees)
labels = ["SBAR", "NP", "VP", "PP", "ADJP", "ADVP"]
l = [{label: f"{recall * 100:.2f}"} for label, recall in dct.items() if label in labels]
df = pd.DataFrame([item.values() for item in l], index=[item.keys() for item in l], columns=["recall"])
df.index = df.index.map(lambda x: list(x)[0])
df_out = df.reindex(labels)
return df_out
if __name__ == "__main__":
import pandas as pd
from weakly_supervised_parser.utils.prepare_dataset import PTBDataset
from weakly_supervised_parser.settings import PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, PTB_SAVE_TREES_PATH
best_parse = PTBDataset(PTB_SAVE_TREES_PATH + "inside_model_predictions.txt").retrieve_all_sentences()
gold_standard = PTBDataset(PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH).retrieve_all_sentences()
print(label_recall_output(gold_standard, best_parse))
|